import os
import time
from pathlib import Path
import sys
import subprocess
from multiprocessing import Pool
from argparse import ArgumentParser
import random
import string
import shutil
import dill

from common import TARGET_CMD

alphanumeric_chars = string.ascii_letters + string.digits
mutant_id = "".join([random.choice(alphanumeric_chars) for _ in range(10)])

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("input_dir", type=Path)
    parser.add_argument("ref_dir", type=Path)
    parser.add_argument("--temp-dir", type=Path)
    parser.add_argument("--log-path", type=Path)
    parser.add_argument("--timeout", default=5, type=float)

    args = parser.parse_args()
    print(f"Got args: {args}")

    temp_dir = args.temp_dir / mutant_id

    # print a unique identifier so that we can match the log files with the results
    print(f"Mutant ID: {mutant_id}")

    def write_failure_log(log):
        with open((args.log_path / mutant_id).with_suffix(".log.pkl"), "wb") as f:
            dill.dump(log, f)

    try:
        os.makedirs(temp_dir, exist_ok=True)
        for dirent in os.scandir(args.input_dir):
            if not dirent.name.endswith(".mlir"):
                continue
            print(dirent.path)
            output_prefix = os.path.splitext(dirent.name)[0]
            output_path = temp_dir / output_prefix
            # if compilation succeeds, we compare the output to the reference output
            # if compilation fails, we compare the return code
            status = subprocess.run(
                [item.replace("%inputpath", dirent.path).replace("%inputname", dirent.name) for item in TARGET_CMD],
                capture_output=True,
                timeout=int(args.timeout) + 5,
            )
            ref_stem = args.ref_dir / output_prefix
            with open(ref_stem.with_suffix(".returncode"), "r") as f:
                ref_retcode = int(f.read().strip())
            with open(ref_stem.with_suffix(".stdout"), "rb") as f:
                ref_stdout = f.read()
            with open(ref_stem.with_suffix(".stderr"), "rb") as f:
                ref_stderr = f.read()
            failure_log = {
                "failing_test_case": dirent.path,
                "return_code": status.returncode,
                "stdout": status.stdout,
                "stderr": status.stderr,
            }
            if ref_retcode == 0 and status.returncode != 0:
                failure_log["reason"] = "returned non-zero, expected zero"
            elif ref_retcode != 0 and status.returncode == 0:
                failure_log["reason"] = "returned zero, expected non-zero"
            elif ref_retcode != 0 and status.returncode != ref_retcode:
                failure_log["reason"] = "non-zero return code mismatch"
            elif (
                ref_retcode == 0
                and status.returncode == 0
                and ref_stdout != status.stdout
            ):
                failure_log["reason"] = "output mismatch"
            write_failure_log(failure_log)
            if "reason" in failure_log:
                raise ValueError(failure_log["reason"])
    except Exception as e:
        shutil.rmtree(temp_dir)
        raise e
    shutil.rmtree(temp_dir)
