import curses
from io import TextIOWrapper
import os
import subprocess
import shutil
import copy
import string
import time
import json
from argparse import ArgumentParser, Namespace
from random import Random
import dill as pickle
from typing import NamedTuple
from lark import ParseTree, Token, UnexpectedCharacters
from lark import ParseTree as Tree
from typing import cast
import mlir  # type: ignore

from mlirmut.match_template import (
    MatchTemplate,
    instantiate_template,
    extract_params,
)

from mlirmut.utils import get_target_node, load_templates
from mlirmut.tree_utils import TreeView, anchor_pattern, exact_tree_match

# type aliases
CorpusEntry = tuple[str, ParseTree]
Corpus = list[CorpusEntry]


def random_update_and_apply_template(
    seed_tree: Tree, template: MatchTemplate, rand: Random
):
    tree = copy.deepcopy(seed_tree)

    anchoring = anchor_pattern(tree=tree, pattern=template.masked_context)

    # specialize as far as we can
    while True:
        template.start_update()
        if not template.specialize_context(rand=rand):
            # if we cannot specialize futher, then we stop specializing
            template.commit_update()
            break
        # try anchoring again with the specialized template
        anchoring = anchor_pattern(tree=tree, pattern=template.masked_context)
        if anchoring is None:
            # if specializing the pattern causes anchoring to fail,
            # then we discard changes and stop specializing
            template.abort_update()
            # anchor again with the original template
            anchoring = anchor_pattern(tree=tree, pattern=template.masked_context)
            break
        # otherwise, we save the update and continue specializing
        template.commit_update()

    # generalize the template until we can anchor it
    while anchoring is None:
        template.start_update()
        if not template.generalize_context(rand=rand):
            # if we cannot generalize futher, then exit
            template.abort_update()
            # we return here since we cannot anchor
            return None
        template.commit_update()
        # if generalizing succeeds, then we try anchoring again
        anchoring = anchor_pattern(tree=tree, pattern=template.masked_context)

    # apply the template
    node_to_replace = get_target_node(
        root=anchoring.tree_view, path=template.masked_fragment_path
    ).original

    if isinstance(node_to_replace, Token):
        parent_node = get_target_node(
            root=anchoring.tree_view, path=template.masked_fragment_path[:-1]
        ).original
        insert_idx = template.masked_fragment_path[-1]

        prior_siblings = parent_node.children[:insert_idx]
        later_siblings = parent_node.children[insert_idx:]

        new_children = prior_siblings + [template.fragment] + later_siblings
        parent_node.set(parent_node.data, new_children)
    else:
        params = extract_params(
            concrete=anchoring.tree_view, context=template.masked_context
        )
        instance = instantiate_template(template=template, params=params)
        node_to_replace.set(instance.data, instance.children)

    return tree


def update_and_apply_template(seed_tree: Tree, template: MatchTemplate, disable_params=False, initial_anchoring=None):
    tree = copy.deepcopy(seed_tree)
    if initial_anchoring is None:
        initial_anchoring = anchor_pattern(tree=tree, pattern=template.masked_context)
    if initial_anchoring is None:
        return None, dict()

    if disable_params:
        anchoring = initial_anchoring
    else:
        # specialize the template as far as possible
        anchoring = template.fully_specialize(
            target_tree=tree, initial_anchoring=initial_anchoring
        )
    if anchoring is None:
        raise ValueError(
            "Failed to anchor template after specialization. "
            "Specialization left the template in an invalid state."
        )

    # apply the template
    node_to_replace = get_target_node(
        root=anchoring.tree_view, path=template.masked_fragment_path
    ).original

    params = dict()
    if isinstance(node_to_replace, Token):
        parent_node = get_target_node(
            root=anchoring.tree_view, path=template.masked_fragment_path[:-1]
        ).original
        insert_idx = template.masked_fragment_path[-1]

        prior_siblings = parent_node.children[:insert_idx]
        later_siblings = parent_node.children[insert_idx:]

        new_children = prior_siblings + [template.fragment] + later_siblings
        parent_node.set(parent_node.data, new_children)
    else:
        if disable_params:
            params = dict()
        else:
            params = extract_params(
                concrete=anchoring.tree_view, context=template.masked_context
            )
        instance = instantiate_template(template=template, params=params)
        node_to_replace.set(instance.data, instance.children)

    return tree, params


class Status(NamedTuple):
    returncode: int
    stdout: str
    stderr: str


def fuzz_one(
    seed_input: Tree,
    template: MatchTemplate,
    mlir_parser: mlir.Parser,
    corpus_hashes: set[int],
):
    mutant = update_and_apply_template(seed_tree=seed_input, template=template)
    if mutant is None:
        return None, Status(-9991, "", "Failed to Mutate")
    if hash(mutant) in corpus_hashes:
        return None, Status(-9992, "", "Mutant already in corpus")
    try:
        mutant_str = mlir_parser.transformer.transform(mutant).dump()
    except:
        return None, Status(-9990, "", "Invalid Syntax")
    args = ["/path/to/Projects/onnx-mlir-2022/build/Debug/bin/onnx-mlir", "-"]
    try:
        status = subprocess.run(
            args, input=mutant_str.encode(), capture_output=True, timeout=5
        )
    except subprocess.TimeoutExpired:
        return None, Status(-9993, "", "Timed out")
    return mutant_str, status


def generate_one(
    seed_input: Tree,
    template: MatchTemplate,
    mlir_parser: mlir.Parser,
    disable_params: bool
):
    mutant, params = update_and_apply_template(seed_tree=seed_input, template=template, disable_params=disable_params)
    if mutant is None:
        return None, Status(-9991, "", "Failed to Mutate"), params
    try:
        return mlir_parser.transformer.transform(mutant).dump(), Status(0, "", ""), params
    except:
        return None, Status(-9990, "", "Invalid Syntax"), params


def get_last_modified_time(path):
    if os.path.isfile(path):
        return os.stat(path).st_mtime

    latest_time = 0
    for dirent in os.scandir(path):
        stat = dirent.stat()
        if latest_time < stat.st_mtime:
            latest_time = stat.st_mtime
    return latest_time


def init_screen(stdscr):
    # Clear the screen
    stdscr.clear()

    # Add a title to the screen
    stdscr.addstr(0, 0, "Fuzzing Monitor")

    # Add some content to the screen
    stdscr.addstr(2, 0, "Number of Iterations:")
    stdscr.addstr(3, 0, "Avg Time Per Iteration (ms):")
    stdscr.addstr(4, 0, "Return Codes:")

    # Refresh the screen
    stdscr.refresh()


def update_stats(
    stdscr, fuzz_count: int, return_code_counts: dict[int, int], avg_time: float
):
    stdscr.addstr(2, 23, str(fuzz_count))
    stdscr.addstr(3, 30, str(avg_time))
    for idx, (ret_code, count) in enumerate(return_code_counts.items()):
        match ret_code:
            case -9991:
                ret_code = "Failed to Mutate"
            case -9992:
                ret_code = "Duplicate Mutant"
            case -9990:
                ret_code = "Invalid Syntax"
        stdscr.addstr(5 + idx, 4, f"{ret_code}: {count}")
    stdscr.move(1, 0)
    stdscr.refresh()


def generate(
    log_file: TextIOWrapper,
    args: Namespace,
    rand: Random,
    corpus: Corpus,
    template_paths: list[str],
    stdscr,
):
    log_file.write(f"seed\ttemplate\tmutant_filename\tstatus\n")
    mlir_parser = mlir.Parser()
    gen_count = 0
    total_time = 0
    return_code_counts = dict()
    while True:
        start_time = time.time()
        if args.max_iterations >= 0 and gen_count >= args.max_iterations:
            print("Reached max iterations, exiting...")
            break
        seed_filepath, seed_input = rand.choice(corpus)
        template_path = rand.choice(template_paths)
        with open(template_path, "rb") as template_file:
            template = pickle.load(template_file)
        mutant_str, status, params = generate_one(
            seed_input=seed_input,
            template=template,
            mlir_parser=mlir_parser,
            disable_params=args.disable_params,
        )
        out_filename = f"{gen_count}.mlir"
        if args.save_mutants and mutant_str is not None:
            with open(os.path.join(args.mutant_path, out_filename), "w") as mutant_file:
                mutant_file.write(mutant_str)
        if args.save_params:
            with open(os.path.join(args.params_path, f"{gen_count}.pkl"), "wb") as params_file:
                pickle.dump(params, params_file)

        columns = f"{seed_filepath}\t{template_path}\t{out_filename}\t{status.returncode}\n"
        elapsed_time = time.time() - start_time
        total_time += elapsed_time
        log_file.write(columns)
        update_stats(
            stdscr,
            gen_count,
            return_code_counts,
            avg_time=total_time / (gen_count + 0.000001) * 1000,
        )
        gen_count += 1


def fuzz(
    log_file: TextIOWrapper,
    args: Namespace,
    rand: Random,
    corpus: Corpus,
    mlir_parser: mlir.Parser,
    corpus_hashes: list[int],
    template_paths: list[str],
    stdscr,
):
    gen_count = 0
    return_code_counts = dict()
    total_time = 0
    while True:
        if args.max_iterations >= 0 and gen_count >= args.max_iterations:
            print("Reached max iterations, exiting...")
            break
        start_time = time.time()
        seed_filepath, seed_input = rand.choice(corpus)
        template_path = rand.choice(template_paths)
        with open(template_path, "rb") as template_file:
            template = pickle.load(template_file)
        mutant_str, status = fuzz_one(
            seed_input=seed_input,
            template=template,
            mlir_parser=mlir_parser,
            corpus_hashes=corpus_hashes,
        )
        out_filename = f"{gen_count}.mlir"
        if args.save_mutants and mutant_str is not None:
            with open(os.path.join(args.mutant_path, out_filename), "w") as mutant_file:
                mutant_file.write(mutant_str)
        elapsed_time = time.time() - start_time
        total_time += elapsed_time
        if status is not None:
            if status.returncode not in return_code_counts:
                return_code_counts[status.returncode] = 0
            return_code_counts[status.returncode] += 1
        else:
            if "Failed to Mutate" not in return_code_counts:
                return_code_counts["Failed to Mutate"] = 0
            return_code_counts["Failed to Mutate"] += 1
        update_stats(
            stdscr,
            gen_count,
            return_code_counts,
            avg_time=total_time / (gen_count + 0.000001) * 1000,
        )

        common_columns = f"{seed_filepath}\t{template_path}\t{elapsed_time}"

        if args.verbose_logging:
            status_columns = (
                f"{status.returncode}\t{repr(status.stdout)}\t{repr(status.stderr)}"
                if status is not None
                else "N/A\tN/A\tN/A"
            )
            log_file.write(f"{common_columns}\t{out_filename}\t{status_columns}\n")
        else:
            return_code = status.returncode if status is not None else "N/A"
            log_file.write(f"{common_columns}\t{return_code}\n")

        log_file.flush()
        gen_count += 1


def main(stdscr):
    # Parse command line arguments
    parser = ArgumentParser()
    parser.add_argument("--mutant-path", type=str, default="data/mutants4")
    parser.add_argument("--seeds", type=str, default="data/4-processed-test-cases")
    parser.add_argument(
        "--templates", type=str, default="data/5-templates/general_first"
    )
    parser.add_argument("--cache-path", type=str, default="data/seed_cache.pkl")
    parser.add_argument("--rand-seed", type=str, default="mlirfuzzer2023")
    parser.add_argument("--log-path", type=str, default="data/log.v2.tsv")
    parser.add_argument("--params-path", type=str, default="params")
    parser.add_argument("--max-iterations", type=int, default=-1)
    parser.add_argument("--save-params", action="store_true")
    parser.add_argument("--generate-only", action="store_true")
    parser.add_argument("--ignore-cache", action="store_true")
    parser.add_argument("--save-mutants", action="store_true")
    parser.add_argument("--verbose-logging", action="store_true")
    parser.add_argument("--display-progress", action="store_true")
    parser.add_argument("--disable-params", action="store_true")
    args = parser.parse_args()

    # initialize MLIR parser
    mlir_parser = mlir.Parser()

    # initialize curses status screen
    if args.display_progress:
        init_screen(stdscr)

    # initialize RNG
    rand = Random(args.rand_seed)

    # retrieve list of available templates
    template_paths = []
    for directory in os.scandir(args.templates):
        template_paths += [
            os.path.join(directory.path, filename)
            for filename in os.listdir(directory.path)
        ]
    template_paths.sort()

    # retrieved cached seed inputs or parse them if it's been modified
    corpus_last_modified = get_last_modified_time(args.seeds)
    cache_last_modified = get_last_modified_time(args.cache_path)
    if (
        os.path.exists(args.cache_path)
        and not args.ignore_cache
        and cache_last_modified >= corpus_last_modified
    ):
        with open(args.cache_path, "rb") as corpus_file:
            corpus = pickle.load(corpus_file)
    else:
        corpus = list()
        # deterministically load seed inputs
        seed_filenames = os.listdir(args.seeds)
        seed_filenames.sort()
        decode_errors = []
        parse_errors = []
        for seed_filename in seed_filenames:
            seed_filepath = os.path.join(args.seeds, seed_filename)
            with open(seed_filepath) as seed_file:
                try:
                    code = seed_file.read()
                except UnicodeDecodeError:
                    decode_errors.append(seed_filepath)
                    continue
                code = code.replace(" floordiv ", "&floordiv&")
                code = code.replace(" ceildiv ", "&ceildiv&")
                code = code.replace(" mod ", "&mod&")
                code = code.replace(
                    '"onnx.NoValue"() {value}', '"onnx.NoValue"() {value="unit"}'
                )
                try:
                    tree = mlir_parser.parser.parse(code)
                    corpus.append((seed_filepath, tree))
                except UnexpectedCharacters as e:
                    parse_errors.append(seed_filepath)
        if len(decode_errors) > 0:
            print(f"Warning: Could not decode {len(decode_errors)} files. Saved to failed_to_decode_files.log.")
            with open("failed_to_decode_files.log", "w") as log_file:
                log_file.write("\n".join(decode_errors))
        if len(parse_errors) > 0:
            print(f"Warning: Could not parse {len(parse_errors)} files. Saved to failed_to_parse_files.log.")
            with open("failed_to_parse_files.log", "w") as log_file:
                log_file.write("\n".join(parse_errors))
        with open(args.cache_path, "wb") as corpus_cache_file:
            pickle.dump(corpus, corpus_cache_file)
    corpus_hashes = {hash(tree) for _, tree in corpus}

    # setup folders
    os.makedirs(args.mutant_path, exist_ok=True)
    os.makedirs(args.params_path, exist_ok=True)

    # begin fuzzing
    if args.generate_only:
        with open(args.log_path, "w") as log_file:
            generate(
                log_file=log_file,
                args=args,
                stdscr=stdscr,
                template_paths=template_paths,
                corpus=corpus,
                rand=rand,
            )
    else:
        with open(args.log_path, "w") as log_file:
            fuzz(
                log_file=log_file,
                args=args,
                rand=rand,
                corpus=corpus,
                mlir_parser=mlir_parser,
                corpus_hashes=corpus_hashes,
                template_paths=template_paths,
                stdscr=stdscr,
            )


if __name__ == "__main__":
    curses.wrapper(main)
