import instr_rev
import stats
import subprocess
import socket

# in favour of not complicating the implementation for no reason, only the socket runner is supported now.

class GenericArchitecturalState(instr_rev.ArchitecturalState):
    
    def __init__(self, page_size, register_prototypes, endianess = "little"):
        super().__init__(register_prototypes)
        self.page_size = page_size
        self.endianess = endianess
        
        # registers for error code (signal) and additional error info (e.g., address for a SIGSEGV)
        self.add_register(
            instr_rev.Register("error_code", 64, [], [], is_fake=True)
        )
        self.add_register(
            instr_rev.Register("error_info", 64, [], [], is_fake=True)
        )
    
    def to_bytes(self) -> bytearray:
        result = bytearray(
            sum([reg.get_bytelen() for reg in self.register_prototypes.values() if not reg.is_fake]) # registers
            # no crash info since this is not serialized
            + 8 # amount of memory mappings as uint64
            + len(self.memory_mappings) * 24 # memory mapping headers: virtual address, size, protection
            + sum(map(lambda x: self.register_prototypes[x.data].get_bytelen(), self.memory_mappings)) # memory mapping data
        )
        
        # registers
        offset = 0
        for reg in self.register_prototypes.values():
            if reg.is_fake:
                # do not serialize "fake" registers
                continue
            byteval = reg.get_bytevalue(self.get_register(reg.name)[1])
            result[offset:offset + len(byteval)] = byteval
            offset += reg.get_bytelen()
        
        # crash info is not serialized, so no 16 additional bytes here
        
        # amount of memory mappings
        result[offset:offset+8] = len(self.memory_mappings).to_bytes(8, self.endianess, signed=False)
        offset += 8
        
        # memory mapping headers (address, size, protection)
        # mem_offset = 0
        for memory_mapping in self.memory_mappings:
            data_reg = self.register_prototypes[memory_mapping.data]
            memory_mapping_size = data_reg.get_bytelen()
            
            address_reg = self.register_prototypes[memory_mapping.address]
            
            result[offset:offset+8] = address_reg.get_bytevalue(self.get_register(memory_mapping.address)[1])
            offset += 8
            
            result[offset:offset+8] = memory_mapping_size.to_bytes(8, self.endianess, signed=False)
            offset += 8
            
            result[offset:offset+8] = memory_mapping.protection.to_bytes(8, self.endianess, signed=False)
            offset += 8
            
        # memory mapping data
        for memory_mapping in self.memory_mappings:
            data_reg = self.register_prototypes[memory_mapping.data]
            memory_mapping_size = data_reg.get_bytelen()
            result[offset:offset + memory_mapping_size] = data_reg.get_bytevalue(self.get_register(data_reg.name)[1])
            offset += memory_mapping_size
        
        return result


    def from_bytes(self, data: bytes|bytearray, mapping_names = list()):
        self.memory_mappings = []
        # registers
        offset = 0
        for reg in self.register_prototypes.values():
            if reg.is_fake:
                # ignore "fake" registers
                continue
            self.register_values[reg.name] = (reg.bitwidth, reg.get_intvalue(data[offset:offset + reg.get_bytelen()]))
            offset += reg.get_bytelen()
        
        # crash info
        error_code = int.from_bytes(data[offset:offset+8], self.endianess, signed=True)
        error_info = int.from_bytes(data[offset+8:offset+16], self.endianess, signed=False)
        
        self.add_register(
            instr_rev.Register("error_code", 64, [], [], is_fake=True)
        )
        self.set_register("error_code", error_code)
        self.add_register(
            instr_rev.Register("error_info", 64, [], [], is_fake=True)
        )
        self.set_register("error_info", error_info)
        offset += 16
        
        # amount of memory mappings
        mem_mappings_count = int.from_bytes(data[offset:offset+8], self.endianess, signed=False)
        if mem_mappings_count > 64 or mem_mappings_count < 1:
            raise Exception("Invalid number of mappings")
        offset += 8
        
        # memory mapping headers
        mem_offset = 0
        data_offset = offset + 24 * mem_mappings_count
        for i in range(mem_mappings_count):
            address = int.from_bytes(data[offset:offset+8], self.endianess, signed=False)
            offset += 8
            
            size = int.from_bytes(data[offset:offset+8], self.endianess, signed=False)
            offset += 8
            
            protection = int.from_bytes(data[offset:offset+8], self.endianess, signed=False)
            offset += 8
            
            if len(mapping_names) > i:
                addr_name, val_name = mapping_names[i]
            else:
                addr_name = f"mem_addr_{i}"
                val_name = f"mem_val_{i}"
            
            # TODO: constructing a register like this destroys the endianess!
            self.add_register(
                instr_rev.Register(
                    addr_name, 64, [], [], is_fake=True
                )
            )
            self.set_register(
                addr_name, address
            )
            
            self.add_register(
                instr_rev.Register(
                    val_name, size*8, [], [], is_fake=True
                )
            )
            self.set_register(
                val_name, int.from_bytes(data[data_offset+mem_offset:data_offset+mem_offset+size], self.endianess, signed=False)
            )
            
            mem_offset += size
            
            self.memory_mappings.append(
                instr_rev.MemoryMapping(
                    addr_name, protection, val_name
                )
            )


class GenericSocketRunner(instr_rev.Runner):
    
    def __init__(self, command, pc_reg_name = "pc", prefix_instructions = b"", trapping_instruction = b"\x00\x00\x00\x00", is_remote = False, ip = "0.0.0.0", port = 0, max_retries = 1):
        self.command = command
        self.pc_reg_name = pc_reg_name
        self.prefix_instructions = prefix_instructions
        self.trapping_instruction = trapping_instruction 
        self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self.socket.bind((ip, port))
        self.socket.listen()
        self.port = self.socket.getsockname()[1]
        self.runner_proc = None
        self.connection = None
        self.is_remote = is_remote
        self.max_retries = max_retries
        self.start_runner()
        self.connect_runner()
    
    def kill_runner(self):
        try:
            if self.runner_proc:
                self.runner_proc.kill()
        except:
            pass
    
    def get_commandline(self):
        # overwrite this if the backend binary works differently
        return self.command + ["0", str(self.port)]
    
    def start_runner(self):
        self.kill_runner()
        if not self.is_remote:
            self.runner_proc = subprocess.Popen(self.get_commandline())
    
    def connect_runner(self):
        if self.is_remote:
            print("Waiting for connection on port:", self.port)
        self.connection = self.socket.accept()[0]
        if self.is_remote:
            print("Accepted connection from:", self.connection)
        self.connection.settimeout(3)
    
    def run(self, before: instr_rev.ArchitecturalState, after: instr_rev.ArchitecturalState, retry = 0):
        stats.calls_to_run += 1
        
        bytes_before = before.to_bytes()
        
        try:
            self.connection.sendall(bytes_before)
            # TODO: size is a dirty hack!
            after.from_bytes(self.connection.recv(len(bytes_before) + 16, socket.MSG_WAITALL), mapping_names = [(f"{x.address}_out", f"{x.data}_out") for x in before.memory_mappings])
        except:
            if retry >= self.max_retries:
                return False
            try:
                self.start_runner()
                self.connect_runner()
            except:
                pass # should not happen though ...
            return self.run(before, after, retry = retry + 1)
        
        return True
    
    # TODO: add page_size to instr_rev.ArchitecturalState
    
    def run_instructions(self, instructions: bytes, before: instr_rev.ArchitecturalState, after: instr_rev.ArchitecturalState):
        
        # prepend required code snippet (if necessary, otherwise this is all a nop)
        original_pc = before.get_register(self.pc_reg_name)[1]
        before.set_register(self.pc_reg_name, original_pc - len(self.prefix_instructions))
        instructions = self.prefix_instructions + instructions + self.trapping_instruction # append a trapping instruction so execution stops after instruction (unless there is a jump)
        

        # make sure pc mapping does not overlap with other mapping
        pc_start = before.get_register(self.pc_reg_name)[1]
        pc_end = pc_start + len(instructions)
        for mapping in before.memory_mappings:
            addr = before.get_register(mapping.address)[1]
            size = (before.get_register(mapping.data)[0] + 7) // 8
            if addr < pc_end and addr + size > pc_start:
                # print("overlapping mapping!")
                return False # overlapping mapping

        # add temporary memory mapping for instructions (at address of pc)
        before.add_register(
            instr_rev.Register(
                "instructions", len(instructions) * 8, [], [], is_fake=True
            )
        )
        before.set_register(
            "instructions", int.from_bytes(instructions, "little", signed=False)
        )
        before.add_mapping(
            instr_rev.MemoryMapping(
                self.pc_reg_name,
                instr_rev.PROT_R | instr_rev.PROT_X,
                "instructions" 
            )
        )
        
        success = self.run(before, after)
        
        if success:
           # dirty hack to pretend there was a SIGSEGV instead of a SIGILL if the address is on the instruction page (and no other mapping)
           # the whole spiel is meaningless if running failed, so we only do it on success
           if after.get_register("error_code")[1] == 4:
                err_addr = after.get_register("error_info")[1]
                pc_before = before.get_register(self.pc_reg_name)[1]
                pc_mapping = pc_before - (pc_before % before.page_size)
                pc_mapping_end = (pc_before + len(instructions)) + before.page_size
                pc_mapping_end -= pc_mapping_end % before.page_size
                # exception address is in pc mapping but not inside boilerplate + supplied instructions + trapping instruction
                if err_addr >= pc_mapping and err_addr < (pc_mapping_end) and not (err_addr >= pc_before and err_addr < pc_before + len(instructions)):
                    for mapping in after.memory_mappings:
                        if mapping.protection & instr_rev.PROT_X:
                            addr = after.get_register(mapping.address)[1]
                            if addr <= err_addr and after.get_register(mapping.data)[0] // 8 + addr > err_addr:
                                break
                    else:
                        after.set_register("error_code", 11)
        
        # reset pc value of before state (if it was modified to prefix required boilerplate code)
        before.set_register(self.pc_reg_name, original_pc)
        
        # remove instruction mapping from before and after state (if existent)
        before.remove_mapping(self.pc_reg_name) # remove mapping with pc as address register
        after.remove_mapping(self.pc_reg_name)
        
        # also remove temporary instruction register that was used to hold the instruction mapping data (the instructions being executed)
        before.remove_register("instructions")
        after.remove_register("instructions_out")
        
        return success
    
