from find_composition import find_composition
from automata_composition import AGAutomaton
from utils import symb_next, to_next

from pysmt.shortcuts import Symbol, TRUE, Int, Real
from pysmt.shortcuts import Not, And, Or, Implies, Iff
from pysmt.shortcuts import GT, GE, Equals, LE
from pysmt.shortcuts import Plus, Minus, Div
from pysmt.typing import INT, REAL
from pysmt.fnode import FNode


def test(nuxmv_path: str, model_file: str, trace_file: str, cmd_file: str,
         output_file: str) -> bool:
    """Test corresponding to benchmarks/synchronisation_ttethernet_network.smv"""
    # symbols.
    zero = Real(0)
    delay = Real(2)
    drifts = [Real(3), Real(2), Real(4), Real(1), Real(10)]
    delta = Symbol("delta", REAL)
    x_delta = Symbol(symb_next("delta"), REAL)
    cms = [ComprMaster("cm{}".format(i), delay, delta, x_delta)
           for i in range(2)]
    sms = [SyncMaster("sm{}".format(i), drifts[i], delta, x_delta)
           for i in range(5)]
    components = cms + sms
    symbols = [delta]
    for comp in components:
        symbols += comp.symbols

    sync_constr = sync_on_broadcast_messages(cms, sms)
    for cm in cms:
        sync_constr = And(sync_constr,
                          And([sync_comprmaster_syncmaster(cm, sm)
                               for sm in sms]))

    # initial location.
    init = And([And(comp.init, comp.invar) for comp in components])
    init = And(init, GE(delta, Real(0)))

    # transition relation.
    trans = And([And(to_next(comp.invar, symbols), comp.trans)
                 for comp in components])
    trans = And(trans, GE(x_delta, Real(0)), sync_constr)

    # fairness.
    cm_val = Div(Plus([sm.sm for sm in sms]), Real(len(sms)))
    fairness = And([Equals(cm.cm, cm_val) for cm in cms])

    # define automata to be composed.
    aut_delta = AGAutomaton(symbols, [delta], "aut_delta", 2)
    aut_delta.set_invar(0, Equals(delta, delay))
    aut_delta.set_assume(0, TRUE())
    aut_delta.set_transitions(0, [(1, [Equals(x_delta, zero)])])
    aut_delta.set_invar(1, Equals(delta, zero))
    aut_delta.set_assume(1, TRUE())
    aut_delta.set_transitions(1, [(0, [Equals(x_delta, delay)]),
                                  (1, [Equals(x_delta, delta)])])

    automata = [aut_delta]
    for cm in cms:
        automata += aut_for_comprmaster(symbols, cm, sms)
    for sm in sms:
        automata += aut_for_syncmaster(symbols, sm, cms)

    # search composition.
    comp, undefs = find_composition(automata, init, trans, fairness,
                                    nuxmv_path, model_file, trace_file,
                                    cmd_file)
    if comp is not None:
        with open(output_file, 'w') as out:
            out.write(str(comp))

    if comp and not undefs:
        res = True
    elif not comp:
        res = False
    else:
        assert not comp and undefs
        res = None
    return res


class ComprMaster:
    """Transition system describing a ComprMaster"""
    def __init__(self, name: str, delay: FNode, delta: FNode, x_delta: FNode):
        self.name = name
        self.delay = delay
        self.delta = delta
        self.x_delta = x_delta
        self.mode = Symbol("{}_mode".format(name), INT)
        self.cm = Symbol("{}_cm".format(name), REAL)
        self.x = Symbol("{}_x".format(name), REAL)

        self.x_mode = Symbol(symb_next("{}_mode".format(name)), INT)
        self.x_cm = Symbol(symb_next("{}_cm".format(name)), REAL)
        self.x_x = Symbol(symb_next("{}_x".format(name)), REAL)

    @property
    def waiting(self) -> FNode:
        """Return formula that holds iff state is waiting"""
        return Equals(self.mode, Int(0))

    @property
    def x_waiting(self) -> FNode:
        """Return formula that holds iff next state is waiting"""
        return Equals(self.x_mode, Int(0))

    @property
    def receive(self) -> FNode:
        """Return formula that holds iff state is receive"""
        return Equals(self.mode, Int(1))

    @property
    def x_receive(self) -> FNode:
        """Return formula that holds iff next state is receive"""
        return Equals(self.x_mode, Int(1))

    @property
    def correct1(self) -> FNode:
        """Return formula that holds iff state is correct1"""
        return Equals(self.mode, Int(2))

    @property
    def x_correct1(self) -> FNode:
        """Return formula that holds iff next state is correct1"""
        return Equals(self.x_mode, Int(2))

    @property
    def correct2(self) -> FNode:
        """Return formula that holds iff state is correct2"""
        return Equals(self.mode, Int(3))

    @property
    def x_correct2(self) -> FNode:
        """Return formula that holds iff next state is correct2"""
        return Equals(self.x_mode, Int(3))

    @property
    def symbols(self) -> list:
        """ Return list of current-state symbols"""
        return [self.mode, self.cm, self.x]

    @property
    def init(self) -> FNode:
        """Return formula representing the initial states"""
        return And(Equals(self.cm, Real(0)), Equals(self.x, Real(0)),
                   self.waiting)

    @property
    def invar(self) -> FNode:
        """Return formula representing the invariant"""
        return And(Implies(Not(self.waiting),
                           And(Equals(self.delta, Real(0)),
                               Equals(self.x, Real(0)))),
                   Implies(self.waiting, LE(self.x, self.delay)))

    @property
    def trans(self) -> FNode:
        """Returns formula representing the transition relation"""
        delta_eq_0 = Equals(self.delta, Real(0))
        return And(Implies(GT(self.delta, Real(0)),
                           And(Equals(self.x_x, Plus(self.x, self.delta)),
                               Equals(self.x_mode, self.mode))),
                   # CFG.
                   Implies(And(delta_eq_0, self.waiting), self.x_receive),
                   Implies(And(delta_eq_0, self.receive), self.x_correct1),
                   Implies(And(delta_eq_0, self.correct1), self.x_correct2),
                   Implies(And(delta_eq_0, self.correct2), self.x_waiting),
                   # guards and resets.
                   Implies(And(self.waiting, self.x_receive),
                           And(GE(self.x, self.delay),
                               Equals(self.x_x, Real(0)))),
                   Implies(Not(And(self.receive, self.x_correct1)),
                           Equals(self.x_cm, self.cm)))


class SyncMaster:
    """Transition system describing a SyncMaster"""
    def __init__(self, name: str, drift: FNode, delta: FNode, x_delta: FNode):
        self.name = name
        self.drift = drift
        self.delta = delta
        self.x_delta = x_delta

        self.mode = Symbol("{}_mode".format(name), INT)
        self.sm = Symbol("{}_sm".format(name), REAL)
        self.x_mode = Symbol(symb_next("{}_mode".format(name)), INT)
        self.x_sm = Symbol(symb_next("{}_sm".format(name)), REAL)

    @property
    def work(self):
        """Return formula that holds iff state is work"""
        return Equals(self.mode, Int(0))

    @property
    def x_work(self):
        """Return formula that holds iff next state is work"""
        return Equals(self.x_mode, Int(0))

    @property
    def send(self):
        """Return formula that holds iff state is send"""
        return Equals(self.mode, Int(1))

    @property
    def x_send(self):
        """Return formula that holds iff next state is send"""
        return Equals(self.x_mode, Int(1))

    @property
    def sync1(self):
        """Return formula that holds iff state is sync1"""
        return Equals(self.mode, Int(2))

    @property
    def x_sync1(self):
        """Return formula that holds iff next state is sync1"""
        return Equals(self.x_mode, Int(2))

    @property
    def sync2(self):
        """Return formula that holds iff state is sync2"""
        return Equals(self.mode, Int(3))

    @property
    def x_sync2(self):
        """Return formula that holds iff next state is sync2"""
        return Equals(self.x_mode, Int(3))

    @property
    def symbols(self) -> list:
        """ Return list of current-state symbols"""
        return [self.mode, self.sm]

    @property
    def init(self) -> FNode:
        """Return formula representing the initial states"""
        return And(Equals(self.sm, Real(0)), self.work)

    @property
    def invar(self) -> FNode:
        """Return formula representing the invariant"""
        return TRUE()

    @property
    def trans(self) -> FNode:
        """Returns formula representing the transition relation"""
        delta_eq_0 = Equals(self.delta, Real(0))
        return And(Implies(GT(self.delta, Real(0)),
                           And(Equals(self.x_sm, Plus(self.sm, self.delta)),
                               Equals(self.x_mode, self.mode))),
                   # CFG.
                   Implies(And(delta_eq_0, self.work), self.x_send),
                   Implies(And(delta_eq_0, self.send), self.x_sync1),
                   Implies(And(delta_eq_0, self.sync1), self.x_sync2),
                   Implies(And(delta_eq_0, self.sync2), self.x_work),
                   # guards and resets.
                   Implies(And(self.work, self.x_send),
                           And(LE(Minus(self.sm, self.drift), self.x_sm),
                               LE(self.x_sm, Plus(self.sm, self.drift)))),
                   Implies(Or(And(self.send, self.x_sync1),
                              And(self.sync2, self.x_work)),
                           Equals(self.x_sm, self.sm)))


def sync_comprmaster_syncmaster(cm: ComprMaster, sm: SyncMaster) -> FNode:
    """Synchronise ComprMaster with SyncMaster on discrete transitions"""
    return And(
        Iff(And(cm.waiting, cm.x_receive), And(sm.work, sm.x_send)),
        Iff(And(cm.receive, cm.x_correct1), And(sm.send, sm.x_sync1)),
        Iff(And(cm.correct1, cm.x_correct2), And(sm.sync1, sm.x_sync2)),
        Iff(And(cm.correct2, cm.x_waiting), And(sm.sync2, sm.x_work))
    )


def sync_on_broadcast_messages(cms: list, sms: list) -> FNode:
    """synchronisation of clocks on broadcast messages"""
    sm_val = Div(Plus([cm.cm for cm in cms]), Real(len(cms)))
    cm_val = Div(Plus([sm.sm for sm in sms]), Real(len(sms)))
    return And(And([Implies(And(sm.sync1, sm.x_sync2), Equals(sm.x_sm, sm_val))
                    for sm in sms]),
               And([Implies(And(cm.receive, cm.x_correct1),
                            Equals(cm.x_cm, cm_val))
                    for cm in cms])
               )


def aut_for_comprmaster(symbols: list, cm: ComprMaster,
                        sms: list) -> list:
    """Return list of AGAutomaton corresponding to the given instance of
    CompMaster syncronised with the given list of SyncMaster `sms`"""
    zero = Real(0)
    aut_x = AGAutomaton(symbols, [cm.x], "aut_{}_x".format(cm.name), 3)
    aut_x.set_invar(0, Equals(cm.x, zero))
    aut_x.set_assume(0, Equals(cm.delta, zero))
    aut_x.set_transitions(0, [(0, [Equals(cm.x_x, cm.x)]),
                              (1, [Equals(cm.x_x, cm.x)])])
    aut_x.set_invar(1, Equals(cm.x, zero))
    aut_x.set_assume(1, Equals(cm.delta, cm.delay))
    aut_x.set_transitions(1, [(2, [Equals(cm.x_x, cm.delta)])])
    aut_x.set_invar(2, Equals(cm.x, cm.delay))
    aut_x.set_assume(2, Equals(cm.delta, zero))
    aut_x.set_transitions(2, [(0, [Equals(cm.x_x, zero)]),
                              (1, [Equals(cm.x_x, zero)])])

    aut_mode = AGAutomaton(symbols, [cm.mode], "aut_{}_mode".format(cm.name),
                           4)
    aut_mode.set_invar(0, cm.waiting)
    aut_mode.set_assume(0, And([sm.work for sm in sms]))
    aut_mode.set_transitions(0, [(0, [Equals(cm.x_mode, cm.mode)]),
                                 (1, [cm.x_receive])])
    aut_mode.set_invar(1, cm.receive)
    aut_mode.set_assume(1, And([sm.send for sm in sms]))
    aut_mode.set_transitions(1, [(1, [Equals(cm.x_mode, cm.mode)]),
                                 (2, [cm.x_correct1])])
    aut_mode.set_invar(2, cm.correct1)
    aut_mode.set_assume(2, And([sm.sync1 for sm in sms]))
    aut_mode.set_transitions(2, [(2, [Equals(cm.x_mode, cm.mode)]),
                                 (3, [cm.x_correct2])])
    aut_mode.set_invar(3, cm.correct2)
    aut_mode.set_assume(3, And([sm.sync2 for sm in sms]))
    aut_mode.set_transitions(3, [(3, [Equals(cm.x_mode, cm.mode)]),
                                 (0, [cm.x_waiting])])

    next_cm_val = Div(Plus([sm.sm for sm in sms]), Real(len(sms)))
    aut_cm = AGAutomaton(symbols, [cm.cm], "aut_{}_cm".format(cm.name), 2)
    aut_cm.set_invar(0, TRUE())
    aut_cm.set_assume(0, TRUE())
    aut_cm.set_transitions(0, [(0, [Equals(cm.x_cm, cm.cm)]),
                               (1, [Equals(cm.x_cm, next_cm_val)])])
    aut_cm.set_invar(1, Equals(cm.cm, next_cm_val))
    aut_cm.set_assume(1, TRUE())
    aut_cm.set_transitions(1, [(0, [Equals(cm.x_cm, cm.cm),
                                    Equals(cm.x_cm, next_cm_val)])])

    return [aut_x, aut_mode, aut_cm]


def aut_for_syncmaster(symbols: list, sm: SyncMaster,
                       cms: list) -> list:
    """Return list of AGAutomaton corresponding to the given instance of
    SyncMaster syncronised with the given list of ComprMaster `cms`"""
    aut_mode = AGAutomaton(symbols, [sm.mode], "aut_{}_mode".format(sm.name),
                           4)
    aut_mode.set_invar(0, sm.work)
    aut_mode.set_assume(0, And([cm.waiting for cm in cms]))
    aut_mode.set_transitions(0, [(0, [Equals(sm.x_mode, sm.mode)]),
                                 (1, [sm.x_send])])
    aut_mode.set_invar(1, sm.send)
    aut_mode.set_assume(1, And([cm.receive for cm in cms]))
    aut_mode.set_transitions(1, [(1, [Equals(sm.x_mode, sm.mode)]),
                                 (2, [sm.x_sync1])])
    aut_mode.set_invar(2, sm.sync1)
    aut_mode.set_assume(2, And([cm.correct1 for cm in cms]))
    aut_mode.set_transitions(2, [(2, [Equals(sm.x_mode, sm.mode)]),
                                 (3, [sm.x_sync2])])
    aut_mode.set_invar(3, sm.sync2)
    aut_mode.set_assume(3, And([cm.correct2 for cm in cms]))
    aut_mode.set_transitions(3, [(3, [Equals(sm.x_mode, sm.mode)]),
                                 (0, [sm.x_work])])

    next_sm_val = Div(Plus([cm.cm for cm in cms]), Real(len(cms)))
    aut_sm = AGAutomaton(symbols, [sm.sm], "aut_{}_sm".format(sm.name), 2)
    aut_sm.set_invar(0, TRUE())
    aut_sm.set_assume(0, Not(sm.sync2))
    aut_sm.set_transitions(0, [(0, [Equals(sm.x_sm, sm.sm),
                                    Equals(sm.x_sm, Plus(sm.sm, sm.delta)),
                                    And(LE(Minus(sm.sm, sm.drift), sm.x_sm),
                                        LE(sm.x_sm, Plus(sm.sm, sm.drift)))]),
                               (1, [Equals(sm.x_sm, next_sm_val)])])
    aut_sm.set_invar(1, Equals(sm.sm, next_sm_val))
    aut_sm.set_assume(1, sm.sync2)
    aut_sm.set_transitions(1, [(0, [Equals(sm.x_sm, sm.sm),
                                    Equals(sm.x_sm, Plus(sm.sm, sm.delta)),
                                    And(LE(Minus(sm.sm, sm.drift), sm.x_sm),
                                        LE(sm.x_sm, Plus(sm.sm, sm.drift)))])])
    return [aut_mode, aut_sm]
