#! /usr/bin/env python3

# EF-SMT solver implementation

from utils import Solver, solve_with_timeout, to_smt2, get_verbosity
from pysmt.shortcuts import Not
from pysmt.logics import AUTO
from pysmt.fnode import FNode
from pysmt.exceptions import SolverReturnedUnknownResultError
# from pysmt.walkers import DagWalker, IdentityDagWalker
# import pysmt.solvers.solver as PySolver
# from pysmt.typing import BOOL
# from pysmt.shortcuts import TRUE, And

TIMEOUT = 5
def set_timeout(val: int):
    global TIMEOUT
    TIMEOUT = val


MAX_LOOPS = 50
def set_maxloops(val: int):
    global MAX_LOOPS
    MAX_LOOPS = val


def efsmt(x2: list, phi: FNode, logic=AUTO,
          esolver_name=None, fsolver_name=None):
    x2 = set(x2)
    x1 = phi.get_free_variables() - x2
    return _efsmt(x1, x2, phi, logic=logic,
                  esolver_name=esolver_name, fsolver_name=fsolver_name)


def _efsmt(x1: set, x2: set, phi: FNode, logic=AUTO,
           esolver_name=None, fsolver_name=None):
    """Solves exists x1. forall x2. phi(x1, x2)"""

    with Solver(logic=logic, name=esolver_name) as esolver:

        # esolver.add_assertion(Bool(True))
        esolver.add_assertion(phi)
        loops = 0
        while MAX_LOOPS is None or loops <= MAX_LOOPS:
            loops += 1

            try:
                eres = solve_with_timeout(TIMEOUT, esolver)
            except SolverReturnedUnknownResultError:
                eres = None

            if eres is None and get_verbosity():
                print("\t\tE(F)-SMT timeout:\n{}"
                      .format(to_smt2(esolver.assertions)))
                return None

            if eres is not True:
                return eres

            # eres is True
            tau = {v: esolver.get_value(v) for v in x1}

            sub_phi = phi.substitute(tau).simplify()
            # if get_verbosity():
            #     print("%d: guess1 = %s" % (loops, tau))

            with Solver(logic=logic, name=fsolver_name) as fsolver:
                fsolver.add_assertion(Not(sub_phi))
                try:
                    fres = solve_with_timeout(TIMEOUT, fsolver)
                except SolverReturnedUnknownResultError:
                    fres = None

                if fres is None:
                    if get_verbosity():
                        print("\t\t(E)F-SMT timeout:\n{}"
                              .format(to_smt2(fsolver.assertions)))
                    return None

                if fres is False:
                    return tau
                # fres is True
                sigma = {v: fsolver.get_value(v) for v in x2}
                sub_phi = phi.substitute(sigma).simplify()
                # if get_verbosity():
                #     print("%d: counterexample1 = %s" % (loops, sigma))
                esolver.add_assertion(sub_phi)

        if get_verbosity():
            print("\tEF-solver reached max number of iterations on: {}"
                  .format(to_smt2(esolver.assertions)))
        return None


def efesmt(x1: list, x2: list, phi: FNode, logic=AUTO,
           esolver_name=None, fsolver_name=None):
    """Solves exists x0. forall x1. exists x2. phi(x0, x1, x2)"""
    x1 = set(x1)
    x2 = set(x2)
    x0 = (phi.get_free_variables() - x1) - x2
    return _efesmt(x0, x1, x2, phi, logic=logic,
                   esolver_name=esolver_name, fsolver_name=fsolver_name)


def _efesmt(x0: set, x1: set, x2: set, phi: FNode, logic=AUTO,
            esolver_name=None, fsolver_name=None):
    """Solves exists x0. forall x1. exists x2. phi(x0, x1, x2)"""
    with Solver(logic=logic, name=esolver_name) as esolver:

        esolver.add_assertion(phi)
        # esolver.add_assertion(Bool(True))
        loops = 0
        while MAX_LOOPS is None or loops <= MAX_LOOPS:
            loops += 1

            try:
                eres = solve_with_timeout(TIMEOUT, esolver)
            except SolverReturnedUnknownResultError:
                eres = None

            if eres is None and get_verbosity():
                print("\t\tE(FE)-SMT timeout:\n{}"
                      .format(to_smt2(esolver.assertions)))

            if eres is not True:
                return eres
            # eres is True
            tau = {v: esolver.get_value(v) for v in x0}
            sub_phi = phi.substitute(tau).simplify()
            # if get_verbosity():
            #     print("%d: guess0 = %s" % (loops, tau))
            fmodel = _efsmt(x1, x2, Not(sub_phi), logic=logic,
                            esolver_name=esolver_name,
                            fsolver_name=fsolver_name)
            if fmodel is False:
                return tau
            elif fmodel is None:
                return None
            else:
                sigma = {v: fmodel[v] for v in x1}
                sub_phi = phi.substitute(sigma).simplify()
                # if get_verbosity():
                #     print("%d: counterexample0 = %s" % (loops, sigma))
                esolver.add_assertion(sub_phi)

        if get_verbosity():
            print("\tEFE-solver reached max number of iterations on: {}"
                  .format(to_smt2(esolver.assertions)))

        return None


# def generalise(phi: FNode, model: PySolver) -> FNode:
#     generaliser = Generaliser(phi, model)
#     return generaliser.generalise()


# class Generaliser(DagWalker):
#     def __init__(self, phi: FNode, model: PySolver,
#                  env=None, invalidate_memoization=False):
#         DagWalker.__init__(self, env=env,
#                            invalidate_memoization=invalidate_memoization)
#         self._model = model
#         self._phi = phi
#         self._res = TRUE()
#         return

#     def _get_key(self, formula, **kwargs):
#         return formula

#     def _push_with_children_to_stack(self, pair, **kwargs):
#         phi, polarity = pair
#         model = self._model

#         self.stack.append((True, phi))
#         to_push = []

#         if phi.get_type() == BOOL:
#             assert not phi.is_quantifier(), "quantifiers are not supported"
#             assert model.get_py_value(phi) == polarity, \
#                 "Expected: {}, got: {}".format(polarity, model.get_value(phi))
#             if (phi.is_and() and polarity) or \
#                (phi.is_or() and not polarity):
#                 to_push.extend([(s, polarity) for s in phi.args()])

#             elif (phi.is_and() and not polarity) or \
#                  (phi.is_or() and polarity):
#                 to_push.append((next(s for s in phi.args()
#                                      if model.get_py_value(s) is False),
#                                 polarity))
#             elif phi.is_not():
#                 to_push.append((phi.args()[0], not polarity))

#             elif phi.is_implies() and polarity:
#                 assert len(phi.args()) == 2
#                 lhs = phi.args()[0]
#                 rhs = phi.args()[1]
#                 if model.get_py_value(lhs):
#                     to_push.extend([(lhs, True), (rhs, True)])
#                 else:
#                     to_push.append((lhs, False))

#             elif phi.is_implies() and not polarity:
#                 assert len(phi.args()) == 2
#                 lhs = phi.args()[0]
#                 rhs = phi.args()[1]
#                 to_push.extend([(lhs, True), (rhs, False)])

#             elif phi.is_iff() and polarity:
#                 assert len(phi.args()) == 2
#                 lhs = phi.args()[0]
#                 rhs = phi.args()[1]
#                 if model.get_py_value(lhs):
#                     to_push.extend([(lhs, True), (rhs, True)])
#                 else:
#                     to_push.extend([(lhs, False), (rhs, False)])

#             elif phi.is_iff() and not polarity:
#                 assert len(phi.args()) == 2
#                 lhs = phi.args()[0]
#                 rhs = phi.args()[1]
#                 if model.get_py_value(lhs):
#                     to_push.extend([(lhs, True), (rhs, False)])
#                 else:
#                     to_push.extend([(lhs, False), (rhs, True)])

#             for s in to_push:
#                 # Add only if not memoized already
#                 key = self._get_key(s, **kwargs)
#                 if key not in self.memoization:
#                     self.stack.append((False, s))

#         else:  # skip non-boolean operators
#             key = self._get_key(pair)
#             leaf = phi if polarity else Not(phi)
#             self.memoization[key] = leaf
#             self._res = And(self._res, leaf)

#     def walk(self, phi, **kwargs):
#         pass

#     def generalise(self) -> FNode:
#         self._res = TRUE()
#         self.walk((self._phi, True))
#         return self._res
