from find_composition import find_composition
from automata_composition import AGAutomaton
from utils import symb_next, to_next

from pysmt.shortcuts import Symbol, TRUE, Real, Int
from pysmt.shortcuts import And, Or, Implies, Ite
from pysmt.shortcuts import GT, GE, Equals, LE, LT
from pysmt.shortcuts import Plus, Minus, Times, Div
from pysmt.typing import REAL, INT
from pysmt.fnode import FNode


def test(nuxmv_path: str, model_file: str, trace_file: str, cmd_file: str,
         output_file: str) -> bool:
    """Test corresponding to benchmarks/tanks.smv"""
    # symbols.
    zero = Real(0)
    delta = Symbol("delta", REAL)
    x_delta = Symbol(symb_next("delta"), REAL)

    # components
    max_flows = [Real(10), Real(1), Real(1)]
    max_speeds = [Real(1), Real(1), Real(1)]
    assert len(max_speeds) == len(max_flows)
    pipes = [Pipe("pipe{}".format(i), delta, x_delta,
                  max_flows[i], max_speeds[i]) for i in range(len(max_speeds))]

    max_vols = [Real(100), Real(10)]
    in_pipes = [[pipes[0]], [pipes[1]]]
    out_pipes = [[pipes[1]], [pipes[2]]]
    assert len(max_vols) == len(in_pipes)
    assert len(max_vols) == len(out_pipes)

    in_flows = [Plus([f.flow for f in in_pipes[i]])
                if in_pipes[i] else Real(0)
                for i in range(len(max_vols))]
    out_flows = [Plus([f.flow for f in out_pipes[i]])
                 if out_pipes[i] else Real(0)
                 for i in range(len(max_vols))]
    flows = [Minus(in_flows[i], out_flows[i])
             for i in range(len(max_vols))]

    in_d_flows = [Plus([f.d_flow for f in in_pipes[i]])
                  if in_pipes[i] else Real(0)
                  for i in range(len(max_vols))]
    out_d_flows = [Plus([f.d_flow for f in out_pipes[i]])
                   if out_pipes[i] else Real(0)
                   for i in range(len(max_vols))]
    d_flows = [Minus(in_d_flows[i], out_d_flows[i])
               for i in range(len(max_vols))]

    tanks = [Tank("tank{}".format(i), delta, x_delta, max_vols[i],
                  flows[i], d_flows[i]) for i in range(len(max_vols))]
    components = pipes + tanks

    symbols = [delta]
    for comp in components:
        symbols += comp.symbols

    # initial location.
    init = And([And(comp.init, comp.invar) for comp in components])
    init = And(init, GE(delta, Real(0)))

    # transition relation.
    trans = And([And(to_next(comp.invar, symbols), comp.trans)
                 for comp in components])
    trans = And(trans, GE(x_delta, Real(0)))

    # fairness.
    # fairness = GT(delta, zero)
    fairness = GT(pipes[-1].flow, zero)

    # define automata to be composed.
    seven = Real(7)
    r35 = Real(35)
    aut_delta = AGAutomaton(symbols, [delta], "aut_delta", 3)
    aut_delta.set_assume(0, TRUE())
    aut_delta.set_invar(0, Equals(delta, seven))
    aut_delta.set_transitions(0, [(1, [Equals(x_delta, seven)])])
    aut_delta.set_assume(1, TRUE())
    aut_delta.set_invar(1, Equals(delta, seven))
    aut_delta.set_transitions(1, [(2, [Equals(x_delta, r35)])])
    aut_delta.set_assume(2, TRUE())
    aut_delta.set_invar(2, Equals(delta, r35))
    aut_delta.set_transitions(2, [(0, [Equals(x_delta, seven)])])

    # source pipe, input to tanks[0]
    aut_pipe0 = AGAutomaton(symbols, [pipes[0].mode, pipes[0].flow],
                            "aut_pipe0", 3)
    aut_pipe0.set_invar(0, And(pipes[0].opening, Equals(pipes[0].flow, zero)))
    aut_pipe0.set_assume(0, Equals(delta, seven))
    aut_pipe0.set_transitions(0, [(1, [And(pipes[0].x_closing,
                                           Equals(pipes[0].x_flow, seven))])])
    aut_pipe0.set_invar(1, And(pipes[0].closing, Equals(pipes[0].flow, seven)))
    aut_pipe0.set_assume(1, Equals(delta, seven))
    aut_pipe0.set_transitions(1, [(2, [And(pipes[0].x_close,
                                           Equals(pipes[0].x_flow, zero))])])
    aut_pipe0.set_invar(2, And(pipes[0].close, Equals(pipes[0].flow, zero)))
    aut_pipe0.set_assume(2, Equals(delta, r35))
    aut_pipe0.set_transitions(2, [(0, [And(pipes[0].x_opening,
                                           Equals(pipes[0].x_flow, zero))])])

    # pipes[1] connects the 2 tanks.
    aut_pipe1 = AGAutomaton(symbols, [pipes[1].mode, pipes[1].flow],
                            "aut_pipe1", 1)
    aut_pipe1.set_invar(0, And(pipes[1].open,
                               Equals(pipes[1].flow, pipes[1].max_flow)))
    aut_pipe1.set_assume(0, And(GT(tanks[0].vol, zero),
                                LT(tanks[1].vol, tanks[1].max_vol)))
    aut_pipe1.set_transitions(0, [(0, [And(pipes[1].x_open,
                                           Equals(pipes[1].x_flow,
                                                  pipes[1].max_flow))])])

    # pipes[2] throws away from tanks[1].
    aut_pipe2 = AGAutomaton(symbols, [pipes[2].mode, pipes[2].flow],
                            "aut_pipe2", 1)
    aut_pipe2.set_invar(0, And(pipes[2].open,
                               Equals(pipes[2].flow, pipes[2].max_flow)))
    aut_pipe2.set_assume(0, GT(tanks[1].vol, zero))
    aut_pipe2.set_transitions(0, [(0, [And(pipes[2].x_open,
                                           Equals(pipes[2].x_flow,
                                                  pipes[2].max_flow))])])

    # tank0 volume between 5 and 50.
    div_45_2 = Div(Real(45), Real(2))
    r40 = Real(40)
    r5 = Real(5)
    aut_tank0 = AGAutomaton(symbols, [tanks[0].vol], "aut_tank0", 3)
    aut_tank0.set_invar(0, Equals(tanks[0].vol, r5))
    aut_tank0.set_assume(0, And(Equals(tanks[0].flow,
                                       Times(Real(-1), pipes[1].max_flow)),
                                Equals(tanks[0].d_flow, Real(1)),
                                Equals(delta, seven)))
    aut_tank0.set_transitions(0, [(1, [Equals(tanks[0].x_vol, div_45_2)])])
    aut_tank0.set_invar(1, Equals(tanks[0].vol, div_45_2))
    aut_tank0.set_assume(1, And(Equals(tanks[0].flow, Real(6)),
                                Equals(tanks[0].d_flow, Real(-1)),
                                Equals(delta, seven)))
    aut_tank0.set_transitions(1, [(2, [Equals(tanks[0].x_vol, r40)])])
    aut_tank0.set_invar(2, Equals(tanks[0].vol, r40))
    aut_tank0.set_assume(2, And(Equals(tanks[0].flow,
                                       Times(Real(-1), pipes[1].max_flow)),
                                Equals(tanks[0].d_flow, zero),
                                Equals(delta, r35)))
    aut_tank0.set_transitions(2, [(0, [Equals(tanks[0].x_vol, r5)])])

    # tank1 volume remains constant.
    aut_tank1 = AGAutomaton(symbols, [tanks[1].vol], "aut_tank1", 1)
    aut_tank1.set_invar(0, And(LE(Real(5), tanks[1].vol),
                               LE(tanks[1].vol, tanks[1].max_vol)))
    aut_tank1.set_assume(0, And(Equals(tanks[1].flow, zero),
                                Equals(tanks[1].d_flow, zero)))
    aut_tank1.set_transitions(0, [(0, [Equals(tanks[1].x_vol, tanks[1].vol)])])

    automata = [aut_delta, aut_pipe0, aut_pipe1, aut_pipe2, aut_tank0,
                aut_tank1]

    # search composition.
    comp, undefs = find_composition(automata, init, trans, fairness,
                                    nuxmv_path, model_file, trace_file,
                                    cmd_file)
    if comp is not None:
        with open(output_file, 'w') as out:
            out.write(str(comp))

    if comp and not undefs:
        res = True
    elif not comp:
        res = False
    else:
        assert not comp and undefs
        res = None
    return res


class Tank:
    """Transition system describing a Tank"""
    def __init__(self, name: str, delta: FNode, x_delta: FNode, max_vol: FNode,
                 flow: FNode, d_flow: FNode):
        self.name = name
        self.max_vol = max_vol
        self.flow = flow
        self.d_flow = d_flow
        self.delta = delta
        self.x_delta = x_delta
        self.vol = Symbol("{}_vol".format(name), REAL)

        self.x_vol = Symbol(symb_next("{}_vol".format(name)), REAL)

    @property
    def symbols(self) -> list:
        """ Return list of current-state symbols"""
        return [self.vol]

    @property
    def init(self) -> FNode:
        """Return formula representing the initial states"""
        return Equals(self.vol, Real(0))

    @property
    def invar(self) -> FNode:
        """Return formula representing the invariant"""
        return And(LE(Real(0), self.vol), LE(self.vol, self.max_vol))

    @property
    def trans(self) -> FNode:
        """Returns formula representing the transition relation"""
        return Equals(self.x_vol,
                      Plus(self.vol,
                           Times(self.flow, self.delta),
                           Div(Times(self.d_flow, self.delta, self.delta),
                               Real(2))))


class Pipe:
    """Transition system describing a Pipe."""
    def __init__(self, name: str, delta: FNode, x_delta: FNode,
                 max_flow: FNode, speed: FNode):
        self.name = name
        self.delta = delta
        self.x_delta = x_delta
        self.max_flow = max_flow
        self.speed = speed

        self.mode = Symbol("{}_mode".format(name), INT)
        self.flow = Symbol("{}_flow".format(name), REAL)
        self.x_mode = Symbol(symb_next("{}_mode".format(name)), INT)
        self.x_flow = Symbol(symb_next("{}_flow".format(name)), REAL)

        self.d_flow = Ite(Or(self.close, self.open), Real(0),
                          Ite(self.opening, Real(1), Real(-1)))

    @property
    def symbols(self) -> list:
        """Return list of current-state symbols"""
        return [self.mode, self.flow]

    @property
    def close(self) -> FNode:
        return Equals(self.mode, Int(0))

    @property
    def x_close(self) -> FNode:
        return Equals(self.x_mode, Int(0))

    @property
    def opening(self) -> FNode:
        return Equals(self.mode, Int(1))

    @property
    def x_opening(self) -> FNode:
        return Equals(self.x_mode, Int(1))

    @property
    def open(self) -> FNode:
        return Equals(self.mode, Int(2))

    @property
    def x_open(self) -> FNode:
        return Equals(self.x_mode, Int(2))

    @property
    def closing(self) -> FNode:
        return Equals(self.mode, Int(3))

    @property
    def x_closing(self) -> FNode:
        return Equals(self.x_mode, Int(3))

    @property
    def init(self) -> FNode:
        """Return formula representing the initial states"""
        return self.close

    @property
    def invar(self) -> FNode:
        """Return formula representing the invariant"""
        zero = Real(0)
        return And(Implies(self.close, Equals(self.flow, zero)),
                   Implies(self.open, Equals(self.flow, self.max_flow)),
                   LE(zero, self.flow), LE(self.flow, self.max_flow))

    @property
    def trans(self) -> FNode:
        """Returns formula representing the transition relation"""
        zero = Real(0)

        return And(
            Implies(self.opening,
                    Equals(self.x_flow,
                           Plus(self.flow, Times(self.speed, self.delta)))),
            Implies(self.closing,
                    Equals(self.x_flow,
                           Minus(self.flow, Times(self.speed, self.delta)))),
            Implies(And(self.opening,
                        Equals(Plus(self.flow, Times(self.speed, self.delta)),
                               self.max_flow)),
                    Or(self.x_open, self.x_closing)),
            Implies(And(self.opening,
                        LT(Plus(self.flow, Times(self.speed, self.delta)),
                           self.max_flow)),
                    Or(self.x_opening, self.x_closing)),
            Implies(And(self.closing,
                        Equals(Minus(self.flow, Times(self.speed, self.delta)),
                               zero)),
                    Or(self.x_close, self.x_opening)),
            Implies(And(self.closing,
                        GT(Minus(self.flow, Times(self.speed, self.delta)),
                           zero)),
                    Or(self.x_closing, self.x_opening)))
