from find_composition import find_composition
from automata_composition import AGAutomaton
from utils import symb_next, to_next

from pysmt.shortcuts import Symbol, TRUE, Real
from pysmt.shortcuts import Not, And, Or, Implies
from pysmt.shortcuts import GT, GE, Equals, LE, LT
from pysmt.shortcuts import Plus, Minus, Times, Div
from pysmt.typing import REAL
from pysmt.fnode import FNode

from math import gcd


def math_lcm(a, b):
    return abs(a*b) // gcd(a, b)


def test(nuxmv_path: str, model_file: str, trace_file: str, cmd_file: str,
         output_file: str) -> bool:
    """Test corresponding to benchmarks/adaptive_cruise_control.smv"""
    # symbols.
    zero = Real(0)
    num_followers = 1
    periods = [1, 2]
    min_acc = Real(-1)
    max_acc = Real(2)

    leader_period = Real(periods[0])
    delta = Symbol("delta", REAL)
    x_delta = Symbol(symb_next("delta"), REAL)
    leader = Leader("leader", leader_period, min_acc, max_acc,
                    delta, x_delta)

    followers = [None] * num_followers

    followers[0] = Follower("follower{}".format(0),
                            leader,
                            Real(periods[1]), min_acc, max_acc,
                            delta, x_delta)
    for i in range(1, num_followers):
        followers[i] = Follower("follower{}".format(i),
                                followers[i - 1],
                                Real(periods[i+1]), min_acc, max_acc,
                                delta, x_delta)

    components = [leader] + followers
    symbols = [delta]
    for comp in components:
        symbols += comp.symbols

    # initial location.
    init = And([And(comp.init, comp.invar) for comp in components])
    init = And(init, GE(delta, Real(0)))

    # transition relation.
    trans = And([And(to_next(comp.invar, symbols), comp.trans)
                 for comp in components])
    trans = And(trans, GE(x_delta, Real(0)))

    # fairness.
    fairness = And(GT(delta, zero), GT(leader.v, zero))

    # define automata to be composed.
    aut_delta_c = automaton_delta_c(symbols, periods, leader, followers,
                                    delta, x_delta)
    automata = [aut_delta_c]

    # aut_delta = automaton_delta(symbols, periods, leader, followers,
    #                             delta, x_delta)
    # automata += [aut_delta]

    automata += aut_for_leader(symbols, leader)
    for follower in followers:
        automata += aut_for_follower(symbols, follower)

    # search composition.
    comp, undefs = find_composition(automata, init, trans, fairness,
                                    nuxmv_path, model_file, trace_file,
                                    cmd_file)
    if comp is not None:
        with open(output_file, 'w') as out:
            out.write(str(comp))

    if comp and not undefs:
        res = True
    elif not comp:
        res = False
    else:
        assert not comp and undefs
        res = None
    return res


class Leader:
    """Transition system describing the Leader"""
    def __init__(self, name: str, period: FNode,
                 min_acc: FNode, max_acc: FNode,
                 delta: FNode, x_delta: FNode):
        self.name = name
        self.period = period
        self.min_acc = min_acc
        self.max_acc = max_acc
        self.delta = delta
        self.x_delta = x_delta
        self.a = Symbol("{}_a".format(name), REAL)
        self.v = Symbol("{}_v".format(name), REAL)
        self.c = Symbol("{}_c".format(name), REAL)

        self.x_a = Symbol(symb_next("{}_a".format(name)), REAL)
        self.x_v = Symbol(symb_next("{}_v".format(name)), REAL)
        self.x_c = Symbol(symb_next("{}_c".format(name)), REAL)

    @property
    def symbols(self) -> list:
        """ Return list of current-state symbols"""
        return [self.a, self.v, self.c]

    @property
    def init(self) -> FNode:
        """Return formula representing the initial states"""
        return And(Equals(self.a, Real(0)), Equals(self.v, Real(0)),
                   Equals(self.c, Real(0)))

    @property
    def invar(self) -> FNode:
        """Return formula representing the invariant"""
        zero = Real(0)
        return And(LE(zero, self.c), LE(self.c, self.period),
                   GE(self.v, zero),
                   GE(Plus(self.v, Times(self.a, self.delta)), zero),
                   LE(self.min_acc, self.a), LE(self.a, self.max_acc))

    @property
    def trans(self) -> FNode:
        """Returns formula representing the transition relation"""
        zero = Real(0)
        return And(Implies(Equals(self.c, self.period),
                           And(Equals(self.x_c, zero),
                               Equals(self.delta, zero))),
                   Implies(LT(self.c, self.period),
                           Equals(self.x_c, Plus(self.c, self.delta))),
                   Implies(Not(And(Equals(self.delta, zero),
                                   Equals(self.c, self.period),
                                   Equals(self.x_c, zero))),
                           Or(Equals(self.x_a, self.a),
                              Equals(self.x_a, zero))),
                   Or(Equals(self.x_v, Plus(self.v,
                                            Times(self.a, self.delta))),
                      And(Equals(self.x_v, zero), Equals(self.x_a, zero))))


class Follower:
    """Transition system describing a Follower vehicle."""
    def __init__(self, name: str, vehicle, period: FNode,
                 min_acc: FNode, max_acc: FNode,
                 delta: FNode, x_delta: FNode):
        self.name = name
        self.vehicle = vehicle
        self.period = period
        self.min_acc = min_acc
        self.max_acc = max_acc
        self.delta = delta
        self.x_delta = x_delta

        self.a = Symbol("{}_a".format(name), REAL)
        self.v = Symbol("{}_v".format(name), REAL)
        self.c = Symbol("{}_c".format(name), REAL)
        self.dist = Symbol("{}_dist".format(name), REAL)
        self.x_a = Symbol(symb_next("{}_a".format(name)), REAL)
        self.x_v = Symbol(symb_next("{}_v".format(name)), REAL)
        self.x_c = Symbol(symb_next("{}_c".format(name)), REAL)
        self.x_dist = Symbol(symb_next("{}_dist".format(name)), REAL)

    @property
    def symbols(self) -> list:
        """ Return list of current-state symbols"""
        return [self.a, self.v, self.c, self.dist]

    @property
    def init(self) -> FNode:
        """Return formula representing the initial states"""
        zero = Real(0)
        return And(Equals(self.a, zero), Equals(self.v, zero),
                   Equals(self.c, zero), Equals(self.dist, Real(1)))

    @property
    def invar(self) -> FNode:
        """Return formula representing the invariant"""
        zero = Real(0)
        return And(LE(zero, self.c), LE(self.c, self.period),
                   GE(self.v, zero),
                   GE(Plus(self.v, Times(self.a, self.delta)), zero),
                   LE(self.min_acc, self.a), LE(self.a, self.max_acc))

    @property
    def trans(self) -> FNode:
        """Returns formula representing the transition relation"""
        zero = Real(0)
        end_speed = Plus(self.v, Times(self.x_a, self.period))
        controller = And(Equals(self.delta, zero), Equals(self.c, self.period),
                         Equals(self.x_c, zero))
        # distance of vehicle ahead after `delta` from self current position.
        vehicle_dist = Plus(self.dist, Times(self.vehicle.v, self.delta),
                            Div(Times(self.vehicle.a, self.delta, self.delta),
                                Real(2)))
        # movement of self in next `delta`.
        curr_dist = Plus(Times(self.v, self.delta),
                         Div(Times(self.a, self.delta, self.delta), Real(2)))
        # new distance after delta.
        next_dist = Minus(vehicle_dist, curr_dist)
        acc_bound = Plus(self.dist, Times(Real(-1), self.v, self.period),
                         Div(Times(self.x_a, self.period, self.period),
                             Real(-2)),
                         Div(Times(end_speed, end_speed),
                             Times(Real(2), self.min_acc)))
        return And(Implies(Equals(self.c, self.period),
                           And(Equals(self.x_c, zero),
                               Equals(self.delta, zero))),
                   Implies(LT(self.c, self.period),
                           Equals(self.x_c, Plus(self.c, self.delta))),
                   Implies(controller, GT(acc_bound, zero)),
                   Implies(Not(controller), Equals(self.x_a, self.a)),
                   Equals(self.x_v, Plus(self.v, Times(self.a, self.delta))),
                   Equals(self.x_dist, next_dist))


def aut_for_leader(symbols: list, lead: Leader) -> list:
    """Return list of AGAutomaton corresponding to the given instance of
    Leader"""
    zero = Real(0)
    aut_a = AGAutomaton(symbols, [lead.a], "aut_{}_a".format(lead.name), 1)
    aut_a.set_invar(0, And(LE(zero, lead.a), LE(lead.a, lead.max_acc)))
    aut_a.set_assume(0, TRUE())
    aut_a.set_transitions(0, [(0, [And(LE(zero, lead.x_a),
                                       LE(lead.x_a, lead.max_acc)),
                                   Equals(lead.x_a, lead.a),
                                   Equals(lead.x_a, zero)])])

    aut_v = AGAutomaton(symbols, [lead.v], "aut_{}_v".format(lead.name), 1)
    aut_v.set_invar(0, And(GT(lead.v, zero), GE(lead.a, zero)))
    aut_v.set_assume(0, GE(lead.delta, zero))
    aut_v.set_transitions(0, [(0, [Equals(lead.x_v,
                                          Plus(lead.v,
                                               Times(lead.delta, lead.a)))])])

    # aut_c = AGAutomaton(symbols, [lead.c], "aut_{}_c".format(lead.name), 3)
    # aut_c.set_assume(0, LT(Plus(lead.c, lead.delta), lead.period))
    # aut_c.set_invar(0, LT(lead.c, lead.period))
    # aut_c.set_transitions(0, [(0, [Equals(lead.x_c,
    #                                       Plus(lead.c, lead.delta))]),
    #                           (1, [Equals(lead.x_c,
    #                                       Plus(lead.c, lead.delta))])])
    # aut_c.set_assume(1, Equals(Plus(lead.c, lead.delta), lead.period))
    # aut_c.set_invar(1, LT(lead.c, lead.period))
    # aut_c.set_transitions(1, [(2, [Equals(lead.x_c,
    #                                       Plus(lead.c, lead.delta))])])
    # aut_c.set_assume(2, Equals(lead.delta, zero))
    # aut_c.set_invar(2, Equals(lead.c, lead.period))
    # aut_c.set_transitions(2, [(0, [Equals(lead.x_c, zero)])])

    # return [aut_a, aut_v, aut_c]
    return [aut_a, aut_v]


def aut_for_follower(symbols: list, foll: Follower) -> list:
    """Return list of AGAutomaton corresponding to the given instance of
    Follower"""
    zero = Real(0)
    aut_a = AGAutomaton(symbols, [foll.a], "aut_{}_a".format(foll.name), 1)
    aut_a.set_invar(0, And(Equals(foll.a, zero)))
    aut_a.set_assume(0, TRUE())
    aut_a.set_transitions(0, [(0, [Equals(foll.x_a, foll.a)])])

    next_dist = Plus(foll.dist, Times(foll.vehicle.v, foll.vehicle.delta),
                     Div(Times(foll.vehicle.a, foll.vehicle.delta,
                               foll.vehicle.delta), Real(2)),
                     Times(Real(-1), foll.v, foll.delta))
    aut_v_dist = AGAutomaton(symbols, [foll.v, foll.dist],
                             "aut_{}_v_dist".format(foll.name), 1)
    aut_v_dist.set_invar(0, And(GT(foll.dist,
                                   Plus(Times(foll.period, foll.v),
                                        Div(Times(foll.v, foll.v),
                                            Times(Real(-2), foll.min_acc)))),
                                GT(foll.v, zero)))
    aut_v_dist.set_assume(0, And(GE(foll.delta, zero), Equals(foll.a, zero),
                                 GE(foll.vehicle.a, zero),
                                 GE(foll.vehicle.v, foll.v)))
    aut_v_dist.set_transitions(0, [(0, [And(Equals(foll.x_dist, next_dist),
                                            Equals(foll.x_v, foll.v))])])

    # aut_c = AGAutomaton(symbols, [foll.c], "aut_{}_c".format(foll.name), 3)
    # aut_c.set_assume(0, LT(Plus(foll.c, foll.delta), foll.period))
    # aut_c.set_invar(0, LT(foll.c, foll.period))
    # aut_c.set_transitions(0, [(0, [Equals(foll.x_c,
    #                                       Plus(foll.c, foll.delta))]),
    #                           (1, [Equals(foll.x_c,
    #                                       Plus(foll.c, foll.delta))])])
    # aut_c.set_assume(1, Equals(Plus(foll.c, foll.delta), foll.period))
    # aut_c.set_invar(1, LT(foll.c, foll.period))
    # aut_c.set_transitions(1, [(2, [Equals(foll.x_c,
    #                                       Plus(foll.c, foll.delta))])])
    # aut_c.set_assume(2, Equals(foll.delta, zero))
    # aut_c.set_invar(2, Equals(foll.c, foll.period))
    # aut_c.set_transitions(2, [(0, [Equals(foll.x_c, zero)])])

    # return [aut_a, aut_v_dist, aut_c]
    return [aut_a, aut_v_dist]


def automaton_delta_c(symbols: list, periods: list, leader: Leader,
                      followers: list, delta: FNode, x_delta: FNode):
    components = followers + [leader]
    lcm = periods[0]
    for p in periods[1:]:
        lcm = math_lcm(lcm, p)

    def multiple_of_period(s, periods):
        for p in periods:
            if s % p == 0:
                return True
        return False

    times = [s for s in range(1, lcm+1)
             if multiple_of_period(s, periods)]
    zero = Real(0)
    aut_delta_c = AGAutomaton(symbols, [delta, leader.c] +
                              [foll.c for foll in followers],
                              "aut_delta_c", 2 * len(times))

    for curr_time_idx in range(0, len(times)):
        next_time_idx = (curr_time_idx + 1) % len(times)
        curr_time = times[curr_time_idx] \
            if next_time_idx > 0 else 0
        next_time = times[next_time_idx]
        assert next_time > curr_time, \
            "curr: {}; next: {}".format(curr_time, next_time)

        time_step = next_time - curr_time
        next_delta = Real(time_step)

        curr_loc = 2 * curr_time_idx
        next_loc = curr_loc + 1
        prev_loc = curr_loc - 1 if curr_loc > 0 \
            else aut_delta_c.num_locations - 1

        def clock_value(curr_time: float, comp,
                        keep_period: bool) -> FNode:
            val = curr_time % comp.period.constant_value().numerator
            if keep_period and val == 0:
                return comp.period
            return Real(val)

        def clock_next_val(curr_time: float, comp) -> FNode:
            is_reset = curr_time % comp.period.constant_value().numerator == 0
            return Real(0) if is_reset else Plus(comp.c, comp.delta)

        clock_invar = And([Equals(comp.c, clock_value(curr_time, comp, True))
                           for comp in components])
        clock_disc_trans = And([Equals(comp.x_c,
                                       clock_next_val(curr_time, comp))
                                for comp in components])
        clock_next_invar = And([Equals(comp.c,
                                       clock_value(curr_time, comp, False))
                                for comp in components])
        clock_timed_trans = And([Equals(comp.x_c, Plus(comp.c, comp.delta))
                                 for comp in components])

        aut_delta_c.set_transitions(prev_loc,
                                    [(curr_loc,
                                      [And(Equals(x_delta, zero),
                                           clock_timed_trans)])])
        aut_delta_c.set_invar(curr_loc, And(Equals(delta, zero), clock_invar))
        aut_delta_c.set_assume(curr_loc, TRUE())
        aut_delta_c.set_transitions(curr_loc,
                                    [(next_loc,
                                      [And(Equals(x_delta, next_delta),
                                           clock_disc_trans)])])
        aut_delta_c.set_invar(next_loc, And(Equals(delta, next_delta),
                                            clock_next_invar))
        aut_delta_c.set_assume(next_loc, TRUE())
    return aut_delta_c


def automaton_delta(symbols: list, periods: list, leader: Leader,
                    followers: list, delta: FNode, x_delta: FNode):
    lcm = periods[0]
    for p in periods[1:]:
        lcm = math_lcm(lcm, p)

    def multiple_of_period(s, periods):
        for p in periods:
            if s % p == 0:
                return True
        return False

    times = [s for s in range(1, lcm+1)
             if multiple_of_period(s, periods)]
    zero = Real(0)
    aut_delta = AGAutomaton(symbols, [delta],
                            "aut_delta", 2 * len(times))

    for curr_time_idx in range(0, len(times)):
        next_time_idx = (curr_time_idx + 1) % len(times)
        curr_time = times[curr_time_idx] \
            if next_time_idx > 0 else 0
        next_time = times[next_time_idx]
        assert next_time > curr_time, \
            "curr: {}; next: {}".format(curr_time, next_time)

        time_step = next_time - curr_time
        next_delta = Real(time_step)

        curr_loc = 2 * curr_time_idx
        next_loc = curr_loc + 1
        prev_loc = curr_loc - 1 if curr_loc > 0 \
            else aut_delta.num_locations - 1

        aut_delta.set_transitions(prev_loc,
                                  [(curr_loc, [Equals(x_delta, zero)])])
        aut_delta.set_invar(curr_loc, Equals(delta, zero))
        aut_delta.set_assume(curr_loc, TRUE())
        aut_delta.set_transitions(curr_loc,
                                  [(next_loc, [Equals(x_delta, next_delta)])])
        aut_delta.set_invar(next_loc, Equals(delta, next_delta))
        aut_delta.set_assume(next_loc, TRUE())
    return aut_delta
