from find_composition import find_composition
from automata_composition import AGAutomaton
from utils import symb_next

from pysmt.shortcuts import Symbol, TRUE, Int
from pysmt.shortcuts import Not, And, Implies
from pysmt.shortcuts import GE, Equals, LE, LT
from pysmt.shortcuts import Plus, Times
from pysmt.typing import INT


def test(nuxmv_path: str, model_file: str, trace_file: str, cmd_file: str,
         output_file: str) -> bool:
    """Test corresponding to benchmarks/integer_log_by_mul.smv"""
    # symbols.
    d = Symbol("d", INT)
    f_mul = Symbol("f_mul", INT)
    log = Symbol("log", INT)
    n = Symbol("n", INT)
    pc = Symbol("pc", INT)
    x_d = Symbol(symb_next("d"), INT)
    x_f_mul = Symbol(symb_next("f_mul"), INT)
    x_log = Symbol(symb_next("log"), INT)
    x_n = Symbol(symb_next("n"), INT)
    x_pc = Symbol(symb_next("pc"), INT)
    symbols = [d, f_mul, log, n, pc]

    # initial location.
    init = Equals(pc, Int(0))

    # control flow graph.
    cfg = And(
        # pc = -1 : -1,
        Implies(Equals(pc, Int(-1)), Equals(x_pc, Int(-1))),
        # pc = 0 & !(n >= 1) : -1,
        Implies(And(Equals(pc, Int(0)), Not(GE(n, Int(1)))),
                Equals(x_pc, Int(-1))),
        # pc = 0 & n >= 1 : 1,
        Implies(And(Equals(pc, Int(0)), GE(n, Int(1))), Equals(x_pc, Int(1))),
        # pc = 1 & !(d >= 0) : -1,
        Implies(And(Equals(pc, Int(1)), Not(GE(d, Int(0)))),
                Equals(x_pc, Int(-1))),
        # pc = 1 & d >= 0 : 2,
        Implies(And(Equals(pc, Int(1)), GE(d, Int(0))), Equals(x_pc, Int(2))),
        # pc = 2 & d < 2 : 3,
        Implies(And(Equals(pc, Int(2)), LT(d, Int(2))), Equals(x_pc, Int(3))),
        # pc = 2 & !(d < 2) : 8,
        Implies(And(Equals(pc, Int(2)), Not(LT(d, Int(2)))),
                Equals(x_pc, Int(8))),
        # pc = 3 : 4,
        Implies(Equals(pc, Int(3)), Equals(x_pc, Int(4))),
        # pc = 4 : 5,
        Implies(Equals(pc, Int(4)), Equals(x_pc, Int(5))),
        # pc = 5 & !(d <= n) : -1,
        Implies(And(Equals(pc, Int(5)), Not(LE(d, n))), Equals(x_pc, Int(-1))),
        # pc = 5 & d <= n : 6,
        Implies(And(Equals(pc, Int(5)), LE(d, n)), Equals(x_pc, Int(6))),
        # pc = 6 : 7,
        Implies(Equals(pc, Int(6)), Equals(x_pc, Int(7))),
        # pc = 7 : 5,
        Implies(Equals(pc, Int(7)), Equals(x_pc, Int(5))),
        # pc = 8 : -1,
        Implies(Equals(pc, Int(8)), Equals(x_pc, Int(-1)))
    )

    # transition labels.
    labels = And(
        # (pc = -1 & pc' = -1) -> (n' = n & d' = d & log' = log & f_mul' = f_mul),
        Implies(And(Equals(pc, Int(-1)), Equals(x_pc, Int(-1))),
                And(Equals(x_n, n), Equals(x_d, d), Equals(x_log, log),
                    Equals(x_f_mul, f_mul))),
        # (pc = 0 & pc' = -1) -> (n' = n & d' = d & log' = log & f_mul' = f_mul),
        Implies(And(Equals(pc, Int(0)), Equals(x_pc, Int(-1))),
                And(Equals(x_n, n), Equals(x_d, d), Equals(x_log, log),
                    Equals(x_f_mul, f_mul))),
        # (pc = 0 & pc' = 1)  -> (n' = n & d' = d & log' = log & f_mul' = f_mul),
        Implies(And(Equals(pc, Int(0)), Equals(x_pc, Int(1))),
                And(Equals(x_n, n), Equals(x_d, d), Equals(x_log, log),
                    Equals(x_f_mul, f_mul))),
        # (pc = 1 & pc' = -1) -> (n' = n & d' = d & log' = log & f_mul' = f_mul),
        Implies(And(Equals(pc, Int(1)), Equals(x_pc, Int(-1))),
                And(Equals(x_n, n), Equals(x_d, d), Equals(x_log, log),
                    Equals(x_f_mul, f_mul))),
        # (pc = 1 & pc' = 2)  -> (n' = n & d' = d & log' = log & f_mul' = f_mul),
        Implies(And(Equals(pc, Int(1)), Equals(x_pc, Int(2))),
                And(Equals(x_n, n), Equals(x_d, d), Equals(x_log, log),
                    Equals(x_f_mul, f_mul))),
        # (pc = 2 & pc' = 3)  -> (n' = n & d' = d & log' = log & f_mul' = f_mul),
        Implies(And(Equals(pc, Int(2)), Equals(x_pc, Int(3))),
                And(Equals(x_n, n), Equals(x_d, d), Equals(x_log, log),
                    Equals(x_f_mul, f_mul))),
        # (pc = 2 & pc' = 8)  -> (n' = n & d' = d & log' = log & f_mul' = f_mul),
        Implies(And(Equals(pc, Int(2)), Equals(x_pc, Int(8))),
                And(Equals(x_n, n), Equals(x_d, d), Equals(x_log, log),
                    Equals(x_f_mul, f_mul))),
        # (pc = 3 & pc' = 4)  -> (n' = n & d' = d & log' = 0 & f_mul' = f_mul),
        Implies(And(Equals(pc, Int(3)), Equals(x_pc, Int(4))),
                And(Equals(x_n, n), Equals(x_d, d), Equals(x_log, Int(0)),
                    Equals(x_f_mul, f_mul))),
        # (pc = 4 & pc' = 5)  -> (n' = n & d' = d & log' = log & f_mul' = d),
        Implies(And(Equals(pc, Int(4)), Equals(x_pc, Int(5))),
                And(Equals(x_n, n), Equals(x_d, d), Equals(x_log, log),
                    Equals(x_f_mul, d))),
        # (pc = 5 & pc' = -1) -> (n' = n & d' = d & log' = log & f_mul' = f_mul),
        Implies(And(Equals(pc, Int(5)), Equals(x_pc, Int(-1))),
                And(Equals(x_n, n), Equals(x_d, d), Equals(x_log, log),
                    Equals(x_f_mul, f_mul))),
        # (pc = 5 & pc' = 6)  -> (n' = n & d' = d & log' = log & f_mul' = f_mul),
        Implies(And(Equals(pc, Int(5)), Equals(x_pc, Int(6))),
                And(Equals(x_n, n), Equals(x_d, d), Equals(x_log, log),
                    Equals(x_f_mul, f_mul))),
        # (pc = 6 & pc' = 7)  -> (n' = n & d' = d & log' = log+1 & f_mul' = f_mul),
        Implies(And(Equals(pc, Int(6)), Equals(x_pc, Int(7))),
                And(Equals(x_n, n), Equals(x_d, d),
                    Equals(x_log, Plus(log, Int(1))), Equals(x_f_mul, f_mul))),
        # (pc = 7 & pc' = 5)  -> (n' = n & d' = d*f_mul & log' = log & f_mul' = f_mul),
        Implies(And(Equals(pc, Int(7)), Equals(x_pc, Int(5))),
                And(Equals(x_n, n), Equals(x_d, Times(d, f_mul)),
                    Equals(x_log, log), Equals(x_f_mul, f_mul))),
        # (pc = 8 & pc' = -1) -> (n' = n & d' = d & log' = log & f_mul' = f_mul),
        Implies(And(Equals(pc, Int(8)), Equals(x_pc, Int(-1))),
                And(Equals(x_n, n), Equals(x_d, d), Equals(x_log, log),
                    Equals(x_f_mul, f_mul)))
    )

    # transition relation.
    trans = And(cfg, labels)

    # fairness.
    fairness = Not(Equals(pc, Int(-1)))

    # define automata to be composed.
    aut_d = AGAutomaton(symbols, [d], "aut_d", 1)
    aut_d.set_assume(0, Equals(f_mul, Int(1)))
    aut_d.set_invar(0, Equals(d, Int(1)))
    aut_d.set_transitions(0, [(0, [Equals(x_d, d),
                                   Equals(x_d, Times(d, f_mul))])])

    aut_f_mul = AGAutomaton(symbols, [f_mul], "aut_f_mul", 1)
    aut_f_mul.set_assume(0, TRUE())
    aut_f_mul.set_invar(0, TRUE())
    aut_f_mul.set_transitions(0, [(0, [Equals(x_f_mul, f_mul)])])

    aut_log = AGAutomaton(symbols, [log], "aut_log", 1)
    aut_log.set_assume(0, TRUE())
    aut_log.set_invar(0, TRUE())
    aut_log.set_transitions(0, [(0, [Equals(x_log, log),
                                     Equals(x_log, Plus(log, Int(1)))])])

    aut_n = AGAutomaton(symbols, [n], "aut_n", 1)
    aut_n.set_assume(0, TRUE())
    aut_n.set_invar(0, GE(n, Int(2)))
    aut_n.set_transitions(0, [(0, [Equals(x_n, n)])])

    aut_dn = AGAutomaton(symbols, [d, n], "aut_dn", 1)
    aut_dn.set_assume(0, And(GE(f_mul, Int(-1)), LE(f_mul, Int(1))))
    aut_dn.set_invar(0, And(GE(n, Int(2)), GE(d, Int(-2)), LE(d, Int(2))))
    aut_dn.set_transitions(0, [(0, [And(Equals(x_n, n), Equals(x_d, d)),
                                    And(Equals(x_n, n),
                                        Equals(x_d, Times(d, f_mul)))])])

    aut_pc = AGAutomaton(symbols, [pc], "aut_pc", 3)
    for loc in range(aut_pc.num_locations):
        n_loc = (loc + 1) % aut_pc.num_locations
        c_pc = Int(loc + 5)
        n_pc = Int(n_loc + 5)
        aut_pc.set_assume(loc, TRUE())
        aut_pc.set_invar(loc, Equals(pc, c_pc))
        aut_pc.set_transitions(loc, [(n_loc, [Equals(x_pc, n_pc)])])

    automata = [aut_d, aut_f_mul, aut_log, aut_n, aut_pc, aut_dn]

    # search composition.
    comp, undefs = find_composition(automata, init, trans, fairness,
                                    nuxmv_path, model_file, trace_file,
                                    cmd_file)
    if comp is not None:
        with open(output_file, 'w') as out:
            out.write(str(comp))

    if comp and not undefs:
        res = True
    elif not comp:
        res = False
    else:
        assert not comp and undefs
        res = None
    return res
