#! /usr/env python3

from timeout_decorator import timeout, TimeoutError

from smv_printer import to_smv

from smv_prefixes import NEXT_MONITOR_PREFIX
from pysmt.environment import get_env
from pysmt.fnode import FNode
from pysmt.shortcuts import substitute, And
from pysmt.smtlib.script import smtlibscript_from_formula
import time
import io

VERBOSE = False
def set_verbosity(b: bool):
    global VERBOSE
    VERBOSE = b


def get_verbosity() -> bool:
    global VERBOSE
    return VERBOSE


SOLVERS_LIST = ["msat", "z3"]
def set_solvers(names: list):
    global SOLVERS_LIST
    SOLVERS_LIST = names


def get_solvers() -> list:
    global SOLVERS_LIST
    return SOLVERS_LIST


# pysmt with mathsat does not support NIRA.
SOLVER_NAME = "msat"
def set_solver(name: str):
    global SOLVER_NAME
    SOLVER_NAME = name


def get_solver() -> str:
    global SOLVER_NAME
    return SOLVER_NAME


try:
    import mathsat
    import pysmt.shortcuts

    from pysmt.solvers.msat import MathSAT5Solver, MSatConverter
    from pysmt.logics import PYSMT_QF_LOGICS

    class MSatConverterExt(MSatConverter):
        def walk_times(self, formula, args, **kwargs):
            res = args[0]
            for x in args[1:]:
                res = mathsat.msat_make_times(self.msat_env(), res, x)
            return res

        def walk_div(self, formula, args, **kwargs):
            res = mathsat.msat_make_divide(self.msat_env(), args[0], args[1])
            return res

    MathSAT5Solver.LOGICS = PYSMT_QF_LOGICS

    _t = [0]
    _DEBUG_MSAT = False

    def Solver(name=None, logic=None, **kwargs):
        if name is None:
            name = SOLVER_NAME
        msat_opts = {'preprocessor.simplification': '24',
                     'theory.na.div_by_zero_mode': '0'}
        if _DEBUG_MSAT:
            msat_opts = dict(msat_opts, **{
                'debug.api_call_trace': '1',
                'debug.api_call_trace_filename': '/tmp/trace-%d.smt2' % _t[0],
                'debug.api_call_trace_dump_config': 'false'
                })
            _t[0] += 1
        if name == 'msat':
            kwargs['solver_options'] = msat_opts
        ret = pysmt.shortcuts.Solver(name, logic, **kwargs)
        if ret and name == 'msat':
            ret.converter = MSatConverterExt(get_env(), ret.msat_env)
        return ret

except ImportError:
    import pysmt.shortcuts

    class MathSAT5Solver:
        pass

    def Solver(name=None, logic=None, **kwargs):
        if name is None:
            name = SOLVER_NAME
        return pysmt.shortcuts.Solver(name, logic, **kwargs)

try:
    from pysmt.solvers.z3 import Z3Solver
except ImportError:
    class Z3Solver:
        pass


def pysmt_dump_whole_expr() -> None:
    # workaround to print whole expressions.
    FNode.__str__ = FNode.serialize


def to_smt2(*formulas):
    from pysmt.logics import QF_NRA
    script = smtlibscript_from_formula(And(*formulas), logic=QF_NRA)
    buf = io.StringIO()
    script.serialize(buf)
    return buf.getvalue()


def solve_with_timeout(timeout_sec, solver, assumptions=None):
    if isinstance(solver, MathSAT5Solver):
        start = time.time()
        count = [0]

        def ttest():
            count[0] += 1
            if count[0] == 100:
                count[0] = 0
                cur = time.time()
                return int(cur - start > timeout_sec)
            else:
                return 0
        mathsat.msat_set_termination_test(solver.msat_env(), ttest)
        return solver.solve(assumptions)
    elif isinstance(solver, Z3Solver):
        solver.z3.set(timeout=timeout_sec*1000)
        return solver.solve(assumptions)
    else:
        @timeout(timeout_sec, timeout_exception=TimeoutError, use_signals=False)
        def call():
            return solver.solve(assumptions)
        return call()


def reset_after_timeout(solver, *assertions):
    solver.reset_assertions()
    for assertion in assertions:
        solver.add_assertion(assertion)
        solver.push()


def symb_next(symb: str) -> str:
    """return smv monitor symbol for next assignment of input symb"""
    return "{}{}".format(NEXT_MONITOR_PREFIX, symb)


def symb_is_next(symb: str) -> bool:
    """True iff symb refers to next assignment"""
    return symb.startswith(NEXT_MONITOR_PREFIX)


def symb_next_to_curr(symb: str) -> str:
    """return smv monitor symbol for current assignment of input symb"""
    assert symb_is_next(symb)
    return symb[len(NEXT_MONITOR_PREFIX):]


def var_to_next(s: FNode) -> FNode:
    """Get monitor for next(s)"""
    str_s = to_smv(s)
    assert not symb_is_next(str_s)
    fm = get_env().formula_manager
    return fm.get_symbol(symb_next(str_s))


def to_next(expr: FNode, symbols: list) -> FNode:
    """Replace symbols with the corresponding monitor for next(symbol)"""
    fm = get_env().formula_manager
    subs = {}
    for s in symbols:
        subs[s] = fm.get_symbol(symb_next(to_smv(s)))
    return substitute(expr, subs)


def var_to_curr(n_s: FNode) -> FNode:
    """Get current assignment symbol"""
    str_n_s = to_smv(n_s)
    assert symb_is_next(str_n_s)
    fm = get_env().formula_manager
    return fm.get_symbol(symb_next_to_curr(str_n_s))


def to_curr(expr: FNode, symbols: list) -> FNode:
    """Replace next symbols with current symbols"""
    fm = get_env().formula_manager
    subs = {}
    for s in symbols:
        n_s = fm.get_symbol(symb_next(to_smv(s)))
        subs[n_s] = s
    return substitute(expr, subs)
