from instr_rev import Equation, Constraint, Solver, FixedInteger

# abstract classes
from abc import ABC, abstractmethod

# type annotations
from typing import List, Dict, Tuple, Any
from numbers import Number
from enum import Enum

# random number generation
import uuid
import random

# z3 constraint solver
import z3

import stats

#try:
#    # progress bar
#    from IPython.display import display
#    from ipywidgets import IntProgress
#except:
#    print("[Warning] IPyhton imports failed, progress bars will not work and crash the program!")

# TODO: add bitwidth into expressions!
# This includes usages, but this should be fine (we always should know the bitwidth of a variable when we create a usage for it)

class Z3Expression(ABC):
    """
    represents a mathematical expression that can be converted to something usable by z3.
    """

    @abstractmethod
    def construct_constraints(self, inputs: Dict[str, Tuple[int, int|None]]):
        """
        converts the expression to z3 usable constraints.
        """
        pass
    
    @abstractmethod
    def to_str(self, inputs: Dict[str, Tuple[int, int|None]]):
        pass
    
    def __str__(self):
        return self.to_str({})

    def __repr__(self):
      return str(self)
    
    @abstractmethod
    def visit(self, visitor) -> "Z3Expression": # TODO: add type annotations
        """
        Applies the given visitor function to the top-level expression and all sub-expressions.
        The function returns a new expression to replace the expression with (or the original expression if replacement is not necessary).
        """
        pass
    
    @staticmethod
    def conditional_replace_visitor(predicate, replacement):
        """
        Creates a visitor function that replaces all subexpressions that fulfill the condition with the given replacement.
        """
        def visitor_internal(expr):
            return replacement if predicate(expr) else expr
        return visitor_internal

    @staticmethod
    def replace_visitor(target, replacement):
        return Z3Expression.conditional_replace_visitor(lambda expr: expr is target, replacement)

    @staticmethod
    def accumulate_visitor(predicate):
        """
        Creates a visitor function that accumulates all expression that the predicate evaluates to True for.
        Returns the list that will be used for accumulation as well as the visitor.
        Note that the same list will be used for all usages of the visitor (created with the same call to accumulate_visitor).
        """

        accumulation = list()

        def visitor_internal(expr):
            if predicate(expr):
                accumulation.append(expr)
            return expr

        return accumulation, visitor_internal
    
    def get_constants_by_value(self, value):
        constants, visitor = Z3Expression.accumulate_visitor(lambda x: isinstance(x, Z3ConstantExpression) and x.value == value)
        self.visit(visitor)
        return constants
    
    def replace(self, subexpression, replacement):
        """
        Convenience method to replace a given subexpression with another in a new returned top expression
        """
        return self.visit(Z3Expression.replace_visitor(subexpression, replacement))

    def conditional_replace(self, predicate, replacement):
        """
        Convenience method to replace all subexpressions matching a condition with a new one in a new returned top expression
        """
        return self.visit(Z3Expression.conditional_replace_visitor(predicate, replacement))

    def accumulate(self, predicate):
        """
        Convenience method to accumulate all subexpressions that match a condition
        """
        accumulation, visitor = Z3Expression.accumulate_visitor(predicate)
        self.visit(visitor)
        return accumulation

    def evaluate(self, inputs: Dict[str, FixedInteger]) -> int:
        """
        Convenience method to evaluate the expression
        """
        z3_expression = self.construct_constraints(inputs)
        return z3.simplify(z3_expression).as_long()
    

class Z3UsageExpression(Z3Expression):

    def __init__(self, bitwidth, name):
        self.bitwidth = bitwidth
        self.name = name

    def visit(self, visitor):
        return visitor(self)

    def construct_constraints(self, inputs: Dict[str, Tuple[int, int|None]]):
        if self.name not in inputs:
            print(f"[Warning] variable {self.name} not in inputs!")
            bitwidth, value = self.bitwidth, None
        else:
            bitwidth, value = inputs[self.name]
        if value is None:
            return z3.BitVec(self.name, bitwidth)
        else:
            return z3.BitVecVal(value, bitwidth)

    def to_str(self, inputs: Dict[str, Tuple[int, int|None]]): # TODO: maybe also push bitwidth into here
        if self.name in inputs and inputs[self.name][1] is not None:
            return f"{self.name}:{inputs[self.name][1]}"
        return self.name

    def evaluate(self, inputs: Dict[str, FixedInteger]) -> int:
        return inputs[self.name][1] & ((1 << self.bitwidth) - 1)

    def get_bitwidth(self) -> int:
        return self.bitwidth

class Z3ConstantExpression(Z3Expression):

    def __init__(self, bitwidth, value):
        if not isinstance(value, int):
            raise Exception("not an int")
        self.bitwidth = bitwidth
        self.value = value

    def visit(self, visitor):
        return visitor(self)

    def construct_constraints(self, inputs: Dict[str, Tuple[int, int|None]]):
        return z3.BitVecVal(self.value, self.bitwidth)

    def to_str(self, inputs: Dict[str, Tuple[int, int|None]]):
        return f"0x{self.value:x}"

    def evaluate(self, inputs: Dict[str, FixedInteger]) -> int:
        return self.value & ((1 << self.bitwidth) - 1)
        
    def get_bitwidth(self) -> int:
        return self.bitwidth

class Z3NotExpression(Z3Expression):
    
    def __init__(self, target):
        self.target = target

    def visit(self, visitor):
        new_target = self.target.visit(visitor)
        if new_target is self.target:
            return visitor(self)
        return visitor(Z3NotExpression(new_target))

    def construct_constraints(self, inputs: Dict[str, Tuple[int, int|None]]):
        return z3.BitVecVal(-1, self.target.get_bitwidth()) ^ self.target.construct_constraints(inputs)

    def to_str(self, inputs: Dict[str, Tuple[int, int|None]]):
        return f"(~{self.target.to_str(inputs)})"

    def evaluate(self, inputs: Dict[str, FixedInteger]) -> int:
        return self.target.evaluate(inputs) ^ ((1 << self.get_bitwidth()) - 1)
        
    def get_bitwidth(self) -> int:
        return self.target.get_bitwidth()


class Z3ExtendExpression(Z3Expression):

    def __init__(self, bitwidth, value, sign_extend):
        self.bitwidth = bitwidth
        self.value = value
        self.sign_extend = sign_extend

    def visit(self, visitor):
        new_value = self.value.visit(visitor)
        if new_value is self.value:
            return visitor(self)
        return visitor(Z3ExtendExpression(self.bitwidth, new_value, self.sign_extend))
        
    def construct_constraints(self, inputs: Dict[str, Tuple[int, int|None]]):
        return (z3.SignExt if self.sign_extend else z3.ZeroExt)(self.bitwidth - self.value.get_bitwidth(), self.value.construct_constraints(inputs))

    def to_str(self, inputs: Dict[str, Tuple[int, int|None]]):
        return f"{'SignExtend' if self.sign_extend else 'ZeroExtend'}({self.value.get_bitwidth()} to {self.bitwidth}, {self.value.to_str(inputs)})"

    def evaluate(self, inputs: Dict[str, FixedInteger]) -> int:
        value_bitwidth = self.value.get_bitwidth()
        value = self.value.evaluate(inputs) & ((1 << value_bitwidth) - 1)
        if self.sign_extend and (value & (1 << (value_bitwidth - 1))):
            value |= (-1 << value_bitwidth) & ((1 << self.bitwidth) - 1)
        return value 
        
    def get_bitwidth(self) -> int:
        return self.bitwidth

class Z3BinaryExpression(Z3Expression):

    def __init__(self, left, right):
        # TODO: maybe put sign_extend argument also here, otherwise it is a bit weird
        self.left = left
        self.right = right

    def visit(self, visitor):
        new_left = self.left.visit(visitor)
        new_right = self.right.visit(visitor)
        if new_left is self.left and new_right is self.right:
            return visitor(self)
        return visitor(self.__class__(self.sign_extend, new_left, new_right))

    def construct_constraints(self, inputs: Dict[str, Tuple[int, int|None]]):
        # print(self.to_str(inputs), self.left.get_bitwidth(), self.right.get_bitwidth())
        l = self.left.construct_constraints(inputs)
        r = self.right.construct_constraints(inputs)
        l_bitwidth = self.left.get_bitwidth()
        r_bitwidth = self.right.get_bitwidth()
        bw = self.get_bitwidth()
        if l_bitwidth < bw:
            l = (z3.SignExt if self.sign_extend else z3.ZeroExt)(bw - l_bitwidth, l)
        if r_bitwidth < bw:
            r = (z3.SignExt if self.sign_extend else z3.ZeroExt)(bw - r_bitwidth, r)
        
        # print(l_bitwidth, l.sort(), l)
        # print(r_bitwidth, r.sort(), r)
        return self.operation(l, r)

    def to_str(self, inputs: Dict[str, Tuple[int, int|None]]):
        l = self.left.to_str(inputs)
        r = self.right.to_str(inputs)
        
        try:
            l_bitwidth = self.left.get_bitwidth()
            r_bitwidth = self.right.get_bitwidth()
        
            extend = 'SignExtend' if self.sign_extend else 'ZeroExtend'
            if l_bitwidth < r_bitwidth:
                l = f"{extend}({l_bitwidth} to {r_bitwidth}, {l})"
            elif r_bitwidth < l_bitwidth:
                l = f"{extend}({r_bitwidth} to {l_bitwidth}, {l})"
        except:
            print(f"[WARNING] used str() on expression with inputs, cannot provide bit-information!")
        
        return f"({l} {self.symbol} {r})" # f"{self.mnemonic}[{self.left.to_str(inputs)},{self.right.to_str(inputs)}]"

    def evaluate(self, inputs: Dict[str, FixedInteger]) -> int:
        l_bitwidth = self.left.get_bitwidth()
        r_bitwidth = self.right.get_bitwidth()
        l = self.left.evaluate(inputs) & ((1 << l_bitwidth) - 1)
        r = self.right.evaluate(inputs) & ((1 << r_bitwidth) - 1)
        bw = self.get_bitwidth()
        if self.sign_extend:
            # TODO: there is probably a better way of doing this ...
            if l_bitwidth < bw:
                if l & (1 << (l_bitwidth - 1)):
                    l = (l | (-1 << l_bitwidth)) & ((1 << bw) - 1)
            elif r_bitwidth < bw:
                if r & (1 << (r_bitwidth - 1)):
                    r = (r | (-1 << r_bitwidth)) & ((1 << bw) - 1)
        return self.operation(l, r) & ((1 << self.get_bitwidth()) - 1)

    def get_bitwidth(self) -> int:
        return max(self.left.get_bitwidth(), self.right.get_bitwidth())


Z3_BINARY_EXPRESSIONS = list()

def constructZ3BinaryExpression(name, mnemonic, symbol, operation):

    global Z3_BINARY_EXPRESSIONS

    class Z3BinaryExpressionImpl(Z3BinaryExpression):
        def __init__(self, sign_extend, left, right):
            self.sign_extend = sign_extend
            self.operation = operation
            self.mnemonic = mnemonic
            self.symbol = symbol
            super().__init__(left, right)

    # Interesting that Python allows me to do this, but since it does ...
    Z3BinaryExpressionImpl.__qualname__ = f"Z3{name}Expression"
    Z3_BINARY_EXPRESSIONS.append(Z3BinaryExpressionImpl)

    return Z3BinaryExpressionImpl

def lshift_hotfix(a, b): # to not run out of memory
    if isinstance(a, int) and isinstance(b, int):
        return a << max(min(b, 64), 0)
    return a << b

def rshift_hotfix(a, b): 
    if isinstance(a, int) and isinstance(b, int):
        if b == 0:
            return a
        return a >> max(1, b)
    return z3.LShR(a, b)

def mod_hotfix(a, b): # to not divide by 0
    if b == 0 and isinstance(a, int) and isinstance(b, int):
        return 0
    return a % b

def div_hotfix(a, b): # to not divide by 0
    if isinstance(a, int) and isinstance(b, int):
        if b == 0:
            return 0
        return a // b
    return a / b

# shifts should be above multiplication and divison, so we find them if applicable instead!
Z3ShiftRightArithmeticExpression = constructZ3BinaryExpression("ShiftRightArithmetic", "BitShiftRight", ">>", rshift_hotfix) # TODO: THis is actually a logical right shift
Z3ShiftLeftExpression = constructZ3BinaryExpression("ShiftLeft", "BitShiftLeft", "<<", lshift_hotfix)
Z3MulExpression = constructZ3BinaryExpression("Mul", "Times", "*", lambda a,b: a * b)
Z3MulExpression.get_bitwidth = lambda x: max(x.left.get_bitwidth(), x.right.get_bitwidth()) * 2
Z3AddExpression = constructZ3BinaryExpression("Add", "Plus", "+", lambda a,b: a + b)
Z3SubExpression = constructZ3BinaryExpression("Sub", "Subtract", "-", lambda a,b: a - b)
Z3AndExpression = constructZ3BinaryExpression("And", "BitAnd", "&", lambda a,b: a & b)
Z3XorExpression = constructZ3BinaryExpression("Xor", "BitXor", "^", lambda a,b: a ^ b)
Z3OrExpression = constructZ3BinaryExpression("Or", "BitOr", "|", lambda a,b: a | b)
# Z3ShiftRightExpression = constructZ3BinaryExpression("ShiftRight", ">>", z3.LShR)
# Z3ModExpression = constructZ3BinaryExpression("Mod", "%", z3.URem)
Z3SignedModExpression = constructZ3BinaryExpression("SignedMod", "Mod", "%", mod_hotfix)
# Z3DivExpression = constructZ3BinaryExpression("Div", "Div", "/", z3.UDiv)
Z3SignedDivExpression = constructZ3BinaryExpression("SignedDiv", "Quotient", "/", div_hotfix)


class Z3Equation(Equation):

    def __init__(self, top_expression):
        self.top_expression = top_expression

    def evaluate(self, inputs: Dict[str, FixedInteger]) -> int:
        return self.top_expression.evaluate(inputs)

    def replace(self, target, replacement):
        visitor = Z3Expression.replace_visitor(target, replacement)
        new_top_expression = self.top_expression.visit(visitor)
        if new_top_expression is self.top_expression:
            return self
        return Z3Equation(new_top_expression)

    def get_bitwidth(self) -> int:
        return self.top_expression.get_bitwidth()

    def get_constants(self) -> List[Z3ConstantExpression]:
        return self.top_expression.accumulate(lambda x: isinstance(x, Z3ConstantExpression))

    def get_constants_by_value(self, value: int) -> List[Z3ConstantExpression]:
        return self.top_expression.get_constants_by_value(value)

    def get_usages(self, name: str) -> List[Z3UsageExpression]:
        return self.top_expression.accumulate(lambda x: isinstance(x, Z3UsageExpression) and x.name == name)

    def get_usage_names(self) -> List[str]:
        return set([expr.name for expr in self.top_expression.accumulate(lambda x: isinstance(x, Z3UsageExpression))])

    def create_constant(self, bitwidth: int, value: int) -> Z3ConstantExpression:
        return Z3ConstantExpression(bitwidth, value)

    def create_usage(self, bitwidth: int, name: str) -> Z3UsageExpression:
        return Z3UsageExpression(bitwidth, name)

    def to_str(self, inputs):
        return self.top_expression.to_str(inputs)

    def __str__(self):
        return str(self.top_expression)

    def __repr__(self):
        return repr(self.top_expression)


class Z3TruncateExpression(Z3Expression):
    def __init__(self, bitwidth, value):
        self.bitwidth = bitwidth
        self.value = value

    def visit(self, visitor):
        new_value = self.value.visit(visitor)
        if new_value is self.value:
            return visitor(self)
        return visitor(Z3TruncateExpression(self.bitwidth, new_value))
        
    def construct_constraints(self, inputs: Dict[str, Tuple[int, int|None]]):
        return z3.Extract(self.bitwidth - 1, 0, self.value.construct_constraints(inputs))

    def to_str(self, inputs: Dict[str, Tuple[int, int|None]]):
        return f"Truncate({self.value.get_bitwidth()} to {self.bitwidth}, {self.value.to_str(inputs)})"

    def evaluate(self, inputs: Dict[str, FixedInteger]) -> int:
        value_bitwidth = self.value.get_bitwidth()
        value = self.value.evaluate(inputs) & ((1 << self.bitwidth) - 1)
        return value 
        
    def get_bitwidth(self) -> int:
        return self.bitwidth

class Z3NonDeterministicExpression(Z3Expression):
    def __init__(self, bitwidth, info, kind, context = None):
        self.bitwidth = bitwidth
        self.info = info
        self.kind = kind
        self.context = context
    
    def visit(self, visitor):
        return visitor(self)
    
    def construct_constraints(self, inputs: Dict[str, Tuple[int, int|None]]):
        return z3.BitVecValue((0x13374242 + self.kind) & ((1 << self.bitwidth) - 1), self.bitwidth) # TODO: this is hacky
    
    def to_str(self, inputs: Dict[str, FixedInteger]):
        return f"NonDeterministic({self.info})"
    
    def evaluate(self, inputs: Dict[str, FixedInteger]) -> int:
        return (0x13374242 + self.kind) & ((1 << self.bitwidth) - 1)
    
    def get_bitwidth(self) -> int:
        return self.bitwidth
        
    def check(self, values):
        if self.kind == 1: # counter
            return sorted(values) == values
        elif self.kind == 2: # few values
            return len(values) < self.context
        elif self.kind == 3: # random 
            return len(values) > 1 # TODO: do something better
        raise Exception(f"invalid non-deterministic kind: {self.kind}")
        

def iterate_Z3Expressions(depth, relevant_inputs, bitwidth):
    if depth <= 0:
        pass
    elif depth == 1:
        for name, value in relevant_inputs.items():
            yield Z3UsageExpression(value[0], name)
            # yield Z3NotExpression(Z3UsageExpression(value[0], name))
        name = f"const_{uuid.uuid4().hex}"
        yield Z3UsageExpression(bitwidth, name) # constant that can be solved by the solver 
    else:
        for a in iterate_Z3Expressions(depth - 1, relevant_inputs, bitwidth):
            yield Z3NotExpression(a)
            # Truncate to power of two bytes
            i = 1
            bw = a.get_bitwidth()
            while (i * 8) < bw:
                yield Z3TruncateExpression(i * 8, a)
                i *= 2
            for b in iterate_Z3Expressions(depth - 1, relevant_inputs, bitwidth):
                for bin_expr in Z3_BINARY_EXPRESSIONS:
                    if a.get_bitwidth() != b.get_bitwidth():
                        # TODO: maybe also implement truncation?
                        yield bin_expr(True, a, b)
                        # yield Z3NotExpression(bin_expr(True, a, b))
                    yield bin_expr(False, a, b)
                    # yield Z3NotExpression(bin_expr(False, a, b))

def iterate_Z3ExpressionsAtDepth(depth, relevant_inputs, bitwidth):
    for expr in iterate_Z3Expressions(depth, relevant_inputs, bitwidth):
        if expr.get_bitwidth() < bitwidth:
            yield Z3ExtendExpression(bitwidth, expr, True)
            yield Z3ExtendExpression(bitwidth, expr, False)
        elif expr.get_bitwidth() > bitwidth:
            yield Z3TruncateExpression(bitwidth, expr)
        else:
            yield expr

class Z3SIMDExpression(Z3Expression):
    def __init__(self, expr, operation_bitwidth, total_bitwidth, reg_bitwidth):
        self.expr = expr
        self.operation_bitwidth = operation_bitwidth
        self.total_bitwidth = total_bitwidth
        self.reg_bitwidth = reg_bitwidth
        self.usages = set([u.name for u in expr.accumulate(lambda x: isinstance(x, Z3UsageExpression))])

    def visit(self, visitor):
        new_expr = self.expr.visit(visitor)
        if new_expr is self.expr:
            return visitor(self)
        return visitor(Z3SIMDExpression(new_expr, self.operation_bitwidth, self.total_bitwidth, self.reg_bitwidth))

    def construct_constraints(self, inputs: Dict[str, Tuple[int, int|None]]):
        out_vals = []
        
        for i in range(self.total_bitwidth // self.operation_bitwidth - 1, -1, -1):
            opr_inputs = dict()
            
            for usage in self.usages:
                bitwidth, val = inputs[usage]
                
                if bitwidth == self.reg_bitwidth:
                    bitwidth = self.operation_bitwidth
                    
                    if val is not None:
                        val >>= self.operation_bitwidth * i
                        val &= (1 << self.operation_bitwidth) - 1
                
                opr_inputs[usage] = (bitwidth, val)
            
            opr_val = self.expr.construct_constraints(opr_inputs)
            out_vals.append(opr_val)
            
        return z3.Concat(out_vals)

    def to_str(self, inputs: Dict[str, Tuple[int, int|None]]):
        # TODO: simd split the inputs?
        return f"SIMD({self.total_bitwidth} -> {self.reg_bitwidth}, {self.operation_bitwidth}, {self.expr.to_str(inputs)})"

    def evaluate(self, inputs: Dict[str, FixedInteger]) -> int:
        out_val = 0
        
        for i in range(self.total_bitwidth // self.operation_bitwidth - 1, -1, -1):
            out_val <<= self.operation_bitwidth
            opr_inputs = dict()
            
            for usage in self.usages:
                bitwidth, val = inputs[usage]
                
                if bitwidth == self.reg_bitwidth:
                    bitwidth = self.operation_bitwidth
                    val >>= self.operation_bitwidth * i
                    val &= (1 << self.operation_bitwidth) - 1
                
                opr_inputs[usage] = (bitwidth, val)
            
            opr_val = self.expr.evaluate(opr_inputs) & ((1 << self.operation_bitwidth) - 1)
            out_val |= opr_val
        
        return out_val
        

    def get_bitwidth(self) -> int:
        return self.total_bitwidth


class Z3Constraint(Constraint):

    @abstractmethod
    def evaluate(self, inputs: Dict[str, Tuple[int, Number]]) -> bool:
        pass

    @abstractmethod
    def construct_constraints(self, inputs: Dict[str, Tuple[int, Number]]): # inputs: name --> bitwidth, value or None
        pass

    @abstractmethod
    def visit(self, visitor):
        """
        applies the given visitor function to all top-level expression and all sub-expressions.
        The function returns a new expression to replace the expression with (or the original expression if replacement is not necessary).
        """
        pass

    def replace(self, target, replacement):
        """
        creates a copy of the equation with one subexpression replaced with a different one.
        If no replacement was necessary, the original equation (self) is returned.
        """
        visitor = Z3Expression.replace_visitor(target, replacement)
        return self.visit(visitor)

    def conditional_replace(self, predicate, replacement):
        """
        Convenience method to replace all subexpressions matching a condition with a new one in a new returned top expression
        """
        return self.visit(Z3Expression.conditional_replace_visitor(predicate, replacement))

    def accumulate(self, predicate):
        """
        Convenience method to accumulate all subexpressions that match a condition
        """
        accumulation, visitor = Z3Expression.accumulate_visitor(predicate)
        self.visit(visitor)
        return accumulation

    def get_constants(self) -> List[Z3ConstantExpression]:
        constants, visitor = Z3Expression.accumulate_visitor(lambda x: isinstance(x, Z3ConstantExpression))
        self.visit(visitor)
        return constants

    def get_constants_by_value(self, value) -> List[Z3ConstantExpression]:
        constants, visitor = Z3Expression.accumulate_visitor(lambda x: isinstance(x, Z3ConstantExpression) and x.value == value)
        self.visit(visitor)
        return constants

    def get_usages(self, name: str) -> List[Z3UsageExpression]:
        usages, visitor = Z3Expression.accumulate_visitor(lambda x: isinstance(x, Z3UsageExpression) and x.name == name)
        self.visit(visitor)
        return usages

    def get_usage_names(self) -> List[str]: # TODO: add to prototype
        usages, visitor = Z3Expression.accumulate_visitor(lambda x: isinstance(x, Z3UsageExpression))
        self.visit(visitor)
        return set([x.name for x in usages])

    def create_constant(self, value: Number, bitwidth: int) -> Z3ConstantExpression:
        return Z3ConstantExpression(value, bitwidth)

    def create_usage(self, bitwidth: int, name: str) -> Z3UsageExpression:
        return Z3UsageExpression(bitwidth, name)

    def is_simple(self):
        return False

    @abstractmethod
    def __str__(self):
        pass
        
    
    def __repr__(self):
        return str(self)

class Z3AndConstraint(Z3Constraint):
    
    def __init__(self, a, b):
        self.a = a
        self.b = b
        
    def __str__(self):
        return f"({self.a} && {self.b})"
    
    def evaluate(self, inputs: Dict[str, Tuple[int, Number]]) -> bool:
        return self.a.evaluate(inputs) and self.b.evaluate(inputs)

    def construct_constraints(self, inputs: Dict[str, Tuple[int, Number]]):
        return z3.And(self.a.construct_constraints(inputs), self.b.construct_constraints(inputs))
    
    def visit(self, visitor):
        new_a = self.a.visit(visitor)
        new_b = self.b.visit(visitor)
        if new_a is self.a and new_b is self.b:
            return visitor(self)
        return visitor(Z3AndConstraint(new_a, new_b))
    
    def is_simple(self):
        return False


class Z3OrConstraint(Z3Constraint):
    
    def __init__(self, a, b):
        self.a = a
        self.b = b
        
    def __str__(self):
        return f"({self.a} || {self.b})"
    
    def evaluate(self, inputs: Dict[str, Tuple[int, Number]]) -> bool:
        return self.a.evaluate(inputs) or self.b.evaluate(inputs)

    def construct_constraints(self, inputs: Dict[str, Tuple[int, Number]]):
        return z3.Or(self.a.construct_constraints(inputs), self.b.construct_constraints(inputs))
    
    def visit(self, visitor):
        new_a = self.a.visit(visitor)
        new_b = self.b.visit(visitor)
        if new_a is self.a and new_b is self.b:
            return visitor(self)
        return visitor(Z3OrConstraint(new_a, new_b))
    
    def is_simple(self):
        return False

class Z3SimpleConstraint(Z3Constraint):
    
    def __init__(self, target, bitwidth, alignment, maximum, minimum=0): # maximum is included here (otherwise we get problems with bitwidth!)
        self.target = target if isinstance(target, Z3UsageExpression) else Z3UsageExpression(bitwidth, target)
        self.bitwidth = bitwidth
        self.alignment = alignment if isinstance(alignment, Z3ConstantExpression) else Z3ConstantExpression(bitwidth, alignment)
        self.maximum = maximum if isinstance(maximum, Z3ConstantExpression) else Z3ConstantExpression(bitwidth, maximum)
        self.minimum = minimum if isinstance(minimum, Z3ConstantExpression) else Z3ConstantExpression(bitwidth, minimum)
    
    def __str__(self):
        return f"({self.target} <= {self.maximum} && {self.target} % {self.alignment} == 0 && {self.target} >= {self.minimum})"

    def evaluate(self, inputs: Dict[str, Tuple[int, Number]]) -> bool:
        value = self.target.evaluate(inputs)
        if self.alignment.value == 0:
            print(f"[WARNING] simple constraint with mod 0")
            return False
        return (value % self.alignment.value) == 0 and (value <= self.maximum.value) and (value >= self.minimum.value)

    def construct_constraints(self, inputs: Dict[str, Tuple[int, Number]]): # inputs: name --> bitwidth, value or None
        bitwidth, value = inputs[self.target.name]
        if value is not None:
            return z3.BoolVal(self.evaluate(inputs)) # easy to simplify if value is assigned
        value = z3.BitVec(self.target.name, bitwidth)
        return z3.And(z3.ULE(value, self.maximum.construct_constraints(inputs)), ((value % self.alignment.construct_constraints(inputs)) == z3.BitVecVal(0, bitwidth)), z3.ULE(self.minimum.construct_constraints(inputs), value))

    def visit(self, visitor):
        new_target = self.target.visit(visitor)
        new_alignment = self.alignment.visit(visitor)
        new_maximum = self.maximum.visit(visitor)
        new_minimum = self.minimum.visit(visitor)
        
        if new_target is self.target and new_alignment is self.alignment and new_maximum is self.maximum and new_minimum is self.minimum:
            return self
        
        return Z3SimpleConstraint(new_target, self.bitwidth, new_alignment, new_maximum, minimum=new_minimum)
    
    def is_simple(self):
        return True
    
    def get_vals(self):
        return self.alignment.value, self.maximum.value, self.minimum.value
    
    def combine(self, alignment, maximum, minimum):
        return max(self.alignment.value, alignment), min(self.maximum.value, maximum), max(self.minimum.value, minimum)

class Z3SimpleEqualsConstraint(Z3Constraint):

    def __init__(self, target, bitwidth, value):
        self.target = target if isinstance(target, Z3Expression) else Z3UsageExpression(bitwidth, target)
        self.bitwidth = bitwidth
        self.value = value if isinstance(value, Z3Expression) else Z3ConstantExpression(bitwidth, value)
    
    def __str__(self):
        return f"({self.target} == {self.value})"

    def evaluate(self, inputs: Dict[str, Tuple[int, Number]]) -> bool:
        value = self.target.evaluate(inputs)
        return self.value.evaluate(inputs) == value

    def construct_constraints(self, inputs: Dict[str, Tuple[int, Number]]): # inputs: name --> bitwidth, value or None
        bitwidth, value = inputs[self.target.name]
        if value is not None:
            return z3.BoolVal(self.evaluate(inputs)) # easy to simplify if value is assigned
        value = z3.BitVec(self.target.name, bitwidth)
        return value == self.value.construct_constraints(inputs)

    def visit(self, visitor):
        new_target = self.target.visit(visitor)
        new_value  = self.value.visit(visitor)
        
        if new_target is self.target and new_value is self.value:
            return self
        
        return Z3SimpleEqualsConstraint(new_target, self.bitwidth, new_value)
    
    def is_simple(self):
        return True
        
    def get_vals(self):
        return 1, self.value.value, self.value.value
    
    def combine(self, alignment, maximum, minimum):
        return alignment, min(maximum, self.value.value), max(minimum, self.value.value)

class Z3NotConstraint(Z3Constraint):
    
    def __init__(self, target):
        self.target = target
    
    def visit(self, visitor):
        new_target = self.target.visit(visitor)
        if new_target is self.target:
            return self
        return visitor(Z3NotConstraint(new_target))
    
    def evaluate(self, inputs: Dict[str, Tuple[int, Number]]) -> bool:
        return not self.target.evaluate(inputs)
    
    def construct_constraints(self, inputs: Dict[str, Tuple[int, Number]]):
        return z3.Not(self.target.construct_constraints(inputs))
    
    def __str__(self):
        return f"(not {self.target})"

class Z3BinaryConstraint(Z3Constraint):

    def __init__(self, left, right):
        self.left = left
        self.right = right

    def visit(self, visitor):
        new_left = self.left.visit(visitor)
        new_right = self.right.visit(visitor)
        if new_left is self.left and new_right is self.right:
            return self
        return visitor(self.__class__(new_left, new_right))

    def evaluate(self, inputs: Dict[str, Tuple[int, Number]]) -> bool:
        return self.operation(
            self.left.evaluate(inputs),
            self.right.evaluate(inputs)
        )

    def construct_constraints(self, inputs: Dict[str, Tuple[int, Number]]):
        # TODO: make sure bitlens match (and if they don't try to extend / truncate)
        return self.operation(self.left.construct_constraints(inputs), self.right.construct_constraints(inputs))

    def __str__(self):
        return f"({self.left} {self.symbol} {self.right})"

Z3_BINARY_CONSTRAINTS = list()

def constructZ3BinaryConstraint(name, symbol, operation):

    global Z3_BINARY_CONSTRAINTS

    class Z3BinaryConstraintImpl(Z3BinaryConstraint):
        def __init__(self, left, right):
            self.operation = operation
            self.symbol = symbol
            super().__init__(left, right)

    # Idk why python allows me to do this, but since it does ...
    Z3BinaryConstraintImpl.__qualname__ = f"Z3{name}Constraint"
    Z3_BINARY_CONSTRAINTS.append(Z3BinaryConstraintImpl)

    return Z3BinaryConstraintImpl

Z3EqualsConstraint = constructZ3BinaryConstraint("Equals", "==", lambda a,b: a == b)
Z3NotEqualsConstraint = constructZ3BinaryConstraint("NotEquals", "!=", lambda a,b: a != b)
Z3LessConstraint = constructZ3BinaryConstraint("Less", "<", lambda a,b: z3.ULT(a, b))
Z3LessConstraint.evaluate = lambda x, a: x.left.evaluate(a) < x.right.evaluate(a) # TODO: fix for negative values
Z3SignedLessConstraint = constructZ3BinaryConstraint("SignedLess", "<s", lambda a,b: a < b)
def to_signed(bitwidth, x):
    if x < 0:
        return x
    if x & (1 << (bitwidth - 1)):
        return x - (1 << bitwidth)
    return x
Z3SignedLessConstraint.evaluate = lambda x, a: to_signed(x.left.get_bitwidth(), x.left.evaluate(a)) < to_signed(x.right.get_bitwidth(), x.right.evaluate(a))


# TODO: this method is not used!
def iterate_Z3ConstraintsToDepth(max_depth, relevant_inputs, bitwidth):
    global Z3_BINARY_CONSTRAINTS
    for a in iterate_Z3ExpressionsToDepth(max_depth, relevant_inputs, bitwidth):
        for b in iterate_Z3ExpressionsToDepth(max_depth, relevant_inputs, bitwidth):
            for c in Z3_BINARY_CONSTRAINTS:
                yield c(a, b)

def iterate_Z3ConstraintsAtDepth(depth, relevant_inputs, bitwidth): # TODO: bitwidth should not be used!
    for a in iterate_Z3Expressions(depth - 1, relevant_inputs, bitwidth):
        for b in iterate_Z3Expressions(depth - 1, relevant_inputs, bitwidth):
            if a.get_bitwidth() != b.get_bitwidth():
                if a.get_bitwidth() > b.get_bitwidth():
                    tmp = a
                    a = b
                    b = tmp
                for c in Z3_BINARY_CONSTRAINTS:
                    yield c(Z3ExtendExpression(b.get_bitwidth(), a, True), b)
                    yield Z3NotConstraint(c(Z3ExtendExpression(b.get_bitwidth(), a, True), b))
                    yield c(Z3ExtendExpression(b.get_bitwidth(), a, False), b)
                    yield Z3NotConstraint(c(Z3ExtendExpression(b.get_bitwidth(), a, False), b))
            else:
                for c in Z3_BINARY_CONSTRAINTS:
                    yield c(a, b)
                    yield Z3NotConstraint(c(a, b))

class Z3Solver(Solver):

    def __init__(self, max_depth=3, preferred_bitwidth=64, debug=False, display_progress=False):
        super().__init__()
        self.max_depth = max_depth
        self.preferred_bitwidth = preferred_bitwidth
        self.debug = debug
        self.display_progress = display_progress

    # TODO: also put this into the prototype!
    def equation_usage(self, bitwidth: int, name: str) -> Z3Equation:
        return Z3Equation(Z3UsageExpression(bitwidth, name))
    
    # TODO: also put this into the prototype!
    def equation_constant(self, bitwidth: int, value: int) -> Z3Equation:
        return Z3Equation(Z3ConstantExpression(bitwidth, value))


    def find_equation(self, input_samples: List[Dict[str, FixedInteger]], output_samples: List[int]) -> Z3Equation|None:
        # TODO: also use output bitwidth here!!
        relevant_inputs = inputs_samples[0]
        for depth in range(1, self.max_depth + 1):

            if self.debug or self.display_progress:
                print(f"depth {depth}/{self.max_depth}")

            if self.display_progress:
                amount = get_amount_of_expressions(depth, len(relevant_inputs))
                f = IntProgress(min=0, max=amount)
                display(f)

            for possible_expr in iterate_Z3ExpressionsAtDepth(depth, relevant_inputs, self.preferred_bitwidth):
                solver = z3.Solver()
                solver.set("timeout", 100) # limit SAT solving time to 100 milliseconds. for each individual possible expression
                constants = dict()
                for const in possible_expr.accumulate(lambda expr: isinstance(expr, Z3UsageExpression) and expr.name.startswith("const_")):
                    constants[const.name] = (self.preferred_bitwidth, None)
                for input_sample, output_sample in zip(input_samples, output_samples):
                    solver.add(z3.BitVecVal(output_sample, self.preferred_bitwidth) == possible_expr.construct_constraints(input_sample | constants))
                result = solver.check()
                if result == z3.unknown:
                    print(f"unknown for {possible_expr}")
                
                if result == z3.sat:
                    model = solver.model()
                    for constant, x in constants.items():
                        bitwidth, _ = x
                        replacement = Z3ConstantExpression(bitwidth, model[z3.BitVec(constant, bitwidth)].as_long())
                        possible_expr = possible_expr.conditional_replace(lambda expr: isinstance(expr, Z3UsageExpression) and expr.name == constant, replacement)
                    if self.debug:
                        print(f"found expression: {possible_expr}")
                    return Z3Equation(possible_expr)
                if self.display_progress:
                    f.value += 1
                    # if f.value % 500 == 0:
                    #     print(f"tested {f.value}")
        if self.debug:
            print("no working expression found1")
        # implicitly return None if nothing is found


# TODO: implement this more clean
# TODO: add stuff for:
#   * constrained sampling (kind of done)
#   * creating constraints
#   * creating equations
#   * 
class Z3Solver(Z3Solver):

    def __init__(self, max_depth=3, preferred_bitwidth=64, debug=False, display_progress=False):
        super().__init__(max_depth, preferred_bitwidth, debug, display_progress)
        self.strategies = {
            "trivial": self._trivial_check,
            # "z3": self._z3_check,
            "single bruteforce": self._single_bruteforce_check,
            # "double bruteforce": self._double_bruteforce_check,
        }
        self.constraint_strategies = {
            "trivial": self._trivial_check, # TODO: make sure this really also works with constraints
            "single bruteforce": self._single_bruteforce_check,
            # "z3": self._z3_check,
            # "double bruteforce": self._double_bruteforce_check,
        }
    
    def _get_constants(self, possible_expr):
        constants = dict()
        for const in possible_expr.accumulate(lambda expr: isinstance(expr, Z3UsageExpression) and expr.name.startswith("const_")):
            constants[const.name] = (const.bitwidth, None)
        return constants
    
    def _trivial_check(self, possible_expr, possible_constants, input_samples: List[Dict[str, FixedInteger]], output_samples: List[int]|List[bool], bw):
        constants = self._get_constants(possible_expr)
        for cval in [0, 1, -1]:
            cenv = dict()
            for name, x in constants.items():
                cenv[name] = (x[0], cval) # TODO: match bitwidth to constant
            for input_sample, output_sample in zip(input_samples, output_samples):
                if possible_expr.evaluate(input_sample|cenv) != output_sample:
                    if len(constants) == 0:
                        return False
                    break
            else:
                return cenv

    def _single_bruteforce_check(self, possible_expr, possible_constants, input_samples: List[Dict[str, FixedInteger]], output_samples: List[int]|List[bool], bw):
        # bruteforce a single constant and fix all others to 0, 1, or -1
        constants = self._get_constants(possible_expr)
        cenv = dict()
        for name, x in constants.items():
            for other_val in [0, 1, -1]:
                for cval in possible_constants:
                    cenv[name] = (x[0], cval)
                    for other_name, other_x in constants.items():
                         if other_name != name:
                             cenv[other_name] = (other_x[0], other_val)
                    for input_sample, output_sample in zip(input_samples, output_samples):
                        if possible_expr.evaluate(input_sample|cenv) != output_sample:
                            if len(constants) == 0:
                                return False
                            break
                    else:
                        return cenv
                if len(constants) == 1:
                    return False
        return None

    def _z3_check(self, possible_expr, possible_constants, input_samples: List[Dict[str, FixedInteger]], output_samples: List[int]|List[bool], bw) -> Z3Equation|None:
        constants = self._get_constants(possible_expr)
        solver = z3.Solver()
        solver.set("timeout", 100) # limit SAT solving time to 100 milliseconds. for each individual possible expression
        # constraint all constants to possible constant values
        for name, x in constants.items():
            constraint = True
            for const in possible_constants:
                constraint = constraint or (z3.BitVec(name, x[0]) == z3.BitVecVal(const, x[0]))
            solver.add(constraint)
        for input_sample, output_sample in zip(input_samples, output_samples):
            if isinstance(output_sample, bool):
                # solving for constraint
                solver.add(z3.BoolVal(output_sample) == possible_expr.construct_constraints(input_sample | constants))
            else:
                # solving for expression
                solver.add(z3.BitVecVal(output_sample, bw) == possible_expr.construct_constraints(input_sample | constants))
        result = solver.check()
        if result == z3.unknown:
            return None
        if result == z3.sat:
            env = dict()
            model = solver.model()
            for constant, x in constants.items():
                bitwidth, _ = x
                env[constant] =  (bitwidth, model[z3.BitVec(constant, bitwidth)].as_long())
            return env 
        return False

    def _double_bruteforce_check(self, possible_expr, possible_constants, input_samples: List[Dict[str, FixedInteger]], output_samples: List[int], bitwidth) -> Z3Equation|None:
        constants = self._get_constants(possible_expr)
        if len(constants) != 2:
            return None
        cenv = dict()
        ca, cb = tuple(constants.items())
        for cval_a in possible_constants:
            for cval_b in possible_constants:
                cenv[ca[0]] = (ca[1][0], cval_a)
                cenv[cb[0]] = (cb[1][0], cval_b)
                for input_sample, output_sample in zip(input_samples, output_samples):
                    if possible_expr.evaluate(input_sample|cenv) != output_sample:
                        if len(constants) == 0:
                            return False
                        break
                else:
                    return cenv
        return False
    
    def find_constraint_with_constants(self, possible_constants, input_samples: List[Dict[str, FixedInteger]], output_samples: List[bool], bitwidth=64) -> Z3Constraint|None:
        stats.calls_to_solve_constraint += 1
        relevant_inputs = input_samples[0]
        for depth in range(1, self.max_depth + 1): # TODO: maybe use a different max_depth here?
            possible_constraints = list(iterate_Z3ConstraintsAtDepth(depth, relevant_inputs, bitwidth))
            
            for strategy_name, strategy in self.constraint_strategies.items():
                if not len(possible_constraints):
                    break
                to_remove = []
                for i, possible_constr in enumerate(possible_constraints):
                    check_result = strategy(possible_constr, possible_constants, input_samples, output_samples, bitwidth)
                    if check_result is None:
                        continue
                    if check_result is False:
                        to_remove.append(i)
                    else:
                        for constant, x in check_result.items():
                            replacement = Z3ConstantExpression(*x)
                            # TODO: make sure we can do this on constraints (and not just expressions)
                            possible_constr = possible_constr.conditional_replace(lambda expr: isinstance(expr, Z3UsageExpression) and expr.name == constant, replacement)
                        return possible_constr # TODO: make sure we don't use a wrapper type here
                for rm in to_remove[::-1]:
                    del possible_constraints[rm]
                    
    def find_equation_with_constants(self, possible_constants, input_samples: List[Dict[str, FixedInteger]], output_samples: List[int], bitwidth=64) -> Z3Equation|None:
        stats.calls_to_solve_equation += 1
        if self.debug:
            print(f"possible constants: {len(possible_constants)}")
        # TODO: make solver outputs a list of FixedInteger instead of int, then we don't have to rely on hacky preferred_bitwidth!
        relevant_inputs = input_samples[0]
        for depth in range(1, self.max_depth + 1):
            if self.debug or self.display_progress:
                print(f"depth {depth}/{self.max_depth}")
            # TODO: what should we do on greater depth, when this becomes BIG?
            possible_expressions = list(iterate_Z3ExpressionsAtDepth(depth, relevant_inputs, bitwidth))
            for strategy_name, strategy in self.strategies.items():
                if not len(possible_expressions):
                    break
                to_remove = []
                if self.debug or self.display_progress:
                    print(f" -> strategy {strategy_name} ({len(possible_expressions)} expressions)")
                if self.display_progress:
                    progress_val = 0
                if self.display_progress:
                    amount = len(possible_expressions)
                    # f = IntProgress(min=0, max=amount)
                    # display(f)
                for i, possible_expr in enumerate(possible_expressions):
                    if self.display_progress:
                        progress_val += 1
                        if (100 * progress_val // len(possible_expressions)) != (100 * (progress_val - 1) // len(possible_expressions)):
                            # f.value = progress_val
                            print(f"progress: {100 * progress_val // len(possible_expressions)}%")
                    check_result = strategy(possible_expr, possible_constants, input_samples, output_samples, bitwidth)
                    if check_result is None:
                        continue
                    if check_result is False:
                        to_remove.append(i)
                    else:
                        for constant, x in check_result.items():
                            replacement = Z3ConstantExpression(*x)
                            possible_expr = possible_expr.conditional_replace(lambda expr: isinstance(expr, Z3UsageExpression) and expr.name == constant, replacement)
                        if self.debug:
                            print(f"found expression {possible_expr}")
                        return Z3Equation(possible_expr)
                if self.debug:
                    print(f" -> checked {len(to_remove)} expressions")
                for rm in to_remove[::-1]:
                    del possible_expressions[rm]
        if self.debug:
            print("no working expression found!")

    def solve_partial_input(self, constraints: List[Constraint], partial_input: Dict[str, Tuple[int, Number|None]]) -> Dict[str, Tuple[int, Number]]|None:
        solver = z3.Solver()
        solver.set("timeout", 100)
        for constraint in constraints:
            c = constraint.construct_constraints(partial_input)
            solver.add(c)
        if solver.check() == z3.sat:
            model = solver.model()
            inputs = dict()
            for name, x in partial_input.items():
                bitlen, value = x
                if value is None:
                    mv = model[z3.BitVec(name, bitlen)]
                    if mv is None:
                        # TODO: add a mode for debug prints
                        # print(f"[WARNING] no constraint on {name}") # TODO: use provided choices!
                        inputs[name] = (bitlen, random.randrange(1 << bitlen))
                    else:
                        inputs[name] = (bitlen, mv.as_long())
                else:
                    inputs[name] = x
            return inputs
        # implicit return None
    
    
    def constraint_not(self, a: Constraint) -> Constraint:
        return Z3NotConstraint(a)
    
    def create_alignment_constraint(self, register: str, bitwidth: int, alignment: int) -> Z3SimpleConstraint:
        return Z3SimpleConstraint(register, bitwidth, alignment, (1 << bitwidth) - 1)
    
    def create_less_constraint(self, register: str, bitwidth: int, value: int) -> Z3SimpleConstraint:
        return Z3SimpleConstraint(register, bitwidth, 1, value - 1)
    
    def create_equal_constraint(self, register: str, bitwidth, value: int) -> Z3SimpleEqualsConstraint:
        return Z3SimpleEqualsConstraint(register, bitwidth, value)
    
    def create_greatereq_constraint(self, register: str, bitwidth: int, value: int) -> Z3SimpleConstraint:
        return Z3SimpleConstraint(register, bitwidth, 1, (1 << bitwidth) - 1, minimum=value)
    
    def constraint_equal(self, a: Equation, b: Equation) -> Constraint:
        return Z3EqualsConstraint(a.top_expression, b.top_expression)
    
    def constraint_less(self, a: Equation, b: Equation) -> Constraint:
        return Z3LessConstraint(a.top_expression, b.top_expression)
    
    def equation_and(self, a: Equation, b: Equation) -> Equation:
        return Z3Equation(Z3AndExpression(False, a.top_expression, b.top_expression)) # TODO also with sign extend?
    
    def equation_usage(self, bitwidth: int, name: str) -> Equation:
        return Z3Equation(Z3UsageExpression(bitwidth, name))
    
    def equation_constant(self, bitwidth: int, value: int) -> Equation:
        return Z3Equation(Z3ConstantExpression(bitwidth, value))
    
    def create_random_sample(self, possible_constants, state: "ArchitecturalState", constraints: List[Constraint], register_constraints: Dict[str, List[Constraint]],  partial_sample: Dict[str, Tuple[int, Number]]) -> bool:
        
        sample = dict()
        register_constraints = {
            x[0]: list(x[1]) for x in register_constraints.items()
        }
        
        c = list(possible_constants) # TODO: inefficient
        
        choices = [random.randrange(1 << 512) for _ in range(3)] + [random.choice(c) for _ in range(6)] + [0, 1, (1 << self.preferred_bitwidth) - 1, (1 << (self.preferred_bitwidth - 1)) - 1] 
        
        all_constraints = constraints + [c for x in register_constraints.values() for c in x ] + [c for reg in state.register_prototypes.values() for c in reg.constraints]
        
        # TODO: ensure no mappings overlap
        # TODO: when taking a choice, maybe add random (small-ish) offset to improve solving of comparisons with immediates / stuff with offsets?
        
        # sample registers according to simple constraints (whenever possible)
        for reg_name, reg in state.register_prototypes.items():
            if reg.constraints:
                if reg_name in register_constraints and register_constraints[reg_name] is not None:
                    register_constraints[reg_name] += reg.constraints
                else:
                    register_constraints[reg_name] = list(reg.constraints)
        for name, reg_constraints in register_constraints.items():
            if name in partial_sample and partial_sample[name] is not None:
                # provided values should not be overwritten
                # TODO: maybe check if provided value already violates a constraint -> early abort?
                continue
            if any(map(lambda x: not x.is_simple(), reg_constraints)):
                # some constraint is not simple to sample, so we need to do it the "hard" way below
                print(f"[WARNING] Register constraints for {name} are not all simple: {reg_constraints}")
                continue
            
            if len(reg_constraints) == 0:
                # no constraints, so nothing clever to do
                continue
            
            alignment, maximum, minimum = reg_constraints[0].get_vals()
            
            for reg_constraint in reg_constraints[1:]:
                alignment, maximum, minimum = reg_constraint.combine(alignment, maximum, minimum)
            
            if minimum > maximum or (alignment > 1 and alignment > maximum and (minimum != 0)):
                raise Exception(f"impossible simple constraints for register {name}: {reg_constraints}")
            
            choice = (random.choice(choices) % (maximum + 1))
            
            # TODO: this is a dirty fix. Make this better!
            if name == "pc" and minimum != maximum:
                choice = random.randrange(1 << 40) * 4 + 0x10008
            
            choice -= choice % alignment
            choice = max(minimum + ((alignment - (minimum % alignment)) % alignment), choice)
            sample[name] = (reg_constraints[0].bitwidth, choice)
        
        
        
        # assign everything else randomly
        for reg_name, x in state.register_values.items():
            if reg_name in sample or (reg_name in partial_sample and partial_sample[reg_name] is not None):
                # provided values and values sampled according to simple constraints should not be overwritten
                continue
            bitwidth, value = x
            choice = random.choice(choices) % (1 << bitwidth)
            sample[reg_name] = (bitwidth, choice)
        
        # try to solve, if unable relax values until solvable (or return False if partial sample cannot be solved)
        env = partial_sample | sample

        if all(map(lambda x: x.evaluate(env), all_constraints)) :
            for name, value in env.items():
                state.set_register(name, value[1])
            return True
        
        # really first try to relax fake registers with constraints
        for reg in state.register_prototypes.values():
            if reg.is_fake and reg.name not in partial_sample and reg.constraints:
                env[reg.name] = (reg.bitwidth, None)
        
        result = self.solve_partial_input(all_constraints, env)
        if result is not None:
            for name, value in result.items():
                state.set_register(name, value[1])
            return True        


        # First, try to relax only a single variable and fullfill constraints.
        for reg_name in state.register_values.keys():
            value = env[reg_name]
            env[reg_name] = (value[0], None)
            result = self.solve_partial_input(all_constraints, env)
            if result is not None:
                for name, value in result.items():
                    state.set_register(name, value[1])
                return True
            env[reg_name] = value

        # If not possible, randomly relax variables until constraints can be fullfilled
        available = list(state.register_values.keys() - partial_sample.keys())
        while available:
            reg_name = random.choice(available)
            available.remove(reg_name)
            
            env[reg_name] = (env[reg_name][0], None)
            result = self.solve_partial_input(all_constraints, env)
            if result is not None:
                for name, value in result.items():
                    state.set_register(name, value[1])
                return True
        
        # Unable to fulfill constraints even after trying a lot
        return False
