import instr_rev
import backend.generic_backend
import clusterer

REG_NAMES = [
    "zero", "ra", "sp", "gp", "tp", "t0", "t1", "t2", "fp", "s1", "a0", "a1", "a2", "a3", "a4", "a5",
    "a6", "a7", "s2", "s3", "s4", "s5", "s6", "s7", "s8", "s9", "s10", "s11", "t3", "t4", "t5", "t6"
]

class Riscv64ArchitecturalState(backend.generic_backend.GenericArchitecturalState):

    def __init__(self, solver, feature_set={}):
        global REGISTER_NAMES
        super().__init__(0x1000, [
            instr_rev.Register("pc", 64, [solver.create_alignment_constraint("pc", 64, 4), solver.create_less_constraint("pc", 64, 1 << 46), solver.create_greatereq_constraint("pc", 64, 0x4008)], [], encoding_group="pc") # pc
        ] + [
            # registers x1 - x31 each have 64 bits and no constraints. They are all encoded using 5 bits as numbers 1-31
            # register zero is always zero
            instr_rev.Register(REG_NAMES[i], 64, [], [(5, i)], encoding_group="gpr") for i in range(32) # general purpose registers
        ])
        # register zero must always be 0
        self.register_prototypes["zero"].constraints.append(
            solver.create_equal_constraint("zero", 64, 0)
        )
        self.register_prototypes["zero"].encoding_group = "zero" # let's not use this in the clusterer!

    def to_bytes(self) -> bytearray:
        # dirty fix because zero register is not part of the C code stuff
        try:
            self.register_prototypes["zero"].is_fake = True
            res = super().to_bytes()
        finally:
            self.register_prototypes["zero"].is_fake = False
        return res
        
    
    def from_bytes(self, data: bytes|bytearray, mapping_names = list()):
        # dirty fix because zero register is not part of the C code stuff
        try:
            self.register_prototypes["zero"].is_fake = True
            super().from_bytes(data, mapping_names=mapping_names)
        finally:
            self.register_prototypes["zero"].is_fake = False
        self.set_register("zero", 0)
        

class Riscv64SocketRunner(backend.generic_backend.GenericSocketRunner):
    
    def __init__(self, command, is_remote = False, ip = "0.0.0.0", port = 0, max_retries = 1):
        super().__init__(
            command,
            pc_reg_name = "pc",
            prefix_instructions = b"\xa8\x67\xbc\x7b",
            trapping_instruction = b"\x00\x00\x00\x00",
            is_remote = is_remote,
            ip = ip,
            port = port,
            max_retries = max_retries
        )

class Riscv64ConstantIterator(clusterer.SimpleInstructionConstantIterator):
    
    def __init__(self, instr, bitwidth, shifts=None, offsets=[0, -1, 1]):
        super().__init__(instr, bitwidth, shifts=shifts, offsets=offsets)

    def _build_iter_cache(self):
        super()._build_iter_cache()
        x = self.instruction
        for i in self.shifts:
            for c, bitlen in [
                    (((x >> 7) & 0b11111) | (((x >> 25) & 0b1111111) << 5), 12),
                    ((((x >> 31) & 1) << 19) | ((x >> 21) & 0b1111111111) | (((x >> 20) & 1) << 10) | (((x >> 12) & 0b11111111) << 11), 20),
                    ((((x >> 31) & 1) << 11) | (((x >> 25) & 0b111111) << 4) | ((x >> 8) & 0b1111) | (((x >> 7) & 1) << 10), 12)
                ]:
                for o in self.offsets:
                    d = c + o
                    self.iter_cache.add(d << i)
                    if d & (1 << (bitlen - 1)):
                        self.iter_cache.add((d - (1 << bitlen)) << i)
    
    def get_constant_encodings(self, constant):
        for c in super().get_constant_encodings(constant):
            yield c
        x = self.instruction
        for i in self.shifts:
            for c, bitlen, bits in [
                    (((x >> 7) & 0b11111) | (((x >> 25) & 0b1111111) << 5), 12, [7, 8, 9, 10, 11, 25, 26, 27, 28, 29, 30, 31]),
                    ((((x >> 31) & 1) << 19) | ((x >> 21) & 0b1111111111) | (((x >> 20) & 1) << 10) | (((x >> 12) & 0b11111111) << 11), 20, [21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 20, 12, 13, 14, 15, 16, 17, 18, 19, 31]),
                    ((((x >> 31) & 1) << 11) | (((x >> 25) & 0b111111) << 4) | ((x >> 8) & 0b1111) | (((x >> 7) & 1) << 10), 12, [8, 9, 10, 11, 25, 26, 27, 28, 29, 30, 7, 31])
                ]:
                for o in self.offsets:
                    d = c + o
                    const = d << i
                    if const == constant:
                        yield clusterer.ConstantCluster(
                            d << i, clusterer.SimpleEncoding(
                                bits
                            , False, implicit_offset = c - d, implicit_shift = i)
                        )
                    if d & (1 << (bitlen - 1)):
                        const = (d - (1 << bitlen)) << i
                    if const == constant:
                        yield clusterer.ConstantCluster(
                            const, clusterer.SimpleEncoding(
                                bits
                                , True, implicit_offset = c - d, implicit_shift = i
                            )
                        )
