from find_composition import find_composition
from automata_composition import AGAutomaton
from utils import symb_next

from pysmt.shortcuts import Symbol, TRUE, Real
from pysmt.shortcuts import Not, And, Or, Implies
from pysmt.shortcuts import GT, GE, Equals, LE, LT
from pysmt.shortcuts import Plus, Minus, Times, Div
from pysmt.typing import REAL


def test(nuxmv_path: str, model_file: str, trace_file: str, cmd_file: str,
         output_file: str) -> bool:
    """Test corresponding to benchmarks/1.smv"""
    # symbols.
    h = Symbol("h", REAL)
    v = Symbol("v", REAL)
    d = Symbol("d", REAL)
    x_h = Symbol(symb_next("h"), REAL)
    x_v = Symbol(symb_next("v"), REAL)
    x_d = Symbol(symb_next("d"), REAL)
    symbols = [h, v, d]

    g = Div(Real(981), Real(100))

    # initial location.
    init = And(Equals(h, Real(0)), Equals(v, g))

    min_half = Div(Real(-1), Real(2))

    # transition relation.
    n_h = Plus(h, Minus(Times(v, d), Times(Div(Real(1), Real(2)), g, d, d)))
    n_v = Minus(v, Times(g, d))
    trans = And(
        Implies(And(Equals(h, Real(0)), LT(v, Real(0))),
                And(Equals(x_h, Real(0)), Equals(x_v, Times(min_half, v)))),
        Implies(Not(And(Equals(h, Real(0)), LT(v, Real(0)))),
                And(Equals(x_h, n_h),
                    Equals(x_v, n_v)))
    )

    # fairness.
    # fairness = GE(d, Real(0))
    fairness = Equals(h, Real(0))

    # define automata to be composed.
    aut_dhv = AGAutomaton(symbols, [d, h, v], "aut_dhv", 3)
    aut_dhv.set_assume(0, TRUE())
    aut_dhv.set_assume(1, TRUE())
    aut_dhv.set_assume(2, TRUE())
    aut_dhv.set_invar(0, And(Equals(h, Real(0)), Equals(d, Real(0)),
                             GE(v, Real(0))))
    aut_dhv.set_transitions(0, [(0, [And(Equals(x_h, h), Equals(x_v, v),
                                         Equals(x_d, d))]),
                                (1, [And(Equals(x_h, h), Equals(x_v, v),
                                         Equals(x_d,
                                                Times(Div(Real(2), g), v)))])])
    aut_dhv.set_invar(1, And(Equals(h, Real(0)), GE(v, Real(0)),
                             Equals(d, Times(Div(Real(2), g), v))))
    aut_dhv.set_transitions(1, [(2, [And(Equals(x_d, Real(0)),
                                         Equals(x_v, n_v),
                                         Equals(x_h, n_h))])])
    aut_dhv.set_invar(2, And(Equals(h, Real(0)), LE(v, Real(0)),
                             Equals(d, Real(0))))
    aut_dhv.set_transitions(2, [(0, [And(Equals(x_h, h), Equals(x_d, Real(0)),
                                         Equals(x_v, Times(min_half, v)))]),
                                (1, [And(Equals(x_h, h),
                                         Equals(x_d, Times(Div(Real(-1), g),
                                                           v)),
                                         Equals(x_v, Times(min_half, v)))])])

    aut_h = AGAutomaton(symbols, [h], "aut_h", 1)
    aut_h.set_assume(0, Or(Equals(d, Real(0)),
                           Equals(d, Times(Div(Real(2), g), v))))
    aut_h.set_invar(0, Equals(h, Real(0)))
    aut_h.set_transitions(0, [(0, [Equals(x_h, h),
                                   Equals(x_h, n_h)])])

    aut_dv = AGAutomaton(symbols, [d, v], "aut_dv", 3)
    aut_dv.set_assume(0, TRUE())
    aut_dv.set_assume(1, TRUE())
    aut_dv.set_assume(2, TRUE())
    aut_dv.set_invar(0, And(Equals(d, Real(0)), GE(v, Real(0))))
    aut_dv.set_transitions(0, [(0, [And(Equals(x_v, v), Equals(x_d, d))]),
                               (1, [And(Equals(x_v, v),
                                        Equals(x_d,
                                               Times(Div(Real(2), g), v)))])])
    aut_dv.set_invar(1, And(GE(v, Real(0)),
                            Equals(d, Times(Div(Real(2), g), v))))
    aut_dv.set_transitions(1, [(2, [And(Equals(x_d, Real(0)),
                                        Equals(x_v, n_v))])])
    aut_dv.set_invar(2, And(LE(v, Real(0)), Equals(d, Real(0))))
    aut_dv.set_transitions(2, [(0, [And(Equals(x_d, Real(0)),
                                        Equals(x_v, Times(min_half, v)))]),
                               (1, [And(Equals(x_d, Times(Div(Real(-1), g),
                                                          v)),
                                        Equals(x_v, Times(min_half, v)))])])

    # automata = [aut_dhv]
    automata = [aut_h, aut_dv]

    # search composition.
    comp, undefs = find_composition(automata, init, trans, fairness,
                                    nuxmv_path, model_file, trace_file,
                                    cmd_file)
    if comp is not None:
        with open(output_file, 'w') as out:
            out.write(str(comp))

    if comp and not undefs:
        res = True
    elif not comp:
        res = False
    else:
        assert not comp and undefs
        res = None
    return res
