#! /usr/bin/env python3

import re

from automata_composition import AGAutomaton
from smv_prefixes import MODULE_INST_PREF, UAPPR_MASK


class Trace:
    """Represent trace of composed model"""

    def __init__(self):
        self._states = []

    def parse_trace(self, buf):
        begin_state = re.compile(r"->\s+State:\s+\d+\.\d+\s+<-")
        assignment = re.compile(r"(?P<symb>[\S]+)\s*="
                                r"\s*(?P<val>[\S]+)")
        integer = re.compile(r"\d+")
        fract = re.compile(r"f'(?P<num>\d+)/(?P<den>\d+)")
        decimal = re.compile(r"\d+\.\d*")
        for line in buf:
            line = line.strip()
            if begin_state.match(line):
                self._states.append({})
            else:
                m = assignment.match(line)
                if m:
                    key = m.group("symb")
                    val = m.group("val")
                    if val == "TRUE" or val == "TOP":
                        val = True
                    elif val == "FALSE" or val == "BOT":
                        val = False
                    elif val == "UNDEF":
                        val = None
                    elif integer.fullmatch(val):
                        val = int(val)
                    elif decimal.fullmatch(val):
                        val = float(val)
                    else:
                        m = fract.fullmatch(val)
                        if m:
                            val = int(m.group("num")) / int(m.group("den"))
                    self._states[-1][key] = val

    def __str__(self) -> str:
        def _str_assignments(state, index):
            # max_key = max([len(key) for key in state])
            yield "-> State: {} <-".format(index)
            for key, val in sorted(state.items(), key=lambda m: m[0]):
                yield "  {} = {}".format(key, val)

        return "\n".join(["\n".join([s for s in _str_assignments(state, idx)])
                          for idx, state in enumerate(self._states)])

    def __len__(self):
        return len(self._states)

    def __iter__(self):
        return self._states.__iter__()

    def __getitem__(self, key):
        return self._states[key]

    def __setitem__(self, key, value):
        assert False

    def __delitem__(self, key):
        assert False

    # END OF TRACE


def _get_trans_mask(a: AGAutomaton, f) -> list:
    """Get transition mask of automaton `a` at frame `f`"""
    return [f["{}{}.{}[{}]".format(MODULE_INST_PREF, a.name, UAPPR_MASK, i)]
            for i in range(len(a.get_uappr_masks(0)[0]))]


def _filter_trans(a: AGAutomaton, src: int, dst: int,
                  trans_mask: list) -> list:
    return [t for (t_dst, t), m in zip(a.get_transitions(src),
                                       a.get_uappr_masks(src))
            if t_dst == dst and m == trans_mask]
