#! /usr/bin/env python3

import sys
import os
import re


def _get_pysmt_type(smv_type: str) -> str:
    int_range_re = re.compile(r"(-)?\d+\.\.(-)?\d+")
    if smv_type == "boolean":
        return "BOOL"
    if smv_type == "integer" or int_range_re.fullmatch(smv_type):
        return "INT"
    if smv_type == "real":
        return "REAL"

    assert False, "unknown type: {}".format(smv_type)
    return None


def _get_pysmt_op(_op: str) -> str:
    op = _op.strip()
    if op == ">":
        return "GT"
    if op == ">=":
        return "GE"
    if op == "=":
        return "Equals"
    if op == "<=":
        return "LE"
    if op == "<":
        return "LT"
    if op == "+":
        return "Plus"
    if op == "-":
        return "Minus"
    if op == "*":
        return "Times"
    if op == "/":
        return "Div"
    assert False, "unknown operator: {}".format(op)


def _parse_symbols(smv_f: str) -> list:
    """Return list of symbols of the model"""
    var_decl_re = re.compile(r"(?P<symb>\S+) : (?P<type>\S+);")
    var_found = False
    symbols = []
    with open(smv_f, 'r') as smv:
        for line in smv:
            line = line.strip()
            if "VAR" in line:
                var_found = True
            if "INIT" in line or "TRANS" in line or "ASSIGN" in line or \
               "FAIRNESS" in line:
                var_found = False
            if var_found:
                m = var_decl_re.match(line)
                if m and "symb" in m.groupdict():
                    symb = m.group("symb")
                    s_type = m.group("type")
                    assert m.group("type"), "{}".format(m.groupdict())
                    s_type = _get_pysmt_type(s_type)
                    symbols.append((symb, s_type))
    return symbols


def _parse_init(smv_f: str) -> list:
    retval = []
    with open(smv_f, 'r') as smv:
        for line in smv:
            line = line.strip()
            if line.startswith("INIT"):
                retval.append(line[4:])
    return retval


def _parse_cfg(smv_f: str) -> list:
    assign_found = False
    in_case = False
    # match init of pc.
    init_re = re.compile(r"init\(pc\) := (?P<init_loc>\d+);")
    # match conditional assign for pc.
    case_re = re.compile(r"pc = (?P<src>(-)?\d+) "
                         r"(?P<extra>.+)?: {?(?P<dst0>(-)?\d+)(, )?"
                         r"(?P<dst1>(-)?\d+)?}?;")

    extra_re = re.compile(r"(?P<not>!)?\s*\(*\s*"
                          r"(?P<symb>\S+)\s*"
                          r"(?P<op>(>=)|(>)|(=)|(<=)|(<)|(!=))\s*"
                          r"(?P<arg>[^\)]+)\)*")
    init_out = []
    cfg_out = []
    with open(smv_f, 'r') as smv:
        for line in smv:
            line = line.strip()
            if "ASSIGN" in line:
                assign_found = True
                if "INIT" in line or "TRANS" in line or "VAR" in line or \
                   "FAIRNESS" in line:
                    assign_found = False
            if assign_found:
                if "case" in line:
                    in_case = True
                if "esac" in line:
                    in_case = False
                if init_re.fullmatch(line):
                    init_loc = init_re.fullmatch(line).group("init_loc")
                    init_out.append("init = Equals(pc, Int({}))"
                                    .format(init_loc))
                elif in_case and line and line != "case":
                    m = case_re.fullmatch(line)
                    assert m, "`{}`".format(line)
                    src = int(m.group("src"))
                    dst0 = int(m.group("dst0"))
                    dst1 = None
                    if "dst1" in m.groupdict() and m.group("dst1"):
                        dst1 = m.group("dst1")
                    _extra = None
                    if "extra" in m.groupdict() and m.group("extra"):
                        _extra = m.group("extra")
                    prec = "Equals(pc, Int({}))".format(src)
                    if _extra:
                        _extra = _extra.strip()
                        if _extra.startswith("&"):
                            _extra = _extra[1:].strip()
                        m = extra_re.fullmatch(_extra)
                        if not m:
                            assert False, "`{}`".format(_extra)
                        symb = m.group("symb")
                        arg = m.group("arg")
                        try:
                            arg = int(arg)
                            arg = "Int({})".format(arg)
                        except ValueError:
                            arg = "`{}`".format(arg)
                        op = _get_pysmt_op(m.group("op"))
                        extra = "{}({}, {})".format(op, symb, arg)
                        if "not" in m.groupdict() and m.group("not"):
                            extra = "Not({})".format(extra)
                        prec = "And({}, {})".format(prec, extra)
                    post = "Equals(x_pc, Int({}))".format(dst0)
                    if dst1 is not None:
                        post = "Or({}, Equals(x_pc, Int({})))".format(post, dst1)
                    if _extra:
                        line = "# pc = {} & {} : {}".format(src, _extra, dst0)
                    elif dst1 is None:
                        line = "# pc = {} : {}".format(src, dst0)
                    else:
                        line = "# pc = {} : {{{}, {}}}".format(src, dst0, dst1)
                    cfg_out.append(line)
                    cfg_out.append("Implies({}, {})".format(prec, post))
        return init_out + cfg_out


def _parse_labels(smv_f: str, symbols) -> list:
    def _get_replace_fun(d: dict):
        return lambda m: d[re.escape(m.group(0))]
    # parse precodition: describe edge from src to dst.
    prec_re = re.compile(r"\(?pc = (?P<src>(-)?\d+)( & "
                         r"next\(pc\) = (?P<dst>(-)?\d+))?\)?")
    # parse expression
    post_re = re.compile(r"(?P<symb>[^ \t\n\r\f\v=]+)\s*\=\s*"
                         r"(?P<lhs>[^ \t\n\r\f\v+\-*/]+)\s*"
                         r"((?P<op>(\+)|(-)|(\*)|(/))\s*"
                         r"(?P<rhs>\S+))?")
    # replace next(symbol) with x_s
    rep_next = dict()
    for s, _ in symbols:
        rep_next["next({})".format(s)] = "x_{}".format(s)
    rep_next = dict((re.escape(k), v) for k, v in rep_next.items())
    rep_next_re = re.compile("|".join(rep_next.keys()))
    # replace next(symbol) with symbol'
    rep_next_comm = dict()
    for s, _ in symbols:
        rep_next_comm["next({})".format(s)] = "{}'".format(s)
    rep_next_comm = dict((re.escape(k), v)
                         for k, v in rep_next_comm.items())
    rep_next_comm_re = re.compile("|".join(rep_next_comm.keys()))
    retval = []
    with open(smv_f, 'r') as smv:
        for line in smv:
            line = line.strip()
            if line.startswith("TRANS"):
                line = line[5:].strip()
                assert "->" in line, "`{}`".format(line)
                prec, post = line.split("->")
                prec = prec.strip()
                post = post.strip()
                m = prec_re.fullmatch(prec)
                if not m:
                    assert False, "`{}`".format(prec)
                src = int(m.group("src"))
                dst = int(m.group("dst"))
                prec = "And(Equals(pc, Int({})), Equals(x_pc, Int({})))"\
                       .format(src, dst)
                post = rep_next_re.sub(_get_replace_fun(rep_next), post)
                if post.endswith(";"):
                    post = post[:-1]
                _post = [p.strip() for p in post.split("&")]
                post = []
                for p in _post:
                    if p.startswith("("):
                        p = p[1:]
                    if p.endswith(")"):
                        p = p[:-1]
                    m = post_re.fullmatch(p.strip())
                    if not m:
                        symb, snd = p.split("=", 1)
                        symb = symb.strip()
                        snd = "`{}`".format(snd.strip())
                    else:
                        symb = m.group("symb")
                        lhs = m.group("lhs")
                        try:
                            lhs = int(lhs)
                            lhs = "Int({})".format(lhs)
                        except ValueError:
                            pass
                        snd = lhs
                        if "op" in m.groupdict() and m.group("op"):
                            op = _get_pysmt_op(m.group("op"))
                            rhs = m.group("rhs")
                            try:
                                rhs = int(rhs)
                                rhs = "Int({})".format(rhs)
                            except ValueError:
                                pass
                            snd = "{}({}, {})".format(op, lhs, rhs)

                    post.append("Equals({}, {})".format(symb, snd))

                post = "And({})".format(", ".join(post))
                post = "{}".format(post)
                if line.endswith(';'):
                    line = line[:-1]
                line = rep_next_comm_re.sub(_get_replace_fun(rep_next_comm),
                                            line)
                retval.append("# {}".format(line))
                retval.append("Implies({}, {})".format(prec, post))
    return retval


def _parse_fairness(smv_f: str) -> list:
    retval = []
    with open(smv_f, 'r') as smv:
        for line in smv:
            line = line.strip()
            if line.startswith("FAIRNESS"):
                retval.append(line[len("FAIRNESS"):])
    return retval


def init_cfg_labels_fair_from_smv(smv_f: str) -> tuple:
    symbols = _parse_symbols(smv_f)
    init = _parse_init(smv_f)
    cfg = _parse_cfg(smv_f)
    labels = _parse_labels(smv_f, symbols)
    fairness = _parse_fairness(smv_f)
    return symbols, init, cfg, labels, fairness


def extract_script_from_smv(f_name: str, buf) -> None:
    smv_f = os.path.abspath(f_name)
    assert os.path.isfile(smv_f), "not a file: {}".format(smv_f)
    symbols, init, cfg, labels, fairness = init_cfg_labels_fair_from_smv(smv_f)
    symbols.sort()
    buf.write("""from find_composition import find_composition
from automata_composition import AGAutomaton
from utils import symb_next

from pysmt.shortcuts import Symbol, TRUE, Int
from pysmt.shortcuts import Not, And, Or, Implies
from pysmt.shortcuts import GT, GE, Equals, LE, LT
from pysmt.shortcuts import Plus, Minus, Times
from pysmt.typing import INT


def test(nuxmv_path: str, model_file: str, trace_file: str, cmd_file: str,
         output_file: str) -> bool:
    \"\"\"Test corresponding to benchmarks/{}\"\"\"\n""".format(f_name))

    buf.write("    # symbols.\n")
    for symb, symb_t in symbols:
        buf.write('    {0} = Symbol("{0}", {1})\n'.format(symb, symb_t))
    for symb, symb_t in symbols:
        buf.write('    x_{0} = Symbol(symb_next("{0}"), {1})\n'
                  .format(symb, symb_t))
    buf.write("    symbols = [{}]\n"
              .format(", ".join([s for s, _ in symbols])))

    if init:
        buf.write("\n    # init.\n")
        for i in init:
            buf.write("    INIT {}\n".format(i))

    buf.write("\n    # initial location.\n")
    buf.write("    {}\n".format(cfg[0]))

    buf.write("\n    # control flow graph.\n")
    buf.write("    cfg = And(\n        {}\n    )\n"
              .format(",\n        ".join(cfg[1:])))

    buf.write("\n    # transition labels.\n")
    buf.write("    labels = And(\n        {}\n    )\n"
              .format(",\n        ".join(labels)))

    buf.write("\n    # transition relation.\n")
    buf.write("    trans = And(cfg, labels)\n")

    buf.write("\n    # fairness.\n")
    if fairness:
        for fair in fairness:
            buf.write("    {}\n".format(fair))
    else:
        buf.write("    fairness = Not(Equals(pc, Int(-1)))\n")

    buf.write("\n    # define automata to be composed.\n")
    for s, _ in symbols:
        buf.write("    aut_{0} = AGAutomaton(symbols, [{0}], "
                  "\"aut_{0}\", 1)\n    "
                  "aut_{0}.set_assume(0, TRUE())\n    "
                  "aut_{0}.set_invar(0, TRUE())\n    "
                  "aut_{0}.set_transitions(0, [(0, [Equals(x_{0}, {0})])])\n\n"
                  .format(s))

    buf.write("    automata = [{}]\n"
              .format(", ".join(["aut_{}".format(s) for s, _ in symbols])))

    buf.write("\n    # search composition.\n")
    buf.write("    comp, undefs = find_composition(automata, init, trans, "
              "fairness,\n                                    "
              "nuxmv_path, model_file, trace_file,\n"
              "                                    cmd_file)\n")
    buf.write("""    if comp is not None:
        with open(output_file, 'w') as out:
            out.write(str(comp))

    if comp and not undefs:
        res = True
    elif not comp:
        res = False
    else:
        assert not comp and undefs
        res = None
    return res\n""")
    return 0


def main(argv: list):
    """Extract INIT, TRANS, FAIRNESS from smv file"""
    if len(argv) < 2 or len(argv) > 3:
        print("Usage: {} <smv_file> [out_file=stdout]".format(argv[0]))
        return 1
    res = 0
    if len(argv) > 2:
        out_f = os.path.abspath(argv[2])
        with open(out_f, 'w') as out_buf:
            res = extract_script_from_smv(argv[1], out_buf)
    else:
        res = extract_script_from_smv(argv[1], sys.stdout)
    return res


if __name__ == "__main__":
    main(sys.argv)
