from find_composition import find_composition
from automata_composition import AGAutomaton
from utils import symb_next

from pysmt.shortcuts import Symbol, TRUE, Int
from pysmt.shortcuts import And, Implies
from pysmt.shortcuts import Equals, LT
from pysmt.shortcuts import Plus
from pysmt.typing import INT


def test(nuxmv_path: str, model_file: str, trace_file: str, cmd_file: str,
         output_file: str) -> bool:
    """Test corresponding to benchmarks/counters.smv"""
    c0_bound = 13
    c1_bound = 3
    # symbols.
    c0 = Symbol("c0", INT)
    c1 = Symbol("c1", INT)
    x_c0 = Symbol(symb_next("c0"), INT)
    x_c1 = Symbol(symb_next("c1"), INT)
    symbols = [c0, c1]

    n_0 = Int(0)
    n_1 = Int(1)
    n_c0 = Int(c0_bound)
    n_c1 = Int(c1_bound)
    # initial location.
    init = And(Equals(c0, n_0), Equals(c1, n_0))

    # transition relation.
    trans = And(
        Implies(Equals(c0, n_c0), Equals(x_c0, n_0)),
        Implies(LT(c0, n_c0), Equals(x_c0, Plus(c0, n_1))),
        Implies(Equals(c1, n_c1), Equals(x_c1, n_0)),
        Implies(LT(c1, n_c1), Equals(x_c1, Plus(c1, n_1)))
    )

    # fairness.
    fairness = And(Equals(c0, n_0), Equals(c1, n_0))

    # define automata to be composed.
    aut_c0 = AGAutomaton(symbols, [c0], "aut_c0", c0_bound + 1)
    for l in range(aut_c0.num_locations):
        n_l = (l + 1) % aut_c0.num_locations
        aut_c0.set_assume(l, TRUE())
        aut_c0.set_invar(l, Equals(c0, Int(l)))
        aut_c0.set_transitions(l, [(n_l, [Equals(x_c0, Int(n_l))])])

    aut_c1 = AGAutomaton(symbols, [c1], "aut_c1", c1_bound + 1)
    for l in range(aut_c1.num_locations):
        n_l = (l + 1) % aut_c1.num_locations
        aut_c1.set_assume(l, TRUE())
        aut_c1.set_invar(l, Equals(c1, Int(l)))
        aut_c1.set_transitions(l, [(n_l, [Equals(x_c1, Int(n_l))])])

    automata = [aut_c0, aut_c1]

    # search composition.
    comp, undefs = find_composition(automata, init, trans, fairness,
                                    nuxmv_path, model_file, trace_file,
                                    cmd_file)
    if comp is not None:
        with open(output_file, 'w') as out:
            out.write(str(comp))

    if comp and not undefs:
        res = True
    elif not comp:
        res = False
    else:
        assert not comp and undefs
        res = None
    return res
