from find_composition import find_composition
from automata_composition import AGAutomaton
from utils import symb_next, to_next

from pysmt.shortcuts import Symbol, TRUE, Real
from pysmt.shortcuts import Not, And, Or, Implies
from pysmt.shortcuts import Plus, Minus, Times, Div
from pysmt.shortcuts import GT, GE, Equals, LE, LT
from pysmt.shortcuts import simplify
from pysmt.typing import REAL, BOOL


def test(nuxmv_path: str, model_file: str, trace_file: str, cmd_file: str,
         output_file: str) -> bool:
    """Test corresponding to benchmarks/etcs.smv"""
    # symbols.
    delta = Symbol("delta", REAL)
    rbc_s = Symbol("rbc_s", BOOL)
    rbc_m = Symbol("rbc_m", REAL)
    rbc_d = Symbol("rbc_d", REAL)
    rbc_v_des = Symbol("rbc_v_des", REAL)
    t_c = Symbol("train_c", REAL)
    t_z = Symbol("train_z", REAL)
    t_v = Symbol("train_v", REAL)
    t_a = Symbol("train_a", REAL)

    x_delta = Symbol(symb_next("delta"), REAL)
    x_rbc_s = Symbol(symb_next("rbc_s"), BOOL)
    x_rbc_m = Symbol(symb_next("rbc_m"), REAL)
    x_rbc_d = Symbol(symb_next("rbc_d"), REAL)
    x_rbc_v_des = Symbol(symb_next("rbc_v_des"), REAL)
    x_t_c = Symbol(symb_next("train_c"), REAL)
    x_t_z = Symbol(symb_next("train_z"), REAL)
    x_t_v = Symbol(symb_next("train_v"), REAL)
    x_t_a = Symbol(symb_next("train_a"), REAL)

    symbols = [delta, rbc_s, rbc_m, rbc_d, rbc_v_des, t_c, t_z, t_v, t_a]

    max_brake = Real(2)
    max_acc = Real(4)
    period = Real(1)
    zero = Real(0)
    rbc_brake = rbc_s
    rbc_drive = Not(rbc_s)
    t_SB = Plus(Div(Minus(Times(t_v, t_v), Times(rbc_d, rbc_d)),
                    Times(Real(2), max_brake)),
                Times(Plus(Div(max_acc, max_brake), Real(1)),
                      Plus(Div(max_acc, Times(Real(2), period, period)),
                           Times(period, t_v))))
    x_rbc_brake = x_rbc_s
    x_rbc_drive = Not(x_rbc_s)

    # initial location.
    init = simplify(And(rbc_brake, Equals(rbc_v_des, zero),
                        LE(t_z, rbc_m), Equals(t_v, zero), Equals(t_a, zero),
                        Equals(t_c, zero)))
    invar = simplify(And(GE(delta, zero), GE(rbc_m, zero), GE(rbc_d, zero),
                         GE(rbc_v_des, zero), LE(t_c, period), GE(t_z, zero),
                         LE(t_a, max_acc), LE(Times(Real(-1), max_brake), t_a),
                         Implies(GE(t_v, rbc_v_des),
                                 LE(t_a, zero))))

    init = simplify(And(init, invar))

    x_invar = to_next(invar, symbols)

    # transition relation.
    trans = And(
        x_invar,
        # RadioBlockController
        Implies(x_rbc_brake, And(Equals(x_rbc_m, rbc_m),
                                 Equals(x_rbc_d, rbc_d))),
        Implies(x_rbc_drive, LE(Minus(Times(rbc_d, rbc_d),
                                      Times(x_rbc_d, x_rbc_d)),
                                Times(Real(2), max_brake,
                                      Minus(x_rbc_m, rbc_m)))),
        # Train
        Implies(LT(t_c, period), Equals(x_t_c, Plus(t_c, delta))),
        Implies(GE(t_c, period), And(Equals(x_t_c, zero),
                                     Equals(delta, zero))),
        Equals(x_t_z, Plus(t_z, Times(t_v, delta),
                           Div(Times(t_a, delta, delta), Real(2)))),
        Equals(x_t_v, Plus(t_v, Times(t_a, delta))),
        Implies(And(Equals(t_c, period), Equals(x_t_c, zero),
                    Or(LE(Minus(rbc_m, t_z), t_SB), rbc_brake)),
                Equals(x_t_a, Times(Real(-1), max_brake))),
        Implies(Not(And(Equals(t_c, period), Equals(x_t_c, zero))),
                Equals(x_t_a, t_a))
    )

    trans = simplify(trans)

    # fairness.
    fairness = And(GT(delta, zero), GT(t_v, zero))

    # define automata to be composed.
    aut_v_des = AGAutomaton(symbols, [rbc_v_des], "aut_v_des", 1)
    aut_v_des.set_assume(0, GE(rbc_d, zero))
    aut_v_des.set_invar(0, GE(rbc_v_des, rbc_d))
    aut_v_des.set_transitions(0, [(0, [Equals(x_rbc_v_des, rbc_v_des)])])

    aut_d = AGAutomaton(symbols, [rbc_d], "aut_d", 1)
    aut_d.set_assume(0, TRUE())
    aut_d.set_invar(0, GE(rbc_d, zero))
    aut_d.set_transitions(0, [(0, [Equals(x_rbc_d, rbc_d)])])

    aut_s = AGAutomaton(symbols, [rbc_s], "aut_s", 1)
    aut_s.set_assume(0, TRUE())
    aut_s.set_invar(0, rbc_drive)
    aut_s.set_transitions(0, [(0, [x_rbc_drive])])

    aut_c = AGAutomaton(symbols, [t_c], "aut_c", 2)
    aut_c.set_assume(0, Equals(delta, period))
    aut_c.set_invar(0, Equals(t_c, zero))
    aut_c.set_transitions(0, [(1, [Equals(x_t_c, period)])])
    aut_c.set_assume(1, Equals(delta, zero))
    aut_c.set_invar(1, Equals(t_c, period))
    aut_c.set_transitions(1, [(0, [Equals(x_t_c, zero)])])

    aut_delta = AGAutomaton(symbols, [delta], "aut_delta", 2)
    aut_delta.set_assume(0, TRUE())
    aut_delta.set_invar(0, Equals(delta, period))
    aut_delta.set_transitions(0, [(1, [Equals(x_delta, zero)])])
    aut_delta.set_assume(1, TRUE())
    aut_delta.set_invar(1, Equals(delta, zero))
    aut_delta.set_transitions(1, [(0, [Equals(x_delta, period)])])

    aut_v = AGAutomaton(symbols, [t_v], "aut_v", 1)
    aut_v.set_assume(0, And(Equals(t_a, zero), GE(delta, zero)))
    aut_v.set_invar(0, And(GT(t_v, zero), LE(t_v, Real(4))))
    aut_v.set_transitions(0, [(0, [Equals(x_t_v,
                                          Plus(t_v, Times(t_a, delta)))])])

    aut_a = AGAutomaton(symbols, [t_a], "aut_a", 1)
    aut_a.set_assume(0, TRUE())
    aut_a.set_invar(0, Equals(t_a, zero))
    aut_a.set_transitions(0, [(0, [Equals(x_t_a, zero)])])

    aut_mz = AGAutomaton(symbols, [rbc_m, t_z], "aut_mz", 1)
    aut_mz.set_assume(0, And(GE(delta, zero), GE(t_v, zero), GE(t_a, zero),
                             Equals(rbc_d, zero), LE(rbc_d, t_v),
                             Equals(rbc_d, zero), LE(t_v, Real(4))))
    aut_mz.set_invar(0, And(GE(rbc_m, t_z), GE(t_z, zero),
                            GT(Minus(rbc_m, t_z), Real(22))))
    aut_mz.set_transitions(0, [(0, [And(GE(x_rbc_m,
                                           Plus(rbc_m, Times(t_v, delta),
                                                Div(Times(t_a, delta,
                                                          delta),
                                                    Real(2)))),
                                        GT(x_rbc_m, Plus(t_SB, x_t_z)),
                                        Equals(x_t_z,
                                               Plus(t_z, Times(t_v, delta),
                                                    Div(Times(t_a, delta,
                                                              delta),
                                                        Real(2)))))])])

    aut_v_des_v = AGAutomaton(symbols, [t_v, rbc_v_des], "aut_v_des_v", 1)
    aut_v_des_v.set_assume(0, And(GE(t_a, zero), GE(delta, zero),
                                  GE(rbc_d, zero)))
    aut_v_des_v.set_invar(0, And(GT(t_v, zero), GE(t_v, rbc_v_des),
                                 GE(rbc_v_des, rbc_d)))
    aut_v_des_v.set_transitions(0, [(0,
                                     [And(Equals(x_rbc_v_des, rbc_v_des),
                                          Equals(x_t_v,
                                                 Plus(t_v,
                                                      Times(t_a, delta))))])])

    automata = [aut_a, aut_v, aut_c, aut_d, aut_v_des, aut_s, aut_mz,
                aut_delta, aut_v_des_v]

    # search composition.
    comp, undefs = find_composition(automata, init, trans, fairness,
                                    nuxmv_path, model_file, trace_file,
                                    cmd_file)
    if comp is not None:
        with open(output_file, 'w') as out:
            out.write(str(comp))

    if comp and not undefs:
        res = True
    elif not comp:
        res = False
    else:
        assert not comp and undefs
        res = None
    return res
