from find_composition import find_composition
from automata_composition import AGAutomaton
from utils import symb_next

from pysmt.shortcuts import Symbol, TRUE, Int, Real
from pysmt.shortcuts import Not, And, Implies
from pysmt.shortcuts import GT, GE, Equals, LE, LT
from pysmt.shortcuts import Plus, Times, Div
from pysmt.shortcuts import ToReal
from pysmt.typing import INT, REAL, BOOL


def test(nuxmv_path: str, model_file: str, trace_file: str, cmd_file: str,
         output_file: str) -> bool:
    """Test corresponding to benchmarks/bouncing_ball_harmonic.smv"""
    # symbols.
    h = Symbol("h", REAL)
    v = Symbol("v", REAL)
    d = Symbol("delta", REAL)
    c = Symbol("counter", INT)
    stop = Symbol("stop", BOOL)
    x_h = Symbol(symb_next("h"), REAL)
    x_v = Symbol(symb_next("v"), REAL)
    x_d = Symbol(symb_next("delta"), REAL)
    x_c = Symbol(symb_next("counter"), INT)
    x_stop = Symbol(symb_next("stop"), BOOL)

    symbols = [c, d, h, v, stop]

    g = Div(Real(981), Real(100))
    c_r = ToReal(c)

    # initial location.
    init = And(Equals(c, Int(1)),
               Equals(h, Real(0)),
               Equals(v, Div(g, Real(2))))

    # transition relation.
    n_h = Plus(h, Times(v, d), Times(Div(Real(-1), Real(2)), g, d, d))
    n_v = Plus(v, Times(Real(-1), g, d))
    cond = And(Equals(h, Real(0)), LT(v, Real(0)))
    trans = And(
        Implies(And(stop, Equals(h, Real(0))),
                And(Equals(x_h, Real(0)), Equals(x_v, Real(0)))),
        Implies(And(Equals(h, Real(0)), Equals(v, Real(0))),
                And(Equals(x_h, Real(0)), Equals(x_v, Real(0)))),
        Implies(cond, And(
            Equals(x_c, Plus(c, Int(1))),
            GE(x_d, Real(0)),
            Equals(x_h, Real(0)),
            Equals(x_v, Times(Real(-1), v, Div(c_r, Plus(c_r, Real(1))))))),
        Implies(Not(cond), And(
            Equals(x_c, c),
            GE(x_d, Real(0)),
            Equals(x_h, n_h),
            Equals(x_v, n_v))))

    h_eq_0_v_lt_0 = And(Equals(h, Real(0)), LT(v, Real(0)))
    h_eq_0_v_le_0 = And(Equals(h, Real(0)), LE(v, Real(0)))
    trans = And(
        # (h = 0 & v < 0) -> next(c) = c + 1
        Implies(h_eq_0_v_lt_0, Equals(x_c, Plus(c, Int(1)))),
        # !(h = 0 & v < 0) -> next(c) = c
        Implies(Not(h_eq_0_v_lt_0), Equals(x_c, c)),
        # (stop & h = 0) -> (next(h) = 0 & next(v) = 0)
        Implies(And(stop, Equals(h, Real(0))),
                And(Equals(x_h, Real(0)), Equals(x_v, Real(0)))),
        # (!stop & h = 0 & v <= 0) -> (next(h) = 0 & next(v) = - v * counter / (counter + 1.0));
        Implies(And(Not(stop), h_eq_0_v_le_0),
                And(Equals(x_h, Real(0)),
                    Equals(x_v, Times(Real(-1), v, Div(c_r, Plus(c_r, Real(1))))))),
        # (!stop & !(h = 0 & v <= 0)) -> (next(h) = n_h & next(v) = n_v);
        Implies(And(Not(stop), Not(h_eq_0_v_le_0)),
                And(Equals(x_h, n_h), Equals(x_v, n_v)))
    )

    # fairness.
    # fairness = GE(d, Div(Real(1.0), c_r))
    fairness = And(Equals(h, Real(0)), GT(v, Real(0)))

    # define automata to be composed.

    aut_c = AGAutomaton(symbols, [c], "aut_c", 1)
    aut_c.set_assume(0, TRUE())
    aut_c.set_invar(0, GE(c, Int(1)))
    aut_c.set_transitions(0, [(0, [Equals(x_c, c),
                                   Equals(x_c, Plus(c, Int(1)))])])

    aut_c1 = AGAutomaton(symbols, [c], "aut_c1", 1)
    aut_c1.set_assume(0, TRUE())
    aut_c1.set_invar(0, GE(c, Int(10)))
    aut_c1.set_transitions(0, [(0, [Equals(x_c, c),
                                    Equals(x_c, Plus(c, Int(1)))])])

    aut_h = AGAutomaton(symbols, [h], "aut_h", 2)
    aut_h.set_assume(0, Equals(d, Real(0)))
    aut_h.set_invar(0, Equals(h, Real(0)))
    aut_h.set_transitions(0, [(0, [Equals(x_h, h)]),
                              (1, [Equals(x_h, h)])])
    aut_h.set_assume(1, Equals(d, Times(Div(Real(2), g), v)))
    aut_h.set_invar(1, Equals(h, Real(0)))
    aut_h.set_transitions(1, [(0, [Equals(x_h, n_h),
                                   Equals(x_h, h)])])

    aut_h1 = AGAutomaton(symbols, [h], "aut_h1", 1)
    aut_h1.set_assume(0, Equals(d, Real(0)))
    aut_h1.set_invar(0, Equals(h, Real(0)))
    aut_h1.set_transitions(0, [(0, [Equals(x_h, h)])])

    aut_dv = AGAutomaton(symbols, [d, v], "aut_dv", 2)
    aut_dv.set_assume(0, And(GE(c, Int(1)), Equals(h, Real(0))))
    aut_dv.set_invar(0, And(Equals(d, Real(0)), Equals(v, Div(g, Times(Real(-2), c_r)))))
    aut_dv.set_transitions(0, [(1, [And(Equals(x_d, Div(Real(1), Plus(c_r, Real(1)))),
                                        Equals(x_v, Times(Real(-1), v,
                                                          Div(c_r, Plus(c_r, Real(1))))))])])
    aut_dv.set_assume(1, And(GE(c, Int(1)), Equals(h, Real(0))))
    aut_dv.set_invar(1, And(Equals(d, Div(Real(1), c_r)),
                            Equals(v, Div(g, Times(Real(2), c_r)))))
    aut_dv.set_transitions(1, [(0, [And(Equals(x_d, Real(0)),
                                        Equals(x_v, n_v))])])

    aut_dv1 = AGAutomaton(symbols, [d, v], "aut_dv1", 2)
    aut_dv1.set_assume(0, And(GE(c, Int(2)), Equals(h, Real(0))))
    aut_dv1.set_invar(0, And(Equals(d, Real(0)), Equals(v, Div(g, Times(Real(-2), c_r)))))
    aut_dv1.set_transitions(0, [(1, [And(Equals(x_d, Div(Real(1), Plus(c_r, Real(1)))),
                                        Equals(x_v, Times(Real(-1), v,
                                                          Div(c_r, Plus(c_r, Real(1))))))])])
    aut_dv1.set_assume(1, And(GE(c, Int(2)), Equals(h, Real(0))))
    aut_dv1.set_invar(1, And(Equals(d, Div(Real(1), c_r)),
                            Equals(v, Div(g, Times(Real(2), c_r)))))
    aut_dv1.set_transitions(1, [(0, [And(Equals(x_d, Real(0)),
                                        Equals(x_v, n_v))])])

    aut_stop = AGAutomaton(symbols, [stop], "aut_stop", 1)
    aut_stop.set_invar(0, Not(stop))
    aut_stop.set_transitions(0, [(0, [Not(x_stop)])])

    aut_stop1 = AGAutomaton(symbols, [stop], "aut_stop1", 2)
    aut_stop1.set_invar(0, stop)
    aut_stop1.set_transitions(0, [(0, [x_stop]), (1, [x_stop])])
    aut_stop1.set_invar(1, stop)
    aut_stop1.set_assume(1, GE(v, Real(0)))
    aut_stop1.set_transitions(1, [(0, [x_stop]), (1, [x_stop])])

    aut_stop2 = AGAutomaton(symbols, [stop], "aut_stop2", 1)
    aut_stop2.set_invar(0, stop)
    aut_stop2.set_transitions(0, [(0, [x_stop])])

    automata = [aut_c, aut_h, aut_dv, aut_stop,
                aut_stop1, aut_dv1, aut_h1, aut_c1,
                aut_stop2]

    # search composition.
    comp, undefs = find_composition(automata, init, trans, fairness,
                                    nuxmv_path, model_file, trace_file,
                                    cmd_file)
    if comp is not None:
        with open(output_file, 'w') as out:
            out.write(str(comp))

    if comp and not undefs:
        res = True
    elif not comp:
        res = False
    else:
        assert comp and undefs
        res = None
    return res
