import instr_rev
import backend.generic_backend
import clusterer

REGISTER_NAMES = ["$zero", "$ra", "$tp", "$sp"] + [f"$a{i}" for i in range(8)] + [f"$t{i}" for i in range(9)] + ["$u0", "$fp"] + [f"$s{i}" for i in range(9)]

class Loongarch64ArchitecturalState(backend.generic_backend.GenericArchitecturalState):

    def __init__(self, solver, **kwargs):
        global REGISTER_NAMES
        super().__init__(
            0x4000, # 16 KB page size
            [
                # registers r0 - r31 each have 64 bits and no constraints (except register 0, see below). They are all encoded using 5 bits as numbers 0-31
                instr_rev.Register(REGISTER_NAMES[i],  64, [], [(5, i)], encoding_group="gpr") for i in range(32) # general purpose registers
            ] + [
                # vector registers x0 - x31 each have 256 bits and no constraints
                instr_rev.Register(f"x{i}", 256, [], [(5, i)], encoding_group="vector") for i in range(32) # vector registers
            ] + [
                # the pc register must be aligned to 4 bytes, and should be less than 2 ** 46 (to be canonical). also, it should not be on the first virtual page
                instr_rev.Register("pc", 64, [
                    solver.create_alignment_constraint("pc", 64, 4), 
                    solver.create_less_constraint("pc", 64, 1 << 46),
                    solver.create_greatereq_constraint("pc", 64, 0x4008)
                ], [], encoding_group="pc") # pc
            ]
        )
        # register r0 must always be 0
        self.register_prototypes["$zero"].constraints.append(
            solver.create_equal_constraint("$zero", 64, 0)
        )
        # give zero register its own encoding group so the clusterer doesn't try to cluster it with normal gprs and fail
        self.register_prototypes["$zero"].encoding_group = "zero"

class Loongarch64SocketRunner(backend.generic_backend.GenericSocketRunner):
    
    def __init__(self, command, is_remote = False, ip = "0.0.0.0", port = 0, max_retries = 1):
        super().__init__(
            command,
            pc_reg_name = "pc",
            prefix_instructions = b"\x85\xa0\xc0\x28\x84\x80\xc0\x28",
            trapping_instruction = b"\x00\x00\x00\x00",
            is_remote = is_remote,
            ip = ip,
            port = port,
            max_retries = max_retries
        )
        
class Loongarch64ConstantIterator(clusterer.SimpleInstructionConstantIterator):
    
    def __init__(self, instr, bitwidth, shifts=None, offsets=[0, -1, 1]):
        super().__init__(instr, bitwidth, shifts=shifts, offsets=offsets)

    def _build_iter_cache(self):
        super()._build_iter_cache()
        x = self.instruction
        for i in self.shifts:
            for c, bitlen in [
                    (((x >> 10) & 0xffff) | ((x & 0b11111) << 16), 21),
                    (((x >> 10) & 0xffff) | ((x & 0x3ff) << 16), 26)
                ]:
                for o in self.offsets:
                    d = c + o
                    self.iter_cache.add(d << i)
                    if d & (1 << (bitlen - 1)):
                        self.iter_cache.add((d - (1 << bitlen)) << i)
    
    def get_constant_encodings(self, constant):
        for c in super().get_constant_encodings(constant):
            yield c
        x = self.instruction
        for i in self.shifts:
            for c, bitlen, bits in [
                    (((x >> 10) & 0xffff) | ((x & 0b11111) << 16), 21, [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 0, 1, 2, 3, 4]),
                    (((x >> 10) & 0xffff) | ((x & 0x3ff) << 16), 26, [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
                ]:
                for o in self.offsets:
                    d = c + o
                    const = d << i
                    if const == constant:
                        yield clusterer.ConstantCluster(
                            d << i, clusterer.SimpleEncoding(
                                bits
                            , False, implicit_offset = c - d, implicit_shift = i)
                        )
                    if d & (1 << (bitlen - 1)):
                        const = (d - (1 << bitlen)) << i
                    if const == constant:
                        yield clusterer.ConstantCluster(
                            const, clusterer.SimpleEncoding(
                                bits
                                , True, implicit_offset = c - d, implicit_shift = i
                            )
                        )
