#! /usr/bin/env python3

from automaton import Automaton
from utils import symb_next, to_next, Solver, solve_with_timeout, \
     reset_after_timeout, to_smt2, get_verbosity


from pysmt.shortcuts import TRUE
from pysmt.shortcuts import Not, And
from pysmt.shortcuts import Exists, ForAll, qelim
from pysmt.fnode import FNode
from pysmt.exceptions import SolverReturnedUnknownResultError, \
    NoSolverAvailableError, NoLogicAvailableError


TIMEOUT = 5
def set_timeout(val):
    global TIMEOUT
    TIMEOUT = val


def check_composition(composition: Automaton, transition: FNode, fair: FNode,
                      automata: list):
    """Check whether `composition` has a non-empty language wrt `fair` and is
    an underapproximation of `transition`."""
    res = True

    # check regions closed wrt transition relation.
    errors, undefs = check_region_closed(composition)

    if undefs:
        res = None
    if errors:
        res = False
    for src, trans, dst, model in errors:
        print("src region: `{}`, transition `{}`, dst region: `{}` "
              "is not closed:\n{}\n".format(src, trans, dst, model))
    if get_verbosity():
        for src, trans, dst in undefs:
            print("Could not decide whether `{}`\n  "
                  "is closed relative to src:\n`{}`\n  dst:\n"
                  "`{}`\n".format(trans, src, dst))

    # check composition under-approximates transition relation
    errors, undefs = check_is_underapprox(composition, transition)

    if res and undefs:
        res = None
    if errors:
        res = False
    for src, trans, model in errors:
        print("`{}` does not under-approximate `{}`, counterexample: `{}`\n"
              .format(And(src, trans), transition, model))
    if get_verbosity():
        for src, trans in undefs:
            print("Could not decide whether `{}` underapproximates `{}`\n"
                  .format(And(src, trans), transition))

    # try to prove existence of successor via quantifier elimination
    errors, undefs = check_successor_existence(composition)

    if errors:
        res = False
    for src, trans, model in errors:
        print("Transition `{}` from region: `{}` not always admits "
              "successor:\n{}\n".format(trans, src, model))
    # previous method is often unable to provide answers,
    # in such cases we look for an in-direct proof: if automata satisfy the hypothesis,
    # then every valid composition satisfies the existence of a successor.
    if undefs:
        # check `composition` under-approximates a composition of automata.
        (loc_errors, trans_errors, loc_undefs,
         trans_undefs) = check_composition_existence(composition,
                                                     automata)

        if res and (loc_undefs or trans_undefs):
            res = None
        if loc_errors or trans_errors:
            res = False
        for src in loc_errors:
            print("Region `{}` cannot be expressed as composition".format(src))
        for trans in trans_errors:
            print("Transition `{}` cannot be expressed as composition"
                  .format(trans))
        if get_verbosity():
            for src in loc_undefs:
                print("Could not decide whether `{}` is expressible as composition"
                      .format(src))
            for trans in trans_undefs:
                print("Transition `{}` cannot be expressed as composition"
                      .format(trans))

        # ensure transitions and regions are not empty
        errors, undefs = check_not_empty(composition)

        if res and undefs:
            res = None
        if errors:
            res = False
        for src, trans in errors:
            print("Transition `{}` from region `{}` is unsat"
                  .format(trans, src))
        if get_verbosity():
            for src, trans in undefs:
                print("Could not decide whether `{}` is empty relative to "
                      "`{}`\n".format(trans, src))

    # check that every location can reach a fair loop.
    errors, undefs = check_finally_fair_loop(composition, fair)

    if res and undefs:
        res = None
    if errors:
        res = False
    for err in errors:
        if isinstance(err, int):
            print("No outgoing transition from location {}, region: `{}`"
                  .format(err, composition.get_invar(err)))
        elif isinstance(err, tuple):
            print("Fair location {} does not imply fairness: "
                  "The following state is not fair:\n{}\n"
                  .format(err[0], err[1]))
        else:
            print("The following SCC does not include any fair region: {}"
                  .format(err))
    if get_verbosity():
        for undef in undefs:
            if isinstance(undef, int):
                print("Could not verify that location {} implies the fairness"
                      .format(undef))
            else:
                print("Could not decide whether the SCC contains a fair "
                      "region: {}".format(undef))

    return res


def check_region_closed(a: Automaton) -> tuple:
    """For all transitions `trans` starting from a region `src_region`
    `trans` must reach a state which is still in one of the regions
    unsat(src_region & trans -> union_regions') =>
    usat(src_region & trans & intersect_not_regions').
    Here we check that each src - trans -> dst is such that
    unsat(src & trans & ! dst)"""
    errors = []
    undef = []
    # not_next_regions = And(*[Not(to_next(a.get_invar(i), a.symbols))
    #                          for i in range(a.num_locs)])
    with Solver() as solver:
        # solver.add_assertion(not_next_regions)
        for src_idx in range(a.num_locs):
            solver.push()
            src_region = a.get_invar(src_idx)
            solver.add_assertion(src_region)
            for dst_idx, trans_list in a.get_trans_from(src_idx).items():
                dst_region = to_next(a.get_invar(dst_idx), a.symbols)
                solver.push()
                solver.add_assertion(Not(dst_region))
                for trans in trans_list:
                    solver.push()
                    solver.add_assertion(trans)
                    try:
                        res = solve_with_timeout(TIMEOUT, solver)
                    except SolverReturnedUnknownResultError:
                        res = None
                    if res is None:
                        reset_after_timeout(solver, TRUE(), src_region,
                                            Not(dst_region), trans)
                        undef.append((src_region, trans, dst_region))
                        if get_verbosity():
                            print('UNKNOWN:\n', to_smt2(src_region, trans,
                                                        Not(dst_region)))
                    elif res:
                        errors.append((src_region, trans, dst_region,
                                       solver.get_model()))
                    solver.pop()  # remove trans
                solver.pop()  # remove !dst_region
            solver.pop()  # remove src_region
    return errors, undef


def check_is_underapprox(a: Automaton, transition: FNode) -> tuple:
    """For all transitions `trans` starting from a region `src_region`
    `trans`: valid(src_region & trans -> transition)"""
    errors = []
    undefs = []
    with Solver() as solver:
        solver.add_assertion(Not(transition))
        for loc_idx in range(a.num_locs):
            solver.push()
            region = a.get_invar(loc_idx)
            solver.add_assertion(region)
            for _, trans_list in a.get_trans_from(loc_idx).items():
                for trans in trans_list:
                    solver.push()
                    solver.add_assertion(trans)
                    try:
                        res = solve_with_timeout(TIMEOUT, solver)
                    except SolverReturnedUnknownResultError:
                        res = None
                    if res is None:
                        reset_after_timeout(solver, Not(transition),
                                            region)
                        undefs.append((region, trans))
                    elif res:
                        errors.append((region, trans, solver.get_model()))
                    solver.pop()  # remove trans
            solver.pop()  # remove region.

    return errors, undefs


def check_successor_existence(a: Automaton) -> tuple:
    """Check that every transition, for every state in the source region admits
    a successor state.
    Try to solve the quantifier alternation."""
    errors = []
    undef = []
    next_symbols = [symb_next(s) for s in a.symbols]
    with Solver() as solver:
        for src_idx in range(a.num_locs):
            src_region = a.get_invar(src_idx)
            for _, trans_list in a.get_trans_from(src_idx).items():
                for trans in trans_list:
                    constr = And(src_region, trans)

                    constr = Exists(next_symbols, constr)
                    constr = ForAll(a.symbols, constr)
                    try:
                        constr = qelim(constr)
                        res = solve_with_timeout(TIMEOUT, solver, [constr])
                        if res is None:
                            solver.reset_assertions()
                            undef.append((src_region, trans))
                        elif not res:
                            errors.append((src_region, trans))
                    except (SolverReturnedUnknownResultError,
                            NoSolverAvailableError, NoLogicAvailableError,
                            AttributeError):
                        undef.append((src_region, trans))
    return errors, undef


def check_composition_existence(a: Automaton, automata: list) -> tuple:
    r"""Check that there exist a composition of `automata` that
    overapproximates `a`,
    if the automata satisfy the existence of a successor for every state
    in the regions, then this implies their composition does.
    For every automata `i` we have `region_i` and a transition relation `f_i`,
    we prove that the composition is given by `region`, `f` such that
    `region \impl \big_wedge_{i} region_i` and
    `f \impl `big_wedge_{i} f_i`.
    If for all `i`, for all `s in region_i` exists `s'` such that `f_i(s, s')`
    holds, then for all `s in region` exists `s'` such that `f(s, s')`.
    This relies on the fact that every f_i(s, s'_i) is such that
    for every i,j s_i \cup s_i = \emptyset.
    """
    loc_errs = []
    trans_errs = []
    loc_undefs = []
    trans_undefs = []
    with Solver() as solver:
        # for each region of `a` compute list of possible regions of the
        # automata in `automata` that contain the region of `a`.
        for src_idx in range(a.num_locs):
            src_region = a.get_invar(src_idx)
            solver.push()
            solver.add_assertion(src_region)

            # collect candidate regions for each automata
            candidates_list = [[] for _ in automata]
            for idx, c in enumerate(automata):
                is_undef = False
                for c_src_idx in range(c.num_locations):
                    solver.push()
                    c_region = And(c.get_invar(c_src_idx),
                                   c.get_assume(c_src_idx))
                    solver.add_assertion(Not(c_region))
                    try:
                        res = solve_with_timeout(TIMEOUT, solver)
                    except SolverReturnedUnknownResultError:
                        res = None
                    if res is None:
                        reset_after_timeout(solver, TRUE(), src_region)
                        is_undef = True
                    elif not res:
                        candidates_list[idx].append(c_src_idx)
                    solver.pop()  # pop c_region

                if not candidates_list[idx]:
                    if is_undef:
                        loc_undefs.append(src_region)
                    else:
                        loc_errs.append(src_region)
            solver.pop()

            assert len(automata) == len(candidates_list)

            # check existence of transition composition
            for _, trans_list in a.get_trans_from(src_idx).items():
                for trans in trans_list:
                    solver.push()
                    solver.add_assertion(trans)
                    for c, candidates in zip(automata, candidates_list):
                        assert candidates
                        found_trans = False
                        for c_loc_idx in candidates:
                            c_trans_lists = c.get_transitions(c_loc_idx)
                            is_undef = False
                            for _, c_trans_list in c_trans_lists:
                                for c_trans in c_trans_list:
                                    solver.push()
                                    solver.add_assertion(Not(c_trans))
                                    try:
                                        res = solve_with_timeout(TIMEOUT,
                                                                 solver)
                                    except SolverReturnedUnknownResultError:
                                        res = None
                                    if res is None:
                                        reset_after_timeout(solver,
                                                            TRUE(), trans)
                                        is_undef = True
                                    elif not res:
                                        found_trans = True
                                    solver.pop()  # pop c_trans
                        if not found_trans:
                            if is_undef:
                                trans_undefs.append(trans)
                            else:
                                trans_errs.append(trans)
                            break  # go to next trans
                    solver.pop()  # pop trans

    return loc_errs, trans_errs, loc_undefs, trans_undefs


def check_not_empty(a: Automaton) -> tuple:
    """Ensure that every transition is satisfiable"""
    errors = []
    undefs = []
    with Solver() as solver:
        for src_idx in range(a.num_locs):
            src_region = a.get_invar(src_idx)
            solver.push()
            solver.add_assertion(src_region)

            for _, trans_list in a.get_trans_from(src_idx).items():
                for trans in trans_list:
                    solver.push()
                    solver.add_assertion(trans)
                    try:
                        res = solve_with_timeout(TIMEOUT, solver)
                    except SolverReturnedUnknownResultError:
                        res = None
                    if res is None:
                        reset_after_timeout(solver, TRUE(), src_region)
                        undefs.append((src_region, trans))
                    elif not res:
                        errors.append((src_region, trans))
                    solver.pop()  # pop trans
            solver.pop()  # pop src_region
    return errors, undefs


def check_finally_fair_loop(a: Automaton, fair: FNode) -> tuple:
    """Visit the location graph of the automaton, ensure that every path leads
    to a loop and it is always possible to reach a fair location"""
    errors = []
    undefs = []
    # all locations must have at least 1 outgoing transition: no dead-end.
    errors.extend([loc for loc in range(a.num_locs)
                   if not a.get_trans_from(loc).items()])

    # exists at least 1 fair location.
    assert a.fairness, "Automaton must have at least one fair location"
    # fair locations imply fairness constraint `fair`
    with Solver() as solver:
        solver.add_assertion(Not(fair))

        for fair_loc in a.fairness:
            solver.push()
            solver.add_assertion(a.get_invar(fair_loc))
            try:
                res = solve_with_timeout(TIMEOUT, solver)
            except SolverReturnedUnknownResultError:
                res = None
            if res is None:
                reset_after_timeout(solver, Not(fair))
                undefs.append(fair_loc)
            elif res:
                errors.append((fair_loc, solver.get_model()))
            solver.pop()

    # check that every SCC contains a fair location
    errors.extend([scc for scc in a.strongly_connected_components()
                   if not any(l in scc for l in a.fairness)])

    return errors, undefs
