import subprocess
import time
import sys

def try_reverse(instruction, depth=2, samples=500):
    try:
        # print(" ".join(["python3", "main.py", "riscv64-socket", hex(instruction), "0", "32", str(depth), str(samples)]))
        cmd = ["docker", "run", "--rm", "-i", "-v", "./instr_reverser:/instr_reverser", "artifact_instrsem", "python3", "main.py", "x86_64-socket", hex(int.from_bytes(instruction, "little")), "0", str(len(instruction)*8), str(depth), str(samples)]
        print(" ".join(cmd), file = sys.stderr)
        x = subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE).stderr.decode()
        variable_bits = x.split("variable bits: ")[1].split("\n")[0]
        return eval(variable_bits), x
    except:
        return None, None

def check_reverse(s, instruction, variable_bits=[], depth=2, samples=500):
    start = time.time()
    result, output = try_reverse(instruction, depth=depth, samples=samples)
    if result is None:
        result = [""]
        output = "Failed!"
    print(f"--- {s} ---", file=sys.stderr)
    print(output, file=sys.stderr)
    return s, set(result) == set(variable_bits), time.time() - start, len(variable_bits), depth, samples 

def print_row(enc, works, time, bits, depth, samples):
    t_str = f"{time:.2f} s" if works else ""
    print(f"| {enc} | {'✓' if works else '✗'} | {t_str} | {bits} | {depth} | {samples} |")
    
if __name__ == "__main__":
    print("| Encoding | Works | Time | Bits | Depth | Samples |")
    print("| :--- | :---: | ---: | ---: | ---: | ---: |")

    print_row(*check_reverse("mov rbp, rsp", bytes([0x48, 0x89, 0xE5]), variable_bits=[16, 17, 18, 19, 20, 21]))
    print_row(*check_reverse("lea rdx, [r8 + 9]", bytes([0x49, 0x8D, 0x51, 0x09]), variable_bits=[19, 20, 21, 24, 25, 26, 27, 28, 29, 30, 31]))
    print_row(*check_reverse("cmp rax, 52", bytes([0x48, 0x83, 0xF8, 0x34]), variable_bits=[""]))
    print_row(*check_reverse("je something", bytes([0x0F, 0x84, 0x00, 0x25, 0x00, 0x00]), variable_bits=[""]))
    print_row(*check_reverse("test rax, rcx", bytes([0x48, 0x85, 0xC8]), variable_bits=[""]))
    print_row(*check_reverse("call something", bytes([0xE8, 0x00, 0x25, 0x00, 0x00]), variable_bits=[]))
    print_row(*check_reverse("jne something", bytes([0x0F, 0x85, 0x00, 0x25, 0x00, 0x00]), variable_bits=[]))
    print_row(*check_reverse("sub rax, rdi", bytes([0x48, 0x29, 0xF8]), variable_bits=[16, 17, 18, 19, 20, 21]))
    print_row(*check_reverse("add rax, 42", bytes([0x48, 0x83, 0xC0, 0x2A]), variable_bits=[16, 17, 18, 24, 25, 26, 27, 28, 29, 30, 31]))
    print_row(*check_reverse("pop rbp", bytes([0x5d]), variable_bits=[0, 1, 2]))
    print_row(*check_reverse("xor rax, rcx", bytes([0x48, 0x31, 0xC8]), variable_bits=[16, 17, 18, 19, 20, 21]))
    print_row(*check_reverse("push rbp", bytes([0x55]), variable_bits=[0, 1, 2]))
    print_row(*check_reverse("nop DWORD PTR [rax+0x0]", bytes([0x0f,0x1f,0x40,0x00]), variable_bits=[]))
    print_row(*check_reverse("ret", bytes([0xc3]), variable_bits=[""]))
    print_row(*check_reverse("and eax, 42", bytes([0x83, 0xE0, 0x2A]), variable_bits=[8, 9, 10, 16, 17, 18, 19, 20, 21, 22]))
    print_row(*check_reverse("movzb eax, bl", bytes([0x0F, 0xB6, 0xC3]), variable_bits=[19, 20, 21]))
    print_row(*check_reverse("endbr64", bytes([0xF3, 0x0F, 0x1E, 0xFA]), variable_bits=[]))
    print_row(*check_reverse("movb [rax], 0x82", bytes([0xC6, 0x00, 0x82]), variable_bits=[16, 17, 18, 19, 20, 21, 22, 23]))
    print_row(*check_reverse("jle 0xa", bytes([0x7e, 0x08]), variable_bits=[0, 1, 2, 3, 4, 5, 6]))
    print_row(*check_reverse("movq [rax], 0x42", bytes([0x48, 0xC7, 0x00, 0x42, 0x00, 0x00, 0x00]), variable_bits=list(range(24,56))))
