from find_composition import find_composition
from automata_composition import AGAutomaton
from utils import symb_next

from pysmt.shortcuts import Symbol, TRUE, Int, Real
from pysmt.shortcuts import Not, And, Implies
from pysmt.shortcuts import GE, Equals, LT
from pysmt.shortcuts import Plus, Times, Div
from pysmt.typing import INT, REAL


def test(nuxmv_path: str, model_file: str, trace_file: str, cmd_file: str,
         output_file: str) -> bool:
    """Example 1 from document"""
    # symbols.
    pc = Symbol("pc", INT)
    i = Symbol("i", REAL)
    j = Symbol("j", REAL)

    x_pc = Symbol(symb_next("pc"), INT)
    x_i = Symbol(symb_next("i"), REAL)
    x_j = Symbol(symb_next("j"), REAL)

    symbols = [pc, i, j]
    # constants
    zero_r = Real(0)
    zero_i = Int(0)
    one_r = Real(1)
    one_i = Int(1)
    two_i = Int(2)
    three_i = Int(3)
    three_r = Real(3)
    four_r = Real(4)

    # initial condition.
    init = Equals(pc, zero_i)
    # transition relation.
    trans = And(Implies(
        # pc = 0 & i < 0 -> pc' = 3,
        And(Equals(pc, zero_i), LT(i, zero_r)),
        Equals(x_pc, three_i)),
        # pc = 0 & i >= 0 -> pc' = 1
        Implies(And(Equals(pc, zero_i), GE(i, zero_r)),
                And(Equals(x_pc, one_i), Equals(x_i, i),
                    Equals(x_j, j))),
        # pc = 1 -> pc' = 2 & i' = i + j
        Implies(Equals(pc, one_i),
                And(Equals(x_pc, two_i), Equals(x_j, j),
                    Equals(x_i, Plus(i, j)))),
        # pc = 2 -> pc' = 0 & j' = j^3/3 + 1
        Implies(Equals(pc, two_i),
                And(Equals(x_pc, zero_i), Equals(x_i, i),
                    Equals(x_j,
                           Plus(Div(Times(j, j, j), three_r),
                                one_r)))),
        # pc = 3 -> pc' = 3
        Implies(Equals(pc, three_i), Equals(x_pc, three_i))
    )
    # fairness condition.
    fairness = Not(Equals(pc, three_i))

    aut_L = AGAutomaton(symbols, [pc], "Loop", 3)
    aut_L.set_invar(0, Equals(pc, zero_i))
    aut_L.set_invar(1, Equals(pc, one_i))
    aut_L.set_invar(2, Equals(pc, two_i))
    aut_L.set_transitions(0, [(1, [Equals(x_pc, one_i)])])
    aut_L.set_transitions(1, [(2, [Equals(x_pc, two_i)])])
    aut_L.set_transitions(2, [(0, [Equals(x_pc, zero_i)])])
    for loc_idx in range(3):
        aut_L.set_assume(loc_idx, TRUE())

    aut_F = AGAutomaton(symbols, [i], "Far", 1)
    aut_F.set_invar(0, GE(i, zero_r))
    aut_F.set_assume(0, GE(j, zero_r))
    aut_F.set_transitions(0, [(0, [Equals(x_i, i),
                                   Equals(x_i, Plus(i, j))])])

    aut_D = AGAutomaton(symbols, [j], "Div", 1)
    aut_D.set_invar(0, GE(j, four_r))
    aut_D.set_assume(0, TRUE())
    aut_D.set_transitions(0, [(0, [Equals(x_j, j),
                                   Equals(x_j,
                                          Plus(Div(Times(j, j, j), three_r),
                                               one_r))])])

    automata = [aut_L, aut_F, aut_D]

    comp, undefs = find_composition(automata, init, trans, fairness,
                                    nuxmv_path, model_file, trace_file,
                                    cmd_file)
    if comp is not None:
        with open(output_file, 'w') as out:
            out.write(str(comp))

    if comp and not undefs:
        res = True
    elif not comp:
        res = False
    else:
        assert not comp and undefs
        res = None
    return res
