# abstract classes
from abc import ABC, abstractmethod

# type annotations
from typing import List, Dict, Tuple, Any
from numbers import Number
from enum import Enum

Expression = Any

FixedInteger = Tuple[int,Number]

class Function(ABC):
    """
    Represents a mathematical operation on some inputs (named fixed-size integers)
    that produces an integer or a boolean.
    """

    @abstractmethod
    def evaluate(self, inputs: Dict[str, FixedInteger]) -> int|bool:
        """
        Evaluates the function on given input variables producing an output number or boolean
        """
        pass

    @abstractmethod
    def replace(self, target: Any, replacement: Any) -> "Function":
        """
        Replaces a target expression with replacement expression in a returned copy of this function
        """
        pass

    @abstractmethod
    def get_constants(self) -> List[Expression]:
        """
        Returns a list of all constants used in the function.
        """
        pass

    @abstractmethod
    def get_usages(self, name: str) -> List[Expression]:
        """
        Returns a list of all usages of an input variable in the function.
        """
        pass

    @abstractmethod
    def create_constant(self, bitwidth: int, value: int) -> Expression:
        """
        Creates a constant expression.
        The created constant is only to be used with this function instance.
        """
        pass

    @abstractmethod
    def create_usage(self, bitwidth: int, name: str) -> Expression:
        """
        Creates an input usage.
        This input usage is only to be used with this function instance.
        """
        pass


class Equation(Function):

    @abstractmethod
    def evaluate(self, inputs: Dict[str, FixedInteger]) -> int:
        """
        Evaluates the equation on given input variables producing an output number
        """
        pass

    @abstractmethod
    def replace(self, target: Expression, replacement: Expression) -> "Equation":
        """
        Replaces a target expression with replacement expression in a returned copy of this equation
        """
        pass


class Constraint(Function):
    """
    A Function that can only result in True or False.
    """

    @abstractmethod
    def evaluate(self, inputs: Dict[str, Tuple[int, Number]]) -> bool:
        """
        Evaluates the constraint on given input variables producing True or False
        """
        pass

    @abstractmethod
    def replace(self, target: Any, replacement: Any) -> "Constraint":
        """
        Replaces a target expression with replacement expression in a returned copy of this constraint
        """
        pass


class Solver(ABC):

    @abstractmethod
    def find_equation(self, input_samples: List[Dict[str, Tuple[int, Number]]], output_samples: List[Number]) -> Equation|None:
        """
        Given input and output samples, this method should find an equation that maps an input sample to the corresponding output sample correctly.
        With a bit of "luck", etc., there is a good chance that this equation actually represents the instruction semantics.
        In some cases, the solver may not be successful and returns None.
        """
        pass
    
    @abstractmethod
    def find_equation_with_constants(self, possible_constants, input_samples: List[Dict[str, Tuple[int, Number|None]]], output_samples: List[int]) -> Equation|None:
       """
       given input and output samples as well as a list of constants, this method should find an equation that maps all input samples to their corresponding output sampel correctly.
       The equation should not use any constants that are not inside the provided list.
       With a bit of "luck", etc., there is a good chance that this equation actually represents the instruction semantics.
       In some cases, the solver may not be successful and returns None.
       """
       pass
    
    # TODO: this method is deprecated and should not be used!
    @abstractmethod
    def solve_partial_input(self, constraints: List[Constraint], partial_input: Dict[str, Tuple[int, Number|None]]) -> Dict[str, Tuple[int, Number]]|None:
        """
        given an equation and partial input, this method tries to find a input that fulfills the constraints.
        Only constraints created with this solver are allowed.
        Note that the partial input is applied to all constraints, meaning the same input variable name must be assigned the same value across all constraints.
        If unable, this function will return None.
        """
        pass
    
    @abstractmethod
    def constraint_equal(self, a: Equation, b: Equation) -> Constraint:
        """
        given two equations a and b, returns a constraint that is only fullfilled if both equations evaluate to the same value.
        """
        pass
    
    @abstractmethod
    def constraint_less(self, a: Equation, b: Equation) -> Constraint:
        """
        given two equations a and b, returns a constraint that is only fullfilled if a evaluates to a value smaller than the value b evaluates to
        """
        pass
    
    @abstractmethod
    def equation_usage(self, bitwidth: int, name: str) -> Equation:
        """
        given a bitwidth and a variable name, returns an equation that represents the usage of this variable
        """
        pass
    
    @abstractmethod
    def equation_constant(self, bitwidth: int, value: int) -> Equation:
        """
        given a bitwidth and a constant value, returns an equation that represents this constant
        """
        pass
    
    @abstractmethod
    def create_alignment_constraint(self, register: str, bitwidth: int, alignment: int) -> Constraint:
        """
        Given a register name and alignment, creates a constraint that is only fullfilled if the register is aligned correctly.
        Providing such specific constraints instead of a generic interface allows for optimizations when sampling registers and also abstracts the underlying implementation.
        Still, it should be sufficient to build required constraints.
        """
        pass
    
    @abstractmethod
    def create_less_constraint(self, register: str, bitwidth: int, value: Number) -> Constraint:
        """
        Given a register name and value, creates a constraint that is only fullfilled if the register value is smaller than the provided value.
        """
        pass
    
    @abstractmethod
    def create_equal_constraint(self, register: str, bitwidth: int, value: Number) -> Constraint:
        """
        Given a register name and value, creates a constraint that is only fullfilled if the register has exactly the provided value.
        """
        pass
    
    @abstractmethod
    def create_random_sample(self, state: "ArchitecturalState", constraints: List[Constraint], register_constraints: Dict[str, List[Constraint]],  partial_sample: Dict[str, Tuple[int, Number]]) -> bool:
        """
        Given an architectural state, list of global constraints, list of register constraints, and partial sample, this method tries to fill the given state with random register and memory mapping values such that all constraints are fullfilled.
        Constraints on registers should be supplied as register constraints to enable further optimizations when sampling random values.
        Otherwise, the quality of sampled values and the sampling speed might drastically be decreased.
        Global constraintaints are usually constraints that are infered during instruction reversing.
        """
        pass

class Register:
    
    def __init__(self, name: str, bitwidth: int, constraints: List[Constraint], possible_encodings: List[FixedInteger], encoding_group: str="fake", default_value=0, is_fake=False):
        self.name = name
        self.bitwidth = bitwidth
        self.constraints = constraints
        self.possible_encodings = possible_encodings
        self.encoding_group = encoding_group
        self.default_value = default_value
        self.is_fake = is_fake

    def get_bytelen(self):
        return (self.bitwidth + 7) // 8
    
    def get_bytevalue(self, value):
        return (value & ((1 << self.bitwidth) - 1)).to_bytes(self.get_bytelen(), "little") # TODO: allow big-endian byte order!

    def get_intvalue(self, bytes):
        return int.from_bytes(bytes, "little", signed=False) # TODO: allow big-endian byte order!
        
PROT_R   = 1
PROT_W   = 2
PROT_X   = 4
PROT_RX  = PROT_R | PROT_X
PROT_RW  = PROT_R | PROT_W
PROT_RWX = PROT_R | PROT_W | PROT_X

class MemoryMapping: # TODO: maybe allow virtual and physical address? but then we would need PTEditor to run on weird devices :/
    # address and data are not directly integers but "register" names.
    # This makes solving etc. easier as everything is unified.
    # If the address is not directly stored in a register, "fake" registers can be added and the address may be stored there (also with support for constraints, etc.).

    def __init__(self, address: str, protection, data: str):
        self.address = address
        self.protection = protection
        self.data = data
    
    def __str__(self):
        return f"MemoryMapping(address={self.address}, protection={self.protection}, data={self.data})"

    def __repr__(self):
        return str(self)

class ArchitecturalState:

    # TODO: maybe at some point model access to portions of registers (e.g., eax for 32-bits of rax) instead of adding constraints to registers?

    def __init__(self, register_prototypes: List[Register]):
        self.register_prototypes = {
            reg.name: reg for reg in register_prototypes
        }
        self.register_values = {
            reg.name: (reg.bitwidth, reg.default_value) for reg in register_prototypes
        }
        self.memory_mappings: List[MemoryMapping] = []
        # self.exception = (0, 0) # kind, additional information (dependent on type). For instance, segfault + virtual address

    def add_register(self, reg: Register):
        if not reg.is_fake:
            raise Exception(f"cannot add non-fake register outside constructor!")
        self.register_prototypes[reg.name] = reg
        self.register_values[reg.name] = (reg.bitwidth, reg.default_value)

    def remove_register(self, name: str):
        if name not in self.register_prototypes:
            return False
        assert self.register_prototypes[name].is_fake, "cannot remove non-fake register!"
        del self.register_prototypes[name]
        del self.register_values[name]
        return True

    def set_register(self, name: str, value: int):
        bitwidth, _ = self.register_values[name]
        env = {name: (bitwidth, value)}
        self.register_values[name] = (bitwidth, value)

    def get_register(self, name: str) -> FixedInteger:
        return self.register_values[name]

    def add_mapping(self, mapping: MemoryMapping):
        self.memory_mappings.append(mapping)

    def remove_mapping(self, reg: str):
        rem = -1
        for i in range(len(self.memory_mappings)):
            if self.memory_mappings[i].address == reg:
                rem = i
                break
        else:
            return False
        del self.memory_mappings[rem]
        return True

    def get_mapping(self, reg: str):
        for m in self.memory_mappings:
            if m.address == reg:
                return m
        # implicit return None
    
    def check_constraints(self):
        # TODO: ensure all constraints (on registers) are fullfilled!
        pass
    
    @abstractmethod
    def to_bytes(self) -> bytearray:
        pass
    
    @abstractmethod
    def from_bytes(self, data: bytes|bytearray):
        pass
        
    def __str__(self):
        return f"Registers: [" + ", ".join(map(lambda a: f"{a[0]} = ({a[1][0]}, 0x{a[1][1]:x})", self.register_values.items())) + "]; Mappings: " + str(self.memory_mappings)

class Runner(ABC):

    @abstractmethod
    def run(self, before: ArchitecturalState, after: ArchitecturalState) -> bool:
        pass
