from find_composition import find_composition
from automata_composition import AGAutomaton
from utils import symb_next

from pysmt.shortcuts import Symbol, TRUE, Int
from pysmt.shortcuts import Not, And, Or, Implies
from pysmt.shortcuts import GT, GE, Equals, LE, LT
from pysmt.shortcuts import Plus, Times
from pysmt.typing import INT


def test(nuxmv_path: str, model_file: str, trace_file: str, cmd_file: str,
         output_file: str) -> bool:
    """Test corresponding to benchmarks/10.smv"""
    # symbols.
    a = Symbol("a", INT)
    b = Symbol("b", INT)
    pc = Symbol("pc", INT)
    x = Symbol("x", INT)
    y = Symbol("y", INT)
    z = Symbol("z", INT)
    x_a = Symbol(symb_next("a"), INT)
    x_b = Symbol(symb_next("b"), INT)
    x_pc = Symbol(symb_next("pc"), INT)
    x_x = Symbol(symb_next("x"), INT)
    x_y = Symbol(symb_next("y"), INT)
    x_z = Symbol(symb_next("z"), INT)
    symbols = [a, b, pc, x, y, z]

    # initial location.
    init = Equals(pc, Int(0))

    # control flow graph.
    cfg = And(
        # pc = -1 : -1,
        Implies(Equals(pc, Int(-1)), Equals(x_pc, Int(-1))),
        # pc = 0 & !(x >= 1) : -1,
        Implies(And(Equals(pc, Int(0)), Not(GE(x, Int(1)))),
                Equals(x_pc, Int(-1))),
        # pc = 0 & x >= 1 : 1,
        Implies(And(Equals(pc, Int(0)), GE(x, Int(1))), Equals(x_pc, Int(1))),
        # pc = 1 & !(b < 0) : -1,
        Implies(And(Equals(pc, Int(1)), Not(LT(b, Int(0)))),
                Equals(x_pc, Int(-1))),
        # pc = 1 & b < 0 : 2,
        Implies(And(Equals(pc, Int(1)), LT(b, Int(0))), Equals(x_pc, Int(2))),
        # pc = 2 & !(a >= 0) : -1,
        Implies(And(Equals(pc, Int(2)), Not(GE(a, Int(0)))),
                Equals(x_pc, Int(-1))),
        # pc = 2 & a >= 0 : 3,
        Implies(And(Equals(pc, Int(2)), GE(a, Int(0))), Equals(x_pc, Int(3))),
        # pc = 3 & !(x >= y & z < 42) : -1,
        Implies(And(Equals(pc, Int(3)), Not(And(GE(x, y), LT(z, Int(42))))),
                Equals(x_pc, Int(-1))),
        # pc = 3 & (x >= y & z < 42) : 4,
        Implies(And(Equals(pc, Int(3)), And(GE(x, y), LT(z, Int(42)))),
                Equals(x_pc, Int(4))),
        # pc = 4 : {5, 7},
        Implies(Equals(pc, Int(4)),
                Or(Equals(x_pc, Int(5)), Equals(x_pc, Int(7)))),
        # pc = 5 : 6,
        Implies(Equals(pc, Int(5)), Equals(x_pc, Int(6))),
        # pc = 6 : 3,
        Implies(Equals(pc, Int(6)), Equals(x_pc, Int(3))),
        # pc = 7 : 8,
        Implies(Equals(pc, Int(7)), Equals(x_pc, Int(8))),
        # pc = 8 : 9,
        Implies(Equals(pc, Int(8)), Equals(x_pc, Int(9))),
        # pc = 9 : 3,
        Implies(Equals(pc, Int(9)), Equals(x_pc, Int(3)))
    )

    # transition labels.
    labels = And(
        # (pc = -1 & pc' = -1) -> (a' = a & b' = b & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(-1)), Equals(x_pc, Int(-1))),
                And(Equals(x_a, a), Equals(x_b, b), Equals(x_x, x),
                    Equals(x_y, y), Equals(x_z, z))),
        # (pc = 0 & pc' = -1) -> (a' = a & b' = b & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(0)), Equals(x_pc, Int(-1))),
                And(Equals(x_a, a), Equals(x_b, b), Equals(x_x, x),
                    Equals(x_y, y), Equals(x_z, z))),
        # (pc = 0 & pc' = 1)  -> (a' = a & b' = b & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(0)), Equals(x_pc, Int(1))),
                And(Equals(x_a, a), Equals(x_b, b), Equals(x_x, x),
                    Equals(x_y, y), Equals(x_z, z))),
        # (pc = 1 & pc' = -1) -> (a' = a & b' = b & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(1)), Equals(x_pc, Int(-1))),
                And(Equals(x_a, a), Equals(x_b, b), Equals(x_x, x),
                    Equals(x_y, y), Equals(x_z, z))),
        # (pc = 1 & pc' = 2)  -> (a' = a & b' = b & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(1)), Equals(x_pc, Int(2))),
                And(Equals(x_a, a), Equals(x_b, b), Equals(x_x, x),
                    Equals(x_y, y), Equals(x_z, z))),
        # (pc = 2 & pc' = -1) -> (a' = a & b' = b & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(2)), Equals(x_pc, Int(-1))),
                And(Equals(x_a, a), Equals(x_b, b), Equals(x_x, x),
                    Equals(x_y, y), Equals(x_z, z))),
        # (pc = 2 & pc' = 3)  -> (a' = a & b' = b & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(2)), Equals(x_pc, Int(3))),
                And(Equals(x_a, a), Equals(x_b, b), Equals(x_x, x),
                    Equals(x_y, y), Equals(x_z, z))),
        # (pc = 3 & pc' = -1) -> (a' = a & b' = b & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(3)), Equals(x_pc, Int(-1))),
                And(Equals(x_a, a), Equals(x_b, b), Equals(x_x, x),
                    Equals(x_y, y), Equals(x_z, z))),
        # (pc = 3 & pc' = 4)  -> (a' = a & b' = b & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(3)), Equals(x_pc, Int(4))),
                And(Equals(x_a, a), Equals(x_b, b), Equals(x_x, x),
                    Equals(x_y, y), Equals(x_z, z))),
        # (pc = 4 & pc' = 5)  -> (a' = a & b' = b & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(4)), Equals(x_pc, Int(5))),
                And(Equals(x_a, a), Equals(x_b, b), Equals(x_x, x),
                    Equals(x_y, y), Equals(x_z, z))),
        # (pc = 4 & pc' = 7)  -> (a' = a & b' = b & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(4)), Equals(x_pc, Int(7))),
                And(Equals(x_a, a), Equals(x_b, b), Equals(x_x, x),
                    Equals(x_y, y), Equals(x_z, z))),
        # (pc = 5 & pc' = 6)  -> (a' = a & b' = b & x' = 1 & y' = y & z' = z),
        Implies(And(Equals(pc, Int(5)), Equals(x_pc, Int(6))),
                And(Equals(x_a, a), Equals(x_b, b), Equals(x_x, Int(1)),
                    Equals(x_y, y), Equals(x_z, z))),
        # (pc = 6 & pc' = 3)  -> (a' = a & b' = b & x' = x & y' = 15 & z' = z),
        Implies(And(Equals(pc, Int(6)), Equals(x_pc, Int(3))),
                And(Equals(x_a, a), Equals(x_b, b), Equals(x_x, x),
                    Equals(x_y, Int(15)), Equals(x_z, z))),
        # (pc = 7 & pc' = 8)  -> (a' = a & b' = b &           y' = y & z' = z),
        Implies(And(Equals(pc, Int(7)), Equals(x_pc, Int(8))),
                And(Equals(x_a, a), Equals(x_b, b), Equals(x_y, y),
                    Equals(x_z, z))),
        # (pc = 8 & pc' = 9)  -> (a' = a & b' = b & x' = x & y' = y & z' = a*b),
        Implies(And(Equals(pc, Int(8)), Equals(x_pc, Int(9))),
                And(Equals(x_a, a), Equals(x_b, b), Equals(x_x, x),
                    Equals(x_y, y), Equals(x_z, Times(a, b)))),
        # (pc = 9 & pc' = 3)  -> (a' = a+1 & b' = b & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(9)), Equals(x_pc, Int(3))),
                And(Equals(x_a, Plus(a, Int(1))), Equals(x_b, b),
                    Equals(x_x, x), Equals(x_y, y), Equals(x_z, z)))
    )

    # transition relation.
    trans = And(cfg, labels)

    # fairness.
    fairness = Not(Equals(pc, Int(-1)))

    # define automata to be composed.
    aut_a = AGAutomaton(symbols, [a], "aut_a", 1)
    aut_a.set_assume(0, TRUE())
    aut_a.set_invar(0, GE(a, Int(0)))
    aut_a.set_transitions(0, [(0, [Equals(x_a, a),
                                   Equals(x_a, Plus(a, Int(1)))])])

    aut_b = AGAutomaton(symbols, [b], "aut_b", 1)
    aut_b.set_assume(0, TRUE())
    aut_b.set_invar(0, LE(b, Int(0)))
    aut_b.set_transitions(0, [(0, [Equals(x_b, b)])])

    loc2pc = [Int(3), Int(4), Int(7), Int(8), Int(9)]
    aut_pc = AGAutomaton(symbols, [pc], "aut_pc", len(loc2pc))
    for loc in range(aut_pc.num_locations):
        n_loc = (loc + 1) % aut_pc.num_locations
        aut_pc.set_assume(loc, TRUE())
        aut_pc.set_invar(loc, Equals(pc, loc2pc[loc]))
        aut_pc.set_transitions(loc, [(n_loc, [Equals(x_pc, loc2pc[n_loc])])])

    aut_x = AGAutomaton(symbols, [x], "aut_x", 2)
    aut_x.set_assume(0, TRUE())
    aut_x.set_invar(0, Equals(x, Int(1)))
    aut_x.set_transitions(0, [(0, [Equals(x_x, x)]),
                              (1, [GT(x_x, Int(15))])])
    aut_x.set_assume(1, LE(y, Int(15)))
    aut_x.set_invar(1, GT(x, Int(15)))
    aut_x.set_transitions(1, [(1, [Equals(x_x, x)]),
                              (0, [Equals(x_x, Int(1))])])

    aut_y = AGAutomaton(symbols, [y], "aut_y", 1)
    aut_y.set_assume(0, TRUE())
    aut_y.set_invar(0, Equals(y, Int(15)))
    aut_y.set_transitions(0, [(0, [Equals(x_y, y),
                                   Equals(x_y, Int(15))])])

    aut_z = AGAutomaton(symbols, [z], "aut_z", 1)
    aut_z.set_assume(0, And(GE(a, Int(0)), LE(b, Int(0))))
    aut_z.set_invar(0, LE(z, Int(0)))
    aut_z.set_transitions(0, [(0, [Equals(x_z, z),
                                   Equals(x_z, Times(a, b))])])

    automata = [aut_a, aut_b, aut_pc, aut_x, aut_y, aut_z]

    # search composition.
    comp, undefs = find_composition(automata, init, trans, fairness,
                                    nuxmv_path, model_file, trace_file,
                                    cmd_file)
    if comp is not None:
        with open(output_file, 'w') as out:
            out.write(str(comp))

    if comp and not undefs:
        res = True
    elif not comp:
        res = False
    else:
        assert not comp and undefs
        res = None
    return res
