import subprocess
import time
import sys

def try_reverse(instruction, depth=2, samples=500):
    try:
        # print(" ".join(["python3", "main.py", "riscv64-socket", hex(instruction), "0", "32", str(depth), str(samples)]))
        cmd = ["docker", "run", "--rm", "-i", "-v", "./instr_reverser:/instr_reverser", "artifact_instrsem", "python3", "main.py", "riscv64-socket", hex(instruction), "0", "32", str(depth), str(samples)]
        print(" ".join(cmd), file=sys.stderr)
        x = subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE).stderr.decode()
        variable_bits = x.split("variable bits: ")[1].split("\n")[0]
        return eval(variable_bits), x
    except:
        return None, None

def check_reverse(s, instruction, variable_bits=[], depth=2, samples=500):
    start = time.time()
    result, output = try_reverse(instruction, depth=depth, samples=samples)
    if not result:
        result = [""]
        output = "Failed!"
    print(f"--- {s} ---", file=sys.stderr)
    print(output, file=sys.stderr)
    return s, set(result) == set(variable_bits), time.time() - start, len(variable_bits), depth, samples 

def print_row(enc, works, time, bits, depth, samples):
    t_str = f"{time:.2f} s" if works else ""
    print(f"| {enc} | {'✓' if works else '✗'} | {t_str} | {bits} | {depth} | {samples} |")

if __name__ == "__main__":
    print("| Encoding | Works | Time | Bits | Depth | Samples |")
    print("| :--- | :---: | ---: | ---: | ---: | ---: |")

    R_TYPE = [7, 8, 9, 10, 11, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]
    I_TYPE = [7, 8, 9, 10, 11, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
    B_TYPE = [7, 8, 9, 10, 11, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
    U_TYPE = [7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
    J_TYPE = [7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
    S_TYPE = [7, 8, 9, 10, 11, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]


    print_row(*check_reverse("add s7, ra, t2", 0x00708bb3, variable_bits=R_TYPE))
    print_row(*check_reverse("addi s10, ra, 0x2a", 0x02a08d13, variable_bits=I_TYPE))
    print_row(*check_reverse("and ra, tp, s1", 0x009270b3, variable_bits=R_TYPE))
    print_row(*check_reverse("andi gp, ra, 0x2a", 0x02a0f193, variable_bits=I_TYPE))
    print_row(*check_reverse("auipc tp, 0x2a", 0x0002a217, variable_bits=U_TYPE))
    print_row(*check_reverse("beq s9, ra, 0x2a", 0x021c8563, variable_bits=B_TYPE))
    print_row(*check_reverse("bge ra, sp, 0x2a", 0x0220d563, variable_bits=B_TYPE))
    print_row(*check_reverse("bgeu sp, tp, 0x2a", 0x02417563, variable_bits=B_TYPE))
    print_row(*check_reverse("blt ra, sp, 0x2a", 0x0220c563, variable_bits=B_TYPE))
    print_row(*check_reverse("bltu gp, sp, 0x2a", 0x0221e563, variable_bits=B_TYPE))
    print_row(*check_reverse("bne tp, ra, 0x2a", 0x02121563, variable_bits=B_TYPE))
    # ebreak, ecall, and fences don't work
    print_row(*check_reverse("ebreak", 0x00100073, variable_bits=[]))
    print_row(*check_reverse("ecall", 0x00000073, variable_bits=[]))
    print_row(*check_reverse("fence", 0x0000100f, variable_bits=I_TYPE))
    print_row(*check_reverse("jal tp, 0x2a000", 0x0002a26f, variable_bits=J_TYPE))
    # jalr is missing one bit but if we use two different immediates we get the full semantics
    print_row(*check_reverse("jalr sp, ra, 0x29", 0x02908167, variable_bits=[7, 8, 9, 10, 11, 15, 16, 17, 18, 19, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], depth=3))
    print_row(*check_reverse("jalr sp, ra, 0x2a", 0x02a08167, variable_bits=[7, 8, 9, 10, 11, 15, 16, 17, 18, 19, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], depth=3))
    print_row(*check_reverse("lb ra, 0x2a(sp)", 0x02a10083, variable_bits=I_TYPE))
    print_row(*check_reverse("lbu ra, 0x2a(tp)", 0x02a24083, variable_bits=I_TYPE))
    print_row(*check_reverse("lh sp, 0x2a(gp)", 0x02a19103, variable_bits=I_TYPE))
    print_row(*check_reverse("lhu ra, 0x2a(gp)", 0x02a1d083, variable_bits=I_TYPE))
    print_row(*check_reverse("lui gp, 0x2a", 0x0002a1b7, variable_bits=U_TYPE))
    print_row(*check_reverse("lw tp, 0x2a(ra)", 0x02a0a203, variable_bits=I_TYPE))
    print_row(*check_reverse("or tp, gp, s7", 0x0171e233, variable_bits=R_TYPE))
    print_row(*check_reverse("ori tp, a0, 0x2a", 0x02a56213, variable_bits=I_TYPE))
    # print_row(*check_reverse("pause",variable_bits=I_TYPE)) # TODO: invalid mnemonic
    print_row(*check_reverse("sb sp, 0x2a(ra)", 0x02208523, variable_bits=S_TYPE))
    print_row(*check_reverse("sh gp, 0x2a(ra)", 0x02309523, variable_bits=S_TYPE))
    print_row(*check_reverse("sll ra, sp, s3", 0x013110b3, depth=3, variable_bits=R_TYPE))
    print_row(*check_reverse("slt tp, a7, a3", 0x00d8a233, variable_bits=R_TYPE))
    print_row(*check_reverse("slti sp, gp, 0x2a", 0x02a1a113, samples=25000, variable_bits=I_TYPE))
    print_row(*check_reverse("sltiu ra, gp, 0x2a", 0x02a1b093, samples=25000, variable_bits=I_TYPE))
    print_row(*check_reverse("sltu gp, ra, sp", 0x0020b1b3, variable_bits=R_TYPE))
    print_row(*check_reverse("sra tp, sp, a6", 0x41015233, variable_bits=R_TYPE))
    print_row(*check_reverse("srl gp, ra, tp", 0x0040d1b3, depth=3, variable_bits=R_TYPE))
    print_row(*check_reverse("sub sp, s8, ra", 0x401c0133, variable_bits=R_TYPE))
    print_row(*check_reverse("sw sp, 0x2a(ra)", 0x0220a523, variable_bits=S_TYPE))
    print_row(*check_reverse("xor ra, gp, tp", 0x0041c0b3, variable_bits=R_TYPE))
    print_row(*check_reverse("xori sp, ra, 0x2a", 0x02a0c113, variable_bits=I_TYPE))
