#! /usr/bin/env python3

"""
parse counterexample trace,
check assume-guarantee constraints on trace, if it fails add new constraints
  for composition.
"""

from itertools import combinations

from pysmt.shortcuts import TRUE
from pysmt.shortcuts import Not, And
from pysmt.shortcuts import simplify
from pysmt.typing import PySMTType, INT, REAL, BOOL
from pysmt.fnode import FNode
from pysmt.rewritings import NNFizer
from pysmt.logics import AUTO
from pysmt.exceptions import SolverReturnedUnknownResultError

from efesolver import efesmt
from utils import symb_next, symb_is_next, var_to_curr, var_to_next, to_next, \
    solve_with_timeout, Solver, reset_after_timeout, to_smt2, get_verbosity, \
    get_solver
from smv_printer import to_smv
from smv_prefixes import LOC_VAR, LOC_ID_PREFIX, LOOPBACK_LOC_VAR, \
    IS_PREFIX, ENABLED, ACTIVE, LOOPBACK, SYMB_MASK, MODULE_INST_PREF, \
    FAIR_MASK, FAIR, FAIR_PRED, N_FAIR_PRED, \
    UAPPR_MASK, UAPPR, UAPPR_PRED, N_UAPPR_PRED, IN_LOOP, FAIR_LOOP, \
    PREFIX_LENGTH, TRANS_LABEL, PROP_FIND_LOOP, PROP_CHECK_UAPPR, \
    PROP_CHECK_FAIR, PROP_REACH, PROP_SAT_UAPPR, PROP_FIND_ANY_LOOP


TIMEOUT = 5  # timeout for solve calls in seconds, None to disable.
def set_timeout(val):
    global TIMEOUT
    TIMEOUT = val


def smv_composition(automata: list, initialisation: FNode,
                    fairness: FNode, transition: FNode, buf,
                    extra_init: list = None, extra_invar: list = None,
                    extra_trans: list = None) -> None:
    """Return smv model of composition of the given list of automata"""

    # normalise expression (NNF) and extract predicates.
    nnf = NNFizer()
    fairness = nnf.convert(fairness)
    smv_fairness = to_smv(fairness)
    fairness_preds = tuple(fairness.get_atoms())

    # normalise expression (NNF) and extract predicates.
    transition = nnf.convert(transition)
    smv_transition = to_smv(transition)
    transition_preds = tuple(transition.get_atoms())

    # these functions have side-effects
    for automaton in automata:
        automaton.set_fairness_predicates(fairness_preds)
        automaton.set_transition_predicates(transition_preds)

    # write main module on buffer `buf`
    main_smv(automata, initialisation, transition, fairness,
             transition_preds, fairness_preds, buf)

    if extra_init:
        buf.write("\n  -- extra init\n")
        for e_init in extra_init:
            buf.write("  INIT {};\n".format(e_init))
    if extra_invar:
        buf.write("\n  -- extra invar\n")
        for e_invar in extra_invar:
            buf.write("  INVAR {};\n".format(e_invar))
    if extra_trans:
        buf.write("\n  -- extra trans\n")
        for e_trans in extra_trans:
            buf.write("  TRANS {};\n".format(e_trans))

    # debug invariants
    buf.write("\n  INVARSPEC NAME {} := {};".format(PROP_REACH, IS_PREFIX))
    buf.write("\n  INVARSPEC NAME {} := "
              "(!{} & {}) -> {};\n"
              .format(PROP_CHECK_FAIR, IS_PREFIX, FAIR, smv_fairness))
    buf.write("\n  INVARSPEC NAME {} := !(!{} & {});\n"
              .format(PROP_SAT_UAPPR, IS_PREFIX, UAPPR))
    buf.write("\n  INVARSPEC NAME {} := (!{} & "
              "{}) -> {};\n"
              .format(PROP_CHECK_UAPPR, IS_PREFIX, UAPPR, smv_transition))

    for a in automata:
        buf.write("\n-- Definition of automaton `{}`\n".format(a.name))
        buf.write(str(a))


def main_smv(automata: list,
             sys_init: FNode, sys_trans: FNode, sys_fairness: FNode,
             trans_preds: tuple, fairness_preds: tuple, buf) -> None:

    smv_init = to_smv(sys_init)

    # transform transition relation on condition over input flags.
    subs = {}
    for i, p in enumerate(trans_preds):
        assert not p.is_not()
        n_p = Not(p)
        subs[p] = "_underappr_pred{}".format(i)
        subs[n_p] = "_not_underappr_pred{}".format(i)
    underapprox = to_smv(sys_trans, subs)
    smv_trans = to_smv(sys_trans)

    # transform fairness condition on a fairness over the locations.
    subs = {}
    for i, p in enumerate(fairness_preds):
        assert not p.is_not()
        n_p = Not(p)
        subs[p] = "{}{}".format(FAIR_PRED, i)
        subs[n_p] = "{}{}".format(N_FAIR_PRED, i)
    loc_fairness = to_smv(sys_fairness, subs=subs)
    # smv_fairness = to_smv(sys_fairness)

    if trans_preds:
        # smv_preds = ", ".join([to_smv(pred)
        #                        for pred in trans_preds])
        # buf.write("-- transition `{}`\n".format(smv_trans))
        buf.write("-- underapproximation mask predicates:\n  --")
        buf.write("\n  -- ".join(["{}{} := {}".format(UAPPR_PRED, i, p)
                                 for i, p in enumerate(trans_preds)]))
        buf.write("\n")

    if fairness_preds:
        # smv_preds = ", ".join([to_smv(pred) for pred in fairness_preds])
        # buf.write("-- fairness `{}`\n".format(smv_fairness))
        buf.write("-- fairness mask predicates\n  -- ")
        buf.write("\n  -- ".join(["{}{} := {}".format(FAIR_PRED, i, p)
                                 for i, p in enumerate(fairness_preds)]))
        buf.write("\n")

    symbols = []
    if automata:
        symbols = automata[0].all_symbols

    var_decls = "\n".join([_define_fair(loc_fairness, len(fairness_preds),
                                        automata),
                           _define_underapproximate(underapprox,
                                                    len(trans_preds),
                                                    automata),
                           _add_reachability_constr([to_smv(s)
                                                    for s in symbols],
                                                    smv_init, smv_trans),
                           _add_liveness2safety_monitors(automata),
                           _state_var_declaration(symbols),
                           _declare_automata(automata)])
    init = ""
    invar = _composition_partitions_symbols(automata)
    # each transition must be an underapproximation of the transition relation.
    trans = "  TRANS (!{}) -> {};".format(IS_PREFIX, UAPPR)
    buf.write(AGAutomaton._smv_template.format(automaton_name="main",
                                               parameters="",
                                               var_decls=var_decls,
                                               init=init,
                                               invar=invar,
                                               transitions=trans))


class AGAutomaton:
    """Assume guatantee automaton"""

    _smv_template = """MODULE {automaton_name}({parameters})
{var_decls}
{init}
{invar}
{transitions}
"""

    def __init__(self, all_symbols: list, local_symbols: list,
                 name: str, loc_num: int = 1):
        """Create new automaton on a system where `all_symbols` is the list of
        all symbols in the model, `local_symbols`  that can assign a
        next value to the symbols in the given list `symbols` and with
        `loc_num` locations"""
        assert name
        assert all_symbols
        assert local_symbols

        self._name = name
        # ordered list of all symbols in the model.
        self._all_symbols = all_symbols
        # symbols to which current automaton can assign a next value.
        self._local_symbols = local_symbols

        # region of location i: invariants[i] \wedge assumptions[i].
        self._invariants = [TRUE() for _ in range(loc_num)]
        self._assumptions = [TRUE() for _ in range(loc_num)]
        # transitions[i]: transitions from location i.
        self._transitions = [None for _ in range(loc_num)]

        # list of predicates in the fairness.
        self._fairness_predicates = None
        # mask over predicates of fairness for each location.
        self._fairness_masks = [[None] for _ in range(loc_num)]
        # list of predicates in the transition relation.
        self._transition_predicates = None
        # mask over predicates of transition relation for each location.
        self._transition_masks = [[None] for _ in range(loc_num)]

    @property
    def name(self) -> str:
        """Return name of automaton"""
        return self._name

    @property
    def num_locations(self) -> int:
        """Return number of locations"""
        return len(self._invariants)

    @property
    def all_symbols(self) -> list:
        """return set of all symbols"""
        return self._all_symbols

    @property
    def local_symbols_mask(self) -> list:
        """return set of symbols assigned by this automaton"""
        return [s in self.local_symbols for s in self._all_symbols]

    @property
    def local_symbols(self) -> list:
        """return list of symbols assigned by this automaton"""
        return self._local_symbols

    def set_invar(self, i: int, invar: FNode) -> None:
        """Set invariant condition of location i"""
        self._invariants[i] = invar.simplify()

    def get_invar(self, i: int) -> FNode:
        """Return invariant condition of location i"""
        return self._invariants[i]

    def set_assume(self, i: int, assume: FNode) -> None:
        """Set assumptions of location i"""
        self._assumptions[i] = assume.simplify()

    def get_assume(self, i: int) -> FNode:
        """Return assumption of location i"""
        return self._assumptions[i]

    def set_transitions(self, src_loc: int, trans: list) -> None:
        """Set transitions of location i.
        trans must be a list of pairs (dst_loc, trans_list),
        where dst_loc is the index of the target location,
        and trans_list is a list of FNodes representing the possible
        ways to each dst_loc from src_loc."""
        # check transition contains next only of allowed symbols.
        seen_locs = set()
        for dst_loc, t_list in trans:
            assert dst_loc not in seen_locs
            seen_locs.add(dst_loc)
            for symbs in [t.get_free_variables() for t in t_list]:
                for symb in symbs:
                    if symb_is_next(to_smv(symb)):
                        c_symb = var_to_curr(symb)
                        assert c_symb in self.local_symbols
        # keep sorted by dst location.
        self._transitions[src_loc] = sorted(trans, key=lambda k: k[0])

    def get_transitions(self, src_loc: int) -> list:
        """Return transitions outgoing from location src_loc"""
        return self._transitions[src_loc]

    def get_transitions_from_to(self, src: int, dst: int) -> list:
        """Return list of transitions from `src` o `dst`"""
        for loc, t_list in self.get_transitions(src):
            if loc == dst:
                return t_list
        return []

    def get_transition(self, src: int, dst: int, trans_id: int) -> FNode:
        """Return transition from `src` to `dst` with id `trans_id`"""
        assert trans_id >= 0
        for loc, t_list in self.get_transitions(src):
            if loc == dst:
                return t_list[trans_id]
        assert False

    def max_transitions_between_locs(self) -> int:
        """Return maximum number of transition between locations"""
        retval = 0
        for loc in range(self.num_locations):
            curr = max([len(t_list)
                        for _, t_list in self.get_transitions(loc)])
            retval = max(curr, retval)
        return retval

    def set_fairness_predicates(self, predicates: list) -> None:
        """Set list of fairness predicates"""
        if self._fairness_predicates != predicates:
            self._fairness_predicates = predicates
            self._compute_fairness_mask()

    def reset_fairness_predicates(self) -> None:
        self._fairness_predicates = None

    def get_fair_mask(self, loc: int) -> list:
        """Returns fairness mask of location loc"""
        return self._fairness_masks[loc]

    def set_transition_predicates(self, predicates: list) -> None:
        """Set list of transition predicates"""
        if self._transition_predicates != predicates:
            self._transition_predicates = predicates
            self._compute_transitions_mask()

    def reset_transition_predicates(self) -> None:
        self._transition_predicates = None

    def get_uappr_masks(self, loc: int) -> list:
        """Return ordered list of masks for transitions from `loc`,
        order corresponds to the order of self.get_trans(loc)"""
        return self._transition_masks[loc]

    def is_correct(self) -> bool:
        """Check whether self meets the hypothesis of AG-Abstractions"""
        symbs = self.all_symbols

        with Solver() as solver:
            for src_loc in range(self.num_locations):
                solver.push()
                t_list = self.get_transitions(src_loc)
                src_region = And(self.get_invar(src_loc),
                                 self.get_assume(src_loc))
                solver.add_assertion(src_region)
                for dst_loc, trans_l in t_list:
                    solver.push()
                    dst_region = And(self.get_invar(dst_loc),
                                     self.get_assume(dst_loc))
                    x_dst_region = to_next(dst_region, symbs)
                    solver.add_assertion(x_dst_region)
                    for trans in trans_l:
                        solver.push()
                        solver.add_assertion(trans)
                        try:
                            res = solve_with_timeout(TIMEOUT, solver)
                        except SolverReturnedUnknownResultError:
                            res = None

                        if res is None:
                            if get_verbosity():
                                print("\tSMT timeout: {}"
                                      .format(to_smt2(solver.assertions)))
                            reset_after_timeout(solver, src_region, x_dst_region)
                            return None
                        # non-empty transition
                        if res is False:
                            return "UNSAT", (src_loc, dst_loc, trans)
                        solver.pop()
                    solver.pop()
                solver.pop()
        # check exists-forall-exists
        next_local = set([var_to_next(s) for s in self.local_symbols])
        next_not_local = set(symbs) - set(self.local_symbols)
        next_not_local = set([var_to_next(s) for s in next_not_local])

        for src_loc in range(self.num_locations):
            t_list = self.get_transitions(src_loc)
            src_region = self.get_invar(src_loc)
            src_assume = self.get_assume(src_loc)
            for dst_loc, trans_l in t_list:
                dst_region = to_next(self.get_invar(dst_loc), symbs)
                dst_assume = to_next(self.get_assume(dst_loc), symbs)
                for trans in trans_l:
                    problem = And(src_region, src_assume, trans,
                                  dst_assume,
                                  Not(dst_region))
                    logic = None if get_solver() == "msat" else AUTO
                    res = efesmt(next_local, next_not_local, problem,
                                 logic=logic)
                    if res is None:
                        return None
                    if res is not False:
                        return "SAT", (src_loc, dst_loc, trans), res
        return True

    def _curr_to_next(self, expr: FNode) -> FNode:
        """Replace symbols with the corresponding monitor for next(symbol)"""
        return to_next(expr, self._all_symbols)

    def _compute_fairness_mask(self) -> None:
        """Given an ordered list of predicates it computes the mask for
        each location."""
        self._fairness_masks = [None for _ in range(self.num_locations)]
        if self._fairness_predicates is None:
            return
        with Solver() as solver:
            # for each location check region -> predicate valid or
            # region -> !predicate valid
            for loc in range(self.num_locations):
                mask = [None for _ in self._fairness_predicates]
                # reset solver
                solver.reset_assertions()
                # valid(region -> predicate) == unsat(region & !predicate)
                # valid(region -> !predicate) == unsat(region & predicate)
                region = And(self.get_invar(loc), self.get_assume(loc))
                solver.add_assertion(region)

                for i, p in enumerate(self._fairness_predicates):
                    solver.push()
                    solver.add_assertion(Not(p))
                    try:
                        res = solve_with_timeout(TIMEOUT, solver)
                    except SolverReturnedUnknownResultError:
                        res = None
                    if res is False:
                        mask[i] = True
                    else:
                        if res is None:
                            reset_after_timeout(solver, region)
                            if get_verbosity():
                                print("\tSMT timeout: {}"
                                      .format(to_smt2(solver.assertions)))
                        solver.pop()
                        solver.push()
                        solver.add_assertion(p)
                        try:
                            res = solve_with_timeout(TIMEOUT, solver)
                        except SolverReturnedUnknownResultError:
                            res = None
                        if res is False:
                            mask[i] = False
                        elif res is None:
                            reset_after_timeout(solver, region)
                            if get_verbosity():
                                print("\tSMT timeout: {}"
                                      .format(to_smt2(solver.assertions)))
                    solver.pop()  # remove predicate

                self._fairness_masks[loc] = mask

    def _compute_transitions_mask(self) -> None:
        """Given an ordered list of of predicates it computes the mask for
        each transition.
        The mask is a list with the same ordering as the transitions.
        To each predicate assign True, False or None, depending on whether the
        predicate is implied positively (True), negatively (False) or cannot be
        decided (None)"""
        self._transition_masks = [None for _ in range(self.num_locations)]
        if self._transition_predicates is None:
            return
        with Solver() as solver:
            # for each location valid(region & trans & region' -> predicate) or
            # valid(region & trans & region' -> !predicate)
            for loc in range(self.num_locations):
                # reset solver
                solver.reset_assertions()
                src_region = And(self.get_invar(loc), self.get_assume(loc))
                solver.add_assertion(src_region)
                solver.push()

                transitions = self.get_transitions(loc)
                assert transitions, \
                    "{}, loc {}, {}".format(self.name, loc, transitions)
                masks = [None for _ in transitions]
                for dst_idx, (dst_loc, t_list) in enumerate(transitions):
                    solver.push()
                    dst_region = And(self.get_invar(dst_loc),
                                     self.get_assume(dst_loc))
                    dst_region = self._curr_to_next(dst_region)
                    solver.add_assertion(dst_region)
                    src_to_dst_masks = [None for _ in t_list]

                    for t_idx, trans in enumerate(t_list):
                        solver.push()
                        solver.add_assertion(trans)

                        src_to_dst_masks[t_idx] = \
                            [None for _ in self._transition_predicates]
                        for p_idx, p in enumerate(self._transition_predicates):
                            solver.push()
                            solver.add_assertion(Not(p))
                            try:
                                res = solve_with_timeout(TIMEOUT, solver)
                            except SolverReturnedUnknownResultError:
                                res = None
                            if res is False:
                                # unsat(region & region' & trans & !p)
                                src_to_dst_masks[t_idx][p_idx] = True
                            else:
                                if res is None:
                                    reset_after_timeout(solver, src_region,
                                                        dst_region, trans)
                                    if get_verbosity():
                                        print("\tSMT timeout: {}"
                                              .format(to_smt2(solver.assertions)))
                                solver.pop()
                                solver.push()
                                solver.add_assertion(p)
                                try:
                                    res = solve_with_timeout(TIMEOUT, solver)
                                except SolverReturnedUnknownResultError:
                                    res = None
                                if res is False:
                                    # unsat(region & region' & trans & p)
                                    src_to_dst_masks[t_idx][p_idx] = False
                                elif res is None:
                                    reset_after_timeout(solver, src_region,
                                                        dst_region, trans)
                                    if get_verbosity():
                                        print("\tSMT timeout: {}"
                                              .format(to_smt2(solver.assertions)))
                            solver.pop()  # remove predicate.
                        solver.pop()  # remove trans

                    solver.pop()  # remove dst_region
                    assert len(src_to_dst_masks) == len(t_list), \
                        "{} != {}".format(src_to_dst_masks, t_list)
                    masks[dst_idx] = src_to_dst_masks

                assert len(masks) == len(transitions), \
                    "{} != {}".format(masks, transitions)
                self._transition_masks[loc] = masks
        assert len(self._transition_masks) == self.num_locations

    def _smv_symb_decls(self) -> str:
        """SMV encoding of declaration of symbols"""
        # variable to encode locations of automaton.
        locations_enum = ["{}{}".format(LOC_ID_PREFIX, str(i))
                          for i in range(self.num_locations)]
        locations_enum = "{{{}}}".format(", ".join(locations_enum))
        # define over the symbols mask.
        word1 = "0b_{}".format("1" * len(self.all_symbols))
        word0 = "0b_{}".format("0" * len(self.all_symbols))
        res = "  DEFINE\n" \
              "    -- defines for symbols partitioning\n" \
              "    {} := ({}? {} : {}) & " \
              "0b_{};\n".format(SYMB_MASK, ENABLED, word1, word0,
                                "".join(["1" if b else "0"
                                         for b in self.local_symbols_mask]))
        res += "    {} := (!{} | {} = {});\n".format(LOOPBACK, ENABLED,
                                                     LOC_VAR, LOOPBACK_LOC_VAR)

        # defines for fairness predicates
        res += "\n    -- defines for fairness predicates\n"
        for pred_idx in range(len(self._fairness_predicates)):
            res += "    {2}{0}     := {4} & {1}[{0}] = TOP;\n" \
                   "    {3}{0} := {4} & {1}[{0}] = BOT;\n" \
                   .format(pred_idx, FAIR_MASK, FAIR_PRED, N_FAIR_PRED,
                           ENABLED)

        # defines for fairness predicates
        res += "\n    -- defines for underapproximation predicates\n"
        for pred_idx in range(len(self._transition_predicates)):
            res += "    {2}{0}     := {1}[{0}] = TOP;\n" \
                   "    {3}{0} := {1}[{0}] = BOT;\n" \
                   .format(pred_idx, UAPPR_MASK, UAPPR_PRED, N_UAPPR_PRED)

        # boolean frozenvar to enable / disable automaton.
        res += "\n  FROZENVAR\n    {} : boolean;\n".format(ENABLED)
        # symbol used to guess the loopback location
        res += "    {} : {};\n".format(LOOPBACK_LOC_VAR, locations_enum)
        # mask to keep track of under-approximation.
        res += "\n  IVAR\n    {} : array 0..{} of {{TOP, UNDEF, BOT}};\n" \
               .format(UAPPR_MASK, len(self._transition_predicates) - 1)
        res += "    {} : -1..{};\n" \
               .format(TRANS_LABEL, self.max_transitions_between_locs() - 1)
        # mask to keep track of fair locations.
        res += "\n  VAR\n    {} : array 0..{} of {{TOP, UNDEF, BOT}};\n" \
               .format(FAIR_MASK, len(self._fairness_predicates) - 1)
        res += "    {} : {};\n".format(LOC_VAR, locations_enum)
        return res

    def _smv_invar(self) -> str:
        """SMV encoding of region associated to each location"""
        # if automaton not active set mask to UNDEF
        res = "  -- if disabled or inactive UNDEF mask\n  "
        mask = self.get_fair_mask(0)
        mask_constr = ["{}[{}] = UNDEF".format(FAIR_MASK, i)
                       for i in range(len(mask))]
        mask_constr = " & ".join(mask_constr)
        res += "INVAR (!{} | !{}) -> ({} & {} = {}0);\n" \
               .format(ACTIVE, ENABLED, mask_constr,
                       LOC_VAR, LOC_ID_PREFIX)

        res += "\n  -- regions of each location: loc -> " \
               "(fairness_mask) & (guarantee) & (assume)\n"
        for loc in range(self.num_locations):
            mask = self.get_fair_mask(loc)
            mask_constr = ""
            for i, b in enumerate(mask):
                value = "UNDEF"
                if b is True:
                    value = "TOP"
                elif b is False:
                    value = "BOT"
                mask_constr += "{}[{}] = {} & ".format(FAIR_MASK, i, value)
            region = And(self.get_invar(loc), self.get_assume(loc))
            region = to_smv(simplify(region))
            res += "  INVAR ({} & {} & {} = {}{}) ->\n    ({} \n" \
                   "      {}\n    );\n" \
                   .format(ACTIVE, ENABLED, LOC_VAR, LOC_ID_PREFIX, loc,
                           mask_constr, region)
        return res

    def _smv_trans(self) -> str:
        """SMV enconding of transition relation"""
        trans_mask_constr = ["{}[{}] = UNDEF".format(UAPPR_MASK, i)
                             for i in range(len(self._transition_predicates))]
        trans_mask_constr = " & ".join(trans_mask_constr)
        disable_trans_mask = " -- if disabled or inactive UNDEF mask and " \
                             "default transition label\n  " \
                             "TRANS (!{} | !{}) -> ({} & {} = -1);\n" \
                             .format(ACTIVE, ENABLED, trans_mask_constr,
                                     TRANS_LABEL)
        cfg = "  -- control flow\n"
        labels = "  -- labels\n"
        for src_loc in range(self.num_locations):
            cfg += "  TRANS ({} & {} & {} = {}{}) -> (" \
                   .format(ACTIVE, ENABLED, LOC_VAR, LOC_ID_PREFIX, src_loc)
            masks = self.get_uappr_masks(src_loc)
            transitions = self.get_transitions(src_loc)
            assert len(transitions) == len(masks), \
                "{} != {}".format(len(transitions), len(masks))

            for (dst_loc, t_list), c_masks in zip(transitions, masks):
                assert len(t_list) == len(c_masks), \
                    "{} != {}".format(t_list, c_masks)

                cfg += "next({}) = {}{} |".format(LOC_VAR, LOC_ID_PREFIX,
                                                  dst_loc)
                labels += "  TRANS\n    ({0} & {1} & {2} = {3}{4} & " \
                          " next({2}) = {3}{5}) ->\n      (" \
                          .format(ACTIVE, ENABLED, LOC_VAR, LOC_ID_PREFIX,
                                  src_loc, dst_loc)
                trans_id = 0
                for trans, mask in zip(t_list, c_masks):
                    trans_mask_constr = ""
                    if mask:
                        for i, b in enumerate(mask):
                            val = "UNDEF"
                            if mask[i] is True:
                                val = "TOP"
                            elif mask[i] is False:
                                val = "BOT"
                            trans_mask_constr += "{}[{}] = {} & "\
                                                 .format(UAPPR_MASK, i, val)
                    constr = to_smv(simplify(trans))
                    trans_label = "({} = {})".format(TRANS_LABEL, trans_id)
                    trans_id += 1
                    labels += "({} & {} {}) |\n"\
                              .format(trans_label, trans_mask_constr,
                                      constr)
                labels = labels[:-2] + "\n      );\n"
            cfg = cfg[:-2] + ");\n"
        return "{}\n{}\n{}".format(disable_trans_mask, cfg, labels)

    def __str__(self) -> str:
        """SMV encoding of self"""
        aut_name = "{}".format(self._name)
        parameters = "{}, ".format(ACTIVE)
        parameters += ", ".join([to_smv(p) for p in self.all_symbols])
        parameters += ", "
        parameters += ", ".join([symb_next(to_smv(p))
                                 for p in self.all_symbols])

        var_decls = self._smv_symb_decls()

        init = "  INIT TRUE;\n"

        invar = self._smv_invar()

        trans = self._smv_trans()

        return AGAutomaton._smv_template.format(automaton_name=aut_name,
                                                parameters=parameters,
                                                var_decls=var_decls,
                                                init=init,
                                                invar=invar,
                                                transitions=trans)


def _type_to_smv_type(symb_sort: PySMTType) -> str:
    """Translate pysmt type into the corresponding smv type"""
    res = ""
    if symb_sort == INT:
        res = "integer"
    elif symb_sort == REAL:
        res = "real"
    elif symb_sort == BOOL:
        res = "boolean"
    else:
        assert False, "unknown type `{}`".format(symb_sort)
    return res


def _define_fair(fairness: str, num_fair_preds: int, automata: list) -> str:
    """Define boolean symbol `_fair` that holds if the current locations describe
    a fair region"""
    res = "  -- fairness based on locations of automata\n" \
          "  DEFINE\n    {} := {};\n".format(FAIR, fairness)
    for pred_idx in range(num_fair_preds):
        pos_pred = " | ".join("{}{}.{}{}".format(MODULE_INST_PREF, a.name,
                                                 FAIR_PRED, pred_idx)
                              for a in automata)
        neg_pred = " | ".join("{}{}.{}{}".format(MODULE_INST_PREF, a.name,
                                                 N_FAIR_PRED, pred_idx)
                              for a in automata)
        res += "    {}{} := ({}) & !({});\n".format(FAIR_PRED, pred_idx,
                                                    pos_pred, neg_pred)
        res += "    {}{} := ({}) & !({});\n".format(N_FAIR_PRED, pred_idx,
                                                    neg_pred, pos_pred)
    return res


def _define_underapproximate(transition: str, num_trans_preds: int,
                             automata: list) -> str:
    """Define boolean symbol `_underapproximate` that holds when the next
    transition underapproximates the transition relation"""
    res = "  -- underappoximate transition relation\n" \
          "  DEFINE\n    {} := {};\n".format(UAPPR, transition)
    for pred_idx in range(num_trans_preds):
        pos_pred = " | ".join("{}{}.{}{}".format(MODULE_INST_PREF, a.name,
                                                 UAPPR_PRED, pred_idx)
                              for a in automata)
        neg_pred = " | ".join("{}{}.{}{}".format(MODULE_INST_PREF, a.name,
                                                 N_UAPPR_PRED, pred_idx)
                              for a in automata)

        res += "    {}{} := ({}) & !({});\n".format(UAPPR_PRED, pred_idx,
                                                    pos_pred, neg_pred)
        res += "    {}{} := ({}) & !({});\n".format(N_UAPPR_PRED, pred_idx,
                                                    neg_pred, pos_pred)
    return res


def _add_liveness2safety_monitors(automata: list) -> str:
    """Add boolean symbol `fair_loop` which hold when a fair_loop is found"""
    res = "  -- symbols and constraints for liveness to safety\n"
    res += "  VAR\n    {} : boolean;\n".format(IN_LOOP)
    res += "    {} : boolean;\n".format(FAIR_LOOP)
    res += "  DEFINE\n    {} := (!{} & {});"\
           .format(LOOPBACK, IS_PREFIX,
                   " & ".join("{}{}.{}".format(MODULE_INST_PREF,
                                               a.name, LOOPBACK)
                              for a in automata))
    res += "\n  ASSIGN\n    init({0}) := FALSE;\n    " \
           "next({0}) := {0} | {1};\n    " \
           "init({2}) := {0} & {3};\n    " \
           "next({2}) := {2} | ({0} & {3});\n" \
           .format(IN_LOOP, LOOPBACK, FAIR_LOOP, FAIR)
    res += "\n  INVARSPEC NAME {} := " \
           "!({} & {});\n".format(PROP_FIND_LOOP, FAIR_LOOP, LOOPBACK)
    res += "\n  INVARSPEC NAME {} := " \
           "!({} & {});\n".format(PROP_FIND_ANY_LOOP, IN_LOOP, LOOPBACK)
    return res


def _add_reachability_constr(symbols: list, init: str, transition: str) -> str:
    """Declare symbols required to check rechability of composition,
    add constraints to encode reachability."""
    # monitors for next assignment equal to next assignment.
    next_curr = ["next({}) = {}".format(s, symb_next(s)) for s in symbols]
    next_curr = " & ".join(next_curr)

    return "  -- components for ensuring reachability of composition\n  " \
        "VAR {0} : integer;\n  " \
        "DEFINE {1} := ({0} > 0);\n  " \
        "INIT {2};\n  " \
        "INVAR {0} >= 0;\n  " \
        "INVAR {1} -> ({3});\n  " \
        "TRANS {1} -> (next({0}) = {0} -1 & {4});\n  " \
        "TRANS (!{1}) -> next({0}) = 0;\n" \
        .format(PREFIX_LENGTH, IS_PREFIX, init, transition, next_curr)


def _state_var_declaration(symbols: list) -> str:
    """String corresponding to the smv declaration of the state variables"""
    symb_type_list = [(to_smv(s), s.get_type())
                      for s in symbols]
    res = "  -- declare state variables\n  VAR\n"
    next_vars = ""
    for n, t in symb_type_list:
        smv_type = _type_to_smv_type(t)
        res += "    {} : {};\n".format(n, smv_type)
        next_vars += "    {} : {};\n".format(symb_next(n), smv_type)
    res += next_vars
    return res


def _declare_automata(automata: list) -> str:
    symbols = []
    if automata:
        symbols = automata[0].all_symbols
    params = "!{}, ".format(IS_PREFIX)
    params += ", ".join(to_smv(s) for s in symbols)
    params += ", "
    params += ", ".join(symb_next(to_smv(s)) for s in symbols)

    res = "  -- instantiate components\n  VAR\n"
    for a in automata:
        res += "    {2}{0} : {0}({1});\n".format(a.name, params,
                                                 MODULE_INST_PREF)
    return res


def _composition_partitions_symbols(automata: list) -> str:
    """Returns invariant requiring the list of symbols handled by the enabled
    automata to be a partitioning of the whole set of symbols"""
    num_symbols = 0
    if not automata:
        return ""
    num_symbols = len(automata[0].all_symbols)
    if num_symbols == 0:
        return ""

    union_target_mask = "0b_{}".format("1" * num_symbols)
    intersection_target_mask = "0b_{}".format("0" * num_symbols)
    symb_masks_list = ["{}{}.{}".format(MODULE_INST_PREF, a.name, SYMB_MASK)
                       for a in automata]

    union = "  -- union of components' symbols must cover all symbols\n" \
            "  INVAR {} = ({});\n".format(union_target_mask,
                                          " | ".join(symb_masks_list))
    intersection = "  -- components' symbols must not intersect\n"
    for mask0, mask1 in combinations(symb_masks_list, 2):
        intersection += "  INVAR {} = ({} & {});\n"\
                        .format(intersection_target_mask,
                                mask0, mask1)

    return union + intersection
