from find_composition import find_composition
from automata_composition import AGAutomaton
from utils import symb_next

from pysmt.shortcuts import Symbol, TRUE, Int, Real, ToReal
from pysmt.shortcuts import Not, And, Implies
from pysmt.shortcuts import GT, GE, Equals, LE, LT
from pysmt.shortcuts import Plus, Times, Div
from pysmt.typing import INT, REAL


def test(nuxmv_path: str, model_file: str, trace_file: str, cmd_file: str,
         output_file: str) -> bool:
    """Example 2 of document."""
    # symbols.
    pc = Symbol("pc", INT)
    i = Symbol("i", REAL)
    j = Symbol("j", REAL)
    x_pc = Symbol(symb_next("pc"), INT)
    x_i = Symbol(symb_next("i"), REAL)
    x_j = Symbol(symb_next("j"), REAL)
    symbols = [pc, i, j]

    # constants
    zero_r = Real(0)
    zero_i = Int(0)
    one_r = Real(1)
    m_one_r = Real(-1)
    one_i = Int(1)
    m_two_r = Real(-2)
    two_i = Int(2)
    three_i = Int(3)
    three_r = Real(3)
    four_r = Real(4)
    one_div_four = Div(one_r, four_r)

    # initial condition.
    init = And(Equals(pc, zero_i), Equals(j, one_r))

    # transition relation.
    loop_cond = And(GE(i, zero_r), LE(m_two_r, j), LE(j, one_r))
    trans = And(
        # pc = 0 & !(i >= 0 & -2 <= j <= 1)  -> pc' = 3
        Implies(And(Equals(pc, zero_i), Not(loop_cond)),
                Equals(x_pc, three_i)),
        # pc = 0 & i >= 0 & -2 <= j <= 1 -> pc' = 1 & i' = i & j' = j
        Implies(And(Equals(pc, zero_i), loop_cond),
                And(Equals(x_pc, one_i), Equals(x_i, i),
                    Equals(x_j, j))),
        # pc = 1 -> pc' = 2 & i' = i + j & j' = j
        Implies(Equals(pc, one_i),
                And(Equals(x_pc, two_i), Equals(x_j, j),
                    Equals(x_i, Plus(i, j)))),
        # pc = 2 -> pc' = 0 & i' = i & j' = -j^3/3 + 1/4
        Implies(Equals(pc, two_i),
                And(Equals(x_pc, zero_i), Equals(x_i, i),
                    Equals(x_j,
                           Plus(Div(Times(m_one_r, j, j, j), three_r),
                                one_div_four)))),
        # pc = 3 -> pc' = 3
        Implies(Equals(pc, three_i), Equals(x_pc, three_i))
    )
    # fairness condition.
    fairness = Not(Equals(pc, three_i))

    aut_pc0 = AGAutomaton(symbols, [pc], "aut_pc0", 3)
    aut_pc0.set_invar(0, Equals(pc, zero_i))
    aut_pc0.set_invar(1, Equals(pc, one_i))
    aut_pc0.set_invar(2, Equals(pc, two_i))
    aut_pc0.set_transitions(0, [(1, [Equals(x_pc, one_i)])])
    aut_pc0.set_transitions(1, [(2, [Equals(x_pc, two_i)])])
    aut_pc0.set_transitions(2, [(0, [Equals(x_pc, zero_i)])])
    for loc_idx in range(aut_pc0.num_locations):
        aut_pc0.set_assume(loc_idx, TRUE())

    aut_pc1 = AGAutomaton(symbols, [pc], "aut_pc1", 3)
    for loc in range(aut_pc1.num_locations):
        n_loc = (loc + 1) % aut_pc1.num_locations
        c_pc = Int(loc + 1)
        n_pc = Int(n_loc + 1)
        aut_pc1.set_assume(loc, TRUE())
        aut_pc1.set_invar(loc, Equals(pc, c_pc))
        aut_pc1.set_transitions(loc, [(n_loc, [Equals(x_pc, n_pc)])])

    aut_pc2 = AGAutomaton(symbols, [pc], "aut_pc2", 3)
    for loc in range(aut_pc2.num_locations):
        n_loc = (loc + 1) % aut_pc2.num_locations
        c_pc = Int(loc)
        n_pc = Int(n_loc)
        aut_pc2.set_assume(loc, Equals(i, zero_r))
        aut_pc2.set_invar(loc, Equals(pc, c_pc))
        aut_pc2.set_transitions(loc, [(n_loc, [Equals(x_pc, n_pc)])])

    aut_pc3 = AGAutomaton(symbols, [pc], "aut_pc3", 3)
    for loc in range(aut_pc3.num_locations):
        n_loc = (loc + 1) % aut_pc3.num_locations
        c_pc = Int(loc)
        n_pc = Int(n_loc)
        aut_pc3.set_assume(loc, LE(i, zero_r))
        aut_pc3.set_invar(loc, Equals(pc, c_pc))
        aut_pc3.set_transitions(loc, [(n_loc, [Equals(x_pc, n_pc)])])

    aut_pc4 = AGAutomaton(symbols, [pc], "aut_pc4", 3)
    for loc in range(aut_pc4.num_locations):
        n_loc = (loc + 1) % aut_pc4.num_locations
        c_pc = Int(loc)
        n_pc = Int(n_loc)
        aut_pc4.set_assume(loc, GE(j, one_r))
        aut_pc4.set_invar(loc, Equals(pc, c_pc))
        aut_pc4.set_transitions(loc, [(n_loc, [Equals(x_pc, n_pc)])])

    aut_pc5 = AGAutomaton(symbols, [pc], "aut_pc5", 3)
    for loc in range(aut_pc5.num_locations):
        n_loc = (loc + 1) % aut_pc5.num_locations
        c_pc = Int(loc)
        n_pc = Int(n_loc)
        aut_pc5.set_assume(loc, LE(i, Real(10)))
        aut_pc5.set_invar(loc, Equals(pc, c_pc))
        aut_pc5.set_transitions(loc, [(n_loc, [Equals(x_pc, n_pc)])])

    aut_pc6 = AGAutomaton(symbols, [pc], "aut_pc6", 4)
    for loc in range(aut_pc6.num_locations):
        n_loc = (loc + 1) % aut_pc6.num_locations
        c_pc = Int(loc)
        n_pc = Int(n_loc)
        aut_pc6.set_invar(loc, Equals(pc, c_pc))
        aut_pc6.set_transitions(loc, [(n_loc, [Equals(x_pc, n_pc)])])

    aut_pc7 = AGAutomaton(symbols, [pc], "aut_pc7", 2)
    for loc in range(aut_pc7.num_locations):
        n_loc = (loc + 1) % aut_pc7.num_locations
        c_pc = Int(loc)
        n_pc = Int(n_loc)
        aut_pc7.set_invar(loc, Equals(pc, c_pc))
        aut_pc7.set_transitions(loc, [(n_loc, [Equals(x_pc, n_pc)])])

    aut_pc8 = AGAutomaton(symbols, [pc], "aut_pc8", 3)
    for loc in range(aut_pc8.num_locations):
        n_loc = (loc + 1) % aut_pc8.num_locations
        c_pc = Int(loc)
        n_pc = Int(n_loc)
        aut_pc8.set_invar(loc, Equals(pc, c_pc))
        aut_pc8.set_assume(loc, LE(i, ToReal(pc)))
        aut_pc8.set_transitions(loc, [(n_loc, [Equals(x_pc, n_pc)])])

    aut_pc9 = AGAutomaton(symbols, [pc], "aut_pc9", 1)
    aut_pc9.set_transitions(0, [(0, [Equals(ToReal(x_pc), i)])])

    aut_i0 = AGAutomaton(symbols, [i], "aut_i0", 1)
    aut_i0.set_invar(0, GE(i, zero_r))
    aut_i0.set_assume(0, GE(j, zero_r))
    aut_i0.set_transitions(0, [(0, [Equals(x_i, i),
                                    Equals(x_i, Plus(i, j))])])

    aut_i1 = AGAutomaton(symbols, [i], "aut_i1", 1)
    aut_i1.set_invar(0, LE(i, zero_r))
    aut_i1.set_assume(0, LE(j, zero_r))
    aut_i1.set_transitions(0, [(0, [Equals(x_i, i),
                                    Equals(x_i, Plus(i, j))])])

    aut_i2 = AGAutomaton(symbols, [i], "aut_i2", 1)
    aut_i2.set_invar(0, GE(i, zero_r))
    aut_i2.set_assume(0, Equals(pc, Int(1)))
    aut_i2.set_transitions(0, [(0, [Equals(x_i, i),
                                    Equals(x_i, Plus(i, j))])])

    aut_i3 = AGAutomaton(symbols, [i], "aut_i3", 1)
    aut_i3.set_invar(0, GE(i, zero_r))
    aut_i3.set_assume(0, And(GE(j, zero_r), LE(j, Div(one_r, Real(10)))))
    aut_i3.set_transitions(0, [(0, [Equals(x_i, i),
                                    Equals(x_i, Plus(i, j))])])

    aut_i4 = AGAutomaton(symbols, [i], "aut_i4", 1)
    aut_i4.set_invar(0, TRUE())
    aut_i4.set_assume(0, GE(j, Real(-10)))
    aut_i4.set_transitions(0, [(0, [Equals(x_i, i),
                                    Equals(x_i, Plus(i, j))])])

    aut_i5 = AGAutomaton(symbols, [i], "aut_i5", 1)
    aut_i5.set_invar(0, GE(i, zero_r))
    aut_i5.set_assume(0, GE(j, one_r))
    aut_i5.set_transitions(0, [(0, [Equals(x_i, i),
                                    Equals(x_i, Plus(i, j))])])

    aut_i6 = AGAutomaton(symbols, [i], "aut_i6", 1)
    aut_i6.set_invar(0, GE(i, zero_r))
    aut_i6.set_assume(0, GE(j, Div(Real(33), Real(10))))
    aut_i6.set_transitions(0, [(0, [Equals(x_i, i),
                                    Equals(x_i, Plus(i, j))])])

    aut_i7 = AGAutomaton(symbols, [i], "aut_i7", 2)
    aut_i7.set_invar(0, GE(i, zero_r))
    aut_i7.set_assume(0, LE(j, zero_r))
    aut_i7.set_transitions(0, [(0, [Equals(x_i, i)]),
                               (1, [Equals(x_i, i)])])
    aut_i7.set_invar(1, GE(i, zero_r))
    aut_i7.set_assume(1, LE(j, i))
    aut_i7.set_transitions(1, [(0, [Equals(x_i, Plus(i, j))]),
                               (1, [Equals(x_i, Plus(i, j))])])

    aut_i8 = AGAutomaton(symbols, [i], "aut_i8", 2)
    aut_i8.set_invar(0, GE(i, zero_r))
    aut_i8.set_assume(0, LE(j, zero_r))
    aut_i8.set_transitions(0, [(0, [Equals(x_i, i)]),
                               (1, [Equals(x_i, i)])])
    aut_i8.set_invar(1, GE(i, zero_r))
    aut_i8.set_assume(1, LE(j, i))
    aut_i8.set_transitions(1, [(0, [Equals(x_i, Plus(i, j))]),
                               (1, [Equals(x_i, Times(i, j))])])

    aut_i9 = AGAutomaton(symbols, [i], "aut_i9", 2)
    aut_i9.set_invar(0, GE(i, zero_r))
    aut_i9.set_assume(0, LE(j, zero_r))
    aut_i9.set_transitions(0, [(0, [Equals(x_i, i)]),
                               (1, [Equals(x_i, i)])])
    aut_i9.set_invar(1, GE(i, zero_r))
    aut_i9.set_assume(1, Equals(j, Real(1)))
    aut_i9.set_transitions(1, [(0, [Equals(x_i, Plus(i, j))]),
                               (1, [Equals(x_i, Times(i, j))])])

    aut_j0 = AGAutomaton(symbols, [j], "aut_j0", 1)
    aut_j0.set_invar(0, And(GE(j, zero_r), LE(j, one_div_four)))
    aut_j0.set_assume(0, TRUE())
    aut_j0.set_transitions(0, [(0, [Equals(x_j, j),
                                    Equals(x_j,
                                           Plus(Div(Times(m_one_r, j, j, j),
                                                    three_r),
                                                one_div_four))])])

    aut_j1 = AGAutomaton(symbols, [j], "aut_j1", 1)
    aut_j1.set_invar(0, LT(j, Real(3)))
    aut_j1.set_assume(0, TRUE())
    aut_j1.set_transitions(0, [(0, [Equals(x_j, j),
                                    Equals(x_j,
                                           Plus(Div(Times(m_one_r, j, j, j),
                                                    three_r),
                                                one_div_four))])])

    aut_j2 = AGAutomaton(symbols, [j], "aut_j2", 1)
    aut_j2.set_invar(0, GT(j, Real(-3)))
    aut_j2.set_assume(0, TRUE())
    aut_j2.set_transitions(0, [(0, [Equals(x_j, j),
                                    Equals(x_j,
                                           Plus(Div(Times(m_one_r, j, j, j),
                                                    three_r),
                                                one_div_four))])])

    aut_j3 = AGAutomaton(symbols, [j], "aut_j3", 1)
    aut_j3.set_invar(0, And(GE(j, zero_r), LE(j, one_div_four)))
    aut_j3.set_assume(0, LT(i, Real(5)))
    aut_j3.set_transitions(0, [(0, [Equals(x_j, j),
                                    Equals(x_j,
                                           Plus(Div(Times(m_one_r, j, j, j),
                                                    three_r),
                                                one_div_four))])])

    aut_j4 = AGAutomaton(symbols, [j], "aut_j4", 1)
    aut_j4.set_invar(0, And(GE(j, zero_r), LE(j, one_div_four)))
    aut_j4.set_assume(0, And(Equals(i, zero_r), LT(pc, Int(1))))
    aut_j4.set_transitions(0, [(0, [Equals(x_j, Plus(j, i)),
                                    Equals(x_j,
                                           Plus(Div(Times(m_one_r, j, j, j),
                                                    three_r),
                                                one_div_four))])])

    aut_j5 = AGAutomaton(symbols, [j], "aut_j5", 1)
    aut_j5.set_invar(0, And(GE(j, zero_r), LE(j, one_div_four)))
    aut_j5.set_assume(0, Equals(i, one_r))
    aut_j5.set_transitions(0, [(0, [Equals(x_j, Times(j, i)),
                                    Equals(x_j,
                                           Plus(Div(Times(m_one_r, j, j, j),
                                                    three_r),
                                                one_div_four))])])

    aut_j6 = AGAutomaton(symbols, [j], "aut_j6", 1)
    aut_j6.set_invar(0, And(GE(j, zero_r), LE(j, one_div_four)))
    aut_j6.set_assume(0, GE(i, one_r))
    aut_j6.set_transitions(0, [(0, [Equals(x_j, Div(j, i)),
                                    Equals(x_j,
                                           Plus(Div(Times(m_one_r, j, j, j),
                                                    three_r),
                                                one_div_four))])])

    aut_j7 = AGAutomaton(symbols, [j], "aut_j7", 2)
    aut_j7.set_invar(0, And(GE(j, zero_r), LE(j, one_div_four)))
    aut_j7.set_assume(0, TRUE())
    aut_j7.set_transitions(0, [(0, [Equals(x_j, j)]), (1, [Equals(x_j, j)])])
    aut_j7.set_invar(1, And(GE(j, zero_r), LE(j, one_div_four)))
    aut_j7.set_assume(1, GE(i, zero_r))
    aut_j7.set_transitions(1, [(1, [Equals(x_j,
                                           Plus(Div(Times(m_one_r, j, j, j),
                                                    three_r),
                                                one_div_four))]),
                               (0, [Equals(x_j, Plus(j, i))])])

    aut_j8 = AGAutomaton(symbols, [j], "aut_j8", 2)
    aut_j8.set_invar(0, And(GE(j, zero_r), LE(j, one_div_four)))
    aut_j8.set_assume(0, TRUE())
    aut_j8.set_transitions(0, [(0, [Equals(x_j, j)]), (1, [Equals(x_j, j)])])
    aut_j8.set_invar(1, And(GE(j, zero_r), LE(j, one_div_four)))
    aut_j8.set_assume(1, GE(pc, Int(3)))
    aut_j8.set_transitions(1, [(1, [Equals(x_j, Plus(j, i))]),
                               (0, [Equals(x_j, j)])])

    aut_j9 = AGAutomaton(symbols, [j], "aut_j9", 1)
    aut_j9.set_invar(0, And(GE(j, zero_r), LE(j, one_div_four)))
    aut_j9.set_transitions(0, [(0, [Equals(x_j, j)])])

    automata = [aut_pc0, aut_j0, aut_i0,
                aut_pc1, aut_j1, aut_i1,
                aut_pc2, aut_j2, aut_i2,
                aut_pc3, aut_j3, aut_i3,
                aut_pc4, aut_j4, aut_i4,
                aut_pc5, aut_j5, aut_i5,
                aut_pc6, aut_j6, aut_i6,
                aut_pc7, aut_j7, aut_i7,
                aut_pc8, aut_j8, aut_i8,
                aut_pc9, aut_j9, aut_i9]

    comp, undefs = find_composition(automata, init, trans, fairness,
                                    nuxmv_path, model_file, trace_file,
                                    cmd_file)
    if comp is not None:
        with open(output_file, 'w') as out:
            out.write(str(comp))

    if comp and not undefs:
        res = True
    elif not comp:
        res = False
    else:
        assert not comp and undefs
        res = None
    return res
