from find_composition import find_composition
from automata_composition import AGAutomaton
from utils import symb_next

from pysmt.shortcuts import Symbol, TRUE, Int
from pysmt.shortcuts import Not, And, Or, Implies
from pysmt.shortcuts import GT, GE, Equals, LE, LT
from pysmt.shortcuts import Plus, Minus, Times
from pysmt.typing import INT


def test(nuxmv_path: str, model_file: str, trace_file: str, cmd_file: str,
         output_file: str) -> bool:
    """Test corresponding to benchmarks/lasso_example3.smv"""
    # symbols.
    i = Symbol("i", INT)
    j = Symbol("j", INT)
    k = Symbol("k", INT)
    pc = Symbol("pc", INT)
    x_i = Symbol(symb_next("i"), INT)
    x_j = Symbol(symb_next("j"), INT)
    x_k = Symbol(symb_next("k"), INT)
    x_pc = Symbol(symb_next("pc"), INT)
    symbols = [i, j, k, pc]

    # initial location.
    init = Equals(pc, Int(0))

    # control flow graph.
    cfg = And(
        # pc = -1 : -1,
        Implies(Equals(pc, Int(-1)), Equals(x_pc, Int(-1))),
        # pc = 0 & !(j >= 1) : -1,
        Implies(And(Equals(pc, Int(0)), Not(GE(j, Int(1)))),
                Equals(x_pc, Int(-1))),
        # pc = 0 & j >= 1 : 1,
        Implies(And(Equals(pc, Int(0)), GE(j, Int(1))), Equals(x_pc, Int(1))),
        # pc = 1 & !(k >= 1) : -1,
        Implies(And(Equals(pc, Int(1)), Not(GE(k, Int(1)))),
                Equals(x_pc, Int(-1))),
        # pc = 1 & k >= 1 : 2,
        Implies(And(Equals(pc, Int(1)), GE(k, Int(1))), Equals(x_pc, Int(2))),
        # pc = 2 & !(i >= 0) : -1,
        Implies(And(Equals(pc, Int(2)), Not(GE(i, Int(0)))),
                Equals(x_pc, Int(-1))),
        # pc = 2 & i >= 0 : 3,
        Implies(And(Equals(pc, Int(2)), GE(i, Int(0))), Equals(x_pc, Int(3))),
        # pc = 3 : 4,
        Implies(Equals(pc, Int(3)), Equals(x_pc, Int(4))),
        # pc = 4 : 5,
        Implies(Equals(pc, Int(4)), Equals(x_pc, Int(5))),
        # pc = 5 : 2,
        Implies(Equals(pc, Int(5)), Equals(x_pc, Int(2)))
    )

    # transition labels.
    labels = And(
        # (pc = -1 & pc' = -1) -> (i' = i & j' = j & k' = k),
        Implies(And(Equals(pc, Int(-1)), Equals(x_pc, Int(-1))),
                And(Equals(x_i, i), Equals(x_j, j), Equals(x_k, k))),
        # (pc = 0 & pc' = -1) -> (i' = i & j' = j & k' = k),
        Implies(And(Equals(pc, Int(0)), Equals(x_pc, Int(-1))),
                And(Equals(x_i, i), Equals(x_j, j), Equals(x_k, k))),
        # (pc = 0 & pc' = 1)  -> (i' = i & j' = j & k' = k),
        Implies(And(Equals(pc, Int(0)), Equals(x_pc, Int(1))),
                And(Equals(x_i, i), Equals(x_j, j), Equals(x_k, k))),
        # (pc = 1 & pc' = -1) -> (i' = i & j' = j & k' = k),
        Implies(And(Equals(pc, Int(1)), Equals(x_pc, Int(-1))),
                And(Equals(x_i, i), Equals(x_j, j), Equals(x_k, k))),
        # (pc = 1 & pc' = 2)  -> (i' = i & j' = j & k' = k),
        Implies(And(Equals(pc, Int(1)), Equals(x_pc, Int(2))),
                And(Equals(x_i, i), Equals(x_j, j), Equals(x_k, k))),
        # (pc = 2 & pc' = -1) -> (i' = i & j' = j & k' = k),
        Implies(And(Equals(pc, Int(2)), Equals(x_pc, Int(-1))),
                And(Equals(x_i, i), Equals(x_j, j), Equals(x_k, k))),
        # (pc = 2 & pc' = 3)  -> (i' = i & j' = j & k' = k),
        Implies(And(Equals(pc, Int(2)), Equals(x_pc, Int(3))),
                And(Equals(x_i, i), Equals(x_j, j), Equals(x_k, k))),
        # (pc = 3 & pc' = 4)  -> (i' = j*k & j' = j & k' = k),
        Implies(And(Equals(pc, Int(3)), Equals(x_pc, Int(4))),
                And(Equals(x_i, Times(j, k)), Equals(x_j, j), Equals(x_k, k))),
        # (pc = 3 & pc' = 5)  -> (i' = i & j' = j+1 & k' = k),
        Implies(And(Equals(pc, Int(3)), Equals(x_pc, Int(5))),
                And(Equals(x_i, i), Equals(x_j, Plus(j, Int(1))),
                    Equals(x_k, k))),
        # (pc = 5 & pc' = 2)  -> (i' = i & j' = j & k' = k+1),
        Implies(And(Equals(pc, Int(5)), Equals(x_pc, Int(2))),
                And(Equals(x_i, i), Equals(x_j, j),
                    Equals(x_k, Plus(k, Int(1)))))
    )

    # transition relation.
    trans = And(cfg, labels)

    # fairness.
    fairness = Not(Equals(pc, Int(-1)))

    # define automata to be composed.
    aut_i = AGAutomaton(symbols, [i], "aut_i", 1)
    aut_i.set_assume(0, And(GE(j, Int(0)), GE(k, Int(0))))
    aut_i.set_invar(0, GE(i, Int(0)))
    aut_i.set_transitions(0, [(0, [Equals(x_i, i),
                                   Equals(x_i, Times(j, k))])])

    aut_j = AGAutomaton(symbols, [j], "aut_j", 1)
    aut_j.set_assume(0, TRUE())
    aut_j.set_invar(0, GE(j, Int(0)))
    aut_j.set_transitions(0, [(0, [Equals(x_j, j),
                                   Equals(x_j, Plus(j, Int(1)))])])

    aut_k = AGAutomaton(symbols, [k], "aut_k", 1)
    aut_k.set_assume(0, TRUE())
    aut_k.set_invar(0, GE(k, Int(0)))
    aut_k.set_transitions(0, [(0, [Equals(x_k, k),
                                   Equals(x_k, Plus(k, Int(1)))])])

    aut_pc = AGAutomaton(symbols, [pc], "aut_pc", 4)
    for loc in range(aut_pc.num_locations):
        n_loc = (loc + 1) % aut_pc.num_locations
        c_pc = Int(loc + 2)
        n_pc = Int(n_loc + 2)
        aut_pc.set_assume(loc, TRUE())
        aut_pc.set_invar(loc, Equals(pc, c_pc))
        aut_pc.set_transitions(loc, [(n_loc, [Equals(x_pc, n_pc)])])

    automata = [aut_i, aut_j, aut_k, aut_pc]

    # search composition.
    comp, undefs = find_composition(automata, init, trans, fairness,
                                    nuxmv_path, model_file, trace_file,
                                    cmd_file)
    if comp is not None:
        with open(output_file, 'w') as out:
            out.write(str(comp))

    if comp and not undefs:
        res = True
    elif not comp:
        res = False
    else:
        assert not comp and undefs
        res = None
    return res
