# --- Macro Emulator ---
import struct

class MacroEmulator:
    """
    Emulator for Logitech G500 macro language using 64-bit register bitfields.

    Registers (64-bit each):
      REG_POS       Bits 0-15:  X (signed 16-bit)
                    Bits 16-31: Y (signed 16-bit)
      REG_CTRL      Bits 0-7:   DPI (unsigned 8-bit)
                    Bits 8-11:  Profile (unsigned 4-bit)
      REG_FLAGS     Bit 0:      G-Shift
                    Bits 1-2:   Tilt (00=none,01=left,10=right)
                    Bit 3:      Hyperscroll
                    Bit 4:      DPI-Shift
      REG_WHEEL     signed 8-bit wheel delta in bits 0-7
      REG_MODIFIERS Bits 0-15:  modifier bitmask
      REG_BUTTONS   Bits 0-15:  button bitmask

    Log buffer: simple list of events for text output.
    """

    def __init__(self, init_regs, macro_bytes):
        # 64-bit registers
        self.reg = {
            'POS': init_regs.get('POS', 0),
            'CTRL': init_regs.get('CTRL', 0),
            'FLAGS': init_regs.get('FLAGS', 0),
            'WHEEL': init_regs.get('WHEEL', 0),
            'MODIFIERS': init_regs.get('MODIFIERS', 0),
            'BUTTONS': init_regs.get('BUTTONS', 0),
        }
        self.log = []  # events log
        self.dpi_levels = init_regs.get('DPI_LEVELS', [])
        self.data = macro_bytes
        self.pc = 0

    def run(self):
        data = self.data
        last_pc = self.pc
        while self.pc < len(data):
            opcode = data[self.pc]
            if opcode == 0xFF:  # END
                break

            # dispatch
            handler = getattr(self, f'_op_{opcode:02X}', None)
            if handler:
                handler()
            else:
                # skip unknown (1 byte)
                self.pc += 1
            if last_pc == self.pc:
                break # stop infinite loop
        return self.reg, self.log

    # Helpers for bitfields
    def _get_bits(self, reg, high, low):
        mask = ((1 << (high - low + 1)) - 1) << low
        return (self.reg[reg] & mask) >> low

    def _set_bits(self, reg, high, low, value):
        mask = ((1 << (high - low + 1)) - 1) << low
        self.reg[reg] = (self.reg[reg] & ~mask) | ((value << low) & mask)

    def _inc_bits(self, reg, high, low, delta, signed=False):
        val = self._get_bits(reg, high, low)
        if signed:
            # interpret val as signed
            maxv = 1 << (high - low + 1)
            if val & (1 << (high - low)):
                val -= maxv
        val += delta
        # wrap/truncate
        mask = (1 << (high - low + 1)) - 1
        self._set_bits(reg, high, low, val & mask)

    # 0x02,0x03 no-op repeats
    def _op_02(self): self.pc += 1
    def _op_03(self): self.pc += 1

    # 2X scroll: 0x20 up,0x21 down
    def _op_20(self):
        d = self.data[self.pc+1]
        self._inc_bits('WHEEL', 7, 0,  d, signed=True)
        self.pc += 2
    def _op_21(self):
        d = self.data[self.pc+1]
        self._inc_bits('WHEEL', 7, 0, -d, signed=True)
        self.pc += 2

    # 4X commands
    def _op_40(self):  # delay
        # skip 4-byte TTTT
        self.pc += 5
    def _op_41(self):  # mouse press
        btn = struct.unpack_from('<H', self.data, self.pc+1)[0]
        self._set_bits('BUTTONS', btn, btn, 1)
        self.log.append(f'MOUSE_PRESS {btn}')
        self.pc += 3
    def _op_42(self):  # mouse release
        btn = struct.unpack_from('<H', self.data, self.pc+1)[0]
        self._set_bits('BUTTONS', btn, btn, 0)
        self.log.append(f'MOUSE_RELEASE {btn}')
        self.pc += 3
    def _op_43(self):  # key press
        mm, kk = self.data[self.pc+1], self.data[self.pc+2]
        self._set_bits('MODIFIERS', 15, 0, self.reg['MODIFIERS'] | mm)
        self.log.append(f'KEY_PRESS mm=0x{mm:02X} usage=0x{kk:02X}')
        self.pc += 3
    def _op_44(self):  # key release
        mm, kk = self.data[self.pc+1], self.data[self.pc+2]
        newmod = self.reg['MODIFIERS'] & ~mm
        self._set_bits('MODIFIERS', 15, 0, newmod)
        self.log.append(f'KEY_RELEASE mm=0x{mm:02X} usage=0x{kk:02X}')
        self.pc += 3
    def _op_45(self): self.pc += 5
    def _op_46(self): self.pc += 5

    # 6X commands
    def _op_60(self):  # jump
        offset = self.data[self.pc+2] * 2
        self.pc = offset
    def _op_61(self):  # mouse move
        dx = struct.unpack_from('<h', self.data, self.pc+1)[0]
        dy = struct.unpack_from('<h', self.data, self.pc+3)[0]
        self._inc_bits('POS', 15, 0, dx, signed=True)
        self._inc_bits('POS', 31,16,dy, signed=True)
        self.pc += 5

    # 8X extended
    def _op_80(self):
        t = self.data[self.pc+1]
        if t == 1:
            btn = struct.unpack_from('<H', self.data, self.pc+2)[0]
            # click = press+release
            self.log.append(f'CLICK {btn}')
            self.pc += 4
        elif t == 2:
            mm, kk = self.data[self.pc+2], self.data[self.pc+3]
            self.log.append(f'KEYSTROKE mm=0x{mm:02X} usage=0x{kk:02X}')
            self.pc += 4
        elif t == 3:
            code = struct.unpack_from('>H', self.data, self.pc+2)[0]
            self.log.append(f'MEDIA {code}')
            self.pc += 4
        else:
            self.pc += 2

    # 9X special
    def _op_90(self):
        c = self.data[self.pc+1]
        if c in (0x00,0x0B): self._set_bits('FLAGS', 0,0,1)
        elif c == 1: self._set_bits('FLAGS', 2,1,1)
        elif c == 2: self._set_bits('FLAGS', 2,1,2)
        elif c == 4: self._set_bits('FLAGS', 3,3, self.reg['FLAGS'] ^ (1<<3))
        elif c == 5:
            dpi = self._get_bits('CTRL',7,0)
            if self.dpi_levels:
                idx = self.dpi_levels.index(dpi) if dpi in self.dpi_levels else -1
                dpi = self.dpi_levels[(idx+1)%len(self.dpi_levels)]
            else: dpi += 100
            self._set_bits('CTRL',7,0,dpi)
        elif c == 6:
            dpi = self._get_bits('CTRL',7,0)
            if self.dpi_levels:
                idx = self.dpi_levels.index(dpi) if dpi in self.dpi_levels else 0
                dpi = self.dpi_levels[idx-1]
            else: dpi = max(0,dpi-100)
            self._set_bits('CTRL',7,0,dpi)
        elif c == 7: self._set_bits('FLAGS',4,4,1)
        self.pc += 2

# Example
if __name__ == '__main__':
    init = {
        'POS': 0,
        'CTRL': (800 & 0xFF) | ((1 & 0xF)<<8),
        'FLAGS': 0,
        'WHEEL': 0,
        'MODIFIERS': 0,
        'BUTTONS': 0,
        'DPI_LEVELS': [400,800,1200,1600]
    }
    macro = bytes([0x61,0x01,0x00,0xFF,0xFF, 0x20,0x05, 0x41,0x01,0x00, 0x43,0x02,0x04, 0x90,0x05, 0xFF])
    emu = MacroEmulator(init, macro)
    regs, log = emu.run()
    print(regs)
    print(log)








## --- Wrapper below ---
import instr_rev
import stats
class LogitechMacrosArchitecturalState(instr_rev.ArchitecturalState):
    def __init__(self, solver, endianess = "little", feature_set=[]):
        super().__init__([
            instr_rev.Register("pc", 64, [], [], encoding_group="pc"),
            instr_rev.Register("X", 16, [], [], encoding_group="X"),
            instr_rev.Register("Y", 16, [], [], encoding_group="Y"),
            instr_rev.Register("DPI", 8, [], [], encoding_group="DPI"),
            instr_rev.Register("PROFILE", 4, [], [], encoding_group="PROFILE"),
            instr_rev.Register("G-Shift", 1, [], [], encoding_group="G-Shift"),
            instr_rev.Register("Tilt", 2, [], [], encoding_group="Tilt"),
            instr_rev.Register("Hyperscroll", 1, [], [], encoding_group="Hyperscroll"),
            instr_rev.Register("DPI-Shift", 1, [], [], encoding_group="DPI-Shift"),
            instr_rev.Register("Wheel", 8, [], [], encoding_group="Wheel"),
            instr_rev.Register("Buttons", 16, [], [], encoding_group="Buttons"),
            instr_rev.Register("Modifiers", 16, [], [], encoding_group="Modifiers")])
        self.page_size = 512 # does not matter
        self.endianess = "little" # does not matter
        self.add_register(instr_rev.Register("error_code", 64, [], [], is_fake=True))
        self.add_register(instr_rev.Register("error_info", 64, [], [], is_fake=True))
    def to_bytes(self) -> bytearray:
        raise Exception(f"cannot serialize {self.__class__.__name__}")
    def from_bytes(self, bytes):
        raise Exception(f"cannot de-serialize {self.__class__.__name__}")
    def to_emu(self):
        if len(self.memory_mappings) != 1:
            raise Exception(f"only works with exactly one mapping for code! (was: {self.memory_mappings})")
        memory_bitwidth, memory_data = self.get_register(self.memory_mappings[0].data)    
        macro_bytes = memory_data.to_bytes((memory_bitwidth + 7) // 8, "little")
        return MacroEmulator({
                "POS": self.get_register("X")[1] | (self.get_register("Y")[1] << 16),
                "CTRL": self.get_register("DPI")[1] | (self.get_register("PROFILE")[1] << 8),
                "FLAGS": self.get_register("G-Shift")[1] | (self.get_register("Tilt")[1] << 1) | (self.get_register("Hyperscroll")[1] << 3) | (self.get_register("DPI-Shift")[1] << 4),
                "WHEEL": self.get_register("Wheel")[1],
                "MODIFIERS": self.get_register("Modifiers")[1],
                "BUTTONS": self.get_register("Buttons")[1],
                "DPI_LEVELS": [400,800,1200,1600]
            }, macro_bytes)
    def from_emu(self, emu):
        self.set_register("X", emu.reg["POS"] & 0xffff)
        self.set_register("Y", (emu.reg["POS"] >> 16) & 0xffff)
        self.set_register("DPI", emu.reg["CTRL"] & 0xff)
        self.set_register("PROFILE", (emu.reg["CTRL"] >> 8) & 0xf)
        self.set_register("G-Shift", emu.reg["FLAGS"] & 1)
        self.set_register("Tilt", (emu.reg["FLAGS"] >> 1) & 3)
        self.set_register("Hyperscroll", (emu.reg["FLAGS"] >> 3) & 1)
        self.set_register("DPI-Shift", (emu.reg["FLAGS"] >> 4) & 1)
        self.set_register("Wheel", emu.reg["WHEEL"])
        self.set_register("Buttons", emu.reg["BUTTONS"])
        self.set_register("Modifiers", emu.reg["MODIFIERS"])
        self.set_register("pc", emu.pc)
        self.set_register("error_code", 4)
        self.set_register("error_info", emu.pc)
class LogitechMacrosRunner(instr_rev.Runner):
    def __init__(self, command, is_remote = False, ip = "0.0.0.0", port = 0, max_retries = 1):
        pass
    def run(self, before: instr_rev.ArchitecturalState, after: instr_rev.ArchitecturalState, retry = 0):
        emu = before.to_emu()
        emu.run()
        after.from_emu(emu)
        return True
    def run_instructions(self, instructions: bytes, before: instr_rev.ArchitecturalState, after: instr_rev.ArchitecturalState):
        instructions += b"\xff"
        before.add_register( instr_rev.Register(
                "instructions", len(instructions) * 8, [], [], is_fake=True))
        before.set_register("instructions", int.from_bytes(instructions, "little", signed=False))
        before.add_mapping(instr_rev.MemoryMapping("pc",
                instr_rev.PROT_R | instr_rev.PROT_X, "instructions"))
        res = self.run(before, after)
        before.remove_mapping("pc")
        before.remove_register("instructions")
        return res
