#! /usr/bin/env python3

from six.moves import cStringIO
from pysmt.fnode import FNode
from pysmt.shortcuts import And, Or

from smv_prefixes import LOC_VAR, LOC_ID_PREFIX
from smv_printer import to_smv, smv_type


class Automaton:
    """Class representing an automaton with a finite number of locations.
    Each location can be associated with an invariant over the `symbols`.
    Each transition can be associated with a formulae over `symbols` and
    their next.
    The fairness is a set of location ids.
    """
    def __init__(self, name: str, symbols: list, num_locs: int,
                 init_locs: set = None, fairness: set = None,
                 init: FNode = None,
                 invars: list = None,
                 trans: list = None):
        self._name = name
        self._symbols = symbols
        self._init_locs = init_locs
        self._init = init
        self._invars = invars if invars else [None for _ in range(0, num_locs)]
        self._trans = trans if trans else [{} for _ in range(0, num_locs)]
        self._fair = fairness if fairness else set()

    @property
    def name(self) -> str:
        """Name of automaton"""
        return self._name

    @property
    def symbols(self) -> list:
        """Symbols of this automaton"""
        return self._symbols

    @property
    def num_locs(self) -> int:
        """Number of locations of the automaton"""
        return len(self._invars)

    @property
    def fairness(self) -> set:
        """Set of fair locations"""
        return self._fair

    @fairness.setter
    def fairness(self, fairness: set) -> None:
        """Set which locations are fair"""
        self._fair = fairness

    def add_fair(self, loc: int) -> None:
        """Add location to the set of fair locations"""
        # assert loc >= 0
        # assert loc < self.num_locs
        self._fair.add(loc)

    @property
    def init_locs(self) -> set:
        """Set of initial locations: None means ALL"""
        return self._init_locs

    @init_locs.setter
    def init_locs(self, init_locs: set) -> None:
        """Set which locations are initial,
        None means all locations are initial"""
        self._init_locs = init_locs

    def add_init_loc(self, loc: int) -> None:
        """Add loc to the set of initial locations"""
        # assert loc >= 0
        # assert loc < self.num_locs
        self._init_locs.add(loc)

    @property
    def init(self) -> FNode:
        """Return init constraint of symbols"""
        return self._init

    @init.setter
    def init(self, init: FNode) -> None:
        """Set initialisation constraint for symbols"""
        self._init = init

    def add_init(self, init: FNode) -> None:
        """Add initialisatio constraints for symbols"""
        self._init = And(self.init, init) if self.init \
            else init

    def set_invar(self, loc: int, invar: FNode) -> None:
        """Set invariant of location `loc`"""
        # assert loc >= 0
        # assert loc < self.num_locs
        self._invars[loc] = invar

    def add_invar(self, loc: int, invar: FNode) -> None:
        """Add `invar` to the invariants of location `loc`"""
        # assert loc >= 0
        # assert loc < self.num_locs
        self._invars[loc] = And(self._invars[loc], invar) \
            if self._invars[loc] else invar

    def get_invar(self, loc: int) -> FNode:
        """Return invariant of location `loc`"""
        # assert loc >= 0
        # assert loc < self.num_locs
        return self._invars[loc]

    def set_trans(self, src: int, trans: dict) -> None:
        """Set outgoing transitions of location src,
        `trans` must map dst locations to the constraint"""
        # assert src >= 0
        # assert src < self.num_locs
        self._trans[src] = trans

    def add_trans(self, src: int, dst: int, trans: FNode) -> None:
        """Add possible relation to go from location `src` to `dst`"""
        # assert src >= 0
        # assert src < self.num_locs
        # assert dst >= 0
        # assert dst < self.num_locs
        trans = trans.args if trans.is_or() else [trans]
        if self._trans[src]:
            self._trans[src][dst].extend(trans)
        else:
            self._trans[src] = {dst: trans}

    def get_trans_from(self, src: int) -> dict:
        """Get map from dst loc to constraint"""
        # assert src >= 0
        # assert src < self.num_locs
        return self._trans[src]

    def get_trans(self, src: int, dst: int) -> list:
        """Get list of constraints relating `src` to `dst`"""
        # assert src >= 0
        # assert src < self.num_locs
        # assert dst >= 0
        # assert dst < self.num_locs
        return self._trans[src][dst]

    def serialize_smv(self, buf) -> None:
        """Write self in SMV format on buffer `buf`"""
        buf.write("MODULE {}\n".format(self.name))
        # write state vars
        buf.write("  VAR\n")
        buf.write("  {} : {};\n"
                  .format(LOC_VAR,
                          Automaton._smv_loc_set(range(self.num_locs))))
        buf.write("\n".join(["  {} : {};".format(to_smv(s),
                                                 smv_type(s.get_type()))
                             for s in self.symbols]))
        buf.write("\n")

        # write control-flow graph of loc
        if self.init_locs and len(self.init_locs) < self.num_locs:
            buf.write("\n  -- init locations\n")
            buf.write("  INIT  {} in {};\n"
                      .format(LOC_VAR,
                              Automaton._smv_loc_set(self.init_locs)))
        buf.write("\n")
        buf.write("  -- transitions between locations\n")
        buf.write("\n".join(["  TRANS {0} = {1}{2} -> next({0}) in {3};"
                             .format(LOC_VAR, LOC_ID_PREFIX, loc,
                                     Automaton._smv_loc_set(self.get_trans_from(loc)))
                             for loc in range(self.num_locs)
                             if self.get_trans_from(loc)]))
        buf.write("\n")
        if self.fairness:
            buf.write("\n  FAIRNESS {} in {};\n\n"
                      .format(LOC_VAR, Automaton._smv_loc_set(self.fairness)))

        buf.write("  -- constraints over symbols\n")
        if self.init:
            buf.write("  INIT {};\n\n".format(to_smv(self.init)))

        # write invars
        buf.write("\n".join(["  INVAR ({} = {}{}) -> {};"
                             .format(LOC_VAR, LOC_ID_PREFIX, loc,
                                     to_smv(self.get_invar(loc)))
                             for loc in range(self.num_locs)
                             if self.get_invar(loc)]))
        buf.write("\n\n")

        for src in range(self.num_locs):
            buf.write("\n".join(["  TRANS ({0} = {1}{2} & next({0}) = {1}{3}) "
                                 "-> {4};"
                                 .format(LOC_VAR, LOC_ID_PREFIX, src, dst,
                                         to_smv(Or(*self.get_trans(src, dst))))
                                 for dst in self.get_trans_from(src)]))
            buf.write("\n")

    def strongly_connected_components(self, include_no_edge_scc=False) -> set:
        """Iterator over strongly connected components,
        each SCC is represented as sets of node indexes"""
        edges = tuple([dst for dst in self.get_trans_from(v)]
                      for v in range(self.num_locs))

        identified = [False for _ in range(self.num_locs)]
        stack = []
        index = [None for _ in range(self.num_locs)]
        boundaries = []

        for v in range(self.num_locs):
            if index[v] is None:
                to_do = [('V', v)]
                while to_do:
                    operation_type, v = to_do.pop()
                    if operation_type == 'V':  # VISIT
                        index[v] = len(stack)
                        boundaries.append(index[v])
                        stack.append(v)
                        to_do.append(('P', v))
                        to_do.extend([('E', w) for w in edges[v]])
                    elif operation_type == 'E':  # VISIT-EDGE
                        if index[v] is None:
                            to_do.append(('V', v))
                        elif not identified[v]:
                            while index[v] < boundaries[-1]:
                                boundaries.pop()
                    else:  # POST-VISIT
                        if boundaries[-1] == index[v]:
                            boundaries.pop()
                            has_edges = include_no_edge_scc
                            # > 1 node
                            has_edges = has_edges or index[v] != len(stack) - 1
                            # only 1 node, must have self-loop
                            if not has_edges:
                                node = stack[-1]
                                has_edges = edges[node] and node in edges[node]
                            if has_edges:
                                scc = set()
                                for idx in stack[index[v]:]:
                                    identified[idx] = True
                                    scc.add(idx)
                                yield scc
                            del stack[index[v]:]

    def __str__(self) -> str:
        """SMV enconding of current automaton"""
        buf = cStringIO()
        self.serialize_smv(buf)
        res = buf.getvalue()
        buf.close()
        return res

    def _smv_loc_set(l: set) -> str:
        return "{{{}}}".format(", ".join("{}{}".format(LOC_ID_PREFIX, i)
                                         for i in l))
