from abc import ABCMeta

from . import keras_callbacks, keras_layers, keras_losses
from .keras_tools import KerasDataGenerator, store_keras_model_as_protobuf, \
    QueueDataGenerator, ReplayQueueDataGenerator,\
    PrioritizedReplayQueueDataGenerator

from .. import Network, LearnerFormat, TrainingOutcome

from ... import parser_tools as parset
from ... import parser
from ... import main_register

from ...bridges import StateFormat
from ...misc import similarities, InvalidModuleImplementation
from ...parser_tools import ArgumentException

from collections import defaultdict
from enum import Enum
import json
import keras
from keras import layers
import math
import matplotlib as mpl
mpl.use('agg')
import matplotlib.pyplot as plt
import numpy as np
import os
import re
from sklearn.metrics import mean_squared_error, mean_absolute_error
import time

# List of formats in which the matplotlib figures shall be stored
MATPLOTLIB_OUTPUT_FORMATS = ["png"]
COLOR_DATA_MEAN = "g"
ALPHA_DATA_MEAN = 0.7
COLORMAP = "viridis"

BN_OFF = 0
BN_PRE_ACTIVATION = 1
BN_POST_ACTIVATION = 2


class StopTraining(Exception):
    pass


class RLBatchGenerator(Enum):
    Plain = "plain"
    Replay = "replay"
    PrioritizedReplay = "prioritized_replay"
    WeightedPrioritizedReplay = "weighted_prioritized_replay"


def type_optimizer(arg):
    if arg in ["adam", "sgd"]:
        return arg
    sgd = re.match(r"^sgd\((.*)\)$", arg)
    if sgd:
        opts = [x.strip().split("=") for x in sgd.group(1).split(",") if x.strip() != ""]
        opts = {k:v for k, v in opts}
        for k, v in opts.items():
            if v.lower() == "true":
                opts[k] = True
            elif v.lower() == "false":
                opts[k] = False
            elif re.match(r"^-?\d+$", v):
                opts[k] = int(v)
            elif re.match(r"^-?\d+\.\d+$", v):
                opts[k] = float(v)
            else:
                assert False
        return keras.optimizers.sgd(**opts)
    assert False, "Invalid optimizer"

class KerasTrainingParameters(object):
    arguments = parset.ClassArguments(
        "KerasTrainingParameters", None,
        ("epochs", True, 1000, int, "Number of training epochs"),
        ("loss", True, "mean_squared_error", keras_losses.get,
         "name of the loss function to use"),
        ("batch", True, 100, int, "Batch size to use"),
        ("balance", True, False, parser.convert_bool, "Balance classes"),
        ("optimizer", True, "adam", type_optimizer, "Name of the optimizer to use"),
        ("metrics", True, None, str, "Single or list of metrics to track"),
        ("callbacks", True, None,
         parset.main_register.get_register(keras_callbacks.BaseKerasCallback),
         "Definition of a single or a list of callback functions for keras."),
        ("epoch_verbosity", True, 1, int,
         "RL: Every epoch_verbosity output will be given"),
        ("rl_batch_generator", True, RLBatchGenerator.Plain, RLBatchGenerator,
         "RL: how to generate the batches (take data as is, use replay"
         "experience)"),
        ('id', True, None, str))

    def __init__(self, epochs=1000, loss="mean_squared_error", batch=100,
                 balance=False, optimizer="adam", metrics=None,
                 callbacks=None, epoch_verbosity=1,
                 rl_batch_generator=RLBatchGenerator.Plain, id=None):
        self.epochs = epochs
        self.loss = loss
        self.batch = batch
        self.balance = balance
        self.optimizer = optimizer
        self.metrics = [] if metrics is None else metrics
        if not isinstance(self.metrics, list):
            self.metrics = [self.metrics]
        self.callbacks = callbacks
        if self.callbacks is not None and not isinstance(self.callbacks, list):
            self.callbacks = [self.callbacks]
        self.epoch_verbosity = epoch_verbosity
        self.rl_batch_generator = rl_batch_generator
        self.id = id

    @staticmethod
    def parse(tree, item_cache):
        return parser.try_whole_obj_parse_process(tree, item_cache,
                                                  KerasTrainingParameters)


main_register.append_register(KerasTrainingParameters,
                              "keras_training_parameters", "ktparams")


class KerasNetwork(Network):
    __metaclass__ = ABCMeta
    arguments = parset.ClassArguments(
        'KerasNetwork', Network.arguments,
        ("tparams", True, None,
         parset.main_register.get_register(KerasTrainingParameters),
         "Training parameters"),
        ("count_samples", True, False, parser.convert_bool,
         "Counts how many and which samples where used during training"
         " (This increases the runtime)."),
        ("test_similarity", True, None, similarities.get_similarity,
         "Estimates for every sample in the evaluation data how close it is"
         " to the trainings data. For this to work ALL training data is kept in"
         "memory, this could need a lot of memory"
         " (provide name of a similarity measure)"),
        ("graphdef", "True", None, str,
         "Name for an ASCII GraphDef file of the stored Protobuf model (only "
         "applicable if Protobuf model is stored)"),

        order=["tparams", "load", "store", "learner_formats",
               "out", "count_samples",
               "test_similarity", "graphdef", "variables", "id"])

    def __init__(self, tparams=None, load=None, store=None, formats=None,
                 out=".",
                 count_samples=False, test_similarity=None,
                 graphdef=None,
                 variables=None, id=None):
        Network.__init__(self, load, store, formats, out, variables, id)
        self.training_params = (KerasTrainingParameters() if tparams is None
                                else tparams)
        self._model = None

        """Analysis data"""
        self._history = None
        self._histories = []
        self._evaluation = None
        self._evaluations = []
        self._count_samples = count_samples
        self._count_samples_hashes = set()  # Used if _count_samples is True
        test_similarity = (similarities.get_similarity(test_similarity)
                           if isinstance(test_similarity, str)
                           else test_similarity)
        self._test_similarity = test_similarity
        # If _test_similarity is not None, add training data
        self._training_data = set()

        """Concrete networks might need to set those values"""
        """(check existing implementations)"""
        # callable which gets a SampleBatchData
        self._x_fields_extractor = None
        # object and returns an ordered list of fields (indices) to use for
        # the X/Y data
        self._y_fields_extractor = None
        # Callable which converts extracted batch of X fields to the format
        # needed by the network
        self._x_encoder = None
        # Similar to x_converter
        self._y_encoder = None
        # Callable which converts network y format into a meaningful format
        # (e.g. one hot to number)
        self._y_decoder = None

        # Needed if _count_samples is True. Defines how to calculate a hash from
        # a sample given as x and y value (where X and Y are the fields
        # extracted by the KerasDataGenerator). Example: lambda x,y: str((x,y))
        self._count_samples_hasher = None
        # Needed for test_similarity (if test_similarity is never used, this is
        # never used). Define for every field extracted from _x_fields_extractor
        # how the this field of two instances shall be compared. Take a look at
        # misc.data_similarities. E.g. provide for every field a callable or for
        # example the Hamming Measure has already two comparators defined which
        # can be used via providing their string name ("equal", "iterable")
        # Providing here None means the measure uses its default comparator for
        # all fields
        self._x_fields_comparators = None

        # Variables for storing Protobuf (ignore if you do not support storing
        # Protobuf files, but in general keras networks can be converted)

        # Tensorflow quantize feature
        self._quantize = False
        # The model used theano as backend
        self._theano = False
        # Number of output PATHS of the network (not output nodes)
        self._num_outputs = 1
        # Prefix before every output path
        self._prefix_outputs = ""
        # Store GraphDef of Tensorflow Graph if file name is given
        self._store_graphdef = graphdef

    def initialize(self, *args, **kwargs):
        Network.initialize(self, *args, **kwargs)
        # Check for a valid initialization of all required fields
        if self._count_samples:
            if self._count_samples_hasher is None:
                raise InvalidModuleImplementation(
                    "If allowing the 'count_samples' option in a KerasNetwork, "
                    "the network implementation has to set the parameter"
                    "'_count_samples_hasher'")

    def _compile(self):
        self._model.compile(optimizer=self.training_params.optimizer,
                            loss=self.training_params.loss,
                            metrics=self.training_params.metrics)

    def get_default_format(self):
        return LearnerFormat.hdf5

    def _get_store_formats(self):
        return {LearnerFormat.hdf5, LearnerFormat.protobuf, LearnerFormat.flag}

    def _get_load_formats(self):
        return {LearnerFormat.hdf5}

    def get_preferred_state_formats(self):
        return [StateFormat.All_A_01]

    def _load(self, path, learner_format, *args, **kwargs):
        if learner_format == LearnerFormat.hdf5:
            self._model = keras.models.load_model(
                path, custom_objects=keras_layers.CUSTOM_LAYERS)
        else:
            raise ValueError("Keras cannot load a network model from the "
                             "format: " + str(learner_format))

    def _store(self, path, learner_formats):
        for learner_format in learner_formats:
            if learner_format is None:
                learner_format = self._get_default_network_format()

            path_format = path + "." + learner_format.suffix[0]
            if learner_format == LearnerFormat.hdf5:
                self._model.save(path_format)
            elif learner_format == LearnerFormat.protobuf:
                # graphdef = (None if self._store_graphdef is None else
                #            os.path.join(self.path_out, self._store_graphdef))
                store_keras_model_as_protobuf(
                    self._model, os.path.dirname(path),
                    os.path.basename(path_format),
                    quantize=self._quantize, theano=self._theano,
                    num_outputs=self._num_outputs,
                    prefix_outputs=self._prefix_outputs,
                    store_graphdef=path + ".graphdef"  # self._store_graphdef
                )
            elif learner_format == LearnerFormat.flag:
                with open(path_format, "w") as _:
                    pass

    def _callbacks_setup(self):
        if self.training_params.callbacks is not None:
            for cb in self.training_params.callbacks:
                cb.setup(self)

    def _callbacks_finalize(self):
        if self.training_params.callbacks is not None:
            for cb in self.training_params.callbacks:
                cb.finalize(self)

    def _callbacks_check(self, extractor, merger, init):
        """
        Checks the value of a property on the callback functions
        :param extractor: callable to extract property from single callback
        :param merger: callable to merge the old merge value with the new
                       extracted property value
        :param init: initial value for the merge process
        :return: merged value
        """
        if self.training_params.callbacks is not None:
            for cb in self.training_params.callbacks:
                value = extractor(cb)
                init = merger(init, value)
        return init

    def _calculate_class_weights(self, dtrain):
        class_weights = None
        if self.training_params.balance:
            class_weights = {}
            for data in dtrain:
                fields = self._y_fields_extractor(data)
                if len(fields) != 1:
                    raise ValueError("Cannot extract class weights for keras "
                                     "network with more than 1 output label")

                def update(entry):
                    y = entry[fields[0]]
                    if y not in class_weights:
                        class_weights[y] = 0
                    class_weights[y] += 1

                data.over_all(update)
            for k, v in class_weights.items():
                class_weights[k] = 1.0 / class_weights[k]
        print("Class Weights: ", class_weights)
        return class_weights

    def train(self, dtrain, dvalid=None):
        """
        The given data is first converted into the format needed for this
        network and then the SampleBatchData objects are finalized. If your
        KerasNetwork subclass needs a different conversion than the default
        given by this class, define in your subclass a staticmethod
        _convert_data(DATA).
        :param dtrain: List of SampleBatchData for training
        :param dvalid: List of SampleBatchData for testing
        :return:
        """

        dtrain = self._convert_data(dtrain)
        class_weights = self._calculate_class_weights(dtrain)

        kdg_train = KerasDataGenerator(
            dtrain,
            batch_size=self.training_params.batch,
            x_fields=(None if self._x_fields_extractor is None
                      else self._x_fields_extractor(dtrain[0])),
            y_fields=(None if self._y_fields_extractor is None
                      else self._y_fields_extractor(dtrain[0])),
            x_converter=self._x_encoder,
            y_converter=self._y_encoder,
            shuffle=True,
            count_diff_samples=(self._count_samples_hasher
                                if self._count_samples
                                else None),
            class_weights=class_weights
        )

        kdg_valid = None
        if dvalid is not None:
            dvalid = self._convert_data(dvalid)
            kdg_valid = KerasDataGenerator(
                dvalid,
                x_fields=(None if self._x_fields_extractor is None
                          else self._x_fields_extractor(dvalid[0])),
                y_fields=(None if self._y_fields_extractor is None
                          else self._y_fields_extractor(dvalid[0])),
                x_converter=self._x_encoder,
                y_converter=self._y_encoder,
                shuffle=True)

        while True:
            self._callbacks_setup()
            kdg_train.reset()

            history = self._model.fit_generator(
                kdg_train,
                epochs=self.training_params.epochs,
                verbose=2, callbacks=self.training_params.callbacks,
                validation_data=kdg_valid,
                validation_steps=None,
                max_queue_size=10, workers=1,
                use_multiprocessing=False,
                shuffle=True, initial_epoch=0)

            training_failed = self._callbacks_check(
                lambda x: x.training_failed(),
                lambda x, y: x or y,
                False
            )
            shall_reinitialize = self._callbacks_check(
                lambda x: x.shall_reinitialize(training_failed),
                lambda x, y: (y if x is None
                              else (x if y is None else (x and y))),
                None)

            self._callbacks_finalize()

            if shall_reinitialize:
                self.reinitialize()
                continue
            else:
                self._count_samples_hashes.update(
                    kdg_train.generated_sample_hashes)
                break

        self._history = history
        self._histories.append(history)
        if self._test_similarity is not None:
            self._training_data.update(set(dtrain))

        training_outcome = (
            TrainingOutcome.Finished if not training_failed else
            (TrainingOutcome.Failed
             if history.history.get("val_acc", [0])[-1] == 0
             else TrainingOutcome.Aborted)
        )

        return {"history": history, "training_outcome": training_outcome}

    @staticmethod
    def _rl_get_batch_generator(
            data_queue, tp, max_time, fetch_timeout=120, buffer_size_factor=20):
        buffer_size = tp.batch * buffer_size_factor
        if tp.rl_batch_generator == RLBatchGenerator.Plain:
            return QueueDataGenerator(data_queue, tp.batch,
                                      fetch_timeout=fetch_timeout,
                                      time_limit=max_time)
        elif tp.rl_batch_generator == RLBatchGenerator.Replay:
            return ReplayQueueDataGenerator(data_queue, tp.batch, buffer_size,
                                            fetch_timeout=fetch_timeout,
                                            time_limit=max_time)
        elif tp.rl_batch_generator == RLBatchGenerator.PrioritizedReplay:
            return PrioritizedReplayQueueDataGenerator(
                data_queue, tp.batch, buffer_size, None,
                fetch_timeout=fetch_timeout,
                time_limit=max_time)
        elif tp.rl_batch_generator == \
                RLBatchGenerator.WeightedPrioritizedReplay:
            def beta_generator():
                i = 0
                sat = tp.epochs * 0.75
                while True:
                    yield min(i/sat, 1)
                    i += 1
            return PrioritizedReplayQueueDataGenerator(
                data_queue, tp.batch, buffer_size, beta_generator(),
                fetch_timeout=fetch_timeout,
                time_limit=max_time)
        else:
            assert False, "Invalid batch generator: %s" % tp.rl_batch_generator

    def train_reinforcement(
            self, data_queue, verbose_epoch=1,
            max_time=-1, additional_callbacks=None, callback_predictions=None,
            callback_flags=None,
            data_postprocessor=None,
            replay_buffer_size_factor=20,
            start_time=0, start_epoch=0,
    ):
        tp = self.training_params
        callbacks = ([] if additional_callbacks is None else
                     additional_callbacks) + tp.callbacks
        time_start_training = time.time() - start_time
        for cb in callbacks:
            if hasattr(cb, "setup"):
                cb.setup(self)

        assert not tp.balance, "balancing not supported for RL"
        assert all(cb.supports_reinforcement_learning() for cb in callbacks), \
          "At least one callback does not support reinforcement learning"

        def summarize_y_values(_y, precision=5):
            _y = _y.flatten()
            _y = [(p, c) for p, c in zip(*np.unique(
                [int(round(x)) for x in _y], return_counts=True))]
            return ("%i-%i" % (_y[0][0], _y[-1][0]) if len(_y) > precision else
                    ", ".join("%s:%s" % (p, c) for p, c in _y))

        batch_generator = self._rl_get_batch_generator(
            data_queue, tp, max_time,
            buffer_size_factor=replay_buffer_size_factor)
        for cb in callbacks:
            cb.setup(self)

        losses = []
        max_input_values = []
        max_predicted_values = []
        iteration_flags = defaultdict(list)
        time_avg_wait = 0
        time_last_wait = None
        time_wait = time.time()
        time_iter = time_wait
        iteration_shift = -start_epoch

        for iteration_batch, batch in enumerate(batch_generator, 1):
            if batch is None:
                print("Empty data batch")
                iteration_shift += 1
                time.sleep(5)
            else:
                iteration = iteration_batch - iteration_shift

                if batch_generator.has_importance_weights():
                    (x_data, y_data, u_data), sample_weights = batch
                else:
                    x_data, y_data, u_data = batch
                    sample_weights = None
                if data_postprocessor is not None:
                    x_data, y_data, sample_weights = data_postprocessor(
                        x_data, y_data, sample_weights)

                time_train = time.time()
                # HACK: Could not get the sample losses from keras
                # HACK: you the loss for the representation in the ifnal search,
                # should use the loss as calculaed by keras...
                yp_data = self._model.predict(x_data)
                yp_data = yp_data if self._y_decoder is None else \
                    self._y_decoder(yp_data)
                yp_data = yp_data.reshape(len(y_data))
                if callback_predictions is not None:
                    for cb in callback_predictions:
                        cb(y_data, yp_data)

                max_input_values.append(np.max(y_data))
                max_predicted_values.append(np.max(yp_data))

                y_data = (y_data if self._y_encoder is None else
                          self._y_encoder(y_data))
                hist = self._model.fit(
                    x_data, y_data,
                    batch_size=tp.batch,
                    epochs=iteration + 1,
                    verbose=0,
                    sample_weight=sample_weights,
                    callbacks=callbacks,
                    initial_epoch=iteration
                )
                loss = hist.history["loss"][-1]
                losses.append(loss)
                if batch_generator.has_priorities():
                    batch_generator.update_priorities((y_data - yp_data) ** 2)
                time_train = time.time() - time_train

                time_iter = time.time() - time_iter
                time_wait = time_iter - time_train
                time_last_wait = time_wait
                time_avg_wait = (time_avg_wait * (iteration - 1) + time_wait) / iteration
                time_curr_training = time.time() - time_start_training
                if iteration % tp.epoch_verbosity == 0:
                    input_values = summarize_y_values(y_data)
                    predicted_values = summarize_y_values(yp_data)
                    print("epoch: {iteration} - loss: {loss} - "
                          "inputs: {input_values} - "
                          "predictions: {predicted_values} - "
                          "time(total, train, wait, avg wait, total): "
                          "{time_iter:.2f}s, {time_train:.2f}s, {time_wait:.2f}s,"
                          "{time_avg_wait:.2f}s, {time_curr_training:.2f}s".
                          format(**locals()))

                time_iter = time.time()
                time_wait = time_iter

                if iteration > tp.epochs:
                    print("Training epochs finished.")
                    break

            flag_stop_training = False
            if callback_flags is not None:
                all_flags = set(["always"] +
                                ([] if batch is None else ["epoch"]))
                for cb in callbacks:
                    all_flags.update(cb.flags)
                for req_flag, cb in callback_flags:
                    if req_flag in all_flags:
                        try:
                            cb()
                        except StopTraining:
                            flag_stop_training = True
                if batch is not None:  # == epoch in all_flags:
                    for flag in all_flags:
                        if flag not in ["always", "epoch"]:
                            iteration_flags[flag].append(iteration)

            if flag_stop_training:
                print("Training stopped by flag.")
                break

            if max_time != -1 and time.time() - time_start_training > max_time:
                print("Training timeout.")
                break

        time_start_training = time.time() - time_start_training
        print("Total Training Time: {time_start_training:.2f}s".format(
            **locals()))
        print("Sample stats: %i (generated), %i (trained on)" %
              (batch_generator.samples_load, batch_generator.samples_requested))
        return {
            "loss": losses,
            "max_inputs": max_input_values,
            "max_predicted": max_predicted_values,
            "flags": iteration_flags
        }

    def evaluate(self, data):
        data = self._convert_data(data)
        # List in which the original y values for the predictions will be added
        y_labels = []
        # Triple for comparing the similarities between a sample to predict and
        # the used training data or None if no similarity shall be measured
        sample_similarities = None
        if self._test_similarity is not None:
            data_set_example = None
            for data_set_example in self._training_data:
                break
            assert data_set_example is not None

            wrapped_set_similarity = similarities.get_wrapper_similarity_on_set(
                self._test_similarity,
                (None if self._x_fields_extractor is None
                 else self._x_fields_extractor(data_set_example)),
                self._x_fields_comparators,
                merge=max, init_measure_value=0, early_stopping=lambda x: x == 1
            )
            sample_similarities = ([], self._training_data,
                                   wrapped_set_similarity)

        kdg_eval = KerasDataGenerator(
            data,
            x_fields=(None if self._x_fields_extractor is None
                      else self._x_fields_extractor(data[0])),
            y_fields=(None if self._y_fields_extractor is None
                      else self._y_fields_extractor(data[0])),
            x_converter=self._x_encoder,
            y_converter=self._y_encoder,
            y_remember=y_labels,
            similarity=sample_similarities)

        result = self._model.predict_generator(
            kdg_eval, max_queue_size=10, workers=1, use_multiprocessing=False)
        y_labels = np.concatenate(y_labels)
        if self._y_decoder is not None:
            y_labels = self._y_decoder(y_labels)
            result = self._y_decoder(result)

        if sample_similarities is not None:
            sample_similarities = np.concatenate(sample_similarities[0])

        if len(y_labels.shape) == 2 and y_labels.shape[1] == 1:
            y_labels = y_labels.squeeze(axis=1)
            result = result.squeeze(axis=1)
        # HACK: Sometimes, keras takes more batches than it uses
        result = (result, y_labels[:len(result)], sample_similarities)

        self._evaluation = result
        self._evaluations.append(result)
        return result

    """----------------------DATA PARSING METHODS----------------------------"""
    def _convert_data(self, data):
        """
        The given data is first converted into the format needed for this
        network and then the SampleBatchData objects are finalized. If your
        KerasNetwork subclass needs a different conversion than the default
        given by this class, define in your subclass a staticmethod
        _convert_data(DATA).
        :param data:
        :return:
        """
        data = data if isinstance(data, list) else [data]
        for data_set in data:
            if data_set.is_finalized:
                print("Warning: Data set previously finalized. Skipping now.")
                continue
            data_set.finalize()
        return data

    """-------------------------ANALYSE PREDICTIONS -------------------------"""
    def _analyse(self, directory, prefix):
        KerasNetwork.static_analyse(
            directory, prefix, self._model, self._evaluations,
            [h.history for h in self._histories],
            self._test_similarity, len(self._count_samples_hashes))

    @staticmethod
    def analyse_from_paths(trg_directory, prefix, paths):
        history = None
        evaluations = [[[], [], None]]  # Keep only last one
        count_samples = 0  # Add
        # state_space_size = None
        # reachable_ssss = None
        model = None
        for path in paths:
            with open(path, "r") as f:
                data = json.load(f)
                if model is None:
                    model = data["model"]
                else:
                    # assert model == data["model"]
                    pass

                last_evaluation = data["evaluations"][-1]
                evaluations[0][0].extend(last_evaluation[0])
                evaluations[0][1].extend(last_evaluation[1])
                count_samples += data["count_samples"]

        for evals in evaluations:
            for idx in range(2):
                evals[idx] = np.array(evals[idx])

        KerasNetwork.static_analyse(
            trg_directory, prefix, model, evaluations, history, None,
            count_samples)

    @staticmethod
    def static_analyse(directory, prefix, model, evaluations, histories,
                       test_similarity, sample_count):
        evaluation = evaluations[-1]
        history = (None if histories is None or len(histories) == 0
                   else histories[-1])

        pred_value, true_label, sample_similarity = evaluation

        print("Evaluation MSE:", mean_squared_error(pred_value, true_label))
        print("Evaluation MAE:", mean_absolute_error(pred_value, true_label))

        if len(true_label.shape) == 1:
            pass
        elif len(true_label.shape) == 2:
            pred_value = np.argmax(pred_value, axis=1)
            true_label = np.argmax(true_label, axis=1)
        else:
            assert False

        # pred_label = np.round(pred_value)
        mean_true_label = true_label.mean()
        predict_mean_correct = (np.round(mean_true_label) == true_label).sum()
        # predicted_correct = (pred_label == true_label).sum()

        if history is not None:
            accuracy = "acc" if "acc" in history else "accuracy"
            val_accuracy = "val_acc" if "val_acc" in history else "val_accuracy"
            KerasNetwork._analyse_from_history_plot(
                history, [accuracy, val_accuracy], "Model Accuracy", "accuracy",
                "epoch", ['train', 'test'], prefix + "evolution_accuracy",
                directory,
                hline=(float(predict_mean_correct)/true_label.size))

            KerasNetwork._analyse_from_history_plot(
                history, ['loss', 'val_loss'], "Model Loss", "loss", "epoch",
                ['train', 'val'], prefix + "evolution_loss", directory)

        KerasNetwork._analyse_from_predictions_scatter(
            pred_value, true_label, "Predictions",
            "original h", "predicted h", prefix + "predictions_scatter",
            directory)

        KerasNetwork._analyse_from_predictions_scatter_tiles(
            pred_value, true_label,
            "Prediction Probabilities with resp. to the correct prediction",
            "original h", "predicted h", prefix + "predictions_tiles",
            directory)

        KerasNetwork._analyse_from_predictions_deviation(
            pred_value, true_label, "Prediction Deviations",
            "count", "deviation", prefix + "deviations", directory,
            diff_mean_to_prediction=True)

        KerasNetwork._analyse_from_predictions_deviation_dep_on_h(
            pred_value, true_label,
            "Prediction Deviations depending on original", "deviation",
            "original", prefix + "deviations_dep_h", directory,
            diff_mean_to_prediction=True)

        if test_similarity:
            KerasNetwork._analyse_from_predictions_deviation_dep_on_similarity(
                pred_value, true_label, sample_similarity,
                "Prediction Deviations depending on the similarity to training "
                "samples", "deviation", "similarity",
                prefix + "deviations_dep_sim", directory)

        # if domain_properties is None else domain_properties.combined_
        # state_space_size
        state_space_sizes = None
        # if domain_properties is None else domain_properties.
        # combined_reachable_state_space_upper_bound
        reachable_ssss = None
        '''
        try:
            KerasNetwork._analyse_misc(sample_count, state_space_sizes,
                                       reachable_ssss,
                                       prefix + "misc", directory)
        except ValueError:
            pass
        '''

        analysis_data = {
            "histories": histories,
            # "evaluations": [(e[0].tolist(), e[1].tolist(), e[2])
            # for e in evaluations],
            "count_samples": sample_count if sample_count else "NA",
            "state_space_size": state_space_sizes,
            "upper_reachable_state_space_bound": reachable_ssss}

        if isinstance(model, str):
            analysis_data["model"] = model
        else:
            analysis_data["model"] = ""

            def add_model_summary(x):
                analysis_data["model"] += x + "\n"
            model.summary(print_fn=add_model_summary)

        with open(os.path.join(directory, prefix + "analysis.meta"), "w") as f:
            json.dump(analysis_data, f)

    """----------------------ANALYSIS METHODS--------------------------------"""
    @staticmethod
    def _analyse_misc(samples_seen, state_space_sizes, reachable_ssss,
                      file_basename, directory=".",
                      image_formats=None):
        image_formats = (MATPLOTLIB_OUTPUT_FORMATS if image_formats is None
                         else image_formats)
        fig = plt.figure()
        if state_space_sizes is not None:
            sss_bar = [state_space_sizes]
            text = [str(state_space_sizes)]
            new_xticks = [None, "Full State Space"]
            if reachable_ssss is not None:
                sss_bar.append(reachable_ssss)
                text.append(str(reachable_ssss))
                new_xticks.append("Reachable State Space")
            for i in range(len(sss_bar)):
                text.append("%d (%.2f)" % (samples_seen,
                                           float(samples_seen)/sss_bar[i]))

            ax = fig.add_subplot(1, 1, 1)
            bars_full = ax.bar(np.arange(len(sss_bar)), sss_bar,
                               color='g', align='center')

            bars_seen = ax.bar(np.arange(len(sss_bar)),
                               [samples_seen] * len(sss_bar),
                               color='r', align='center',
                               label="Training Samples Count")

            i = -1
            for rect in bars_full + bars_seen:
                i += 1
                height = rect.get_height()
                plt.text(rect.get_x() + rect.get_width() / 2.0, height,
                         text[i], ha='center', va='bottom')

            ax.set_title("Seen Parts of State Spaces of all problems")
            ax.set_ylabel("samples")
            ax.set_yscale("log")
            ax.xaxis.set_major_locator(plt.MultipleLocator(1))
            ax.set_xticklabels(new_xticks)
            ax.legend()
            # ax.set_xlabel(xlabel)

        fig.tight_layout()
        for image_format in image_formats:
            fig.savefig(os.path.join(directory,
                                     file_basename + "." + image_format))
        plt.close(fig)

    @staticmethod
    def _analyse_from_history_plot(history, measures, title, ylabel, xlabel,
                                   legend, file_basename, directory=".",
                                   image_formats=None,
                                   hline=None):
        image_formats = (MATPLOTLIB_OUTPUT_FORMATS if image_formats is None
                         else image_formats)
        fig = plt.figure()
        ax = fig.add_subplot(1, 1, 1)
        for m in measures:
            ax.plot(history[m])
        ax.set_title(title)
        ax.set_ylabel(ylabel)
        ax.set_xlabel(xlabel)
        if hline is not None:
            ax.axhline(hline, color=COLOR_DATA_MEAN, alpha=ALPHA_DATA_MEAN)
            legend.append("Predicting Data Mean")
        ax.legend(legend, loc='upper left')
        fig.tight_layout()
        for image_format in image_formats:
            fig.savefig(os.path.join(directory,
                                     file_basename + "." + image_format))
        plt.close(fig)

    @staticmethod
    def _analyse_from_predictions_scatter(
            predicted, original, title, pred_label, orig_label, file_basename,
            directory=".",
            image_formats=None):
        image_formats = (MATPLOTLIB_OUTPUT_FORMATS if image_formats is None
                         else image_formats)
        print(predicted.shape)
        print(original.shape)
        fig = plt.figure()
        ax = fig.add_subplot(1, 1, 1)
        ax.scatter(original, predicted, s=80, c='maroon', alpha=0.1)

        unique = np.unique(predicted)
        if len(unique) == 1:
            yticks = ax.get_yticks().tolist()
            pivot = int(len(yticks)/2)
            mid = yticks[pivot]
            yticks = ["" for _ in yticks]
            yticks[pivot] = mid
            ax.set_yticklabels(yticks)

        ax.set_xlabel(orig_label)
        ax.set_ylabel(pred_label)
        ax.set_title(title)
        fig.tight_layout()
        for image_format in image_formats:
            fig.savefig(os.path.join(directory,
                                     file_basename + "." + image_format))
        plt.close(fig)

    @staticmethod
    def _analyse_from_predictions_scatter_tiles(
            predicted, original, title, pred_label, orig_label, file_basename,
            directory=".",
            image_formats=None):
        image_formats = (MATPLOTLIB_OUTPUT_FORMATS if image_formats is None
                         else image_formats)
        max_o_h = max(original)
        min_o_h = min(original)
        max_p_h = math.ceil(max(predicted))
        min_p_h = math.floor(min(predicted))

        by_h = {}
        for idx in range(len(predicted)):
            h = original[idx]
            p = predicted[idx]
            if h not in by_h:
                by_h[h] = []
            by_h[h].append(p)

        exponent = 1
        power = 2**exponent
        h_p_bins = []
        for i in range(int(min_p_h*power), int(max_p_h*power + 1)):
            h_p_bins.append(float(i)/power)

        tiles = np.ndarray(shape=(max_o_h - min_o_h + 1, len(h_p_bins)),
                           dtype=float)
        for h_o in range(min_o_h, max_o_h + 1):
            if h_o not in by_h:
                for idx_p in range(len(h_p_bins)):
                    tiles[h_o - min_o_h, idx_p] = float("NaN")
            else:
                ary = np.around(np.array(by_h[h_o]) * power) / power
                count = float(len(ary))
                unique, counts = np.unique(ary, return_counts=True)
                # Otherwise they are numpy values and do not hash as needed
                unique = [float(i) for i in unique]
                occurrences = dict(zip(unique, counts))
                for idx_p in range(len(h_p_bins)):
                    h_p = h_p_bins[idx_p]
                    if h_p not in occurrences:
                        tiles[h_o - min_o_h, idx_p] = 0
                    else:
                        tiles[h_o - min_o_h, idx_p] = occurrences[h_p] / count

        fig = plt.figure()
        ax = fig.add_subplot(1, 1, 1)
        handle = ax.matshow(
            tiles, cmap=COLORMAP, aspect=tiles.shape[1] / float(tiles.shape[0]),
            vmin=0.0, vmax=1.0)
        ax.set_xlabel(orig_label)
        ax.set_ylabel(pred_label)
        ax.set_title(title)
        cax = fig.add_axes([0.9, 0.1, 0.02, 0.8])
        fig.colorbar(handle, cax, orientation='vertical')
        xticks = ax.get_xticks().tolist()
        ax.set_xticklabels(
            [''] + [h_p_bins[int(i)] for i in xticks[1:-1]] + [''])
        yticks = ax.get_yticks().tolist()
        ax.set_yticklabels(
            [''] + [min_o_h + i for i in yticks[1:-1]] + [''])
        fig.tight_layout()
        for image_format in image_formats:
            fig.savefig(os.path.join(directory,
                                     file_basename + "." + image_format))
        plt.close(fig)

    @staticmethod
    def _analyse_from_predictions_deviation(
            predicted, original, title, ylabel, xlabel, file_basename,
            directory=".",
            image_formats=None,
            diff_mean_to_prediction=False):
        image_formats = (MATPLOTLIB_OUTPUT_FORMATS if image_formats is None
                         else image_formats)
        # (Legend, Color, Alpha, Data)
        deviations = [('Deviation of Predictions', 'b', 1,
                       predicted - original)]
        if diff_mean_to_prediction:
            mean = original.mean()
            deviations.insert(0, ("Deviation of Data Mean", COLOR_DATA_MEAN,
                                  ALPHA_DATA_MEAN, mean - original))
        fig = plt.figure()
        ax = fig.add_subplot(1, 1, 1)

        iteration = -1
        for (label, color, alpha, data) in deviations:
            iteration += 1
            width = 0.4  # 0.9 - 0.2 * iteration
            dev = np.round(data.astype(np.float))
            unique, counts = np.unique(dev, return_counts=True)
            # Otherwise they are numpy values and do not hash as needed
            unique = [float(i) for i in unique]
            occurrences = dict(zip(unique, counts))
            min_d, max_d = min(unique), max(unique)

            bars = (np.arange(min_d, max_d + 1) -
                    width*len(deviations)/2 +
                    iteration * width +
                    width/2)
            heights = [0 if i not in occurrences else occurrences[i]
                       for i in range(int(math.floor(min_d)),
                                      int(math.ceil(max_d) + 1))]
            ax.bar(bars, heights, width=width,
                   color=color, alpha=alpha, align='center', label=label)
        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)
        ax.set_title(title)
        if len(deviations) > 1:
            ax.legend()
        fig.tight_layout()
        for image_format in image_formats:
            fig.savefig(os.path.join(directory,
                                     file_basename + "." + image_format))
        plt.close(fig)

    @staticmethod
    def _analyse_from_predictions_deviation_dep_on_h(
            predicted, original, title, pred_label, orig_label,
            file_basename, directory=".",
            image_formats=None,
            diff_mean_to_prediction=False):
        image_formats = (MATPLOTLIB_OUTPUT_FORMATS if image_formats is None
                         else image_formats)
        by_h = {}
        min_h = min(original)
        max_h = max(original)
        for idx in range(len(original)):
            h = original[idx]
            p = predicted[idx]
            if h not in by_h:
                by_h[h] = []
            by_h[h].append(p - h)

        new_x_ticks = []
        data = []
        means = []
        mean = original.mean() if diff_mean_to_prediction else None
        for i in range(min_h, max_h + 1):
            if diff_mean_to_prediction:
                means.append(mean - i)
            if i in by_h:
                new_x_ticks.append("%d\n$n=%d$" % (i, len(by_h[i])))
                data.append(by_h[i])
            else:
                new_x_ticks.append("%d" % i)
                data.append([float('nan')])
        fig = plt.figure()
        ax = fig.add_subplot(1, 1, 1)
        ax.boxplot(data)
        ax.set_xticklabels(new_x_ticks)
        if diff_mean_to_prediction:
            ax.scatter(np.arange(len(data)) + 1, means, color='r', alpha=0.6,
                       label="Deviation to Data Mean")
        ax.set_xlabel(orig_label)
        ax.set_ylabel(pred_label)
        ax.set_title(title)
        ax.legend()
        fig.tight_layout()
        for image_format in image_formats:
            fig.savefig(os.path.join(directory,
                                     file_basename + "." + image_format))
        plt.close(fig)

    @staticmethod
    def _analyse_from_predictions_deviation_dep_on_similarity(
            predicted, original, similarity,
            title, pred_label, orig_label,
            file_basename, directory=".",
            image_formats=None,
            steps=10, precision="%.2f"):
        image_formats = (MATPLOTLIB_OUTPUT_FORMATS if image_formats is None
                         else image_formats)
        min_sim, max_sim = min(similarity), max(similarity)
        step_sim = (max_sim-min_sim)/float(steps)

        def get_bin(x):
            return max(0, min(steps - 1, int((x - min_sim) / step_sim)))
        by_sim = {}
        for idx in range(len(original)):
            d = original[idx] - predicted[idx]
            s = similarity[idx]
            b = get_bin(s)
            if b not in by_sim:
                by_sim[b] = []
            by_sim[b].append(d)

        fig = plt.figure()
        ax = fig.add_subplot(1, 1, 1)
        new_x_ticks = []
        data = []
        for i in range(steps):
            if i in by_sim:
                sim_range = ("[%s,%s" % (precision, precision)
                             + ("]" if i == steps - 1 else "["))
                sim_range = sim_range % (min_sim + i * step_sim,
                                         min_sim + (i + 1) * step_sim)

                new_x_ticks.append("%s\n$n=%d$" % (sim_range, len(by_sim[i])))
                data.append(by_sim[i])

        ax.boxplot(data)
        ax.set_xticklabels(new_x_ticks)

        ax.set_xlabel(orig_label)
        ax.set_ylabel(pred_label)
        ax.set_title(title)
        fig.tight_layout()
        for image_format in image_formats:
            fig.savefig(os.path.join(
                directory, file_basename + "." + image_format))
        plt.close(fig)

    """----------------------LAYER HELP METHODS------------------------------"""

    @staticmethod
    def next_dense(prev, neurons, activation=None, dropout=None,
                   kernel_regularizer=None, batch_normalization=None):
        assert batch_normalization in [None, BN_OFF, BN_PRE_ACTIVATION,
                                       BN_POST_ACTIVATION], batch_normalization
        next_layer = layers.Dense(
            neurons, kernel_regularizer=kernel_regularizer)(prev)
        if batch_normalization == BN_PRE_ACTIVATION:
            next_layer = layers.BatchNormalization()(next_layer)
        if activation is not None:
            next_layer = layers.core.Activation(activation)(next_layer)
        if batch_normalization == BN_POST_ACTIVATION:
            next_layer = layers.BatchNormalization()(next_layer)
        if dropout is not None:
            next_layer = layers.Dropout(dropout)(next_layer)
        return next_layer

    @staticmethod
    def calculate_hidden_layer_size(size, input_units):
        assert size >= 0 or input_units is not None, size
        x = int(round(size if size > 0 else (abs(size) * input_units)))
        assert x > 0
        return x

    """-------------------------OTHER METHODS--------------------------------"""

    @staticmethod
    def parse(tree, item_cache):
        obj = parser.try_lookup_obj(tree, item_cache, KerasNetwork, None)
        if obj is not None:
            return obj
        else:
            raise ArgumentException("The definition of KerasNetwork can "
                                    "only be used for look up of any previously"
                                    " defined KerasNetwork via "
                                    "'keras_networks(id=ID)'")


main_register.append_register(KerasNetwork, "keras_networks")
