import os
import subprocess
import time
import signal

from automaton import Automaton
from automata_composition import smv_composition, AGAutomaton
from parse_trace import Trace
from check_composition import check_composition
from utils import to_next, Solver, solve_with_timeout, reset_after_timeout, \
    get_solvers, set_solver
from smv_prefixes import LOC_VAR, LOC_ID_PREFIX, IS_PREFIX, ENABLED, \
    LOOPBACK, MODULE_INST_PREF, FAIR, PREFIX_LENGTH, TRANS_LABEL

from pysmt.shortcuts import Not, And, Or, TRUE, FALSE
from pysmt.fnode import FNode
from pysmt.exceptions import SolverReturnedUnknownResultError


TIMEOUT = 15
def set_timeout(val):
    global TIMEOUT
    TIMEOUT = val


NUXMV_TIMEOUT = 250
def set_nuxmv_timeout(val):
    global NUXMV_TIMEOUT
    NUXMV_TIMEOUT = val


TRY_SYNC_PRODUCT = False
def set_try_sync_product(val):
    global TRY_SYNC_PRODUCT
    TRY_SYNC_PRODUCT = val


def run_nuxmv(exe: str, model_file: str, cmd_file: str) -> int:
    """Run nuXmv executable `exe` on the given model with the commands
    `exe_cmd`"""
    global NUXMV_TIMEOUT
    assert os.path.isfile(exe), "Not a file: {}".format(exe)
    assert os.path.isfile(model_file), "Not a file: {}".format(model_file)
    assert os.path.isfile(cmd_file), "Not a file: {}".format(cmd_file)

    cmd = ["runlim", "-o", "/dev/null", "--time-limit={}".format(NUXMV_TIMEOUT),
           exe, "-source", cmd_file, model_file]
    res = subprocess.run(cmd,
                         stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
    if res.returncode != 0:
        print("Command: {}".format(" ".join(res.args)))
        print("stdout:\n{}".format(res.stdout))
        return False

    return True


def composition_from_trace(trace: Trace, _automata: list) -> tuple:
    """Returns a triple.
    The first element is an `Automaton` corresponding to the composition
    described by trace, None if the composition is not valid.
    The second element is the list of enabled automata.
    The third element is a list of violations of the assume guarantee.
    Every element of the list describes an incompatibility,
    each incompatibility if a list of triples: automaton, location_id,
    trans_id"""
    def get_loc(frame: dict, a: AGAutomaton) -> int:
        loc = "{}{}.{}".format(MODULE_INST_PREF, a.name, LOC_VAR)
        return int(frame[loc][len(LOC_ID_PREFIX):])

    def get_trans_id(frame: dict, a: AGAutomaton) -> int:
        return frame["{}{}.{}".format(MODULE_INST_PREF, a.name, TRANS_LABEL)]

    if len(trace) == 0:
        return None
    init_frame = trace[0]
    # consider only enabled automata.
    automata = [a for a in _automata
                if init_frame["{}{}.{}".format(MODULE_INST_PREF,
                                               a.name, ENABLED)]]
    assert len(automata) > 0
    symbols = automata[0].all_symbols
    symb_to_aut = {}
    for a in automata:
        for s in a.local_symbols:
            assert s not in symb_to_aut
            symb_to_aut[s] = a

    # skip prefix
    first_frame_idx = int(init_frame[PREFIX_LENGTH])
    assert not trace[first_frame_idx][IS_PREFIX]
    # last frame is equal to loop-back frame
    num_locations = len(trace) - first_frame_idx - 1
    aut = Automaton("main", symbols, num_locations)
    errors = []
    undefs = []

    loopback_loc_idx = None
    assert len(trace[first_frame_idx:-1]) == aut.num_locs
    with Solver() as solver:
        for frame, loc_idx in zip(trace[first_frame_idx:-1],
                                  range(aut.num_locs)):
            assert not frame[IS_PREFIX]

            if frame[LOOPBACK]:
                if loopback_loc_idx is None:
                    loopback_loc_idx = loc_idx

            if frame[FAIR]:
                aut.add_fair(loc_idx)

            next_loc_idx = (loc_idx + 1 if loc_idx < aut.num_locs - 1
                            else loopback_loc_idx)
            assert next_loc_idx is not None, \
                "loc: {}, loopback: {}".format(loc_idx, loopback_loc_idx)
            next_frame = trace[first_frame_idx + next_loc_idx]

            # extract location of each automaton in `automata` in current frame
            curr_locs = [get_loc(frame, a) for a in automata]

            # build invariant associated with current location of composition
            invar = And(*[And(a.get_invar(l), a.get_assume(l))
                          for a, l in zip(automata, curr_locs)])
            aut.set_invar(loc_idx, invar)

            # build transition
            trans = And(*[a.get_transition(get_loc(frame, a),
                                           get_loc(next_frame, a),
                                           get_trans_id(frame, a))
                          for a in automata])
            trans = {next_loc_idx: [trans]}
            aut.set_trans(loc_idx, trans)

            solver.push()
            solver.add_assertion(invar)

            # check assume guarantee
            for a in automata:
                solver.push()
                assume = a.get_assume(get_loc(next_frame, a))
                assume_symbs = assume.get_free_variables()
                # here we consider only automata directly involved in the assume.
                related_automata = [symb_to_aut[s] for s in assume_symbs]
                assume = to_next(assume, assume_symbs)
                solver.add_assertion(Not(assume))
                aut_loc_trans_tuples = [(a, get_loc(frame, a),
                                         get_trans_id(frame, a))]
                # get transitions,
                for rel_a in related_automata:
                    rel_a_loc = get_loc(frame, rel_a)
                    rel_a_n_loc = get_loc(next_frame, rel_a)
                    rel_a_trans_id = get_trans_id(frame, rel_a)
                    aut_loc_trans_tuples.append((rel_a, rel_a_loc,
                                                 rel_a_trans_id))
                    rel_a_trans = rel_a.get_transition(rel_a_loc, rel_a_n_loc,
                                                       rel_a_trans_id)
                    solver.add_assertion(rel_a_trans)
                try:
                    res = solve_with_timeout(TIMEOUT, solver)
                except SolverReturnedUnknownResultError:
                    res = None
                if res is None:
                    reset_after_timeout(solver, TRUE(), invar)
                    undefs.append(solver.assertions)
                elif res:
                    m = solver.get_model()
                    m = ["{} : {}".format(s, v) for s, v in m]
                    print("\tAssumption `{}` of `{}` not satisfied\n"
                          "\tSMT-Assertions: {}\n"
                          "\tModel:\n\t{}"
                          .format(assume, a.name, solver.assertions,
                                  "\n\t".join(m)))
                    errors.append(aut_loc_trans_tuples)
                solver.pop()  # remove Not(assume) and trans
            solver.pop()  # remove region

    return aut, automata, errors, undefs


def find_composition(automata: list, init: FNode, trans: FNode,
                     fairness: FNode, nuxmv_exe: str, model_file: str,
                     trace_file: str, cmd_file: str,
                     check_automata=True) -> Automaton:
    def smv_constr_from_error(error: list) -> str:
        """Build !(\bigwedge_i aut_i.enabled & aut_i.loc = loc_i &
                               aut_i.trans = trans_i)"""
        res = "("
        for a, loc, trans_id in error:
            mod_inst = "{}{}".format(MODULE_INST_PREF, a.name)
            disabled = "!{}.{}".format(mod_inst, ENABLED)
            loc = "{}.{} != {}{}".format(mod_inst, LOC_VAR,
                                         LOC_ID_PREFIX, loc)
            trans_label = "{}.{} != {}".format(mod_inst, TRANS_LABEL,
                                               trans_id)
            res += "{} | {} | {} | ".format(disabled, loc, trans_label)
        return res[:-2] + ")"
    trans = trans.simplify()
    fairness = fairness.simplify()

    if check_automata:
        # check correctness of AG-Abstractions
        print("\tCheck correctness of AG-Abstractions", flush=True)
        undefs = []
        check_time = 0
        for a in automata:
            segfault = False
            # we catch segfault since sometimes z3 crashes.
            def sig_handler(signum, frame):
                segfault = True
            signal.signal(signal.SIGSEGV, sig_handler)

            # if one solver fails try the next one.
            for solver in get_solvers():
                print("\t\tCheck `{}` using SMT-solver: {}"
                      .format(a.name, solver), flush=True)
                set_solver(solver)
                start = time.time()
                res = a.is_correct()
                check_time += time.time() - start
                if res is not None and not segfault:
                    print("\t\tDone", flush=True)
                    break
                if segfault:
                    print("Solver segmentation fault")
                segfault = False
                print("\t\tFailed, try next solver if any", flush=True)

            if res is not True:
                if res is None:
                    undefs.append(a.name)
                else:
                    print("\t\tAG-Abstraction {} is not correct"
                          .format(a.name))
                    if res[0] == "UNSAT":
                        print("\t\tUNSAT transition from "
                              "location {} to location {}, via: {}\n"
                              .format(res[1][0], res[1][1], res[1][2]))
                    elif res[0] == "SAT":
                        print("\t\tHypothesis violation from {} to {}, "
                              "via: {}\n\t\tmodel: {}"
                              .format(res[1][0], res[1][1], res[1][2],
                                      res[2]))
                    return None, []

        if len(undefs) > 0:
            print("\tUnable to decide correctness of "
                  "AG-Abstractions: {}, in {}s\n"
                  .format(undefs, check_time))
        else:
            print("\tAG-Abstractions are correct, verified in {}s\n"
                  .format(check_time))

    for solver in get_solvers():  # if one solver fails try the next one.
        print("\tUsing SMT-solver: {}\n".format(solver))
        set_solver(solver)
        extra_trans = []
        comp = None
        undefs = None

        while comp is None:
            print("\tCreate reachability problem", flush=True)
            # write smv model to search composition.
            with open(model_file, 'w') as out_buf:
                start = time.time()
                # give also additional constraints for search.
                smv_composition(automata, init, fairness, trans, buf=out_buf,
                                extra_trans=extra_trans)
            print("\tReachability problem `{}` generated in {}s\n"
                  .format(model_file, time.time() - start))
            # run nuXmv.
            try:
                os.remove(trace_file)
            except OSError:
                pass
            print("\tRun nuXmv", flush=True)
            start = time.time()
            res = run_nuxmv(nuxmv_exe, model_file, cmd_file)
            if not res or not os.path.isfile(trace_file):
                print("\tnuXmv failed to identify a trace, terminated in: "
                      "{}s\n".format(time.time() - start))
                # reset cached fairness and transition masks.
                for aut in automata:
                    aut.reset_fairness_predicates()
                    aut.reset_transition_predicates()
                undefs = ["nuXmv"]
                break
            else:
                print("\tnuXmv returned in {}s\n".format(time.time() - start))
                # parse trace and build composed model.
                print("\tparse trace", flush=True)
                start = time.time()
                trace = Trace()
                with open(trace_file, 'r') as in_buf:
                    trace.parse_trace(in_buf)
                print("\tTrace `{}` parsed in {}s\n"
                      .format(trace_file, time.time() - start))

                print("\tcheck candidate composition assumptions", flush=True)
                start = time.time()
                comp, aut, errors, undefs = composition_from_trace(trace,
                                                                   automata)
                duration = time.time() - start
                for undef in undefs:
                    print("\tUnable to decide satisfiability of:\n{}"
                          .format(undef))
                print("\tCandidate composition checked in {}s"
                      .format(duration))
                print("\tautomata: {}\n".format([a.name for a in aut]))
                extra_trans.extend(smv_constr_from_error(error)
                                   for error in errors if error)
                if errors:
                    print("\tRefine search with: {}\n".format(extra_trans),
                          flush=True)
                    comp = None

        if len(undefs) == 0:
            # no undefined result: exit, otherwise try other solver.
            break

    if comp:
        print("\tAdditional correctness check on final result", flush=True)
        start = time.time()
        res = check_composition(comp, trans, fairness, aut)
        duration = time.time() - start
        if res is True:
            print("\tComposition VERIFIED, {}s".format(duration), flush=True)
        elif res is None:
            print("\tComposition NOT VERIFIED, {}s".format(duration),
                  flush=True)

        elif res is False:
            print("\tComposition WRONG, {}s".format(duration),
                  flush=True)

    global TRY_SYNC_PRODUCT
    if TRY_SYNC_PRODUCT:
        prod_trans_rel = TRUE()
        for a in aut:
            a_trans = encode_transition_rel(a)
            prod_trans_rel = And(prod_trans_rel, a_trans)

        for solver in get_solvers():
            set_solver(solver)
            with Solver() as solver:
                solver.add_assertion(Not(trans))
                solver.add_assertion(prod_trans_rel)
                try:
                    res = solve_with_timeout(TIMEOUT, solver)
                except SolverReturnedUnknownResultError:
                    res = None
                if res is None:
                    reset_after_timeout(solver, TRUE())
                else:
                    if res:
                        m = solver.get_model()
                        m = ["{} : {}".format(s, v) for s, v in m]
                        print("\tSynchronous product of selected AGAutomata "
                              "does not imply the transition relation: {}"
                              .format(m))
                    else:
                        print("\tSynchronous product of selected AGAutomata "
                              "implies the transition relation")
                    break

    return comp, undefs


def encode_transition_rel(a: AGAutomaton) -> FNode:
    r"""Return disjunction of:
    region /\ assume /\ trans /\ region' /\ assume'"""
    n_locs = a.num_locations
    symbs = a.all_symbols
    res = FALSE()
    for src in range(n_locs):
        region = a.get_invar(src)
        assume = a.get_assume(src)
        trans_pairs = a.get_transitions(src)
        for dst, trans_l in trans_pairs:
            x_region = to_next(a.get_invar(dst), symbs)
            x_assume = to_next(a.get_assume(dst), symbs)
            trans = Or(trans_l)
            curr_trans = And(region, assume, x_region, x_assume, trans)
            res = Or(res, curr_trans)
    return res
