import time

import random
from instr_rev import PROT_RWX, MemoryMapping
import instr_rev_z3
from reverser import InstructionReverser
import clusterer
import constants
import sys

random.seed(42)

DEPTH = 2
SAMPLES = 500

if len(sys.argv) < 5:
    print(f"usage: python3 {sys.argv[0]} <target> <instruction> <port> <instruction-bitwidth> [DEPTH] [SAMPLES]")
    sys.exit(-1)


target = sys.argv[1]
instruction = int(sys.argv[2], 0)
port = int(sys.argv[3], 0)
bitwidth = int(sys.argv[4], 0)

if len(sys.argv) >= 6:
    DEPTH = int(sys.argv[5], 0)
if len(sys.argv) >= 7:
    SAMPLES = int(sys.argv[6], 0)

# create a solver that uses z3
solver = instr_rev_z3.Z3Solver(max_depth=DEPTH, debug=True, display_progress=True)

feature_set=set()

constants.INSTR_BITWIDTH = bitwidth

const_iter_class = clusterer.SimpleInstructionConstantIterator

match target:

    case "logitech":
        constants.CRASH_VALUE = 0xff
        from backend.logitech_macros import LogitechMacrosArchitecturalState, LogitechMacrosRunner
        runner = LogitechMacrosRunner([])
        state_class = LogitechMacrosArchitecturalState

    case "x86_64-socket":
        constants.CRASH_VALUE = 0xffffffff
        from backend.x86_64 import X86_64ArchitecturalState, X86_64SocketRunner
        runner = X86_64SocketRunner(["backend/runner_socket_x86_64"], port=port)
        state_class = X86_64ArchitecturalState
    
    case "riscv64-socket":
        from backend.riscv64 import Riscv64ArchitecturalState, Riscv64SocketRunner, Riscv64ConstantIterator
        const_iter_class = Riscv64ConstantIterator
        runner = Riscv64SocketRunner(["backend/qemu-user-riscv64-static", "backend/runner_socket_riscv64"], port=port)
        state_class = Riscv64ArchitecturalState

    case "riscv64-socket-remote":
        from backend.riscv64 import Riscv64ArchitecturalState, Riscv64SocketRunner, Riscv64ConstantIterator
        const_iter_class = Riscv64ConstantIterator
        runner = Riscv64SocketRunner(["backend/qemu-user-riscv64-static", "backend/runner_socket_riscv64"], is_remote=True, port=port)
        state_class = Riscv64ArchitecturalState

    case "aarch64-socket":
        from backend.aarch64 import Aarch64ArchitecturalState, Aarch64SocketRunner
        runner = Aarch64SocketRunner(["backend/qemu-user-aarch64-static", "backend/runner_socket_aarch64"], port=port)
        state_class = Aarch64ArchitecturalState

    case "aarch64-vector-socket":
        from backend.aarch64 import Aarch64ArchitecturalState, Aarch64SocketRunner
        feature_set.add("VECTOR")
        runner = Aarch64SocketRunner(["backend/qemu-user-aarch64-static", "backend/runner_socket_aarch64-vector"], port=port)
        state_class = Aarch64ArchitecturalState

    case "aarch64-vector-socket-remote":
        from backend.aarch64 import Aarch64ArchitecturalState, Aarch64SocketRunner
        feature_set.add("VECTOR")
        runner = Aarch64SocketRunner(["backend/qemu-user-aarch64-static", "backend/runner_socket_aarch64-vector"], is_remote=True, port=port)
        state_class = Aarch64ArchitecturalState

    case "loongarch64-socket":
        from backend.loongarch64 import Loongarch64ArchitecturalState, Loongarch64SocketRunner, Loongarch64ConstantIterator
        runner = Loongarch64SocketRunner(["backend/qemu-user-loongarch64-static", "backend/runner_socket_loongarch64"], port=port)
        state_class = Loongarch64ArchitecturalState
        const_iter_class = Loongarch64ConstantIterator

state_before = state_class(solver, feature_set=feature_set)
state_after  = state_class(solver, feature_set=feature_set)

constant_iter = lambda x: const_iter_class(x, constants.INSTR_BITWIDTH, shifts=[0, 1, 2, 12, 18, 32], offsets=[-1,0,1,constants.INSTR_BITWIDTH // 8])

find_constants = lambda x: constant_iter(x).iterate_constants()
reverser = InstructionReverser(solver, runner, lambda: state_class(solver, feature_set=feature_set), find_constants, samples=SAMPLES)

print(f"Trying to reverse {hex(instruction)}")
success, tree, state = reverser.reverse_instruction(instruction)
print(f"reversing success: {success}")
print(tree.to_str())

if success:
    semantics = tree.clean()
    cltr = clusterer.Clusterer(solver, runner, constant_iter(instruction), lambda: state_class(solver, feature_set=feature_set))
    cltr.cluster(instruction, semantics, state)

try:
    runner.kill_runner()
except:
    pass

