import instr_rev
import backend.generic_backend

class Aarch64ArchitecturalState(backend.generic_backend.GenericArchitecturalState):

    def __init__(self, solver, feature_set=set()):
        regs=[
            instr_rev.Register("pc", 64, [solver.create_alignment_constraint("pc", 64, 4), solver.create_less_constraint("pc", 64, 1 << 39), solver.create_greatereq_constraint("pc", 64, 0x100000)], [], encoding_group="pc") # pc
        ] + [
            # registers x0 - x31 each have 64 bits and no constraints. They are all encoded using 5 bits as numbers 0-31
            instr_rev.Register(f"x{i}", 64, [], [(5, i)], encoding_group=("gpr" if i != 31 else "special")) for i in range(32) # general purpose registers
        ] + [
            # pstate
            # TODO: probably remove?
            instr_rev.Register(f"pstate", 64, [], [], encoding_group="pstate")
        ]
        if "VECTOR" in feature_set:
            regs += [
                instr_rev.Register(f"fpsr", 64, [], [], encoding_group="fpsr")
            ] + [
                instr_rev.Register(f"v{i}", 16*8, [], [(5, i)], encoding_group="vec") for i in range(32)
            ]
        if "AMX" in feature_set:
            regs += [
                instr_rev.Register(f"amxx{i}", 64*8, [], [], encoding_group="amx") for i in range(8)
            ] + [
                instr_rev.Register(f"amxy{i}", 64*8, [], [], encoding_group="amx") for i in range(8)
            ] + [
                instr_rev.Register(f"amxz{i}", 64*8, [], [], encoding_group="amx") for i in range(64)
            ]
        super().__init__(0x1000, regs) # TODO: page size can also be different than 4KB (also some things hardcode 16KB, so this might cause issues!)
        

class Aarch64SocketRunner(backend.generic_backend.GenericSocketRunner):

    def __init__(self, command, is_remote = False, ip = "0.0.0.0", port = 0, max_retries = 1):
        super().__init__(
            command,
            pc_reg_name = "pc",
            prefix_instructions = b"\x20\x00\x40\xf9\x21\x04\x40\xf9",
            trapping_instruction = b"\x00\x00\x00\x00",
            is_remote = is_remote,
            ip = ip,
            port = port,
            max_retries = max_retries
        )

    

    def run(self, before: instr_rev.ArchitecturalState, after: instr_rev.ArchitecturalState, retry = 0):
        if super().run(before, after, retry = retry):
            # TODO: dirty fix because pstate is annoying: just set it to 0 ...
            after.set_register("pstate", 0)
            return True
        return False
    
    
