from .. import parser
from .. import parser_tools as parset
from .. import AbstractBaseClass

from ..bridges import SamplerBridge
from ..bridges.sampling_bridges import StopSamplingException
from ..parser_tools import main_register, ArgumentException
from ..variable import Variable

import abc
import time
import sys


class InvalidMethodCallException(Exception):
    pass


class Sampler(AbstractBaseClass):
    """
    Base class for all sampler.
    Do not forget to register your network subclass in this packages 'register'
    dictionary via 'append_register' of the main package.
    """

    arguments = parset.ClassArguments(
        "Sampler", None,
        ("sampler_bridge", False, None,
         main_register.get_register(SamplerBridge)),
        ('timeout', True, None, int,
         "Maximum time in seconds to run a single sample call. "
         "This is a soft timeout and the concrete samplers use "
         "have to support the timeouts."),
        ('variables', True, None, main_register.get_register(Variable)),
        ('id', True, None, str),
        variables=[('sample_calls', 0, int)],
)

    def __init__(self, sampler_bridge, timeout=None, variables=None, id=None):
        variables = {} if variables is None else variables
        if not isinstance(variables, dict):
            raise ArgumentException("The provided variables have to be a map. "
                                    "Please define them as {name=VARIABLE,...}.")
        if not isinstance(sampler_bridge, list):
            sampler_bridge = [sampler_bridge]
        self.sbridges = sampler_bridge
        self.stopped_bridges = None
        self.start_sampling_time = None
        self.max_sampling_time = timeout
        self.variables = {} if variables is None else variables
        self.id = id

        self.var_sample_calls, = Sampler.arguments.validate_and_return_variables(variables)

        self.out_log = None  # Message objects for messages from this sampler
        self.in_logs = None  # Dictionary of message objects for ingoing communication
        self.initialized = False
        self.finalized = False

    def initialize(self, out_log=None, in_logs=None):
        if not self.initialized:
            self.in_logs = in_logs
            self.out_log = out_log
            for sb in self.sbridges:
                sb.initialize()
            self.stopped_bridges = [False for _ in self.sbridges]
            self._initialize()
            self.initialized = True
        else:
            raise InvalidMethodCallException("Multiple initializations of"
                                             "sampler.")

    def timed_out(self):
        if self.start_sampling_time is None:
            raise InvalidMethodCallException("Cannot check timed out status if"
                                             "sampling is not running.")
        if self.max_sampling_time is None:
            return False
        return time.time() - self.start_sampling_time > self.max_sampling_time

    def _call_bridge_sample(self, problem, **kwargs):
        check_timeouts = kwargs.pop("check_timeouts", False)
        do_merge = kwargs.pop("do_merge", False)
        merge_container = kwargs.pop("merge_container", None)

        datas = []
        for no, sb in enumerate(self.sbridges):
            if self.stopped_bridges[no]:
                continue
            if check_timeouts and self.timed_out():
                break
            try:
                sys.stdout.flush()
                datas.append(sb.sample(
                    problem, data_container=merge_container if do_merge else None,
                    **kwargs))
                if do_merge and datas[-1] is not None:
                    merge_container = datas[-1]
            except StopSamplingException:
                self.stopped_bridges[no] = True
                if all(self.stopped_bridges):
                    raise StopSamplingException(
                        "All SamplingBridges have stopped")


        return datas, merge_container

    def sample(self, **kwargs):
        if not self.initialized:
            raise InvalidMethodCallException("Cannot call sample without "
                                             "initializing the sampler.")
        self.start_sampling_time = time.time()
        if self.var_sample_calls is not None:
            self.var_sample_calls.value += 1

        datas = self._sample(**kwargs)

        i = len(datas) - 1
        while i >= 0:
            if datas[i] is None or datas[i].empty():
                del datas[i]
            i -= 1
        self.start_sampling_time = None
        return datas

    def _problem_path_printer(self, problem):
        print("Next problem: %s" % problem)

    def _sample_with_timeout_skeleton(self, problem_generator,
                                      problem_printer=None,
                                      bridge_kwargs=None,
                                      sampling_callback=None,
                                      verbose=0):
        bridge_kwargs = {} if bridge_kwargs is None else bridge_kwargs
        for problem in problem_generator:
            if verbose > 0 and problem_printer is not None:
                problem_printer(problem)

            if self.timed_out():
                if verbose > 0:
                    print("SamplingTimeout")
                break
            try:
                sample_output = self._call_bridge_sample(
                    problem,
                    check_timeouts=self.max_sampling_time is not None,
                    **bridge_kwargs)
                if sampling_callback is not None:
                    sampling_callback(sample_output)
            except StopSamplingException:
                break


    def finalize(self):
        if not self.initialized:
            raise InvalidMethodCallException("Cannot call finalize the sampler"
                                             " without initializing first.")
        if not self.finalized:
            for sb in self.sbridges:
                sb.finalize()
            self._finalize()
            self.finalized = True
        else:
            raise InvalidMethodCallException("Mutliple finalization calls of"
                                             "sampler.")

    @abc.abstractmethod
    def _initialize(self):
        pass

    @abc.abstractmethod
    def _sample(self, **kwargs):
        pass

    @abc.abstractmethod
    def _finalize(self):
        pass

    @staticmethod
    def parse(tree, item_cache):
        obj = parser.try_lookup_obj(tree, item_cache, Sampler, None)
        if obj is not None:
            return obj
        else:
            raise ArgumentException("The definition of the base sampler can "
                                    "only be used for look up of any previously"
                                    " defined schema via 'Sampler(id=ID)'")


main_register.append_register(Sampler, "sampler")
saregister = main_register.get_register(Sampler)

