from find_composition import find_composition
from automata_composition import AGAutomaton
from utils import symb_next

from pysmt.shortcuts import Symbol, TRUE, Int
from pysmt.shortcuts import Not, And, Or, Implies
from pysmt.shortcuts import GT, GE, Equals, LE, LT
from pysmt.shortcuts import Plus, Minus, Times
from pysmt.typing import INT


def test(nuxmv_path: str, model_file: str, trace_file: str, cmd_file: str,
         output_file: str) -> bool:
    """Test corresponding to benchmarks/4.smv"""
    # symbols.
    pc = Symbol("pc", INT)
    w = Symbol("w", INT)
    x = Symbol("x", INT)
    y = Symbol("y", INT)
    z = Symbol("z", INT)
    x_pc = Symbol(symb_next("pc"), INT)
    x_w = Symbol(symb_next("w"), INT)
    x_x = Symbol(symb_next("x"), INT)
    x_y = Symbol(symb_next("y"), INT)
    x_z = Symbol(symb_next("z"), INT)
    symbols = [pc, w, x, y, z]

    # initial location.
    init = Equals(pc, Int(0))

    # control flow graph.
    cfg = And(
        # pc = -1 : -1,
        Implies(Equals(pc, Int(-1)), Equals(x_pc, Int(-1))),
        # pc = 0 & !(z >= 4) : -1,
        Implies(And(Equals(pc, Int(0)), Not(GE(z, Int(4)))),
                Equals(x_pc, Int(-1))),
        # pc = 0 & z >= 4 : 1,
        Implies(And(Equals(pc, Int(0)), GE(z, Int(4))), Equals(x_pc, Int(1))),
        # pc = 1 : 2,
        Implies(Equals(pc, Int(1)), Equals(x_pc, Int(2))),
        # pc = 2 & x >= 0 : 3,
        Implies(And(Equals(pc, Int(2)), GE(x, Int(0))), Equals(x_pc, Int(3))),
        # pc = 2 & !(x >= 0) : 5,
        Implies(And(Equals(pc, Int(2)), Not(GE(x, Int(0)))),
                Equals(x_pc, Int(5))),
        # pc = 3 & !(w <= -5) : -1,
        Implies(And(Equals(pc, Int(3)), Not(LE(w, Int(-5)))),
                Equals(x_pc, Int(-1))),
        # pc = 3 & w <= -5 : 4,
        Implies(And(Equals(pc, Int(3)), LE(w, Int(-5))), Equals(x_pc, Int(4))),
        # pc = 4 : 6,
        Implies(Equals(pc, Int(4)), Equals(x_pc, Int(6))),
        # pc = 5 : 6,
        Implies(Equals(pc, Int(5)), Equals(x_pc, Int(6))),
        # pc = 6 : {7, 9},
        Implies(Equals(pc, Int(6)),
                Or(Equals(x_pc, Int(7)), Equals(x_pc, Int(9)))),
        # pc = 7 & !(x < 0) : -1,
        Implies(And(Equals(pc, Int(7)), Not(LT(x, Int(0)))),
                Equals(x_pc, Int(-1))),
        # pc = 7 & x < 0 : 8,
        Implies(And(Equals(pc, Int(7)), LT(x, Int(0))), Equals(x_pc, Int(8))),
        # pc = 8 : -1,
        Implies(Equals(pc, Int(8)), Equals(x_pc, Int(-1))),
        # pc = 9 & !(x >= w) : 18,
        Implies(And(Equals(pc, Int(9)), Not(GE(x, w))),
                Equals(x_pc, Int(18))),
        # pc = 9 & x >= w : 10,
        Implies(And(Equals(pc, Int(9)), GE(x, w)),
                Equals(x_pc, Int(10))),
        # pc = 10 : {11, 13},
        Implies(Equals(pc, Int(10)),
                Or(Equals(x_pc, Int(11)), Equals(x_pc, Int(13)))),
        # pc = 11 & !(x < 0) : -1,
        Implies(And(Equals(pc, Int(11)), Not(LT(x, Int(0)))),
                Equals(x_pc, Int(-1))),
        # pc = 11 & x < 0 : 12,
        Implies(And(Equals(pc, Int(11)), LT(x, Int(0))),
                Equals(x_pc, Int(12))),
        # pc = 12 : -1,
        Implies(Equals(pc, Int(12)), Equals(x_pc, Int(-1))),
        # pc = 13 & z <= 8 : 14,
        Implies(And(Equals(pc, Int(13)), LE(z, Int(8))),
                Equals(x_pc, Int(14))),
        # pc = 13 & !(z <= 8) : 15,
        Implies(And(Equals(pc, Int(13)), Not(LE(z, Int(8)))),
                Equals(x_pc, Int(15))),
        # pc = 14 : 16,
        Implies(Equals(pc, Int(14)), Equals(x_pc, Int(16))),
        # pc = 15 : 16,
        Implies(Equals(pc, Int(15)), Equals(x_pc, Int(16))),
        # pc = 16 : 17,
        Implies(Equals(pc, Int(16)), Equals(x_pc, Int(17))),
        # pc = 17 : 9,
        Implies(Equals(pc, Int(17)), Equals(x_pc, Int(9))),
        # pc = 18 : {-1, 19},
        Implies(Equals(pc, Int(18)),
                Or(Equals(x_pc, Int(-1)), Equals(x_pc, Int(19)))),
        # pc = 19 & !(x < 0) : -1,
        Implies(And(Equals(pc, Int(19)), Not(LT(x, Int(0)))),
                Equals(x_pc, Int(-1))),
        # pc = 19 & x < 0 : 20,
        Implies(And(Equals(pc, Int(19)), LT(x, Int(0))),
                Equals(x_pc, Int(20))),
        # pc = 20 : -1,
        Implies(Equals(pc, Int(20)), Equals(x_pc, Int(-1)))
    )

    # transition labels.
    labels = And(
        # (pc = -1 & pc' = -1) -> (w' = w & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(-1)), Equals(x_pc, Int(-1))),
                And(Equals(x_w, w), Equals(x_x, x), Equals(x_y, y),
                    Equals(x_z, z))),
        # (pc = 0 & pc' = -1)  -> (w' = w & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(0)), Equals(x_pc, Int(-1))),
                And(Equals(x_w, w), Equals(x_x, x), Equals(x_y, y),
                    Equals(x_z, z))),
        # (pc = 0 & pc' = 1)   -> (w' = w & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(0)), Equals(x_pc, Int(1))),
                And(Equals(x_w, w), Equals(x_x, x), Equals(x_y, y),
                    Equals(x_z, z))),
        # (pc = 1 & pc' = 2)   -> (w' = w & x' = x & y' = y & z' = z+1),
        Implies(And(Equals(pc, Int(1)), Equals(x_pc, Int(2))),
                And(Equals(x_w, w), Equals(x_x, x), Equals(x_y, y),
                    Equals(x_z, Plus(z, Int(1))))),
        # (pc = 2 & pc' = 3)   -> (w' = w & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(2)), Equals(x_pc, Int(3))),
                And(Equals(x_w, w), Equals(x_x, x), Equals(x_y, y),
                    Equals(x_z, z))),
        # (pc = 2 & pc' = 5)   -> (w' = w & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(2)), Equals(x_pc, Int(5))),
                And(Equals(x_w, w), Equals(x_x, x), Equals(x_y, y),
                    Equals(x_z, z))),
        # (pc = 3 & pc' = -1)  -> (w' = w & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(3)), Equals(x_pc, Int(-1))),
                And(Equals(x_w, w), Equals(x_x, x), Equals(x_y, y),
                    Equals(x_z, z))),
        # (pc = 3 & pc' = 4)   -> (w' = w & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(3)), Equals(x_pc, Int(4))),
                And(Equals(x_w, w), Equals(x_x, x), Equals(x_y, y),
                    Equals(x_z, z))),
        # (pc = 4 & pc' = 6)   -> (w' = w & x' = x & y' = y & z' = z+1),
        Implies(And(Equals(pc, Int(4)), Equals(x_pc, Int(6))),
                And(Equals(x_w, w), Equals(x_x, x), Equals(x_y, y),
                    Equals(x_z, Plus(z, Int(1))))),
        # (pc = 5 & pc' = 6)   -> (w' = w & x' = x & y' = y & z' = z-1),
        Implies(And(Equals(pc, Int(5)), Equals(x_pc, Int(6))),
                And(Equals(x_w, w), Equals(x_x, x), Equals(x_y, y),
                    Equals(x_z, Minus(z, Int(1))))),
        # (pc = 6 & pc' = 7)   -> (w' = w & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(6)), Equals(x_pc, Int(7))),
                And(Equals(x_w, w), Equals(x_x, x), Equals(x_y, y),
                    Equals(x_z, z))),
        # (pc = 6 & pc' = 9)   -> (w' = w & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(6)), Equals(x_pc, Int(9))),
                And(Equals(x_w, w), Equals(x_x, x), Equals(x_y, y),
                    Equals(x_z, z))),
        # (pc = 7 & pc' = -1)  -> (w' = w & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(7)), Equals(x_pc, Int(-1))),
                And(Equals(x_w, w), Equals(x_x, x), Equals(x_y, y),
                    Equals(x_z, z))),
        # (pc = 7 & pc' = 8)   -> (w' = w & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(7)), Equals(x_pc, Int(8))),
                And(Equals(x_w, w), Equals(x_x, x), Equals(x_y, y),
                    Equals(x_z, z))),
        # (pc = 8 & pc' = -1)  -> (w' = w & x' = x & y' = y & z' = z-1),
        Implies(And(Equals(pc, Int(8)), Equals(x_pc, Int(-1))),
                And(Equals(x_w, w), Equals(x_x, x), Equals(x_y, y),
                    Equals(x_z, Minus(z, Int(1))))),
        # (pc = 9 & pc' = 18)  -> (w' = w & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(9)), Equals(x_pc, Int(18))),
                And(Equals(x_w, w), Equals(x_x, x), Equals(x_y, y),
                    Equals(x_z, z))),
        # (pc = 9 & pc' = 10)  -> (w' = w & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(9)), Equals(x_pc, Int(10))),
                And(Equals(x_w, w), Equals(x_x, x), Equals(x_y, y),
                    Equals(x_z, z))),
        # (pc = 10 & pc' = 11) -> (w' = w & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(10)), Equals(x_pc, Int(11))),
                And(Equals(x_w, w), Equals(x_x, x), Equals(x_y, y),
                    Equals(x_z, z))),
        # (pc = 10 & pc' = 13) -> (w' = w & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(10)), Equals(x_pc, Int(13))),
                And(Equals(x_w, w), Equals(x_x, x), Equals(x_y, y),
                    Equals(x_z, z))),
        # (pc = 11 & pc' = -1) -> (w' = w & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(11)), Equals(x_pc, Int(-1))),
                And(Equals(x_w, w), Equals(x_x, x), Equals(x_y, y),
                    Equals(x_z, z))),
        # (pc = 11 & pc' = 12) -> (w' = w & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(11)), Equals(x_pc, Int(12))),
                And(Equals(x_w, w), Equals(x_x, x), Equals(x_y, y),
                    Equals(x_z, z))),
        # (pc = 12 & pc' = -1) -> (w' = w & x' = x & y' = y & z' = z-1),
        Implies(And(Equals(pc, Int(12)), Equals(x_pc, Int(-1))),
                And(Equals(x_w, w), Equals(x_x, x), Equals(x_y, y),
                    Equals(x_z, Minus(z, Int(1))))),
        # (pc = 13 & pc' = 14) -> (w' = w & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(13)), Equals(x_pc, Int(14))),
                And(Equals(x_w, w), Equals(x_x, x), Equals(x_y, y),
                    Equals(x_z, z))),
        # (pc = 13 & pc' = 15) -> (w' = w & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(13)), Equals(x_pc, Int(15))),
                And(Equals(x_w, w), Equals(x_x, x), Equals(x_y, y),
                    Equals(x_z, z))),
        # (pc = 14 & pc' = 16) -> (w' = w & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(14)), Equals(x_pc, Int(16))),
                And(Equals(x_w, w), Equals(x_x, x), Equals(x_y, y),
                    Equals(x_z, z))),
        # (pc = 15 & pc' = 16) -> (w' = w & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(15)), Equals(x_pc, Int(16))),
                And(Equals(x_w, w), Equals(x_x, x), Equals(x_y, y),
                    Equals(x_z, z))),
        # (pc = 16 & pc' = 17) -> (w' = w & x' = z*z & y' = y & z' = z),
        Implies(And(Equals(pc, Int(16)), Equals(x_pc, Int(17))),
                And(Equals(x_w, w), Equals(x_x, Times(z, z)), Equals(x_y, y),
                    Equals(x_z, z))),
        # (pc = 17 & pc' = 9)  -> (w' = w-1 & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(17)), Equals(x_pc, Int(9))),
                And(Equals(x_w, Minus(w, Int(1))), Equals(x_x, x),
                    Equals(x_y, y), Equals(x_z, z))),
        # (pc = 18 & pc' = -1) -> (w' = w & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(18)), Equals(x_pc, Int(-1))),
                And(Equals(x_w, w), Equals(x_x, x), Equals(x_y, y),
                    Equals(x_z, z))),
        # (pc = 18 & pc' = 19) -> (w' = w & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(18)), Equals(x_pc, Int(19))),
                And(Equals(x_w, w), Equals(x_x, x), Equals(x_y, y),
                    Equals(x_z, z))),
        # (pc = 19 & pc' = -1) -> (w' = w & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(19)), Equals(x_pc, Int(-1))),
                And(Equals(x_w, w), Equals(x_x, x), Equals(x_y, y),
                    Equals(x_z, z))),
        # (pc = 19 & pc' = 20) -> (w' = w & x' = x & y' = y & z' = z),
        Implies(And(Equals(pc, Int(19)), Equals(x_pc, Int(20))),
                And(Equals(x_w, w), Equals(x_x, x), Equals(x_y, y),
                    Equals(x_z, z))),
        # (pc = 20 & pc' = -1) -> (w' = w & x' = x & y' = y & z' = z-1),
        Implies(And(Equals(pc, Int(20)), Equals(x_pc, Int(-1))),
                And(Equals(x_w, w), Equals(x_x, x), Equals(x_y, y),
                    Equals(x_z, Minus(z, Int(1)))))
    )

    # transition relation.
    trans = And(cfg, labels)

    # fairness.
    fairness = Not(Equals(pc, Int(-1)))

    # define automata to be composed.
    aut_pc0 = AGAutomaton(symbols, [pc], "aut_pc0", 6)
    loc2pc = [9, 10, 13, 14, 16, 17]
    for loc in range(6):
        n_loc = (loc + 1) % 6
        aut_pc0.set_assume(loc, TRUE())
        aut_pc0.set_invar(loc, Equals(pc, Int(loc2pc[loc])))
        aut_pc0.set_transitions(loc,
                                [(n_loc, [Equals(x_pc, Int(loc2pc[n_loc]))])])

    aut_pc1 = AGAutomaton(symbols, [pc], "aut_pc1", 7)
    # loc 6 : pc = 15
    for loc in range(6):
        n_loc = (loc + 1) % 6
        aut_pc1.set_assume(loc, TRUE())
        aut_pc1.set_invar(loc, Equals(pc, Int(loc2pc[loc])))
        aut_pc1.set_transitions(loc,
                                [(n_loc, [Equals(x_pc, Int(loc2pc[n_loc]))])])
    aut_pc1.set_assume(6, TRUE())
    aut_pc1.set_invar(6, Equals(pc, Int(15)))
    # pc = 13 -> pc' in {14, 15}
    aut_pc1.set_transitions(2, [(3, [Equals(x_pc, Int(14))]),
                                (6, [Equals(x_pc, Int(15))])])
    # pc = 15 -> pc' = 16
    aut_pc1.set_transitions(6, [(4, [Equals(x_pc, Int(16))])])

    aut_w = AGAutomaton(symbols, [w], "aut_w", 1)
    aut_w.set_assume(0, LT(w, x))
    aut_w.set_invar(0, LE(w, Int(0)))
    aut_w.set_transitions(0, [(0, [Equals(x_w, w),
                                   Equals(x_w, Minus(w, Int(1)))])])

    aut_x = AGAutomaton(symbols, [x], "aut_x", 1)
    aut_x.set_assume(0, Or(GT(z, Int(0)), LT(z, Int(0))))
    aut_x.set_invar(0, GE(x, Int(0)))
    aut_x.set_transitions(0, [(0, [Equals(x_x, x),
                                   Equals(x_x, Times(z, z))])])

    aut_y = AGAutomaton(symbols, [y], "aut_y", 1)
    aut_y.set_assume(0, TRUE())
    aut_y.set_invar(0, TRUE())
    aut_y.set_transitions(0, [(0, [Equals(x_y, y)])])

    aut_z_neg = AGAutomaton(symbols, [z], "aut_z_neg", 1)
    aut_z_neg.set_assume(0, TRUE())
    aut_z_neg.set_invar(0, LE(z, Int(0)))
    aut_z_neg.set_transitions(0, [(0, [Equals(x_z, z),
                                   Equals(x_z, Minus(z, Int(1)))])])

    aut_z_gt_8 = AGAutomaton(symbols, [z], "aut_z_gt_8", 1)
    aut_z_gt_8.set_assume(0, TRUE())
    aut_z_gt_8.set_invar(0, GT(z, Int(8)))
    aut_z_gt_8.set_transitions(0, [(0, [Equals(x_z, z),
                                   Equals(x_z, Plus(z, Int(1)))])])

    aut_z_le_8 = AGAutomaton(symbols, [z], "aut_z_le_8", 1)
    aut_z_le_8.set_assume(0, TRUE())
    aut_z_le_8.set_invar(0, And(GE(z, Int(0)), LE(z, Int(8))))
    aut_z_le_8.set_transitions(0, [(0, [Equals(x_z, z)])])

    automata = [aut_pc0, aut_pc1, aut_w, aut_x, aut_y, aut_z_neg,
                aut_z_gt_8, aut_z_le_8]

    # search composition.
    comp, undefs = find_composition(automata, init, trans, fairness,
                                    nuxmv_path, model_file, trace_file,
                                    cmd_file)
    if comp is not None:
        with open(output_file, 'w') as out:
            out.write(str(comp))

    if comp and not undefs:
        res = True
    elif not comp:
        res = False
    else:
        assert not comp and undefs
        res = None
    return res
