import instr_rev
import instr_rev_z3 # TODO: get rid of this !!!
import math
import abc

import time
import stats
import constants

from typing import List, Dict, Tuple, Any
from numbers import Number

ALLOW_NONDETERMINISM = False

# TODO: we could minmize semantic trees a bit sometimes: if the condition is a not expression and both, ifTrue and ifFalse are present, we could remove the not part and swap ifTrue and ifFalse
# TODO: solving of nested ifs seems broken (even if both cases are corectly recovered in the inner if, it is not seen as valid)

class FakeConstraint(instr_rev.Constraint):
    
    def __init__(self, content):
        self.content = content
    
    def evaluate(self, inputs: Dict[str, Tuple[int, Number]]) -> bool:
        raise Exception("FakeConstraint cannot do this")

    def replace(self, target: Any, replacement: Any) -> instr_rev.Constraint:
        raise Exception("FakeConstraint cannot do this")
    
    def evaluate(self, inputs: Dict[str, instr_rev.FixedInteger]) -> int|bool:
        raise Exception("FakeConstraint cannot do this")

    def replace(self, target: Any, replacement: Any) -> instr_rev.Function:
        raise Exception("FakeConstraint cannot do this")

    def get_constants(self) -> List[instr_rev.Expression]:
        raise Exception("FakeConstraint cannot do this")

    def get_usages(self, name: str) -> List[instr_rev.Expression]:
        raise Exception("FakeConstraint cannot do this")

    def create_constant(self, bitwidth: int, value: int) -> instr_rev.Expression:
        raise Exception("FakeConstraint cannot do this")

    def create_usage(self, bitwidth: int, name: str) -> instr_rev.Expression:
        raise Exception("FakeConstraint cannot do this")

    def __str__(self):
        return self.content
    
    def __repr__(self):
        return self.content

# constant replacement: int -> SemanticsTree (returns a new semantics tree with the constant replaced)
# register replacement: str -> SemanticsTree (returns a new semantics tree with the register replaced)

    

class SemanticsTree:

    def __init__(self, parent):
        self.ifTrue = None
        self.ifFalse = None
        self.condition = None
        self.equations = None
        self.parent = parent
        self.pseudo = False # pseudo nodes are only to make things "look" good, but their conditions and equations are ignored (condition is always assumed to be true).
        # TODO: this is a dirty hack
        self.reg_renames = dict()

    def copy(self):
        # shallow copy
        copy = SemanticsTree(self.parent)
        copy.condition = self.condition
        copy.equations = self.equations.copy()
        copy.parent = self.parent
        copy.pseudo = self.pseudo
        copy.reg_renames = self.reg_renames.copy()
        if self.ifTrue is not None:
            copy.ifTrue = self.ifTrue.copy()
            copy.ifTrue.parent = copy
        if self.ifFalse is not None:
            copy.ifFalse = self.ifFalse.copy()
            copy.ifFalse.parent = copy
        return copy

    @staticmethod
    def create_leaf(parent=None) -> "SemanticsTree":
        leaf = SemanticsTree(parent)
        leaf.equations = dict()
        return leaf
    
    def split(self, condition: instr_rev.Constraint):
        if self.ifTrue is not None:
            raise Exception("Already split!")
        self.condition = condition
        self.ifTrue = SemanticsTree.create_leaf(parent=self)
        self.ifFalse = SemanticsTree.create_leaf(parent=self)
        return self.ifTrue, self.ifFalse
    
    def get_output_registers(self, accu=None):
        if accu is None:
            accu = set()
        
        accu |= self.equations.keys()
        
        if self.ifTrue is not None:
            self.ifTrue.get_output_registers(accu=accu)
        if self.ifFalse is not None:
            self.ifFalse.get_output_registers(accu=accu)
        
        return accu
    
    def add_equation(self, output: str, equation: instr_rev.Equation):
        self.equations[output] = equation
    
    def get_equation(self, inputs, output):
        if output in self.equations:
            return self.equations[output]
        if self.condition is not None:
            if self.condition.evaluate(inputs):
                return self.ifTrue.get_equation(inputs, output)
            else:
                return self.ifFalse.get_equation(inputs, output)
        if self.ifTrue is not None:
            return self.ifTrue.get_equation(inputs, output)
        return None
    
    def _evaluate(self, env, results, fakes):
        for name, equation in self.equations.items():
            results[name] = equation.evaluate(env)
            if name in fakes:
                # TODO: only replace fake registers? idk
                env[name] = (equation.get_bitwidth(), results[name])
        
        if self.condition is not None:
            if self.condition.evaluate(env):
                if self.ifTrue is not None:
                    self.ifTrue._evaluate(env, results, fakes)
            else:
                if self.ifFalse is not None:
                    self.ifFalse._evaluate(env, results, fakes)
        
        if self.condition is None and self.ifTrue is not None:
            self.ifTrue._evaluate(env, results, fakes)

    def evaluate(self, env, fakes):
        results = dict()
        self._evaluate(env, results, fakes)
        return results

    
    def get_registers(self):
        regs = set()
        
        for name, equation in self.equations.items():
            regs.add(name)
            regs |= equation.get_usage_names()
        
        if self.condition is not None:
            regs |= self.condition.get_usage_names()
        
        if self.ifTrue is not None:
            regs |= self.ifTrue.get_registers()

        if self.ifFalse is not None:
            regs |= self.ifFalse.get_registers()

        return regs

    def get_constants(self):
        constants = set()
        
        for equation in self.equations.values():
            constants |= set([x.value for x in equation.get_constants()])
        
        if self.condition is not None:
            constants |= set([x.value for x in self.condition.get_constants()])
        
        if self.ifTrue is not None:
            constants |= self.ifTrue.get_constants()
        
        if self.ifFalse is not None:
            constants |= self.ifFalse.get_constants()

        return constants

    def get_register_instances(self, register):
        if register in self.equations:
            def replace_result(tree, new_usage):
                t = tree.copy()
                t.reg_renames[register] = new_usage
                eq = t.equations[register]
                t.equations[new_usage] = eq
                del t.equations[register]
                return t
            yield replace_result
        
        for name, equation in self.equations.items():
            for usage in equation.get_usages(register):
                def wrapper(name, usage, equation):
                    def replace_usage(tree, new_usage):
                        t = tree.copy()
                        t.equations[t.reg_renames.get(name, name)] = t.equations[t.reg_renames.get(name, name)].replace(usage, equation.create_usage(usage.get_bitwidth(), new_usage)) # TODO: allow to replace with different bitwidths?
                        return t
                    return replace_usage
                yield wrapper(name, usage, equation)
        
        if self.condition is not None:
            for usage in self.condition.get_usages(register):
                def wrapper(usage):
                    def replace_usage(tree, new_usage):
                        t = tree.copy()
                        t.condition = tree.condition.replace(usage, tree.condition.create_usage(usage.get_bitwidth(), new_usage)) # TODO: allow to replace with different bitwidths?
                        return t
                    return replace_usage
                yield wrapper(usage)
        
        if self.ifTrue is not None:
            for replacer in self.ifTrue.get_register_instances(register):
                def wrapper(replacer):
                    def replace_usage(tree, new_usage):
                        t = tree.copy()
                        t.ifTrue = replacer(t.ifTrue, new_usage)
                        return t
                    return replace_usage
                yield wrapper(replacer)
        
        if self.ifFalse is not None:
            for replacer in self.ifFalse.get_register_instances(register):
                def wrapper(replacer):
                    def replace_usage(tree, new_usage):
                        t = tree.copy()
                        t.ifFalse = replacer(t.ifFalse, new_usage)
                        return t
                    return replace_usage
                yield wrapper(replacer)

    def get_constant_instances(self, value):
        for name, equation in self.equations.items():
            for constant in equation.get_constants_by_value(value):
                def wrapper(name, equation, constant):
                    def replace_constant(tree, new_value):
                        t = tree.copy()
                        repl = equation.create_constant(constant.get_bitwidth(), new_value) if isinstance(new_value, int) else equation.create_usage(constant.get_bitwidth(), new_value)
                        t.equations[name] = t.equations[name].replace(constant, repl) # TODO: allow to replace with different bitwidths?
                        return t
                    return replace_constant
                yield wrapper(name, equation, constant)
        
        if self.condition is not None:
            for constant in self.condition.get_constants_by_value(value):
                def wrapper(constant):
                    def replace_constant(tree, new_value):
                        t = tree.copy()
                        repl = equation.create_constant(constant.get_bitwidth(), new_value) if isinstance(new_value, int) else equation.create_usage(constant.get_bitwidth(), new_value)
                        t.condition = tree.condition.replace(constant, repl) # TODO: allow to replace with different bitwidths?
                        return t
                    return replace_constant
                yield wrapper(constant)
        
        if self.ifTrue is not None:
            for replacer in self.ifTrue.get_constant_instances(value):
                def wrapper(replacer):
                    def replace_constant(tree, new_value):
                        t = tree.copy()
                        t.ifTrue = replacer(t.ifTrue, new_value)
                        return t
                    return replace_constant
                yield wrapper(replacer)
        
        if self.ifFalse is not None:
            for replacer in self.ifFalse.get_constant_instances(value):
                def wrapper(replacer):
                    def replace_constant(tree, new_value):
                        t = tree.copy()
                        t.ifFalse = replacer(t.ifFalse, new_value)
                        return t
                    return replace_constant
                yield wrapper(replacer)

    def clean(self, parent=None) -> "SemanticsTree":
        # returns a copy of this tree with all pseudo nodes removed
        if self.pseudo:
            return self.ifTrue.clean()
        newSelf = SemanticsTree.create_leaf(parent=parent)
        if self.condition is not None:
            newSelf.condition = self.condition
        if self.ifTrue is not None:
            newTrue = self.ifTrue.clean(parent=newSelf)
            newSelf.ifTrue = newTrue
        if self.ifFalse is not None:
            newFalse = self.ifFalse.clean(parent=newSelf)
            newSelf.ifFalse = newFalse
        newSelf.equations = dict(self.equations)
        return newSelf   
            
    def create_pseudo(self) -> "SemanticsTree":
        if self.ifTrue is not None:
            raise Exception("Already split!")
        leaf = SemanticsTree.create_leaf(self)
        leaf.pseudo = True
        self.ifTrue = leaf
        return leaf
    
    def to_str(self, indent=0):
        space = " " * indent * 2
        pseudo_str = " # (pseudo)" if self.pseudo else ""
        res = "\n".join(
            f"{space}{x[0]} = {x[1]}{pseudo_str}" for x in self.equations.items()
        )
        
        if self.condition is not None:
            ifTrueStr = self.ifTrue.to_str(indent=indent + 1)
            ifFalseStr = self.ifFalse.to_str(indent=indent + 1)
            
            if ifTrueStr:
                if res:
                    res += "\n"
                res += f"{space}if {self.condition}:{pseudo_str}\n"
                res += ifTrueStr
                if ifFalseStr:
                    res += "\n"
                    res += f"{space}else:\n"
                    res += ifFalseStr
            elif ifFalseStr:
                if res:
                    res += "\n"
                res += f"{space}if not {self.condition}:\n"
                res += ifFalseStr
            else:
                print(f"[WARNING] condition {self.condition}, but neither true or false has equations")
        elif self.ifTrue is not None:
            ifTrueStr = self.ifTrue.to_str(indent=indent)
            if ifTrueStr:
                if res:
                    res += "\n"
                res += ifTrueStr
        return res

    

class InstructionReverser:
        
    def __init__(self, solver, runner, empty_state_f, constants_enum_f, samples=300):
        self.solver = solver
        self.runner = runner
        self.empty_state_f = empty_state_f
        self.constants_enum_f = constants_enum_f
        self.samples = samples
    
    def _check_nondeterminism(self, state_before, instruction, amount_test, amount_sample, constraints, register_constraints, output):
        sample_start = time.time()
        
        state_after  = self.empty_state_f()
        c = self.constants_enum_f(instruction) # TODO: we should probably only evaluate this once
        
        for _ in range(amount_test):
            if not self.solver.create_random_sample(c, state_before, constraints, register_constraints, dict()):
                print(constraints, register_constraints)
                raise Exception("failed to create sample!") # TODO: log more information!
            
            out_vals = set()
            
            for i in range(10):
                if not self.runner.run_instructions(instruction.to_bytes((constants.INSTR_BITWIDTH + 7 ) // 8, "little", signed=False), state_before, state_after):
                    raise Exception("failed to run!")
                out_vals.add(state_after.get_register(output))
                time.sleep(1) # for counters
            
            if len(out_vals) > 1:
                print(f"output {output} is non-deterministic!")
                # found non-determinism
                out_vals = list()
                for s in range(amount_sample):
                    if not self.runner.run_instructions(instruction.to_bytes((constants.INSTR_BITWIDTH + 7 ) // 8, "little", signed=False), state_before, state_after):
                        raise Exception("failed to run!")
                    out_vals.append(state_after.get_register(output))
                return out_vals
        
        stats.time_in_sample_solve += time.time() - sample_start
        return None
        
    
    def _collect_samples(self, state_before, instruction, amount, constraints, register_constraints, inputs, outputs, repeat=1):
        sample_start = time.time()
        input_values = list()
        output_values = list()
        
        state_after  = self.empty_state_f()
        
        c = self.constants_enum_f(instruction)
        
        failed_attempts = 0
        
        while len(input_values) < amount and failed_attempts < amount * 5:
            if not self.solver.create_random_sample(c, state_before, constraints, register_constraints, dict()):
                print(constraints, register_constraints)
                raise Exception("failed to create sample!") # TODO: log more information!

            ins = [state_before.get_register(x) for x in inputs]

            if not self.runner.run_instructions(instruction.to_bytes((constants.INSTR_BITWIDTH + 7 ) // 8, "little", signed=False), state_before, state_after): # TODO: follow by known SIGILL/SIGTRAP instruction -> runner for the platform should do this now!
                failed_attempts += 1
                continue
            
            outs = [state_after.get_register(x) for x in outputs]
            
            input_values.append(ins)
            output_values.append(outs)
        
        if failed_attempts >= amount * 5:
            raise Exception("too many sample collections failed!")
        
        stats.time_in_sample_solve += time.time() - sample_start
        
        return input_values, output_values
    
    # inputs: list of list of integers, outputs: list of list of integers (since we have multiple changing outputs and splits might be easier to detect on a single one!)
    def _find_split(self, inputs, outputs):
        # simple split: only two outputs exist
        for o in range(len(outputs[0])):
            s = set([outputs[x][o] for x in range(len(outputs))])
            if len(s) == 2:
                s = iter(s)
                a = next(s)
                b = next(s)
                yield [
                    outputs[x][o] == a for x in range(len(outputs))
                ]

        # simple split by output being the same as a specific input vs. not
        for i in range(len(inputs[0])):
            for o in range(len(outputs[0])):
                
                equal_count = 0
                for x in range(len(inputs)):
                    equal_count += inputs[x][i] == outputs[x][o]
                
                if equal_count >= 0.02 * len(inputs) and equal_count <= 0.98 * len(inputs):
                    yield [
                        inputs[x][i][1] == outputs[x][o][1] for x in range(len(inputs))
                    ]
                
                equal_count = 0
                for x in range(len(inputs)):
                    equal_count += inputs[x][i][1] + (constants.INSTR_BITWIDTH // 8) == outputs[x][o][1]
                
                if equal_count >= 0.02 * len(inputs) and equal_count <= 0.98 * len(inputs):
                    yield [
                        inputs[x][i][1] + (constants.INSTR_BITWIDTH // 8) == outputs[x][o][1] for x in range(len(inputs))
                    ]
        
        # TODO: more advanced splitting algorithms
        
    
    # TODO: maybe move splitting also to address stuff!
    def _recursive_reverse(self, instruction, s, constraints, tree, max_depth=2, known_outputs=None, c_outputs=None, error_info=False):
        
        if max_depth == 0:
            raise Exception("maximum recursive solving depth exceeded!")
    
        if known_outputs is None:
            known_outputs = list() # instanciate here so we get a new list on each invokation
    
        if c_outputs is None:
            c_outputs = list()
    
        print(f"recursive solving, remaining depth: {max_depth}")
        
        current_constraints = list(constraints)
        
        constraint_usages = [
            c.get_usage_names() for c in current_constraints
        ]
        
        # 1. determine all outputs that change
        inputs  = list(s.register_prototypes.keys()) # TODO: ensure that keys and items is both defined as insertion order!
        outputs = [f"{x[0]}_out" if x[1].is_fake else x[0] for x in s.register_prototypes.items() if "error" not in x[0]]
        
        if error_info:
            inputs += ["error_code", "error_info"]
            outputs += ["error_code", "error_info"]
        
        ins, outs = self._collect_samples(s, instruction, self.samples, current_constraints, dict(), inputs, outputs)
        
        # TODO: for changing outputs we might want to include the signal number as well?
        
        changing_outputs = list()
        for i,o in zip(ins, outs):
            for idx, x in enumerate(zip(i, o)):
                if x[0] != x[1]:
                    if outputs[idx] not in changing_outputs and outputs[idx] != "mem_addr_out":
                        changing_outputs.append(outputs[idx])
                        print(f"new output {outputs[idx]}: {x[0]} -> {x[1]}")
        
        print(f"changing outputs: {changing_outputs}")
        
        for c in c_outputs:
            if c not in changing_outputs and c not in known_outputs:
                known_outputs.append(c)
        
        remembered_relevant_inputs = dict()
        
        
        
        # for each changing output: ...
        for output in changing_outputs:
            
            equation = None
            
            # 1.1 if output is already known (on a higher level in the tree), don't try to infer it again
            if output in known_outputs: # already known
                continue
            
            # 1.2 detect non-determinism in output. 
            # TODO: make numbers configurable
            non_determinism = self._check_nondeterminism(s, instruction, 3, 1000000, current_constraints, dict(), output) if ALLOW_NONDETERMINISM else None
            if non_determinism:
                relevant_inputs = list()
                # TODO: output register prototype might not exist in input so this can fail
                bitwidth = s.register_prototypes[output].bitwidth
            
                unique = set(non_determinism)
            
                # are they non-decreasing?
                # TODO: do some more fancy math
                if sorted(non_determinism) == non_determinism and len(unique) > 10:
                    equation = instr_rev_z3.Z3Equation(instr_rev_z3.Z3NonDeterministicExpression(bitwidth, "Counter", 1))
                # is it only a few values?
                # TODO: make this configurable. This is also not that nice since we no longer have the input state if we want to do weird tests
                # TODO: on the other hand, if this is needed it can easily be added
                elif len(unique) < 256:
                    equation = instr_rev_z3.Z3Equation(instr_rev_z3.Z3NonDeterministicExpression(bitwidth, str(sorted(unique)), 2, context=len(unique)))
                # TODO: do some more fancy math
                else:
                    equation = instr_rev_z3.Z3Equation(instr_rev_z3.Z3NonDeterministicExpression(bitwidth, f"Random (at least {math.log2(len(unique))} bits)", 3, context=math.log2(len(unique))))
                    
                
            if equation is None:
                # 2. ... find input registers that influence the output
                relevant_inputs = list()
                output_index = outputs.index(output)
                for good_sample_idx in range(30): # limiting to 10 samples should be enough
                    for changing_input in inputs:
                        if changing_input in relevant_inputs: # input is already known to be relevant, so it can be skipped
                            continue
                            
                        reg_constraints = dict()
                        
                        # TODO: detect over-constraint inputs that always have the same value
                        for i in range(len(inputs)):
                            prototype = s.register_prototypes[inputs[i]]
                            # if there is some sort of constraint on the changing input and this input, we allow both to change.
                            # For instance, if the changing input is equal to some other input but we fix the other input, the changing input could never change either and will be found irrelevant even if it is relevant.
                            # TODO: only do this if sampling fails (since this costs performance)!
                            # TODO: we can also better pre-compute the dependencies
                            if any(map(lambda x: inputs[i] in x and changing_input in x, constraint_usages)):
                                continue
                            if inputs[i] != changing_input: # memory address is forced by other things anyways, so it should not be forced to be fixed!
                                reg_constraints[inputs[i]] = [self.solver.create_equal_constraint(inputs[i], prototype.bitwidth, ins[good_sample_idx][i][1])]
                        
                        _, new_outs = self._collect_samples(s, instruction, 4, current_constraints, reg_constraints, inputs, [output]) # 10 samples should be more than enough here!
                        
                        
                        if any(map(lambda x: x[0][1] != outs[good_sample_idx][output_index][1], new_outs)):
                            print(f"found relevant input for {output}: {changing_input}")
                            relevant_inputs.append(changing_input)
                            break
                
                # TODO: also find which bits of the input register(s) are used, so we can easily detect e.g. 32-bit operations with 64-bit registers WITHOUT splitting registers themselves which would probably be very expensive if done too much!
                
                # 3. ... try to infer function using relevant inputs
                # 3.1. collect samples (do NOT overwrite ins and outs since they are used in further loop iterations!)
                new_ins, new_outs = self._collect_samples(s, instruction, self.samples, current_constraints, dict(), relevant_inputs, [output])
                # 3.2. try to solve for an equation
                solver_in = [
                    {relevant_inputs[ii] : i[ii] for ii in range(len(relevant_inputs))} for i in new_ins
                ]
                solver_out = [
                    o[0][1] for o in new_outs
                ]
                print(f"relevant inputs for {output}: {relevant_inputs}")
                equation = self.solver.find_equation_with_constants(self.constants_enum_f(instruction), solver_in, solver_out, s.register_prototypes[output.replace("_out", "")].bitwidth)
            if equation is None:
                # 3.2.1 simple solving did not work, try to find bitwidth of inputs and outputs and solve more cleverly
                
                # find bitwidth of output
                out_bitwidth = max(min(x.bit_length(), (((1 << new_outs[0][0][0]) - 1) - x).bit_length() + 1) for x in solver_out)
                if out_bitwidth != new_outs[0][0][0]:
                    print(f"output bitwidth is actually {out_bitwidth}")
                
                # TODO: actually find bitwidth of inputs and not just use the same, this is a stupid, non-generic hack!
                new_solver_in = [
                    {relevant_inputs[ii] : (out_bitwidth, i[ii][1] % (1 << out_bitwidth)) for ii in range(len(relevant_inputs))} for i in new_ins
                ]
                new_solver_out = [
                    o[0][1] for o in new_outs
                ]
                equation = self.solver.find_equation_with_constants(self.constants_enum_f(instruction), new_solver_in, new_solver_out, out_bitwidth)
                
                if equation is None:
                 
                    # 3.2.2 if more clever solving still isn't enough, try to use SIMD
                    
                    # TODO: try to use SIMD?
                    print(f"could not solve for output {output}, trying conditional split later, but trying SIMD now!")
                    
                    bits = 8
                    
                    while bits < out_bitwidth:
                        if out_bitwidth % bits == 0:
                            print(f"trying SIMD {bits}")
                            split_count = out_bitwidth // bits
                            
                            simd_solver_in = list()
                            simd_solver_out = list()
                            
                            for i in range(len(new_solver_in)):
                                prev_in = new_solver_in[i]
                                for split in range(split_count):
                                    cur_in = dict()
                                    for name, val in prev_in.items():
                                        if val[0] != out_bitwidth:
                                            # bitwidth differs from output, just treat this value as not simd
                                            cur_in[name] = val
                                        else:
                                            # bitwidth is same as output, so we can split!
                                            cur_in[name] = (bits, (val[1] >> (split * bits)) & ((1 << bits) - 1))
                                    simd_solver_in.append(cur_in)
                            
                            for i in range(len(new_solver_out)):
                                for split in range(split_count):
                                    simd_solver_out.append((new_solver_out[i] >> (split * bits)) & ((1 << bits) - 1))
                            
                            
                            
                            simd_equation = self.solver.find_equation_with_constants(self.constants_enum_f(instruction), simd_solver_in, simd_solver_out, bits)
                            
                            if simd_equation is not None:
                                print(f"SIMD worked ({bits})!")
                                # TODO: add SIMD expression so stuff does not break later when clustering!
                                equation = instr_rev_z3.Z3Equation(instr_rev_z3.Z3SIMDExpression(simd_equation.top_expression, bits, out_bitwidth, new_outs[0][0][0]))
                                break
                        
                        bits += 8
                        
                    
                    remembered_relevant_inputs[output] = relevant_inputs
                
            if equation is not None: 
                # 3.3 sucessfully found an equation for this output, so add it to the tree
                known_outputs.append(output)    
                tree.add_equation(output, equation)
                undef_inputs = { # inputs with correct bitwidth, but all have no assigned value
                    name: (s.register_prototypes[name].bitwidth, None) for name in relevant_inputs
                }
                print(f"found equation: {output} = {equation.to_str(undef_inputs)}")
        
        
        # 4. for outputs that could not be reversed, we try to conditionally split and solve recursively under an added constraint
        for output in changing_outputs:
            if output in known_outputs or max_depth == 1:
                continue
            
            relevant_inputs = remembered_relevant_inputs[output]
            
            # 4.1 find possible ways to split
            new_ins, new_outs = self._collect_samples(s, instruction, self.samples, current_constraints, dict(), relevant_inputs, changing_outputs)
            
            for split in self._find_split(new_ins, new_outs):
                # 4.2 for conditional split, try to find condition that produces this split
                solver_in = [
                    {relevant_inputs[ii] : i[ii] for ii in range(len(relevant_inputs))} for i in new_ins
                ]
                
                
                # print(relevant_inputs)
                
                print(f"found split, trying to solve for constraint . . .")
                solve_result = self.solver.find_constraint_with_constants(self.constants_enum_f(instruction), solver_in, split)
                
                
                
                if solve_result is not None:
                    print(f"found constraint: {solve_result}")
                    
                    copied_tree = tree.copy()
                    
                    treeTrue, treeFalse = copied_tree.split(
                        solve_result
                    )
                    
                    known_true = list(known_outputs)
                    known_false = list(known_outputs)
                    
                    try:
                        self._recursive_reverse(instruction, s, constraints + [solve_result], treeTrue, max_depth=max_depth-1, known_outputs=known_true, c_outputs=changing_outputs, error_info=error_info)
                        if len(known_true) == len(changing_outputs):
                            # early abort if first part already failed
                            self._recursive_reverse(instruction, s, constraints + [self.solver.constraint_not(solve_result)], treeFalse, max_depth=max_depth-1, known_outputs=known_false, c_outputs=changing_outputs, error_info=error_info)
                    except:
                        import traceback
                        print(traceback.format_exc())
                        # failed to reverse here
                        continue
                    if len(known_true) == len(known_false) and len(known_true) == len(changing_outputs):
                        # TODO: make this less shitty
                        tree.ifFalse = copied_tree.ifFalse
                        tree.ifTrue = copied_tree.ifTrue
                        tree.condition = copied_tree.condition
                        tree.equations = copied_tree.equations
                        tree.parent = copied_tree.parent
                        tree.pseudo = copied_tree.pseudo
                        if tree.ifFalse is not None:
                            tree.ifFalse.parent = tree
                        if tree.ifTrue is not None:
                            tree.ifTrue.parent = tree
                        return True, tree
                else:
                    print("no constraint found for split")
                    
            
            else:
                print(f"no split found for {output}!")
                
        return len(known_outputs) == len(changing_outputs), tree
        # TODO: if this fails for some outputs, try to split outputs based on condition and recursively solve!
        # TODO: for solving, maybe first generate trees without constants (or mixed trees only up to a very low depth), so we can use very fast solving for instructions that don't have constants
        # TODO: then, maybe only use one constant since this is still easier to solve and only if this fails try with more constants.
        
        # TODO: do simd stuff
        # TODO: actually encorporate conditions
        # TODO: cluster
        
        
    
    def reverse_instruction(self, instruction, error_info=False):
        root = SemanticsTree.create_leaf()
        tree = root
        
        s = self.empty_state_f()
        
        current_constraints = [
            # self.solver.constraint_less(self.solver.equation_constant(64, 0x10007), self.solver.equation_usage(64, "pc")) # TODO: hardcoding this here is total bs
        ]
        
        ins, outs = self._collect_samples(s, instruction, self.samples, current_constraints, dict(), [], ["error_code", "error_info"])
        
        # TODO: what about e.g., conditional jumps (which use memory!). We need to also put our condition splitting stuff somewhere in.
        # Probably, the easiest way to get it (kind of) working would be to sample normally and only split based on different signals.
        
        # 1. find out whether instruction uses memory
        if any(map(lambda x: x[0][1] == 11, outs)):
            print("Instruction uses memory!")
            inputs = list(s.register_prototypes.keys())
            # 1.1 constraint registers until we get enough samples with non-null addresses
            max_bits = 46 # TODO: put back to 64
            while max_bits > 5: # less than 5-bit inputs would be very strange, so let's limit it there!
                print(f"trying to collect samples with {max_bits} bit inputs ...")
                
                reg_constraints = dict()
                
                for i in inputs:
                    prototype = s.register_prototypes[i]
                    if prototype.bitwidth > max_bits: # only limit value if necessary
                        reg_constraints[i] = [
                            self.solver.create_less_constraint(i, prototype.bitwidth, 1 << max_bits),
                        ]
                
                ins, outs = self._collect_samples(s, instruction, self.samples, current_constraints, reg_constraints, inputs, ["error_code", "error_info"])
                
                good_samples = [i for i in range(len(ins)) if outs[i][0][1] == 11 and outs[i][1][1] != 0] # indices of samples that produce a segfault at address != 0
                
                if len(good_samples) >= len([i for i in outs if i[0][1] == 11]) * 0.7: # TODO: this is an arbitrary threshold, maybe needs tweaking
                    print(f"limiting inputs to {max_bits} bits seems to work (got {len(good_samples)} good samples)")
                    break
                
                max_bits -= 1
            else:
                print("failed to reverse instruction: could not pinpoint segfault address")
                raise Exception("failed to reverse instruction: could not pinpoint segfault address")
            
            # 1.2 we have a limited bitlength as well as a few good samples.
            # For each of these good samples, we try to modify each register with a few random values and see whether the segfaut address changes.
            # Using this, we can find all relevant inputs into the address function (and then reverse it in the next step)
            relevant_inputs = set()
            for good_sample_idx in good_samples[:10]: # limiting to 10 should be fine here
                print(f"checking sample {good_sample_idx}")
                for changing_input in inputs:
                    if changing_input in relevant_inputs: # input is already known to be relevant, so it can be skipped
                        continue
                        
                    reg_constraints = dict()
                    
                    # TODO: detect over-constraint inputs that always have the same value
                    # TODO: maybe do not limit input to max_bits?
                    
                    for i in range(len(inputs)):
                        prototype = s.register_prototypes[inputs[i]]
                        if inputs[i] == changing_input:
                            reg_constraints[inputs[i]] = [
                                self.solver.create_less_constraint(inputs[i], prototype.bitwidth, 1 << max_bits),
                            ]
                        else:
                            reg_constraints[inputs[i]] = [self.solver.create_equal_constraint(inputs[i], prototype.bitwidth, ins[good_sample_idx][i][1])]
                        
                    _, new_outs = self._collect_samples(s, instruction, 10, current_constraints, reg_constraints, inputs, ["error_code", "error_info"]) # 10 samples should be more than enough here!
                        
                    if any(map(lambda x: x[0][1] != 11 or x[1][1] != outs[good_sample_idx][1][1], new_outs)):
                        print(f"found relevant input: {changing_input} {[hex(new_outs[i][1][1]) for i in range(10)]} vs {hex(outs[good_sample_idx][1][1])}")
                        relevant_inputs.add(changing_input)
                        break
            
            print(f"relevant inputs for memory address: {relevant_inputs}")
            
            # 1.3 now we know all relevant input registers, we can collect some samples
            relevant_inputs = list(relevant_inputs)
            reg_constraints = dict()
            for i in inputs:
                prototype = s.register_prototypes[i]
                # if prototype.bitwidth > max_bits: # only limit value if necessary
                reg_constraints[i] = [
                    self.solver.create_less_constraint(i, prototype.bitwidth, 1 << max_bits),
                ]
            ins, outs = self._collect_samples(s, instruction, 2 * self.samples, current_constraints, reg_constraints, relevant_inputs, ["error_code", "error_info"])
            good_samples = [i for i in range(len(ins)) if outs[i][0][1] == 11 and outs[i][1][1] != 0]
            # if len(good_samples) < self.samples:
            #     raise Exception(f"failed to collect enough samples (got {len(good_samples)}, required {self.samples})")
            solver_in = list()
            solver_out = list()
            for i in range(len(good_samples)):
                l = dict()
                for rii, ri in enumerate(relevant_inputs):
                    l[ri] = (s.register_prototypes[ri].bitwidth, ins[good_samples[i]][rii][1])
                solver_in.append(l)
                solver_out.append(outs[good_samples[i]][1][1])
            
            # 1.4 using the collected samples, we can try to use the solver to find an equation that works for all samples
            memory_equation = self.solver.find_equation_with_constants(self.constants_enum_f(instruction), solver_in, solver_out, 64)
            if memory_equation is None:
                raise Exception(f"could not reverse equation for memory address")
            undef_inputs = {
                name: (s.register_prototypes[name].bitwidth, None) for name in relevant_inputs
            }
            print(f"equation for memory address: {memory_equation.to_str(undef_inputs)} (alignment: TODO)")
            
            # TODO: what if pc and mem_addr hit the same memory range?
            
            # 1.5 with the equation, we can now add a fake register for the address
            s.add_register(
                instr_rev.Register(
                    "mem_addr", 64, [
                        self.solver.create_less_constraint("mem_addr", 64, 1 << 46),
                        self.solver.create_greatereq_constraint("mem_addr", 64, 0x40000), # try to not hit the first page as this can be a problem on some architectures + without root. (16KB since some architectures use 16KB pages)
                        # self.solver.create_alignment_constraint("mem_addr", 64, 4) # TODO: find alignment and do not do this
                    ], [], is_fake = True # TODO: 64 bitwidth hardcoding may not be a good idea!
                )
            )

            tree.add_equation("mem_addr", memory_equation) # add this to top tree since we use it in the condition and it looks nicer if it is "defined" before
            
            # TODO: maybe add back?
            tree = tree.create_pseudo()
            treeTrue, treeFalse = tree.split(self.solver.create_less_constraint("mem_addr", 64, 1 << max_bits))
            treeFalse.add_equation(
                "error_code",
                self.solver.equation_constant(
                    64, 11
                )
            )
            treeFalse.add_equation(
                "error_info",
                self.solver.equation_constant(
                    64, 0
                )
            )
            tree = treeTrue
            
            
            
            # TODO: solve for required alignment somehow!
            # TODO: add constraint for alignment requirements!
            # TODO: split tree for required alignment!
            
            # ... and add a constraint to ensure that the fake register value is correct
            current_constraints.append(self.solver.constraint_equal(self.solver.equation_usage(64, "mem_addr"), memory_equation))
            # ... as well as one that forces a canonical address. # TODO: do this one as a simple constraint
            # TODO: in the z3 reverser, relax used variables first (and prefer those that don't have register constraints). This would allow us to sample values more sensibly!
            # current_constraints.append(self.solver.constraint_less(self.solver.equation_constant(64, 0xffff), self.solver.equation_usage(64, "mem_addr"))) # TODO: do not hardcode what canonical means!
            # ... and a fake register for the value. Initially, we start with 1KB, which should be plenty!
            s.add_register(
                instr_rev.Register("mem_val", 8*1024, [], [], is_fake = True)
            )
            # ... and finally a memory mapping! we start with RWX as sanity check
            s.add_mapping(
                instr_rev.MemoryMapping(
                    "mem_addr", instr_rev.PROT_RWX, "mem_val"
                )
            )
            # constraint that memory value always "crashes" TODO: null might not crash on all CPUs
            mem_val_crash = lambda: {"mem_val": [self.solver.create_equal_constraint("mem_val", s.get_register("mem_val")[0], 0)]}
            
            # 1.5.1 as a sanity check, we should no longer be getting segfaults now:
            ins, outs = self._collect_samples(s, instruction, self.samples, current_constraints, mem_val_crash(), ["pc", "mem_addr"], ["error_code", "error_info"])
            
            if any(map(lambda x: x[0][1] == 11, outs)):
                print([f"pc = 0x{x[0][0][1]:x}, addr = 0x{x[0][1][1]:x}, segv = 0x{x[1][1][1]:x}" for x in zip(ins,outs) if x[1][0][1] == 11])
                print("We are still getting segfaults, mapping may be broken!")
            
            # 1.6 now we can find the required protection
            # TODO: weird thing: on loongarch64 everything (X and W) implies R, so mapping something as only X will actually map it as RX.
            # Thus, we need this weird code here, but it's still not nice
            required_prot = 0
            # since X sometimes implies read, we check whether execute is required first
            
            s.memory_mappings[0].protection = instr_rev.PROT_RW
            _, outs = self._collect_samples(s, instruction, self.samples, current_constraints, mem_val_crash(), [], ["error_code"])
            if any(map(lambda x: x[0][1] == 11, outs)):
                required_prot = instr_rev.PROT_R | instr_rev.PROT_X
            else:
                # if mapping does not need to be executable, we just need to check whether read and write is required
                for prot in [instr_rev.PROT_R, instr_rev.PROT_W]:
                    s.memory_mappings[0].protection = prot
                    _, outs = self._collect_samples(s, instruction, self.samples, current_constraints, dict(), [], ["error_code"])
                    if not any(map(lambda x: x[0][1] == 11, outs)):
                        required_prot = prot
                        break
                if required_prot == 0:
                    required_prot = instr_rev.PROT_RW
            # TODO: we only check RW, R, and W (and kind of RX) but not some weird combinations
            print(f"required protection: {required_prot}")
            s.memory_mappings[0].protection = required_prot
            
            
            # 1.6.1 as a sanity check, we should not be getting segfaults with the required protection
            _, outs = self._collect_samples(s, instruction, self.samples, current_constraints, mem_val_crash(), [], ["error_code"])
            if any(map(lambda x: x[0][1] == 11, outs)):
                print("We are still getting segfaults, mapping protection may be broken!")
            
            # 1.7 now that we know the correct protection, we should try to find the size of the mapping
            required_size = 0
            if (required_prot & instr_rev.PROT_RX) == instr_rev.PROT_R:
                # find read size
                
                # first, create samples with completely random data
                output_regs = [
                    f"{reg.name}_out" if reg.is_fake and "error" not in reg.name else reg.name for reg in s.register_prototypes.values() if reg.name not in ["mem_val"]
                ]
                inputs = list(s.register_prototypes.keys())
                ins, outs = self._collect_samples(s, instruction, self.samples, current_constraints, dict(), inputs, output_regs) 
                
                # now for each sample, only use first n bits and check if output changes.
                # gradually increase until output changes no more.
                while True:
                    if required_size > 1024 * 8:
                        raise Exception("memory read size exceeds 1024 bytes. Something is probably broken!")
                    reg_constraints = dict()
                    
                    print(f"Trying memory size {required_size} (bits)")
                    for good_sample_idx in range(len(ins)):
                        for i in range(len(inputs)):
                            prototype = s.register_prototypes[inputs[i]]
                            if inputs[i] == "mem_val":
                                reg_constraints[inputs[i]] = [self.solver.create_equal_constraint(inputs[i], prototype.bitwidth, ins[good_sample_idx][i][1] & ((1 << required_size) - 1))]
                            else:
                                reg_constraints[inputs[i]] = [self.solver.create_equal_constraint(inputs[i], prototype.bitwidth, ins[good_sample_idx][i][1])]
                        _, new_outs = self._collect_samples(s, instruction, 1, current_constraints, reg_constraints, [], output_regs) # 1 sample since we fix all values anyways ...
                        if new_outs[0] != outs[good_sample_idx]:
                            # there are still changes, try with bigger value
                            break
                    else:
                        print(f"Read size {required_size}")
                        break
                    
                    required_size += 8
            if required_prot & instr_rev.PROT_W:
                # first, create samples with completely random data and collect previous value and value after
                ins, outs = self._collect_samples(s, instruction, self.samples, current_constraints, dict(), ["mem_val"], ["mem_val_out"])
                
                # now check how many output bits change, the maximum amount of changing output bits is the written size (probably)
                write_size = int(max(map(int.bit_length, [x[0][0][1] ^ x[1][0][1] for x in zip(ins,outs)])))
                print(f"write size: {write_size}")
                
                # since there might also be a read size, set the required size to the maximum of read and write size. TODO: maybe actually seperate them in case of RW?
                required_size = max(required_size, write_size)
                
            if required_prot & instr_rev.PROT_X:
                required_size = max(required_size, 4 * 8) # TODO: hardcoding 4 bytes here is shitty. We should instead just map an always illegal instruction!
                
            # pseudo constraints for required mapping protection and size
            tree = tree.create_pseudo()
            prot = ["", "r", "w", "", "x"]
            treeTrue, treeFalse = tree.split(FakeConstraint(f"is_mapped(mem_addr to mem_addr + {required_size} as {prot[required_prot & 1]}{prot[required_prot & 2]}{prot[required_prot & 4]})"))
            treeFalse.add_equation(
                "error_code",
                self.solver.equation_constant(
                    64, 11
                )
            )
            treeFalse.add_equation(
                "error_info",
                self.solver.equation_usage(
                    64, "mem_addr"
                )
            )
            tree = treeTrue
            
            print(f"Memory mapping size: {required_size}")
            # add register overwrites existing registers, so we can use this to modify the size!
            s.add_register(
                instr_rev.Register("mem_val", required_size, [], [], is_fake = True)
            )
            if required_prot & instr_rev.PROT_X:
                s.register_prototypes["mem_val"].constraints.extend(mem_val_crash()["mem_val"])
                # s.register_prototypes["mem_addr"].constraints.append(self.solver.create_alignment_constraint("mem_addr", 64, 4)) # TODO: remove if not necessary
            
            # We should be done with memory stuff (except if conditions are a thing ... we need to add that!)
            
        else:
            print("Instruction does not use memory!")
        
        success, _ = self._recursive_reverse(instruction, s, current_constraints, tree, error_info=error_info)
        return success, root, s
        

