import keras
from keras.callbacks import Callback, ModelCheckpoint, EarlyStopping

from . import keras_layers

from ... import misc
from ... import parser
from ... import parser_tools as parset
from ... import AbstractBaseClass


from ...parser_tools import main_register, ArgumentException

import abc
import math
import os
import shutil
import time


class BaseKerasCallback(Callback):
    arguments = parset.ClassArguments("BaseKerasCallback", None,
                                      ('id', True, None, str))

    flags = property(lambda self: set())

    def __init__(self, id=None):
        Callback.__init__(self)
        self._id = id

    def setup(self, network, *args, **kwargs):
        """
        Prepares the Callback function before it is used for training. If using
        this callback objects for multiple trainings, the setup is called before
        each of them. It has the same purpose as on_training_begin of keras,
        but allows us to feed arguments.
        :param network: KerasNetwork object for which this callback shall be
                        used
        :return:
        """
        pass

    def finalize(self, network, *args, **kwargs):
        """

        :param network: instance of the KerasNetwork which used this callback
                        during training
        :param args:
        :param kwargs:
        :return:
        """
        pass

    """DEFAULTS FOR SPECIAL METHODS THE DIFFERENT CALLBACKS NEED"""
    def training_failed(self):
        return False

    def shall_reinitialize(self, training_failed=None):
        """
        :param training_failed: boolean informing if the training failed or None
                                if unknown. Convention: Interpret unknown as
                                not failed (= successful) unless there is a
                                good reason for you case.
        :return: True = yes; None = no decision; False = no. Priority order:
                 None < True < False
        """
        return None

    def supports_reinforcement_learning(self):
        return False

    @staticmethod
    def parse(tree, item_cache):
        obj = parser.try_lookup_obj(tree, item_cache, BaseKerasCallback, None)
        if obj is not None:
            return obj
        else:
            raise ArgumentException("The definition of the base network can "
                                    "only be used for look up of any previously"
                                    " defined schema via 'Sampler(id=ID)'")


main_register.append_register(BaseKerasCallback, "keras_callback")


class KerasCallbackWrapper(BaseKerasCallback, AbstractBaseClass):
    arguments = parset.ClassArguments(
        'KerasCallbackWrapper', BaseKerasCallback.arguments)

    def __init__(self, id=None):
        self._internal_callback = None
        BaseKerasCallback.__init__(self, id=id)

    def setup(self, network, *args, **kwargs):
        self._setup(*args, **kwargs)

    @abc.abstractmethod
    def _setup(self, *args, **kwargs):
        pass

    """WRAPPING"""

    def _wrap_var_get_model(self):
        return self._internal_callback.model

    def _wrap_var_set_model(self, value):
        if self._internal_callback is not None:
            self._internal_callback.model = value
    model = property(_wrap_var_get_model, _wrap_var_set_model)

    def _wrap_var_get_val_data(self):
        return self._internal_callback.validation_data

    def _wrap_var_set_val_data(self, value):
        if self._internal_callback is not None:
            self._internal_callback.validation_data = value
    validation_data = property(_wrap_var_get_val_data, _wrap_var_set_val_data)

    def set_params(self, params):
        self._internal_callback.set_params(params)

    def set_model(self, model):
        self._internal_callback.set_model(model)

    def on_epoch_begin(self, epoch, logs=None):
        self._internal_callback.on_epoch_begin(epoch, logs)

    def on_epoch_end(self, epoch, logs=None):
        self._internal_callback.on_epoch_end(epoch, logs)

    def on_batch_begin(self, batch, logs=None):
        self._internal_callback.on_batch_begin(batch, logs)

    def on_batch_end(self, batch, logs=None):
        self._internal_callback.on_batch_end(batch, logs)

    def on_train_begin(self, logs=None):
        self._internal_callback.on_train_begin(logs)

    def on_train_end(self, logs=None):
        self._internal_callback.on_train_end(logs)

    @staticmethod
    def parse(tree, item_cache):
        return parser.try_whole_obj_parse_process(tree, item_cache,
                                                  KerasCallbackWrapper)


main_register.append_register(KerasCallbackWrapper, "keras_callback_wrapper")


class KerasRestarter(BaseKerasCallback):
    arguments = parset.ClassArguments(
        'KerasRestarter', BaseKerasCallback.arguments,
        ("restarts", True, -1, int,
         "How often the training may be restarted according to this callback ("
         "after it run out of iterations it does neither plead for restarts nor"
         " for stopping). "
         "Provide a positive integer or 0 to indicate the number of restarts "
         "or -1 for infinite many restarts (other callbacks who prohibit "
         "restarts have priority)"),
        ("stop_successful", True, False, parser.convert_bool,
         "If the training was successful, it does not plead for a retraining ("
         "neither does it forbid retraining)."),
        order=["restarts", "stop_successful", "id"])

    def __init__(self, restarts=-1, stop_successful=False, id=None):
        BaseKerasCallback.__init__(self, id=id)
        self._restarts = restarts
        self._stop_successful = stop_successful
        self._current_round = -1

    def shall_reinitialize(self, training_failed=None):
        if self._stop_successful and (training_failed is False or
                                      training_failed is None):
            return None

        return True if (self._restarts == -1 or
                        self._restarts > self._current_round) else None

    def on_train_begin(self, logs=None):
        self._current_round += 1

    @staticmethod
    def parse(tree, item_cache):
        return parser.try_whole_obj_parse_process(tree, item_cache,
                                                  KerasRestarter)


main_register.append_register(KerasRestarter, "keras_restart", "krestart")


class KerasProgressChecking(BaseKerasCallback):
    arguments = parset.ClassArguments(
        'KerasProgressChecking', BaseKerasCallback.arguments,
        ("monitor", False, None, str, "Metric of the network to monitor"),
        ("epochs", False, None, int, "After how many epochs the check is done"),
        ("threshold", False, None, float, "Threshold value to satisfy"),
        ("minimize", True, True, parser.convert_bool,
         "True = minimize metric else maximize"),
        ("ratio", True, False, parser.convert_bool,
         "True = if initial metric/final metric is below(minimize) resp. "
         "above(maximize)"),
        order=["monitor", "epochs", "threshold", "minimize", "ratio", "id"])

    def __init__(self, monitor, epochs, threshold, minimize=True,
                 ratio=False, id=None):
        BaseKerasCallback.__init__(self, id=id)
        self._monitor = monitor
        self._epochs = epochs
        self._threshold = threshold
        self._minimize = minimize
        self._ratio = ratio
        self._active = True
        self._failed = False

        self._initial_performance = None
        self._last_performance = None

    def do_progress_check(self):
        assert self._initial_performance is not None
        assert self._last_performance is not None

        score = ((self._initial_performance / self._last_performance)
                 if self._ratio else self._last_performance)

        if ((self._minimize and score >= self._threshold) or
                (not self._minimize and score <= self._threshold)):
            self.model.stop_training = True
            self._failed = True
        else:
            self._failed = False

    def training_failed(self):
        return self._failed

    def on_train_begin(self, logs=None):
        self._active = True
        self._failed = False

    def on_train_end(self, logs=None):
        if not self._active:
            return

        # Without a measurement, the network was never trained.
        if self._last_performance is None:
            self._failed = True
        else:
            self.do_progress_check()

    def on_epoch_end(self, epoch, logs=None):
        if not self._active:
            return

        if self._initial_performance is None:
            self._initial_performance = logs.get(self._monitor)
        self._last_performance = logs.get(self._monitor)
        if self._last_performance is None:
            raise RuntimeError("ProgressCheckingCallback requires %s"
                               " available!" % self._monitor)

        if epoch >= self._epochs - 1:
            self.do_progress_check()
            self._active = False

    @staticmethod
    def parse(tree, item_cache):
        return parser.try_whole_obj_parse_process(tree, item_cache,
                                                  KerasProgressChecking)


main_register.append_register(KerasProgressChecking, "keras_progress_checking")


class KerasStopTimer(BaseKerasCallback):
    arguments = parset.ClassArguments(
        'KerasStopTimer', BaseKerasCallback.arguments,
        ("max_time", False, None, int, "Maximum time in seconds after which no"
                                       "new epoch shall be started."),
        ("start_time", True, 0, int, "Initial time in seconds of this timer"),
        ("per_training", True, True, parser.convert_bool,
         "Resets the timer to 'start_time' on each new training with this "
         "object"),
        ("prevent_reinit", True, False, parser.convert_bool,
         "If the timer is timeout, then it prevents a reinitialization of the "
         "training."),
        ("timeout_as_failure", True, False, parser.convert_bool,
         "With this options, a timeout is a training failure (by default a "
         "timeout is not a training failure)."),
        order=["max_time", "start_time", "per_training", "prevent_reinit",
               "timeout_as_failure", "id"])

    def __init__(self, max_time, start_time=0, per_training=True,
                 prevent_reinit=False, timeout_as_failure=False, id=None):
        BaseKerasCallback.__init__(self, id=id)

        self._max_time = max_time
        self._start_time = start_time
        if start_time < 0:
            raise ValueError("Timer start time has to be positive or zero.")
        self._per_training = per_training
        self._prevent_reinitialize = prevent_reinit
        self._timeout_as_failure = timeout_as_failure

        self._paused = False
        self._pause_timestamp = None

        self._start_timestamp = None
        self._stopped = None

    def training_failed(self):
        return False if not self._timeout_as_failure else self._stopped is True

    def shall_reinitialize(self, training_failed=None):
        if not self._prevent_reinitialize or self._stopped is None:
            return None
        else:
            return False if self._stopped is True else None

    def on_train_begin(self, logs=None):
        if self._start_timestamp is None or self._per_training:
            self._start_timestamp = time.time() - self._start_time
            self._pause_timestamp = None
            self._paused = False
            self._stopped = False

    def on_epoch_end(self, epoch, logs=None):
        if ((self._paused and
             self._pause_timestamp - self._start_timestamp > self._max_time) or
            (not self._paused and
             time.time() - self._start_timestamp > self._max_time)):
            self.model.stop_training = True
            self._stopped = True

    def pause(self):
        if not self._paused:
            self._pause_timestamp = time.time()
            self._paused = True

    def resume(self):
        if self._paused:
            interval = time.time() - self._pause_timestamp
            if self._start_timestamp is not None:
                self._start_timestamp += interval
            self._paused = False

    @staticmethod
    def parse(tree, item_cache):
        return parser.try_whole_obj_parse_process(tree, item_cache,
                                                  KerasStopTimer)


main_register.append_register(KerasStopTimer, "keras_stoptimer", "kstoptimer")


class KerasLearningRateSchedulerExponential(KerasCallbackWrapper):
    arguments = parset.ClassArguments(
        "KerasLearningRateSchedulerExponential", KerasCallbackWrapper.arguments,
        ("decay_rate", False, None, float,
         "after decay_step epochs, the learning rate is old_lr * decay_rate."),
        ("decay_step", False, None, int,
         "after this many steps, the learning rate shall decay with decay_rate"),
        order=["decay_rate", "decay_step", "id"]
    )

    def __init__(self, decay_rate, decay_step, id=None):
        KerasCallbackWrapper.__init__(self, id=id)
        self._decay_coefficient = math.pow(decay_rate, 1./decay_step)


    def _setup(self, *args, **kwargs):
        def schedule(epoch, lr):
            new = lr * self._decay_coefficient
            return new

        self._internal_callback = keras.callbacks.LearningRateScheduler(
            schedule)

    @staticmethod
    def parse(tree, item_cache):
        return parser.try_whole_obj_parse_process(
            tree, item_cache, KerasLearningRateSchedulerExponential)

    def supports_reinforcement_learning(self):
        return True

main_register.append_register(KerasLearningRateSchedulerExponential, "keras_learning_rate_scheduler_exponential")


class KerasLearningRateSchedulerCycle(KerasLearningRateSchedulerExponential):
    arguments = parset.ClassArguments(
        "KerasLearningRateSchedulerExponential",
        KerasLearningRateSchedulerExponential.arguments,
        ("min_lr", False, None, float, "minimum start learning rate"),
        ("max_lr", False, None, float, "maximum start learning rate"),
        ("cycle_length", False, None, int,
         "number of steps from min to max to min lr again"),
        ("decay_rate", True, 1, float,
         "after decay_step epochs, the learning rate is old_lr * decay_rate."),
        ("decay_step", True, 1, int,
         "after this many steps, the learning rate shall decay with decay_rate"),
        order=["min_lr", "max_lr", "cycle_length",
               "decay_rate", "decay_step", "id"]
    )

    def __init__(self, min_lr, max_lr, cycle_length, decay_rate=1, decay_step=1,
                 id=None):
        KerasLearningRateSchedulerExponential.__init__(
            self, decay_rate=decay_rate, decay_step=decay_step, id=id)
        self._min_lr = min_lr
        self._max_lr = max_lr
        self._cycle_length = cycle_length
        self._half_cycle = self._cycle_length / 2.0

    def _setup(self, *args, **kwargs):
        def schedule(epoch, lr):
            if epoch != 0:
                self._min_lr *= self._decay_coefficient
                self._max_lr *= self._decay_coefficient

            phase = epoch % self._cycle_length
            if phase > self._half_cycle:
                phase = self._cycle_length - phase

            return (self._min_lr +
                    (self._max_lr - self._min_lr) * (phase/self._half_cycle))


        self._internal_callback = keras.callbacks.LearningRateScheduler(
            schedule)

    @staticmethod
    def parse(tree, item_cache):
        return parser.try_whole_obj_parse_process(
            tree, item_cache, KerasLearningRateSchedulerCycle)

    def supports_reinforcement_learning(self):
        return True


main_register.append_register(KerasLearningRateSchedulerCycle, "keras_learning_rate_scheduler_cycle")


class KerasModelCheckpoint(KerasCallbackWrapper):
    arguments = parset.ClassArguments(
        'KerasModelCheckpoint', KerasCallbackWrapper.arguments,
        ("monitor", False, None, str, "Metric of the network to monitor"),
        ("filepath", True, None, str, "Path to the temporary file for the "
                                      "model checkpoint (or None)"),
        ("mode", True, "auto", str, "Mode of the ModelCheckpoint Callback"),
        ("period", True, 1, int, "In which periods the model shall be checked"),
        ("save_best_only", True, True, "Store only the best network"),
        ("verbose", True, 0, int, "Verbosity level"),
        order=["monitor", "filepath", "mode", "period", "save_best_only",
               "verbose", "id"])

    def __init__(self, monitor, filepath=None, mode="auto", period=1,
                 save_best_only=True, verbose=0, id=None):
        KerasCallbackWrapper.__init__(self, id=id)
        self._monitor = monitor
        self._filepath = filepath
        self._mode = mode
        self._period = period
        self._save_best_only = save_best_only
        self._verbose = verbose

    def setup(self, network, *args, **kwargs):
        if self._filepath is None:
            self._filepath = (network.get_store_path() + "."
                              + str(misc.get_rnd_suffix()) + ".ckp")

        if os.path.exists(self._filepath):
            # storing format changed with keras version
            if os.path.isfile(self._filepath):			
                os.remove(self._filepath)
            else:
                shutil.rmtree(self._filepath)
        KerasCallbackWrapper.setup(self, network, *args, **kwargs)

    def _setup(self, *args, **kwargs):
        self._internal_callback = ModelCheckpoint(
            filepath=self._filepath, monitor=self._monitor,
            verbose=self._verbose,
            save_best_only=self._save_best_only, save_weights_only=False,
            mode=self._mode,
            period=self._period)

    def finalize(self, network, *args, **kwargs):
        if os.path.exists(self._filepath):
            # storing format changed with keras version
            network._model = keras.models.load_model(
                self._filepath, custom_objects=keras_layers.CUSTOM_LAYERS)
            if os.path.isfile(self._filepath):			
                os.remove(self._filepath)
            else:
                shutil.rmtree(self._filepath)

    @staticmethod
    def parse(tree, item_cache):
        return parser.try_whole_obj_parse_process(tree, item_cache,
                                                  KerasModelCheckpoint)


main_register.append_register(KerasModelCheckpoint, "keras_model_checkpoint")


class KerasEarlyStopping(KerasCallbackWrapper):
    arguments = parset.ClassArguments(
        'KerasEarlyStopping', KerasCallbackWrapper.arguments,
        ("monitor", False, None, str, "Metric of the network to monitor"),
        ("min_delta", True, 0, float, "Minimum required improvement"),
        ("patience", True, 0, int,
         "Maximum number of epochs to wait without improvements"),
        ("mode", True, "auto", str, "min/max/auto the monitored value"),
        ("verbose", True, 0, int, "Verbosity level"),
        order=["monitor", "min_delta", "patience", "mode", "verbose", "id"])

    def __init__(self, monitor, min_delta=0, patience=0, mode="auto",
                 verbose=0, id=None):
        KerasCallbackWrapper.__init__(self, id=id)
        self._monitor = monitor
        self._min_delta = min_delta
        self._patience = patience
        self._mode = mode
        self._verbose = verbose
        self._internal_callback = None

    def _setup(self, *args, **kwargs):
        self._internal_callback = EarlyStopping(
            monitor=self._monitor, min_delta=self._min_delta,
            patience=self._patience, verbose=self._verbose,
            mode=self._mode)

    def finalize(self, network, *args, **kwargs):
        pass

    @staticmethod
    def parse(tree, item_cache):
        return parser.try_whole_obj_parse_process(tree, item_cache,
                                                  KerasEarlyStopping)


main_register.append_register(KerasEarlyStopping, "keras_early_stopping")


class BaseKerasConditionExecutor(BaseKerasCallback):
    arguments = parset.ClassArguments(
        'BaseKerasConditionExecutor', BaseKerasCallback.arguments,
        ("metric", True, "loss", str, "Metric of the network to monitor"),
        ("min_time", True, None, int, "Enable store after at least min_time "
                                    "seconds have passed."),
        ("every_x_time", True, None, int, "Enable storing earliest every x time"),
        ("max_time", True, None, int, "Enable store after at most max_time "
                                   "seconds have passed."),
        ("min_epochs", True, None, int, "Enable store after at least min_epochs "
                                   "epochs have passed."),
        ("every_x_epochs", True, None, int, "Enable storing earliest every x epochs"),
        ("max_epochs", True, None, int, "Enable store after at least max_epochs "
                                   "epochs have passed."),
        ("top_x", True, None, int, "Enable storing only if metric is in top X"
                                 "of seen models."),
        ("threshold", True, None, float, "Enable storing if metric is better "
                                         "than threshold. Default: fewer is "
                                         "better."),
        ("minimize_metric", True, True, parser.convert_bool,
         "Metric has to be minimized (if false it has to be maximized)"),
        ("step", True, "batch_end", str, "When to execute the check"),
        ("flags", True, None, str,
         "Flags which will be set once the condition is met and unset on the "
         "next check."),
        order=["metric", "min_time", "every_x_time", "max_time", "min_epochs",
               "every_x_epochs", "max_epochs",
               "top_x", "threshold",
               "minimize_metric", "step", "flags", "id"])

    flags = property(lambda self: (set(self._flag_names)
                                   if self._flags_active else set()))
    active = property(lambda self: self._flags_active)
    iter = property(lambda self: self._iter)
    def __init__(self, metric="loss", min_time=None, every_x_time=None,
                 max_time=None, min_epochs=None, every_x_epochs=None,
                 max_epochs=None, top_x=None,
                 threshold=None, minimize_metric=True, step="batch_end",
                 flags=None, id=None):
        BaseKerasCallback.__init__(self, id)

        self._network = None
        self._metric = metric
        self._min_time = min_time
        self._every_x_time = every_x_time
        self._max_time = max_time
        self._min_epochs = min_epochs
        self._every_x_epochs = every_x_epochs
        self._max_epochs = max_epochs
        self._top_x = top_x
        self._threshold = threshold
        self._minimize = minimize_metric
        self._step = step
        assert step in ["batch_begin", "batch_end", "epoch_begin", "epoch_end",
                        "train_begin", "train_end"]
        self._flag_names = ([] if flags is None else
                            (flags if isinstance(flags, list) else [flags]))
        self._flags_active = False
        self._tops = []
        self._initial_time = time.time()
        self._last_execute_epoch = 0 if self._every_x_epochs is None else -self._every_x_epochs
        self._last_execute_time = self._initial_time if self._min_time is None else (self._initial_time - self._min_time)
        self._iter = 0

        self._last_epoch_end_log = None
        self.check_condition()

    def setup(self, network, *args, **kwargs):
        self._network = network

    def check_condition(self, logs=None):
        logs = {} if logs is None else logs
        m = logs.get(self._metric)
        curr_time = time.time()
        x = (  # Time based conditions
             (self._min_time is None or
              curr_time - self._initial_time >= self._min_time) and
             (self._every_x_time is None or
              curr_time - self._last_execute_time >= self._every_x_time) and
             (self._max_time is None or
              curr_time - self._initial_time <= self._max_time) and
             # Iteration based conditions
             (self._min_epochs is None or
              self._iter >= self._min_epochs) and
             (self._every_x_epochs is None or
              self._iter - self._last_execute_epoch >= self._every_x_epochs) and
             (self._max_epochs is None or
              self._iter <= self._max_epochs) and
             # Other conditions
             (self._top_x is None or len(self._tops) < self._top_x or
              (m is not None and m < self._top_x[-1])) and
             (self._threshold is None or
             (m is not None and self._minimize and m <= self._threshold) or
             (m is not None and not self._minimize and m >= self._threshold)))
        self._flags_active = x
        return x

    def _update_state(self, logs):
        m = logs.get(self._metric)
        if self._flags_active:
            self._last_execute_epoch = self._iter
            self._last_execute_time = time.time()
            if self._top_x is not None:
                if len(self._tops) < self._top_x:
                    self._tops.append(m)
                else:
                    self._tops[-1] = m
                self._tops = sorted(self.tops)
        self._iter += 1

    def _execute_satisfied(self, logs):
        pass

    def _execute_unsatisfied(self, logs):
        pass

    def _run(self, logs):
        self.check_condition(logs)
        self._update_state(logs)

        if self._flags_active:
            self._execute_satisfied(logs)
        else:
            self._execute_unsatisfied(logs)

    def on_epoch_begin(self, epoch, logs=None):
        if self._step == "epoch_begin":
            self._run(logs)

    def on_epoch_end(self, epoch, logs=None):
        if self._step == "epoch_end":
            self._run(logs)
        self._last_epoch_end_log = logs

    def on_batch_begin(self, batch, logs=None):
        if self._step == "batch_begin":
            self._run(logs)

    def on_batch_end(self, batch, logs=None):
        if self._step == "batch_end":
            self._run(logs)

    def on_train_begin(self, logs=None):
        if self._step == "train_begin":
            self._run(logs)

    def on_train_end(self, logs=None):
        if self._step == "train_end":
            # Because Keras does not provide any log
            self._run(self._last_epoch_end_log)


class KerasModelSaver(BaseKerasConditionExecutor):
    arguments = parset.ClassArguments(
        'KerasModelSaver', BaseKerasConditionExecutor.arguments,
        ("add_model_indices", True, False, parser.convert_bool,
         "Stores every model in a new file with an index appended to its name"),
        order=["metric", "min_time", "every_x_time", "max_time", "min_epochs",
               "every_x_epochs", "max_epochs",
               "top_x", "threshold",
               "minimize_metric", "step", "flags", "add_model_indices", "id"]
    )

    def __init__(self, metric="loss", min_time=None, every_x_time=None,
                 max_time=None, min_epochs=None, every_x_epochs=None,
                 max_epochs=None, top_x=None,
                 threshold=None, minimize_metric=True, step="batch_end",
                 flags=None, add_model_indices=None, id=None):
        BaseKerasConditionExecutor.__init__(
            self, metric=metric,
            min_time=min_time, every_x_time=every_x_time, max_time=max_time,
            min_epochs=min_epochs, every_x_epochs=every_x_epochs,
            max_epochs=max_epochs,
            top_x=top_x, threshold=threshold,
            minimize_metric=minimize_metric, step=step, flags=flags, id=id
        )
        self.add_model_indices = add_model_indices
        self._counter = 0

    def _execute_satisfied(self, logs):
        self._network.store(path_suffix=(str(self._counter)
                                         if self.add_model_indices else None))
        self._counter += 1

    def supports_reinforcement_learning(self):
        return True

    @staticmethod
    def parse(tree, item_cache):
        return parser.try_whole_obj_parse_process(tree, item_cache,
                                                  KerasModelSaver)


main_register.append_register(KerasModelSaver, "keras_model_saver")


class KerasConditionCounter(BaseKerasConditionExecutor):
    arguments = parset.ClassArguments(
        'KerasConditionCounter', BaseKerasConditionExecutor.arguments)

    def __init__(self, metric="loss", min_time=None, every_x_time=None,
                 max_time=None, min_epochs=None, every_x_epochs=None,
                 max_epochs=None, top_x=None,
                 threshold=None, minimize_metric=True, step="batch_end",
                 flags=None, callback=None, id=None):
        BaseKerasConditionExecutor.__init__(
            self, metric=metric,
            min_time=min_time, every_x_time=every_x_time, max_time=max_time,
            min_epochs=min_epochs, every_x_epochs=every_x_epochs,
            max_epochs=max_epochs,
            top_x=top_x, threshold=threshold,
            minimize_metric=minimize_metric, step=step, flags=flags, id=id
        )
        self.callback = callback

        self._total_satisfied = 0
        self._consecutively_satisfied = 0
        self._last_time_satisfied = False

    def call_callback(self, logs):
        if self.callback is not None:
            self.callback(
                keras_callback=self,
                total_satisfied=self._total_satisfied,
                consecutively_satisfied=self._consecutively_satisfied,
                is_satisfied=self._last_time_satisfied,
                logs=logs
            )

    def _execute_satisfied(self, logs):
        self._total_satisfied += 1
        self._consecutively_satisfied += 1
        self._last_time_satisfied = True
        self.call_callback(logs)

    def _execute_unsatisfied(self, logs):
        self._consecutively_satisfied = 0
        self._last_time_satisfied = False
        self.call_callback(logs)

    def supports_reinforcement_learning(self):
        return True

    @staticmethod
    def parse(tree, item_cache):
        return parser.try_whole_obj_parse_process(tree, item_cache,
                                                  KerasConditionCounter)


main_register.append_register(KerasConditionCounter, "keras_condition_counter")
