#!/usr/bin/python3

"""
Produce Table II and Figure 2 (from the paper) given result CSV files.

Convert the resulting CSV files of evaluation runs into a partial or complete version
of Table II and Figure 2from the paper. The CSV files should be generated with
`run-target.sh`, since a specific directory layout is expected:

    <TARGET>__<MINUTES_PER_RUN>mpr_<RUNS>r_<DATE>/
        phase_1_30s/
            ...
            results/
        phase_1_60s/
            ...
            results/
        phase_1_300s/
            ...
            results/
        phase_1_600s/
            ...
            results/
        phase_1_900s/
            ...
            results/
        phase_1_1200s/
            ...
            results/

The `results/` directories should contain the CSV files.
"""

from __future__ import annotations

import argparse
import datetime
import json
import math
import os
import shutil
import statistics
import subprocess
import tempfile
from dataclasses import dataclass
from typing import Optional, Union

TIMEOUT_SECONDS = 60 * 60 * 8  # 8 hours
TARGET_FILE_PATH = os.path.join("/root", "artifact", "targets.json")
PDF_DIR = os.path.join("/root", "evaluation")
AUTOMATION_PERCENTAGE = 0.95

TEX_HEADER = r"""
\documentclass[10pt, conference]{IEEEtran}
\usepackage{multirow}
\usepackage{xspace}
\usepackage{pgfplots}
\usetikzlibrary{arrows, shadows}
\newcommand{\rosa}{\textsc{Rosa}\xspace}
\begin{document}
"""
TEX_FOOTER = r"""
\end{document}
"""

TABLE_II_HEADER = r"""
\begin{table*}[!h]
\centering
\begin{tabular}{| l || c | c c c | c | c c c |}
\hline

\multicolumn{1}{|c||}{\multirow{4}{*}{\textbf{Backdoor}}}
    & \multicolumn{8}{c |}{\textbf{\rosa} --- 1 minute of fuzzing for phase 1}\\

{}
    & \multicolumn{4}{c |}{\textbf{Robustness + speed}}
    & \multicolumn{4}{c |}{\textbf{Automation level}}\\

{}
    & \multicolumn{1}{c|}{\multirow{2}{1cm}{\centering\textit{\textbf{Failed runs}}}}
    & \multicolumn{3}{c|}{\textit{\textbf{Time to first backdoor input}}}
    & \multicolumn{1}{c|}{{\centering\textit{\textbf{Baseline}}}}
    & \multicolumn{3}{c|}{\textit{\textbf{Manually inspected inputs}}}\\

{}
    &
    & {\scriptsize \textbf{Min.}}
    & {\scriptsize \textbf{Avg.}}
    & {\scriptsize \textbf{Max.}}
    & {\scriptsize \textbf{Avg. seeds}}
    & {\scriptsize \textbf{Min.}}
    & {\scriptsize \textbf{Avg.}}
    & {\scriptsize \textbf{Max.}}
"""
TABLE_II_FOOTER = r"""
\\\hline
\end{tabular}
\end{table*}
"""

FIGURE_2_TEMPLATE = r"""
\begin{{figure*}}
\centering
\begin{{tikzpicture}}
\pgfplotsset{{
    % Important to be able to have a left and right Y axis.
    compat = 1.3,
    width=15cm,
}}

\begin{{axis}}[
    axis y line* = left,
    xtick = {{60, 300, 600, 900, 1200}},
    xticklabels = {{1m, 5m, 10m, 15m, 20m}},
    x tick label style = {{ rotate = 45 }},
    extra x ticks = {{30}},
    extra x tick labels = {{30s}},
    extra x tick style = {{ xticklabel style = {{ yshift = 5pt }}}},
    xlabel = Phase 1 duration,
    xmin = 0,
    xmax = 1240,
    ymin = 0,
    ymax = 10,
    ytick distance = 2,
    ylabel = {{\textcolor{{red}}{{Inputs}}}},
]
    \addplot[mark = o, red]
    coordinates{{
        {}
    }}; \label{{inputs}}
\end{{axis}}

\begin{{axis}}[
    % legend style = {{ at = {{(0.5, 1)}}, anchor = north }},
    legend cell align = left,
    axis x line = none,
    axis y line* = right,
    xmin = 0,
    xmax = 1240,
    ymin = 0,
    ymax = 180,
    ytick distance = 20,
    ylabel = {{\textcolor{{blue}}{{Runs}}}},
]
    \addlegendimage{{/pgfplots/refstyle=inputs}}\addlegendentry{{%
        Manually inspected inputs (avg. per run)
    }}
    \addplot[mark = *, blue]
    coordinates{{
        {}
    }}; \addlegendentry{{Failed runs}}
\end{{axis}}
\end{{tikzpicture}}
\end{{figure*}}
"""

TARGETS: dict[str, dict[str, Union[str, list[str], dict[str, str]]]] = {}
with open(TARGET_FILE_PATH, "r") as targets_file:
    TARGETS = json.load(targets_file)


@dataclass
class RunResults:
    """Results from a single backdoor detection run with ROSA."""

    run_id: int
    true_positives: int
    false_positives: int
    true_negatives: int
    false_negatives: int
    seconds_to_first_backdoor: Optional[int]

    def true_positive_probability(self, samples: int) -> float:
        """Compute the probability of picking at least one TP sample."""
        return 1 - math.comb(self.false_positives, samples) / math.comb(
            self.false_positives + self.true_positives, samples
        )

    def min_samples(self, percentage: float) -> int:
        """Compute the minimum amount of samples needed to achieve a TP probability.

        If there are no true positives, the number of false positives is returned.
        """
        if self.true_positives <= 0:
            return self.false_positives
        samples = self.true_positives + self.false_positives
        for i in range(1, self.true_positives + self.false_positives + 1):
            new_p = self.true_positive_probability(i)
            if new_p >= percentage:
                samples = i
                break

        return samples


@dataclass
class CampaignResults:
    """Results from the full backdoor detection campaign.

    A campaign can contain many runs.
    """

    target_name: str
    total_runs: int
    failed_runs: int
    time_to_first_backdoor_min: str
    time_to_first_backdoor_avg: str
    time_to_first_backdoor_max: str
    baseline_inputs_avg: int
    manually_inspected_inputs_min: int
    manually_inspected_inputs_avg: int
    manually_inspected_inputs_max: int


def format_timedelta(td: datetime.timedelta) -> str:
    """Format a `datetime.TimeDelta` into a human-readable time.

    We follow the format used in the paper:
    - `HHhMMmSSs` format;
    - No "prefixing" if not necessary (e.g., 00h05m02s should be 5m02s).
    """
    if td == datetime.timedelta(seconds=TIMEOUT_SECONDS):
        return "Timeout"

    output = str(td).replace(":", "h", 1).replace(":", "m", 1).split(".")[0] + "s"

    if output.startswith("0h") or output.startswith("00h"):
        output = "".join(output.split("0h")[1:])
    if output.startswith("00m"):
        output = output[3:]
    if output.startswith("00s"):
        output = "0s"

    if output[0] == "0":
        output = output[1:]

    return output


def extract_results(output_dir: str) -> tuple[list[RunResults], list[RunResults]]:
    """Extract the ROSA and baseline fuzzer run results from an evaluation directory.

    The tuple (ROSA, baseline) returned should contain all of the results found within
    the evaluation directory.
    """
    results_dir = os.path.join(output_dir, "results")
    assert os.path.isdir(results_dir)
    rosa_results = []
    baseline_results = []
    for element in os.listdir(results_dir):
        full_element_path = os.path.join(results_dir, element)
        if os.path.isfile(full_element_path) and element.startswith("results-"):
            with open(full_element_path, "r") as fp:
                run_id = int(element.split("-")[1])
                # Get the summary on the last line of the file.
                results_summary = [line.strip() for line in fp.readlines()][-1]
                (
                    true_positives,
                    false_positives,
                    true_negatives,
                    false_negatives,
                    seconds_to_first_backdoor,
                ) = results_summary.split(",")
                run_results = RunResults(
                    run_id=run_id,
                    true_positives=int(true_positives),
                    false_positives=int(false_positives),
                    true_negatives=int(true_negatives),
                    false_negatives=int(false_negatives),
                    seconds_to_first_backdoor=(
                        int(seconds_to_first_backdoor)
                        if seconds_to_first_backdoor != "N/A"
                        else None
                    ),
                )
            if element.endswith("-dedup.csv"):
                rosa_results.append(run_results)
            elif element.endswith("-baseline.csv"):
                baseline_results.append(run_results)

    return (rosa_results, baseline_results)


def generate_table_ii(results: dict[str, dict[str, CampaignResults]]) -> str:
    """Generate Table II from the paper."""
    tex_output = TABLE_II_HEADER
    for target in results:
        name = TARGETS[target]["name"]
        data = results[target]["60s"]
        tex_output += "\\\\\\hline\n"
        tex_output += " & ".join(
            [
                f"{name}",
                f"{data.failed_runs} / {data.total_runs}",
                f"{data.time_to_first_backdoor_min}",
                f"{data.time_to_first_backdoor_avg}",
                f"{data.time_to_first_backdoor_max}",
                f"{data.baseline_inputs_avg}",
                f"{data.manually_inspected_inputs_min}",
                f"{data.manually_inspected_inputs_avg}",
                f"{data.manually_inspected_inputs_max}",
            ]
        )
    tex_output += TABLE_II_FOOTER

    return tex_output


def generate_figure_2(results: dict[str, dict[str, CampaignResults]]) -> str:
    """Generate Figure 2 from the paper."""
    summary: dict[str, dict[str, list[int]]] = {
        duration: {"inputs": [], "failed_runs": []}
        for duration in ("30s", "60s", "300s", "600s", "900s", "1200s")
    }

    for target in results:
        for duration in summary:
            summary[duration]["inputs"].append(
                results[target][duration].manually_inspected_inputs_avg
            )
            summary[duration]["failed_runs"].append(
                results[target][duration].failed_runs
            )

    input_points = "\n".join(
        [
            f"({duration[:-1]}, {round(statistics.mean(summary[duration]['inputs']))})"
            for duration in summary
        ]
    )
    failed_run_points = "\n".join(
        [
            (f"({duration[:-1]}, " f"{round(sum(summary[duration]['failed_runs']))})")
            for duration in summary
        ]
    )

    return FIGURE_2_TEMPLATE.format(input_points, failed_run_points)


def main() -> None:
    """Parse arguments and generate a TEX."""
    parser = argparse.ArgumentParser(
        description="Generate a TEX with Table II given result CSV files."
    )
    parser.add_argument(
        "evaluation_dirs",
        help=(
            "The evaluation directory of the benchmark "
            "(generated with `run-target.sh`)."
        ),
        nargs="+",
        metavar="EVALUATION_DIR",
    )
    parser.add_argument(
        "-o",
        "--output",
        help="The path to the output PDF file that will be generated.",
        default=os.path.join("/root", "evaluation", "replication.pdf"),
    )

    args = parser.parse_args()
    assert args is not None

    evaluation_dirs = args.evaluation_dirs or []

    results: dict[str, dict[str, CampaignResults]] = {}

    for evaluation_dir in evaluation_dirs:
        if evaluation_dir.endswith("/"):
            evaluation_dir = evaluation_dir[:-1]
        name = os.path.basename(evaluation_dir).split("__")[0]
        results[name] = {}

        for phase_dir in os.listdir(evaluation_dir):
            full_phase_dir_path = os.path.join(evaluation_dir, phase_dir)
            rosa_results, baseline_results = extract_results(full_phase_dir_path)

            total_runs = len(rosa_results)
            failed_runs = [
                run.seconds_to_first_backdoor is None for run in rosa_results
            ].count(True)

            seconds_to_first_backdoor = [
                (
                    run.seconds_to_first_backdoor
                    if run.seconds_to_first_backdoor is not None
                    else TIMEOUT_SECONDS
                )
                for run in rosa_results
            ]
            time_to_first_backdoor_min = format_timedelta(
                datetime.timedelta(seconds=round(min(seconds_to_first_backdoor)))
            )
            time_to_first_backdoor_avg = format_timedelta(
                datetime.timedelta(
                    seconds=round(statistics.mean(seconds_to_first_backdoor))
                )
            )
            time_to_first_backdoor_max = format_timedelta(
                datetime.timedelta(seconds=round(max(seconds_to_first_backdoor)))
            )

            baseline_inputs = [
                run.true_positives
                + run.false_positives
                + run.true_negatives
                + run.false_negatives
                for run in baseline_results
            ]
            baseline_inputs_avg = statistics.mean(baseline_inputs)

            manually_inspected_inputs = [
                run.min_samples(percentage=AUTOMATION_PERCENTAGE)
                for run in rosa_results
            ]
            manually_inspected_inputs_min = min(manually_inspected_inputs)
            manually_inspected_inputs_avg = statistics.mean(manually_inspected_inputs)
            manually_inspected_inputs_max = max(manually_inspected_inputs)

            results[name][phase_dir.split("_")[-1]] = CampaignResults(
                target_name=name,
                total_runs=total_runs,
                failed_runs=failed_runs,
                time_to_first_backdoor_min=time_to_first_backdoor_min,
                time_to_first_backdoor_avg=time_to_first_backdoor_avg,
                time_to_first_backdoor_max=time_to_first_backdoor_max,
                baseline_inputs_avg=round(baseline_inputs_avg),
                manually_inspected_inputs_min=round(manually_inspected_inputs_min),
                manually_inspected_inputs_avg=round(manually_inspected_inputs_avg),
                manually_inspected_inputs_max=round(manually_inspected_inputs_max),
            )

    tex_output = TEX_HEADER

    tex_output += generate_table_ii(results)
    tex_output += "\n\n"
    tex_output += generate_figure_2(results)

    tex_output += TEX_FOOTER
    with tempfile.TemporaryDirectory() as tmp_dir:
        tex_file = os.path.join(tmp_dir, "replication.tex")
        pdf_file = os.path.splitext(tex_file)[0] + ".pdf"
        with open(tex_file, "w") as f:
            f.write(tex_output)
        subprocess.call(
            [
                "latexmk",
                "-pdf",
                "-cd",
                "-interaction=nonstopmode",
                f"-output-directory={tmp_dir}",
                tex_file,
            ],
            stdout=subprocess.DEVNULL,
            stderr=subprocess.STDOUT,
        )
        shutil.copyfile(pdf_file, args.output)
    print(f"PDF compiled successfully and copied to: {args.output}")


if __name__ == "__main__":
    main()
