#!/usr/bin/python3

"""
Evaluate a set of benchmark runs for both ROSA and the baseline (fuzzer).

This script evaluates a set of benchmark runs and stores the results in an output
directory, in the form of CSV files. These CSV files are named after the following
format:

    results-<RUN ID>-<DEDUPLICATION VARIANT>.csv

Where:
    <RUN ID>: the (padded) ID of the run.
    <DEDUPLICATION VARIANT>: either "baseline" (raw fuzzer inputs) or "native" (no
        ROSA deduplication applied) or "dedup" (native ROSA deduplication is applied).
"""

from __future__ import annotations

import argparse
import functools
import json
import multiprocessing
import os
import re
import shutil
import subprocess
from dataclasses import dataclass
from typing import Optional, Union

AFL_QEMU_TRACE = os.path.join(
    "/root", "rosa", "fuzzers", "aflpp", "aflpp", "afl-qemu-trace"
)
DEFAULT_GROUND_TRUTH_MARKER = "***BACKDOOR TRIGGERED***"
TARGET_FILE_PATH = os.path.join("/root", "artifact", "targets.json")
EVALUATION_TIMEOUT_SECONDS = 5
EIGHT_HOURS_IN_SECONDS = 60 * 60 * 8


@dataclass
class GroundTruthConfig:
    """Declare a configuration for the ground-truth variant of a target."""

    """The path to the ground-truth executable."""
    executable: str
    """Any arguments to pass to the ground-truth executable."""
    arguments: list[str]
    """Any environment variables to set for the ground-truth executable."""
    environment: dict[str, str]


def evaluate_test_input(
    test_input_path: str,
    ground_truth_config: GroundTruthConfig,
    ground_truth_marker: str,
) -> bool:
    """Evaluate the result of a single baseline input."""
    qemu_program = (
        "timeout",
        "--signal=KILL",
        f"{EVALUATION_TIMEOUT_SECONDS}s",
        AFL_QEMU_TRACE,
    )

    evaluation_command = (
        *qemu_program,
        "--",
        ground_truth_config.executable,
        *ground_truth_config.arguments,
    )

    target_env = os.environ.copy()
    target_env.update(ground_truth_config.environment)

    with open(test_input_path, "rb") as test_input_file:
        output = subprocess.run(
            evaluation_command,
            stdin=test_input_file,
            env=target_env,
            capture_output=True,
        )

    return ground_truth_marker in output.stdout.decode(
        encoding="utf-8", errors="ignore"
    ) or ground_truth_marker in output.stderr.decode(encoding="utf-8", errors="ignore")


def evaluate_baseline(
    baseline_inputs_dir: str,
    ground_truth_config: GroundTruthConfig,
    ground_truth_marker: str,
    output_dir: str,
    output_prefix: str,
    time_limit_seconds: Optional[int],
) -> None:
    """Evaluate baseline (fuzzer) results.

    Run a very simple evaluation on the baseline/raw inputs generated by the
    fuzzer, by running them through the ground-truth programs.
    """
    baseline_inputs = []
    baseline_input_times_milliseconds = []
    last_timestamp_milliseconds = 0

    for element in sorted(os.listdir(baseline_inputs_dir)):
        # Only keep relevant fuzzer input files.
        if element.startswith("id:"):
            # Get the time for each input.
            timestamp_match = re.search(r"time:(\d+)", element)
            input_milliseconds = last_timestamp_milliseconds
            if timestamp_match is not None:
                input_milliseconds = int(timestamp_match.group(1))
                last_timestamp_milliseconds = input_milliseconds

            baseline_inputs.append(os.path.join(baseline_inputs_dir, element))
            baseline_input_times_milliseconds.append(input_milliseconds)

    # Run actual evaluation.
    true_positives = 0
    false_positives = 0
    seconds_to_first_backdoor = None

    with multiprocessing.Pool(multiprocessing.cpu_count()) as process_pool:
        results = process_pool.map(
            functools.partial(
                evaluate_test_input,
                ground_truth_config=ground_truth_config,
                ground_truth_marker=ground_truth_marker,
            ),
            baseline_inputs,
        )

    for index, result in enumerate(results):
        seconds = baseline_input_times_milliseconds[index] // 1000
        # Filter out inputs based on the time limit.
        if time_limit_seconds is not None and seconds > time_limit_seconds:
            continue

        if result:
            true_positives += 1
            if seconds_to_first_backdoor is None:
                seconds_to_first_backdoor = seconds
        else:
            false_positives += 1

    # Write the results to the output files.
    output_header = (
        "true_positives,false_positives,true_negatives,false_negatives,"
        "seconds_to_first_backdoor"
    )

    output_summary = f"{true_positives},{false_positives},0,0,"
    if seconds_to_first_backdoor is None:
        output_summary += "N/A"
    else:
        output_summary += f"{seconds_to_first_backdoor}"

    output_base = os.path.join(output_dir, output_prefix)
    with open(f"{output_base}-baseline.csv", "w") as output_file:
        output_file.write(f"{output_header}\n")
        output_file.write(f"{output_summary}\n")


def evaluate(
    rosa_dir: str,
    ground_truth_config: GroundTruthConfig,
    output_dir: str,
    output_prefix: str,
    time_limit_seconds: Optional[int],
    verbose: bool,
) -> None:
    """Evaluate the results of a run."""
    qemu_program = (
        "timeout",
        "--signal=KILL",
        f"{EVALUATION_TIMEOUT_SECONDS}s",
        AFL_QEMU_TRACE,
    )

    evaluation_program = " ".join(
        (
            *qemu_program,
            "--",
            ground_truth_config.executable,
            *ground_truth_config.arguments,
        )
    )

    rosa_evaluate = [
        "rosa-evaluate",
    ]
    if time_limit_seconds is not None:
        rosa_evaluate = ["rosa-evaluate", "--time-limit", f"{time_limit_seconds}"]

    evaluation_command = (
        *rosa_evaluate,
        "--target-program",
        evaluation_program,
        "--summary",
        rosa_dir,
    )

    output_base = os.path.join(output_dir, output_prefix)

    with open(f"{output_base}-dedup.csv", "w") as output_file:
        subprocess.run(
            [
                *evaluation_command,
            ],
            stdout=output_file,
            stderr=None if verbose else subprocess.DEVNULL,
        )


def main() -> None:
    """Parse arguments and run a benchmark."""
    targets: dict[str, dict[str, Union[str, list[str], dict[str, str]]]] = {}
    with open(TARGET_FILE_PATH, "r") as targets_file:
        targets = json.load(targets_file)

    parser = argparse.ArgumentParser(
        description="Evaluate backdoor benchmarks with the ROSA backdoor detector."
    )
    parser.add_argument(
        "target",
        help="The target to evaluate.",
        choices=list(targets.keys()),
    )
    parser.add_argument(
        "benchmarks_dir",
        help=(
            "The directory containing the benchmark data (structured by instance and "
            "run ID)."
        ),
    )
    parser.add_argument(
        "output_dir", help="The output directory for the evaluation results."
    )
    parser.add_argument(
        "-g",
        "--ground-truth-marker",
        help=(
            "The ground-truth marker to check for in the stderr and stdout of the "
            "target."
        ),
        default=DEFAULT_GROUND_TRUTH_MARKER,
    )
    parser.add_argument(
        "-t",
        "--time-limit",
        help=(
            "Do not evaluate inputs/traces analyzed past a certain time limit "
            "(in seconds)."
        ),
        type=int,
    )
    parser.add_argument(
        "-v",
        "--verbose",
        help="Display more detailed output.",
        action="store_true",
    )

    args = parser.parse_args()
    assert args is not None

    chosen_target = targets[args.target]
    assert type(chosen_target["ground_truth_executable"]) is str
    assert type(chosen_target["ground_truth_arguments"]) is list
    assert type(chosen_target["baseline_evaluation_env"]) is dict
    ground_truth_config = GroundTruthConfig(
        executable=chosen_target["ground_truth_executable"],
        arguments=chosen_target["ground_truth_arguments"],
        environment=chosen_target["baseline_evaluation_env"],
    )

    # Let this fail on purpose if the directory exists.
    os.makedirs(args.output_dir)

    # Collect ROSA/fuzzer directories.
    run_dirs = []
    for element in sorted(os.listdir(args.benchmarks_dir)):
        full_element_path = os.path.join(args.benchmarks_dir, element)
        if os.path.isdir(full_element_path):
            instance = full_element_path
            for element in sorted(os.listdir(instance)):
                full_element_path = os.path.join(instance, element)
                if os.path.isdir(full_element_path) and element.startswith("run-"):
                    run = full_element_path
                    run_dirs.append(
                        (
                            os.path.join(
                                run,
                                f"rosa-out-{args.target}",
                            ),
                            os.path.join(
                                run,
                                f"fuzzer-out-{args.target}",
                            ),
                        ),
                    )

    # Run target setup code.
    target_root_dir = targets[args.target]["root_dir"]
    assert type(target_root_dir) is str
    subprocess.call(
        ["make", "-C", target_root_dir, "setup"],
        stdout=None if args.verbose else subprocess.DEVNULL,
        stderr=None if args.verbose else subprocess.STDOUT,
    )

    try:
        # Run the evaluation on each directory.
        for run_id, (rosa_dir, fuzzer_dir) in enumerate(run_dirs):
            padded_run_id = f"{run_id:02d}"
            output_prefix = f"results-{padded_run_id}"
            # Run an evaluation for ROSA.
            evaluate(
                rosa_dir=rosa_dir,
                ground_truth_config=ground_truth_config,
                output_dir=args.output_dir,
                output_prefix=output_prefix,
                time_limit_seconds=args.time_limit,
                verbose=args.verbose,
            )
            # Run an evaluation for the fuzzer.
            evaluate_baseline(
                baseline_inputs_dir=os.path.join(fuzzer_dir, "main", "queue"),
                ground_truth_config=ground_truth_config,
                ground_truth_marker=args.ground_truth_marker,
                output_dir=args.output_dir,
                output_prefix=output_prefix,
                time_limit_seconds=args.time_limit,
            )
            # Store the ROSA stats files.
            shutil.copyfile(
                os.path.join(rosa_dir, "stats.csv"),
                os.path.join(args.output_dir, f"stats-{padded_run_id}.csv"),
            )
    except KeyboardInterrupt:
        pass

    # Run target teardown code.
    subprocess.call(
        ["make", "-C", target_root_dir, "teardown"],
        stdout=None if args.verbose else subprocess.DEVNULL,
        stderr=None if args.verbose else subprocess.STDOUT,
    )


if __name__ == "__main__":
    main()
