import os
import random
import sys
import subprocess
import time
import zipfile

def try_reverse(instruction, depth=2, samples=300, timeout=600, bitmap=None):
    try:
        # print(" ".join(["python3", "main.py", "loongarch64-socket", hex(instruction), "0", "32", str(depth), str(samples)]))
        # cmd = ["docker", "run", "--rm", "-i", "-v", "./instr_reverser:/instr_reverser", "artifact_instrsem", "python3", "main.py", "riscv64-socket", hex(instruction), "0", "32", str(depth), str(samples)]
        cmd = ["docker", "run", "--rm", "-i", "-v", "./instr_reverser:/instr_reverser", "artifact_instrsem", "timeout", "-s", "9", str(timeout), "python3", "main.py", "loongarch64-socket", hex(instruction), "0", "32", str(depth), str(samples)]
        print(" ".join(cmd), file=sys.stderr)
        x = subprocess.run(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.PIPE).stderr.decode() # TODO: change cwd
        variable_bits = x.split("variable bits: ")[1].split("\n")[0]
        return eval(variable_bits), x
    except KeyboardInterrupt:
        if bitmap:
            print("updating bitmap")
            bitmap.save()
        quit()
    except:
        return None, None

class Bitmap:
    
    def __init__(self, file):
        self.file = file
        self.data = bytearray(open(file, "rb").read())
    
    def get(self, bit):
        return not not (self.data[bit // 8] & (1 << (bit % 8)))
    
    def clear(self, bit):
        if not self.get(bit):
            return False
        self.data[bit // 8] &= ~(1 << (bit % 8))
        return True
    
    def set(self, bit):
        if self.get(bit):
            return False
        self.data[bit // 8] |= (1 << (bit % 8))
        return True

    def save(self):
        open(self.file, "wb").write(self.data)

if __name__ == "__main__":
    encodings_found = 0
    os.makedirs("out", exist_ok=True)
    if not os.path.exists("bitmap_qemu.bin"):
        zipfile.ZipFile("bitmap_qemu.zip").extractall()
    bitmap = Bitmap("bitmap_qemu.bin")
    start = time.time()
    while True:
        r = random.randrange(0x100000000)
        if bitmap.get(r):
            variable, out = try_reverse(r, bitmap=bitmap)
            if variable is not None:
                print(f" --- {r:08x} --- ")
                print(out)
                open(f"out/{r:08x}.txt", "w").write(out)
                print("found", len(variable), "bits!")
                for i in range(1 << len(variable)):
                    cur_instr = r
                    for j, b in enumerate(variable):
                        if i & (1 << j):
                            cur_instr ^= (1 << b)
                    encodings_found += bitmap.clear(cur_instr)
                print(f"{int(time.time()-start)}: {encodings_found} total encodings")
                print("")
