import abc
import random
import sys

import time
import stats

import instr_rev_z3
import constants


class Encoding(abc.ABC):
    
    def __init__(self, bitwidth, signed):
        self.bitwidth = bitwidth
        self.signed = signed
    
    @abc.abstractmethod
    def replace(self, instr, new_value):
        pass
        
    @abc.abstractmethod
    def __str__(self):
        pass 
    
    @abc.abstractmethod
    def bits_count(self):
        pass
    
    @abc.abstractmethod
    def get_bits(self):
        pass

    @abc.abstractmethod
    def get_bit_idx(self, bit):
        pass
    
    @abc.abstractmethod
    def iterate_instructions(self, instr):
        pass
        
    def __repr__(self):
        return str(self)

class SimpleEncoding(Encoding):

    def __init__(self, bits, signed, implicit_offset = 0, implicit_shift = 0):
        super().__init__(len(bits), signed)
        self.bits = bits
        self.signed = signed
        self.implicit_offset = implicit_offset
        self.implicit_shift = implicit_shift
    
    def replace(self, instr, new_value):
        new_value >>= self.implicit_shift
        new_value -= self.implicit_offset
        new_instr = instr
        for b in self.bits:
            if (new_value & 1) != ((instr >> b) & 1):
                new_instr ^= 1 << b
            new_value >>= 1
        return new_instr
    
    def get_random_value(self):
        val = random.randrange(1 << len(self.bits))
        if self.signed and val & (1 << (len(self.bits) - 1)):
            val -= (1 << len(self.bits))
        val += self.implicit_offset
        val <<= self.implicit_shift
        return val
    
    def __str__(self):
        # TODO: combine bits that are next to each other
        s = ""
        if self.signed:
            s += "signed "
        s += "|".join(map(str, self.bits))
        if self.implicit_offset:
            s += f" + {self.implicit_offset}"
        if self.implicit_shift:
            s += f" << {self.implicit_shift}"
        return s
    
    def bits_count(self):
        return len(self.bits)
    
    def get_bits(self):
        return self.bits
    
    def get_bit_idx(self, bit):
        if bit not in self.bits:
            return -1
        return self.bits.index(bit)
    
    def iterate_instructions(self, instr):
        for b in self.bits:
            instr &= ~(1 << b)
        for i in 1 << len(self.bits):
            c = 0
            for j in range(len(self.bits)):
                if i & (1 << j):
                    c |= 1 << self.bits[j]
            yield instr | c
    
class ConstantCluster:
    
    def __init__(self, constant, encoding):
        self.constant = constant
        self.encoding = encoding
        
    def __str__(self):
        return f"{self.constant}: {self.encoding}"

class RegisterCluster:
    
    def __init__(self, register, encoding):
        self.register = register
        self.encoding = encoding
    
    def __str__(self):
        return f"{self.register}: {self.encoding}"

class InstructionConstantIterator(abc.ABC):
    
    def __init__(self, instruction):
        self.instruction = instruction
        
    @abc.abstractmethod
    def iterate_constants(self): # Iterable[int]
        pass
    
    @abc.abstractmethod
    def get_constant_encodings(self, constant): # int? -> Iterable[Encoding]
        pass

    @abc.abstractmethod
    def get_register_encodings(self, register): # Register -> Iterable[Encoding]
        pass


class SimpleInstructionConstantIterator(InstructionConstantIterator):

    def __init__(self, instruction, bitwidth, shifts=None, offsets=[0, -1, 1]):
        super().__init__(instruction)
        self.bitwidth = bitwidth
        self.shifts = shifts if shifts is not None else list(range(bitwidth))
        self.offsets = offsets
        self.iter_cache = None
    
    def _build_iter_cache(self):
        self.iter_cache = set()
        for shift in self.shifts:
            for a in range(self.bitwidth):
                for b in range(a, self.bitwidth):
                    if b - a < 3:
                        continue
                    for offset in self.offsets:
                        bw = b - a
                        c = (self.instruction >> a) & ((2 << bw) - 1)
                        c += offset
                        c <<= shift
                        
                        self.iter_cache.add(c)
                        
                        if self.instruction & (1 << b):
                            c = (self.instruction >> a) & ((2 << bw) - 1)
                            c -= (2 << bw)
                            c += offset
                            c <<= shift
                            self.iter_cache.add(c)
                        
    
    def iterate_constants(self, ignored=None):
        if self.iter_cache is None:
            self._build_iter_cache()
        return self.iter_cache
    
    def get_constant_encodings(self, constant):
         for shift in self.shifts:
            for a in range(self.bitwidth):
                for b in range(a, self.bitwidth):
                    if b - a < 1:
                        continue
                    for offset in self.offsets:
                        bw = b - a
                        c = (self.instruction >> a) & ((2 << bw) - 1)
                        c += offset
                        c <<= shift
                        
                        # TODO: also add signed-extended version (if leading 1 bit)
                        if c == constant:
                            yield ConstantCluster(
                                c, SimpleEncoding(list(range(a, b + 1)), False, implicit_offset=offset, implicit_shift = shift)
                            )
                        
                        if self.instruction & (1 << b):
                            c = (self.instruction >> a) & ((2 << bw) - 1)
                            c -= (2 << bw)
                            c += offset
                            c <<= shift
                            if c == constant:
                                yield ConstantCluster(
                                    c, SimpleEncoding(list(range(a, b + 1)), True, implicit_offset=offset, implicit_shift = shift)
                                )
                        elif c == constant:
                            yield ConstantCluster(
                                c, SimpleEncoding(list(range(a, b + 1)), True, implicit_offset=offset, implicit_shift = shift)
                            )
                       
    def get_register_encodings(self, register):
        if len(register.possible_encodings) != 1:
            raise Exception(f"TODO: implement registers with multiple encodings ({register})")
        
        enc_bitwidth, reg_encoding = register.possible_encodings[0]
        for start in range(self.bitwidth - enc_bitwidth + 1):
            c = (self.instruction >> start) & ((1 << enc_bitwidth) - 1)
            if c == reg_encoding:
                yield RegisterCluster(
                    register.name, # TODO: maybe use register here not its name?
                    SimpleEncoding(list(range(start, start + enc_bitwidth)), False)
                )
    
class Clusterer:
    
    def __init__(self, solver, runner, constant_iterator, empty_state_f, samples=25):
        self.solver = solver
        self.runner = runner
        self.constant_iterator = constant_iterator
        self.empty_state_f = empty_state_f
        self.samples = samples
    
    def _collect_samples(self, state_before, instruction, amount, constraints):
        sample_start = time.time()
        input_values = list()
        output_values = list()
        
        state_after  = self.empty_state_f()
        
        c = self.constant_iterator.iterate_constants()
        
        failed_attempts = 0
        
        while len(input_values) < amount and failed_attempts < amount * 5:
            if not self.solver.create_random_sample(c, state_before, constraints, dict(), dict()):
                print(constraints)
                raise Exception("failed to create sample!") # TODO: log more information!
            
            ins = state_before.register_values.copy()
            
            if not self.runner.run_instructions(instruction.to_bytes((constants.INSTR_BITWIDTH + 7) // 8, "little", signed=False), state_before, state_after):
                failed_attempts += 1
                continue
            
            outs = {x[0]: x[1][1] for x in state_after.register_values.items()}
            
            input_values.append(ins)
            output_values.append(outs)
        
        if failed_attempts >= amount * 5:
            raise Exception("too many sample collections failed!")
        
        stats.time_in_sample_cluster += time.time() - sample_start
        
        return input_values, output_values
    
    def _cluster_constants(self, instruction, semantics, initial_state):
        encodings = list()
        
        constants = semantics.get_constants()
        
        fakes = set([r.name for r in initial_state.register_prototypes.values() if r.is_fake])
        
        for value in constants:
            # TODO: maybe we do want to cluster small constants or at least give the option to?
            if value.bit_length() < 4:
                continue
            
            best_encodings = None
            usages = list(semantics.get_constant_instances(value))
            encoding_possibilities = list(self.constant_iterator.get_constant_encodings(value))
            
            print(f"clustering {value}: found {len(usages)} usages and {len(encoding_possibilities)} encodings (-> {((1 << len(usages) - 1)) * len(encoding_possibilities)} iterations)")
            
            if len(encoding_possibilities) == 0:
                continue
            
            for i in range(1, 1 << len(usages)):
                active_usages = [usages[j] for j in range(len(usages)) if (i & (1 << j))]
                for a in range(len(encoding_possibilities)):
                    # if (10 * a) // len(encoding_possibilities) != (100 * (a-1)) // len(encoding_possibilities):
                    #     print(f"{(100 * a) // len(encoding_possibilities)}%")
                    enc_working = True
                    random_vals = [encoding_possibilities[a].encoding.get_random_value() for i in range(16)]
                    
                    # TODO: make sure random_vals do not break any constraints!!!
                    if len(set(random_vals + [value])) == 1:
                        # print("only one possibility")
                        continue # constant can only be replaced with itself. This is a stupid encoding!
                    for other_value in random_vals:
                        new_semantics = semantics
                        for usage in active_usages:
                            new_semantics = usage(new_semantics, other_value) # TODO: check if this is a fixed int or just an int (probably int is correct though)
                        replaced_instruction = encoding_possibilities[a].encoding.replace(instruction, other_value)
                        
                        # TODO: hotfix: non-output fake registers actually must have the value before execution -> need constraints
                        outs = new_semantics.get_output_registers()
                        constraints = list()
                        for o in outs:
                            if "_out" in o or o not in fakes:
                                continue
                            constraints.append(self._compute_constraints(new_semantics, o))
                        try:
                           ins, outs = self._collect_samples(initial_state, replaced_instruction, self.samples, constraints) # TODO: change to self.samples maybe?
                        except:
                            # impossible constraint
                            print("impossible constraint / sample collection failed")
                            enc_working = False
                            break
                        
                        for j, inp in enumerate(ins):
                            emulated_outs = new_semantics.evaluate(inp.copy(), fakes)
                            real_outs = outs[j]
                            for name, emu_value in emulated_outs.items():
                                tmp_sem = new_semantics.get_equation(inp.copy(), name).top_expression
                                if isinstance(tmp_sem, instr_rev_z3.Z3NonDeterministicExpression):
                                    if not tmp_sem.check([o[name] for o in outs]):
                                        enc_working = False
                                        break
                                elif name in real_outs and emu_value != real_outs[name]:
                                    print(f"wrong output: {name}: 0x{emu_value:x} -> 0x{real_outs[name]:x}  ({j}) ({encoding_possibilities[a].encoding})")
                                    enc_working = False
                                    break
                            if not enc_working:
                                break
                            for name, in_value in inp.items():
                                if name in real_outs and real_outs[name] != in_value[1] and name not in emulated_outs and "error" not in name:
                                    print(f"unexpected changing output: {name}: 0x{in_value[1]:x} -> 0x{real_outs[name]:x}  ({j}) ({encoding_possibilities[a].encoding})")
                                    # TODO: ignore fake registers!
                                    enc_working = False
                                    break
                        if not enc_working:
                            break
                    else:
                        print(f"found new encoding {encoding_possibilities[a]} for {value}")
                        if best_encodings is None or best_encodings[0][1].encoding.bits_count() <= encoding_possibilities[a].encoding.bits_count():
                            if best_encodings is None or best_encodings[0][1].encoding.bits_count() < encoding_possibilities[a].encoding.bits_count():
                                best_encodings = list()
                            best_encodings.append(
                                (value, encoding_possibilities[a], active_usages)
                            )
            if best_encodings is not None:
                if len(best_encodings) != 1:
                    print(f"[WARNING] found multiple possible encodings for constant {value}")
                    print("just using first one!")
                    best_encodings = [best_encodings[0]]
                encodings.extend(best_encodings)
        # TODO: filter encodings that were found -> use the biggest one!
        return encodings
                     
    def _compute_constraints(self, semantics, reg):
        
        constraint = None
        
        if reg in semantics.equations:
            equation = semantics.equations[reg]
            constraint = instr_rev_z3.Z3EqualsConstraint(instr_rev_z3.Z3UsageExpression(equation.get_bitwidth(), reg), equation.top_expression) # TODO: ensure this is the correct bitwidth!
        
        if semantics.condition is not None:
            
            if semantics.ifTrue is not None:
                c = self._compute_constraints(semantics.ifTrue, reg)
                if c is not None:
                    c = instr_rev_z3.Z3AndConstraint(c, semantics.condition)
                if constraint is None:
                    constraint = c
                elif c is not None:
                    raise Exception(f"duplicate output {reg} on level and parent level!")
            
            if semantics.ifTrue is not None:
                c = self._compute_constraints(semantics.ifFalse, reg)
                if c is not None:
                    c = instr_rev_z3.Z3AndConstraint(c, instr_rev_z3.Z3NotConstraint(semantics.condition))
                if constraint is None:
                    constraint = c
                elif c is not None:
                    constraint = instr_rev_z3.Z3OrConstraint(constraint, c)
            
        elif semantics.ifTrue is not None:
            c = self._compute_constraints(semantics.ifTrue, reg)
            if constraint is None:
                return c
            elif c is not None:
                raise Exception(f"duplicate output {reg} on same level!")

        return constraint
    
    def _cluster_registers(self, instruction, semantics, initial_state):
        encodings = list()
        
        registers = semantics.get_registers()
        
        fakes = set([r.name for r in initial_state.register_prototypes.values() if r.is_fake])
                
        for reg in registers:
            usages = list(semantics.get_register_instances(reg))
            if reg not in initial_state.register_prototypes:
                continue # error_info and error_code do not have prototypes? we could just add them?
            reg_prototype = initial_state.register_prototypes[reg]
            if len(reg_prototype.possible_encodings) == 0:
                continue # register cannot be encoded --> it is some fake thingy
            possible_replacements = {
                r.name: r.possible_encodings for r in initial_state.register_prototypes.values() if r.encoding_group == reg_prototype.encoding_group and r is not reg_prototype
            }
            
            if len(possible_replacements) == 0:
                # no other registers that could replace the register ... cannot cluster this one
                continue
            encoding_possibilities = list(self.constant_iterator.get_register_encodings(reg_prototype))
            
            print(f"clustering {reg}: found {len(usages)} usages and {len(encoding_possibilities)} encodings (-> {((1 << len(usages) - 1)) * len(encoding_possibilities)} iterations)")
            
                
            for i in range(1, 1 << len(usages)):
                active_usages = [usages[j] for j in range(len(usages)) if (i & (1 << j))]
                for a in range(len(encoding_possibilities)):
                    # we always try with maximum of 8 different other registers
                    for other_reg, other_encodings in random.sample(list(possible_replacements.items()), min(len(possible_replacements), 8)):
                        reg_working = True
                        new_semantics = semantics
                        
                        # print(f"replacing {reg} with {other_reg}")
                        
                        for usage in active_usages:
                            new_semantics = usage(new_semantics, other_reg)
                        
                        # print(new_semantics.to_str())
                            
                        if len(other_encodings) != 1:
                            raise Exception(f"TODO: implement multiple encodings for a single register!")
                        replaced_instruction = encoding_possibilities[a].encoding.replace(instruction, other_encodings[0][1])
                        
                        # TODO: hotfix: non-output fake registers actually must have the value before execution -> need constraints
                        outs = new_semantics.get_output_registers()
                        constraints = list()
                        for o in outs:
                            if "_out" in o or o not in fakes:
                                continue
                            constraints.append(self._compute_constraints(new_semantics, o))
                        
                        try:
                            ins, outs = self._collect_samples(initial_state, replaced_instruction, self.samples, constraints)
                        except:
                            # impossible constraints
                            reg_working = False
                            break
                        
                        for i, inp in enumerate(ins):
                            emulated_outs = new_semantics.evaluate(inp.copy(), fakes)
                            real_outs = outs[i]
                            for name, value in emulated_outs.items():
                                tmp_sem = new_semantics.get_equation(inp.copy(), name).top_expression
                                if isinstance(tmp_sem, instr_rev_z3.Z3NonDeterministicExpression):
                                    if not tmp_sem.check([o[name] for o in outs]):
                                        reg_working = False
                                        break
                                elif name in real_outs and value != real_outs[name]:
                                        print(f"wrong value for {name}: 0x{value:x} vs. 0x{real_outs[name]:x} ({i}) ({encoding_possibilities[a].encoding})")
                                        reg_working = False
                                        break
                            if not reg_working:
                                break   
                            for name, value in inp.items():
                                if name in real_outs and real_outs[name] != value[1] and name not in emulated_outs and "error" not in name:
                                    print(f"unexpected changing {name}: 0x{value[1]:x} vs. 0x{real_outs[name]:x}")
                                    # TODO: ignore fake registers!
                                    reg_working = False
                                    break
                        if not reg_working:
                            break
                    else:
                        # TODO: new encoding was found (encoding_possibilities[a] encodes reg), add it to list!
                        print(f"found new encoding {encoding_possibilities[a]} for {reg}")
                        encodings.append(
                            (reg, encoding_possibilities[a], active_usages, reg_prototype.encoding_group)
                        )
                    # TODO: if necessary, also try with encodings at other positions. However, unless two operands refer to the same register, this is not needed.
                    # (and if two operands refer to the same register, we can also just fail clustering because at some point we will hit the same encoding but with not the same reg multiple times)
        return encodings
    
    def _cluster_alias(self, instruction, semantics, initial_state):
        
        # TODO: maybe also vary other encodings (e.g., registers) that we found before since this may otherwise find invalid stuff in some cases!
        # TODO: return value of this method is not that nice
        
        flippy_bits = []
        
        fakes = set([r.name for r in initial_state.register_prototypes.values() if r.is_fake])
        
        for i in range(constants.INSTR_BITWIDTH):
            enc_working = True
            replaced_instruction = instruction ^ (1 << i)
            new_semantics = semantics
            
            # TODO: hotfix: non-output fake registers actually must have the value before execution -> need constraints
            outs = semantics.get_output_registers()
            constraints = list()
            for o in outs:
                if "_out" in o or o not in fakes:
                    continue
                constraints.append(self._compute_constraints(new_semantics, o))
            
            try:
                ins, outs = self._collect_samples(initial_state, replaced_instruction, self.samples, constraints) # TODO: change to self.samples maybe?
            except:
                # impossible constraint
                print("impossible constraint / sample collection failed")
                enc_working = False
                break
                        
            for j, inp in enumerate(ins):
                emulated_outs = new_semantics.evaluate(inp.copy(), fakes)
                real_outs = outs[j]
                for name, emu_value in emulated_outs.items():
                    tmp_sem = new_semantics.get_equation(inp.copy(), name).top_expression
                    if isinstance(tmp_sem, instr_rev_z3.Z3NonDeterministicExpression):
                        if not tmp_sem.check([o[name] for o in outs]):
                            enc_working = False
                            break
                    elif name in real_outs and emu_value != real_outs[name]:
                        # print(f"wrong output: {name}: 0x{emu_value:x} -> 0x{real_outs[name]:x}  ({j}) ({encoding_possibilities[a].encoding})")
                        enc_working = False
                        break
                if not enc_working:
                    break
                
                for name, in_value in inp.items():
                    if name in real_outs and real_outs[name] != in_value[1] and name not in emulated_outs and "error" not in name:
                        # print(f"unexpected changing output: {name}: 0x{in_value[1]:x} -> 0x{real_outs[name]:x}  ({j}) ({encoding_possibilities[a].encoding})")
                        # TODO: ignore fake registers!
                        enc_working = False
                        break
                if not enc_working:
                    break
            else:
                flippy_bits.append(i)
                print("found flippy bit:", i)
                # print(f"found new encoding {encoding_possibilities[a]} for {value}")
        
        # initial check: are flippy bits all ignored?
        enc_working = True
        for _ in range(500):
            random_flips = random.randrange(1 << len(flippy_bits))
            
            replaced_instruction = instruction ^ sum([1 << b for i,b in enumerate(flippy_bits) if random_flips & (1 << i)])
            new_semantics = semantics
            
            # TODO: hotfix: non-output fake registers actually must have the value before execution -> need constraints
            outs = semantics.get_output_registers()
            constraints = list()
            for o in outs:
                if "_out" in o or o not in fakes:
                    continue
                constraints.append(self._compute_constraints(new_semantics, o))
            
            try:
                ins, outs = self._collect_samples(initial_state, replaced_instruction, self.samples, constraints) # TODO: change to self.samples maybe?
            except:
                # impossible constraint
                # print("impossible constraint / sample collection failed")
                enc_working = False
                break
                        
            for j, inp in enumerate(ins):
                emulated_outs = new_semantics.evaluate(inp.copy(), fakes)
                real_outs = outs[j]
                for name, emu_value in emulated_outs.items():
                    tmp_sem = new_semantics.get_equation(inp.copy(), name).top_expression
                    if isinstance(tmp_sem, instr_rev_z3.Z3NonDeterministicExpression):
                        if not tmp_sem.check([o[name] for o in outs]):
                            enc_working = False
                            break
                    elif name in real_outs and emu_value != real_outs[name]:
                        # print(f"wrong output: {name}: 0x{emu_value:x} -> 0x{real_outs[name]:x}  ({j}) ({encoding_possibilities[a].encoding})")
                        enc_working = False
                        break
                if not enc_working:
                    break
                
                for name, in_value in inp.items():
                    if name in real_outs and real_outs[name] != in_value[1] and name not in emulated_outs and "error" not in name:
                        # print(f"unexpected changing output: {name}: 0x{in_value[1]:x} -> 0x{real_outs[name]:x}  ({j}) ({encoding_possibilities[a].encoding})")
                        # TODO: ignore fake registers!
                        enc_working = False
                        break
                if not enc_working:
                    break
            if not enc_working:
                break
        else:
            print("all flippy :)", flippy_bits)
            return flippy_bits
        
        # TODO: more sophisticated check...
        
        print("not really all flippy :(", flippy_bits)
        return flippy_bits # TODO
            
    
    def cluster(self, instruction, semantics, initial_state):
        
        enc_semantics = semantics
        encoding_names = "abcdefghijklmnopqrstuvwxyz"
        
        constant_encodings = self._cluster_constants(instruction, semantics, initial_state)
        register_encodings = self._cluster_registers(instruction, semantics, initial_state)
        # flippy_bits = self._cluster_alias(instruction, semantics, initial_state)
        
        
        enc_bits = []
        for enc in constant_encodings + register_encodings:
            enc_bits += enc[1].encoding.get_bits()
        
        for i,x in enumerate(constant_encodings):
            value, encoding, active_usages = x
            enc_name = f"const_{encoding_names[i]}"
            
            # TODO: this should work as equations do a lazy copy (only copy nodes closer to the root as the replaced node and the replaced node), but there might be cases it is broken?
            # Also, we shouldn't assume lazy copying for other implementations...
            for usage in active_usages:
                enc_semantics = usage(enc_semantics, enc_name)
            
            print(f"{value} -> {enc_name}: {encoding}", file=sys.stderr)
        
        for i,x in enumerate(register_encodings):
            reg_name, encoding, active_usages, encoding_group = x
            enc_name = f"{encoding_group}_{encoding_names[i]}"
            
            # TODO: this should work as equations do a lazy copy (only copy nodes closer to the root as the replaced node and the replaced node), but there might be cases it is broken?
            # Also, we shouldn't assume lazy copying for other implementations...
            for usage in active_usages:
                enc_semantics = usage(enc_semantics, enc_name)
            
            print(f"{reg_name} -> {enc_name}: {encoding}", file=sys.stderr)
        
        print("", file=sys.stderr)
        print("variable bits:", enc_bits, file=sys.stderr)
        print("     ", " ".join(f"{x:3}" for x in range(constants.INSTR_BITWIDTH - 1, -1, -1)), file=sys.stderr)
        
        print("     ", end="", file=sys.stderr)
        for x in range(constants.INSTR_BITWIDTH - 1, -1, -1):
            if any(map(lambda enc: enc[1].encoding.get_bit_idx(x) != -1, constant_encodings + register_encodings)):
                print("    ", end = "", file=sys.stderr)
            else:
                print(f" {(instruction >> x) & 1:3}", end = "", file=sys.stderr)
        print("", file=sys.stderr)
                
        for j,cenc in enumerate(constant_encodings):
            print(f"val_{encoding_names[j]}", " ".join(f"{str(cenc[1].encoding.get_bit_idx(idx)) if cenc[1].encoding.get_bit_idx(idx) != -1 else '' :>3}" for idx in range(constants.INSTR_BITWIDTH - 1, -1, -1)), file=sys.stderr)
        for j,renc in enumerate(register_encodings):
            print(f"reg_{encoding_names[j]}", " ".join(f"{str(renc[1].encoding.get_bit_idx(idx)) if renc[1].encoding.get_bit_idx(idx) != -1 else '' :>3}" for idx in range(constants.INSTR_BITWIDTH - 1, -1, -1)), file=sys.stderr)
        
        print(enc_semantics.to_str(), file=sys.stderr)
        
