#! /usr/bin/env python3

import os
from importlib.util import spec_from_file_location, module_from_spec
from pysmt.shortcuts import reset_env

from utils import pysmt_dump_whole_expr, set_solvers, set_verbosity
from smv_prefixes import PROP_FIND_LOOP
import argparse
import automata_composition
import find_composition
import check_composition

# used for profiling
# import yappi


def getopts():
    p = argparse.ArgumentParser()
    p.add_argument("--nuxmv", required=True)
    p.add_argument("--clear", action='store_true', default=True)
    p.add_argument("--no-clear", action="store_false", dest="clear")
    p.add_argument("-o", "--output")
    # p.add_argument("-s", "--solver", choices=['z3', 'msat', 'z3-msat',
    #                                           'msat-z3'],
    #                default=['msat-z3'])
    p.add_argument('--smv-comp-timeout', type=int, default=5)
    p.add_argument('--check-comp-timeout', type=int, default=15)
    p.add_argument('--nuxmv-timeout', type=int, default=250)
    p.add_argument('--find-comp-timeout', type=int, default=15)
    p.add_argument('-v', '--verbose', action='store_true')
    p.add_argument("--try-sync-product", action='store_true', default=False)
    p.add_argument("inputs", nargs="+")
    return p.parse_args()


def main(opts):
    # if opts.solver == "z3":
    #     set_solvers(["z3"])
    # elif opts.solver == "msat":
    #     set_solvers(["msat"])
    # elif opts.solver == "z3-msat":
    #     set_solvers(["z3", "msat"])
    # elif opts.solver == "msat-z3":
    #     set_solvers(["msat", "z3"])

    automata_composition.set_timeout(opts.smv_comp_timeout)
    check_composition.set_timeout(opts.check_comp_timeout)
    find_composition.set_timeout(opts.find_comp_timeout)
    find_composition.set_nuxmv_timeout(opts.nuxmv_timeout)
    find_composition.set_try_sync_product(opts.try_sync_product)
    set_verbosity(opts.verbose)

    pysmt_dump_whole_expr()
    ic3_cmd_template = """set on_failure_script_quits 1
go_msat
check_invar_ic3 -P {}
show_traces -v -o {{}}
quit
""".format(PROP_FIND_LOOP)
    bmc_cmd_template = """set on_failure_script_quits 1
go_msat
msat_check_invar_bmc -k 30 -a falsification -P {}
show_traces -v -o {{}}
quit
""".format(PROP_FIND_LOOP)

    nuxmv = os.path.abspath(opts.nuxmv)
    assert os.path.isfile(nuxmv), "Not a file: {}".format(nuxmv)

    clear = opts.clear
    test_files = []
    for _curr_path in opts.inputs:
        curr_path = os.path.abspath(_curr_path)
        if os.path.isfile(curr_path):
            assert curr_path.endswith(".py")
            test_files.append((os.path.basename(_curr_path)[:-3], curr_path))
        elif os.path.isdir(curr_path):
            for _f_name in os.listdir(curr_path):
                f_name = os.path.join(curr_path, _f_name)
                if os.path.isfile(f_name) and _f_name.endswith(".py") and \
                   not _f_name.startswith("_"):
                    test_files.append((_f_name[:-3], f_name))
        else:
            assert False, "{} does not exist".format(curr_path)

    test_files.sort()
    model_path = "/tmp/{}_model.smv"
    trace_path = "/tmp/{}_trace.txt"
    invar_ic3_path = "/tmp/{}_invar_ic3.cmd"
    composition_path = "/tmp/{}_composed.smv"

    num_tests = len(test_files)
    for test_num, (label, test_file) in enumerate(test_files):
        assert os.path.isfile(test_file)
        model_file = model_path.format(label)
        trace_file = trace_path.format(label)
        cmd_file = invar_ic3_path.format(label)
        composition_file = composition_path.format(label)

        # write cmd file
        with open(cmd_file, 'w') as out:
            cmd_templ = ic3_cmd_template.format(trace_file)
            if label == "bench_2a" or label == "10-bouncing_ball_harmonic":
                cmd_templ = bmc_cmd_template.format(trace_file)
            out.write(cmd_templ)

        # load test file.
        test_spec = spec_from_file_location("test", test_file)
        test_module = module_from_spec(test_spec)
        test_spec.loader.exec_module(test_module)

        # execute test function.
        print("\n\nrun {}/{}: `{}`".format(test_num + 1, num_tests, label))
        res = test_module.test(nuxmv, model_file, trace_file, cmd_file,
                               composition_file)
        # yappi.stop()
        if res is True:
            result_str = "SUCCESS"
        elif res is False:
            result_str = "FAILURE"
        else:
            result_str = "UNKNOWN"
        if clear:
            try:
                os.remove(model_file)
                os.remove(trace_file)
                os.remove(cmd_file)
                os.remove(composition_file)
            except FileNotFoundError:
                pass
        print("end of `{}`: {}".format(label, result_str))

        # delete all previously generated symbols and expressions.
        reset_env()
    # func_stats = yappi.get_func_stats()
    # func_stats.save("callgrind.out", "CALLGRIND")
    # yappi.clear_stats()


if __name__ == "__main__":
    main(getopts())
