from .keras_aux_networks import KerasFieldsNetwork
from .keras_layers import AdjacencyLayer #, Stack
from .keras_network import KerasNetwork
from .keras_network import BN_OFF, BN_PRE_ACTIVATION, BN_POST_ACTIVATION
from .keras_residual_block import KResidualBlock

from ... import parser_tools as parset
from ... import parser
from ... import main_register

from ...misc import hasher, similarities

from ....translate.pddl import Atom, TypedObject

import collections
import keras
import numpy as np
from tensorflow.python.framework.ops import Tensor as TFTensor


def clip(value, min_value, max_value):
    return max(min_value, min(value, max_value))


def parse_hidden_layer_size(tree, item_cache):
    try:
        return float(tree.data[0])
    except ValueError:
        assert tree.data[0] in parset.main_register.get_register(
            KResidualBlock), tree.data[0]
        return parset.main_register.get_register(KResidualBlock).get_reference(
            tree.data[0]).parse(tree, item_cache)


class KerasMLP(KerasFieldsNetwork):
    arguments = parset.ClassArguments(
        'KerasMLP',
        KerasFieldsNetwork.arguments,
        ("hidden_layer_size", False, None, parse_hidden_layer_size,
         "List of neurons to put into the hidden layers OR definition of"
         "a residual layer. Each list entry causes"
         "a new hidden layer. An entry with a positive values creates a layer "
         "with the entries number of neurons. For a negative value "
         "abs(value)*#input_units neurons are in the layer. If no list, but an "
         "integer is given, this is interpreted as list with the integer as "
         "value."),
        ("output_units", False, True, int,
         "Classification network with output_units output units or regression "
         "network if -1. If -2 is set, then the output units have to be set "
         "in the initialization routine via the parameter 'output_size'."),
        ("ordinal_classification", True, False, parser.convert_bool,
         "Performs ordinal classification (if classification network)"),
        ("ordinal_classification_threshold", True, 0.01, float,
         "Threshold for interpreting prediction as 1."),
        ("bin_size", True, 1, int,
         "Puts the output values into bins of this size (divides their value "
         "by bin size and converts back to int)"),
        ("activation", True, "sigmoid", str,
         "Activation function of hidden layers"),
        ("batch_normalization", True, None, int,
         "0 = no batch normalization, %i batch normalization pre activation,"
         "%i batch normalization post activation" %
         (BN_PRE_ACTIVATION, BN_POST_ACTIVATION)),
        ("dropout", True, None, float,
         "Dropout probability or None if no dropout"),
        ("l2", True, None, float, "L2 regularization weight"),
        ("domain_properties", True, None, None,
         "DomainProperties object. Currently this cannot be parsed from the "
         "command line. Either construct the network yourself or let your code "
         "after construct set the DomainProperty object."),
        order=["x_fields", "y_fields", "hidden_layer_size",
               "output_units", "ordinal_classification",
               "ordinal_classification_threshold", "bin_size",
               "activation", "batch_normalization", "dropout", "l2",
               "tparams",
               "load", "store", "learner_formats", "out", "domain_properties",
               "count_samples", "test_similarity",
               "graphdef", "variables", "id"]
    )
    output_units = property(lambda self: self._output_units)
    ordinal_classification = property(lambda self: self._ordinal_classification)
    ordinal_classification_threshold = property(
        lambda self: self._ordinal_classification_threshold)
    bin_size = property(lambda self: self._bin_size)
    def __init__(self, x_fields, y_fields, hidden_layer_size=None,
                 output_units=-1, ordinal_classification=False,
                 ordinal_classification_threshold=0.01, bin_size=1,
                 activation="sigmoid", batch_normalization=None,
                 dropout=None, l2=None, tparams=None,
                 load=None, store=None, learner_formats=None, out=".",
                 domain_properties=None,
                 count_samples=False, test_similarity=None, graphdef=None,
                 variables=None, id=None):

        KerasFieldsNetwork.__init__(
            self, x_fields, y_fields, tparams, load, store, learner_formats,
            out, count_samples, test_similarity, graphdef, variables, id)

        assert len(self._x_field_names) > 0, \
            "Input Field Count: %i" % len(self._x_field_names)
        # we predict heuristic only
        assert len(self._y_field_names) == 1, \
            "Output Field Count: %i" % len(self._y_field_names)

        self._hidden_layer_size = ([hidden_layer_size]
                                   if isinstance(hidden_layer_size, int) else
                                   hidden_layer_size)
        assert all([hls != 0 for hls in self._hidden_layer_size])

        self._output_units = output_units
        self._ordinal_classification = ordinal_classification
        self._ordinal_classification_threshold = \
            ordinal_classification_threshold
        assert not self._ordinal_classification or self._output_units != -1
        self._bin_size = bin_size
        assert self._bin_size > 0 and self._bin_size == int(self._bin_size)
        self._activation = activation
        assert batch_normalization in [None, BN_OFF, BN_PRE_ACTIVATION,
                                       BN_POST_ACTIVATION]
        self._batch_normalization = batch_normalization
        self._dropout = None if dropout == 0 else dropout
        assert self._dropout is None or 0 < self._dropout <= 1.0
        assert l2 is None or l2 >= 0
        self._l2 = (None if (l2 is None or l2 == 0)
                    else keras.regularizers.l2(l2))

        self._x_encoder = lambda x: [np.stack(x[:, i], axis=0)
                                     for i in range(len(self._x_field_names))]
        self._x_fields_comparators = [
            similarities.hamming_measure_cmp_iterable_equal
            for _ in range(len(self._x_field_names))
        ]

        self._count_samples_hasher = lambda x, y: hasher.tuplify(x, y)
        # Either self._domain_properties will be used to determine the state
        # size or on initialization the state size has to be given
        # If both is given, the DomainProperties will be preferred
        self._state_size = None

        for metric in ["accuracy", "mean_absolute_error"]:
            if metric not in self.training_params.metrics:
                self.training_params.metrics.append(metric)

        self._domain_properties = domain_properties

    def _get_final_activation(self):
        return ("relu" if self._output_units == -1 else
                ("sigmoid" if self._ordinal_classification else "softmax"))

    def __get_state_size(self, state_size, data):
        if state_size is not None:
            return state_size
        else:
            if data is not None and len(data) > 0:
                for key in data[0].data:
                    return len(data[0].data[key][0][0][
                                   self._x_fields_extractor(data[0])[0]])
        # It might happen that the output number is not defined...
        # e.g. when loading a network
        return None

    def __get_output_units(self, output_size, data):
        if self._output_units != -2:
            return self._output_units
        else:
            if output_size is not None:
                return output_size
            elif data is not None:
                assert len(self._y_field_names) == 1, \
                    "currently nothing else supported"

                def update(entry):
                    update.maximum = max(update.maximum, entry[fields[0]])
                update.maximum = -1
                for d in data:
                    fields = self._y_fields_extractor(d)
                    d.over_all(update)

                return int(update.maximum/self._bin_size) + 1
            assert False

    def _encode_y_one_hot(self, y):
        yy = np.zeros((y.shape[0], self._output_units))
        if len(y.shape) == 1:
            _y = y
        elif len(y.shape) == 2:
            _y = y[:, 0]
        else:
            assert False, y.shape
        yy[np.arange(y.shape[0]),
           _y.astype(int).clip(0, self._output_units - 1)] = 1
        return yy

    @staticmethod
    def _decode_y_one_hot(y):
        return np.argmax(y) if len(y.shape) == 1 else np.argmax(y, axis=1)

    def _encode_y_ordinal(self, y):
        yy = np.zeros((y.shape[0], self._output_units))
        if len(y.shape) == 1:
            _y = y
        elif len(y.shape) == 2:
            _y = y[:, 0]
        else:
            assert False, y.shape
        for i in range(y.shape[0]):
            yy[i, :clip(int(y[i]), 0, self._output_units - 1) + 1] = 1
        return yy

    def _decode_y_ordinal(self, y):
        y = y > self._ordinal_classification_threshold
        yp = (np.argmin(y) if len(y.shape) == 1 else np.argmin(y, axis=1)) - 1
        if len(y.shape) == 1:
            return yp if yp != -1 else (y.shape[0] - 1)
        else:
            yp[yp == -1] = y.shape[1] - 1
            return yp

    def _initialize_general(self, *args, **kwargs):
        arg_state_size = kwargs.pop("state_size", None)
        arg_data = kwargs.pop("data", None)
        arg_output_size = kwargs.pop("output_size", None)
        self._state_size = self.__get_state_size(arg_state_size, arg_data)
        self._output_units = self.__get_output_units(arg_output_size, arg_data)

        regression = self._output_units == -1

        print("Keras MLP State Size: %s" % str(self._state_size))
        print("Keras MLP Output Units: %i" % (
            1 if regression else self._output_units))

        if regression:
            self._y_encoder = None
            self._y_decoder = None
        else:
            assert len(self._y_field_names) == 1
            if self._ordinal_classification:
                self._y_encoder = self._encode_y_ordinal
                self._y_decoder = self._decode_y_ordinal
            else:
                self._y_encoder = self._encode_y_one_hot
                self._y_decoder = self._decode_y_one_hot

        if self._bin_size != 1:
            assert False, "y_decoder missing"
            previous_y_encoder = self._y_encoder

            def new_y_encoder(y):
                y = (y / self._bin_size).astype(int)
                return (y if previous_y_encoder is None else
                        previous_y_encoder(y))
            self._y_encoder = new_y_encoder

    def _construct_model(
            self, input_units, output_units, regression):
        total_input_units = input_units * len(self._x_field_names)

        ins = [keras.layers.Input(shape=(input_units,))
               for _ in range(len(self._x_field_names))]
        next_layer = (ins[0] if len(ins) == 1
                      else keras.layers.concatenate(ins, axis=-1))

        for hidden_layer_size in self._hidden_layer_size:
            if isinstance(hidden_layer_size, KResidualBlock):
                next_layer = hidden_layer_size(
                    activation=self._activation,
                    batch_normalization=self._batch_normalization,
                    dropout=self._dropout,
                    kernel_regularizer=self._l2,
                    input_size=total_input_units,
                )(next_layer)
            else:
                hidden_layer_size = KerasNetwork.calculate_hidden_layer_size(
                    hidden_layer_size, total_input_units)

                next_layer = KerasNetwork.next_dense(
                    next_layer, hidden_layer_size, self._activation,
                    self._dropout,
                    kernel_regularizer=self._l2,
                    batch_normalization=self._batch_normalization)

        next_layer = KerasNetwork.next_dense(
            next_layer, output_units, self._get_final_activation(), None,
            kernel_regularizer=self._l2,
            batch_normalization=None)
        return keras.Model(inputs=ins, outputs=next_layer)

    def _initialize_model(self, *args, **kwargs):
        if self._state_size is None:
            raise ValueError(
                "This networks needs to know the size of the states fed as "
                "input. Provide 'state_size'=<state size> for the "
                "initialization of the network or provide "
                "'data'=<SampleBatchData object which could be used>.")
        if self._output_units < -1 or self._output_units == 0:
            raise ValueError("The number of output units has to be positive or "
                             "-1 (for regression).")

        regression = self._output_units == -1
        self._model = self._construct_model(
            self._state_size,
            1 if regression else self._output_units,
            regression)
        self._compile()

    def _load(self, path, learner_format, *args, **kwargs):
        KerasFieldsNetwork._load(self, path, learner_format, *args, **kwargs)

        all_input_tensors = ([self._model.input]
                             if isinstance(self._model.input, TFTensor) else
                             self._model.input)
        assert all(all([
            len(tensor.shape) == 2,
            tensor.shape[0].value is None,
            tensor.shape[1].value == all_input_tensors[0].shape[1].value])
                   for tensor in all_input_tensors)

        self._state_size = all_input_tensors[0].shape[1].value
        self._output_units = self._model.output.shape[1].value

    def reinitialize(self, *args, **kwargs):
        skip_loading = kwargs.get("skip_loading", False)
        keras.backend.clear_session()
        if self.path_load is not None and not skip_loading:
            self.load(**kwargs)
        else:
            self._initialize_model(*args, **kwargs)

    def _finalize(self, *args, **kwargs):
        pass

    @staticmethod
    def parse(tree, item_cache):
        return parser.try_whole_obj_parse_process(tree, item_cache,
                                                  KerasMLP)


main_register.append_register(KerasMLP, "keras_mlp")


class KerasAdaptiveMLP(KerasMLP):
    arguments = parset.ClassArguments(
        'KerasAdaptiveMLP',
        KerasMLP.arguments,
        ("hidden", False, None, int, "Number of hidden layers"),
        ("residual_layers", True, None,
         parset.main_register.get_register(KResidualBlock),
         "Single or list of residual blocks to add after the hidden layers."),
        ("pseudo_output_units", True, -1, int,
         "builds the hidden layer for the pseudo output units if given a "
         "positive value."),
        delete=["hidden_layer_size"],
        order=["hidden",
               "x_fields", "y_fields",
               "residual_layers",
               "output_units", "pseudo_output_units", "ordinal_classification",
               "ordinal_classification_threshold",
               "bin_size",
               "activation", "batch_normalization", "dropout", "l2",
               "tparams",
               "load", "store", "learner_formats", "out", "domain_properties",
               "count_samples", "test_similarity",
               "graphdef", "variables", "id"]
        )

    def __init__(self, hidden, x_fields, y_fields,
                 residual_layers=None,
                 output_units=-1, pseudo_output_units=-1,
                 ordinal_classification=False,
                 ordinal_classification_threshold=0.01, bin_size=1,
                 activation="sigmoid", batch_normalization=None,
                 dropout=None, l2=None, tparams=None,
                 load=None, store=None, learner_formats=None, out=".",
                 domain_properties=None,
                 count_samples=False, test_similarity=None, graphdef=None,
                 variables=None, id=None):

        KerasMLP.__init__(
            self, x_fields, y_fields, -1,
            output_units, ordinal_classification,
            ordinal_classification_threshold, bin_size, activation,
            batch_normalization, dropout, l2,
            tparams, load, store, learner_formats,
            out, domain_properties, count_samples, test_similarity, graphdef,
            variables, id)

        self._hidden = hidden
        if residual_layers is None:
            self._residual_layers = []
        elif isinstance(residual_layers, list):
            self._residual_layers = residual_layers
        else:
            self._residual_layers = [residual_layers]
        self._pseudo_output_units = pseudo_output_units

    def _construct_model(
            self, input_units, output_units, regression):
        pseudo_output_units = (output_units
                               if self._pseudo_output_units == -1 else
                               self._pseudo_output_units)
        total_input_units = input_units * len(self._x_field_names)
        unit_diff = total_input_units - pseudo_output_units
        nb_layers = self._hidden + 1 + len(self._residual_layers)
        step = int(unit_diff / nb_layers)
        units = total_input_units

        ins = [keras.layers.Input(shape=(input_units,))
               for _ in range(len(self._x_field_names))]
        next_layer = keras.layers.concatenate(ins, axis=-1)
        for i in range(self._hidden):
            units -= step
            next_layer = KerasNetwork.next_dense(
                next_layer, units, self._activation, self._dropout,
                kernel_regularizer=self._l2,
                batch_normalization=self._batch_normalization)

        for residual_layer in self._residual_layers:
            units -= step
            next_layer = residual_layer(
                hidden_layer_size = units,
                activation=self._activation,
                batch_normalization=self._batch_normalization,
                dropout=self._dropout,
                kernel_regularizer=self._l2,
                input_size=total_input_units,
            )(next_layer)
        next_layer = KerasNetwork.next_dense(
            next_layer, output_units, self._get_final_activation(), None,
            kernel_regularizer=self._l2,
            batch_normalization=None)
        return keras.Model(inputs=ins, outputs=next_layer)

    @staticmethod
    def parse(tree, item_cache):
        return parser.try_whole_obj_parse_process(tree, item_cache,
                                                  KerasAdaptiveMLP)


main_register.append_register(KerasAdaptiveMLP, "keras_adp_mlp")


class KerasDependencyMLP(KerasMLP):
    arguments = parset.ClassArguments(
        'KerasDependencyMLP',
        KerasMLP.arguments,
        ("hidden", False, None, int, "Number of hidden layers"),
        ("dependency", False, None, str,
         "Name of the dependency schema to use. Available are:\n"
         "\t\"pre_post\": Dependency between Atom A & B iff it exist and "
         "action with A in its precondition and B in its post condition."),
        ("dense_layers", True, 1, int,
         "Number of additional dense layers to add after the dependency "
         "based layers. At least one might be required to map the output size "
         "from the dependencies layers to the output layer. By default 1 layer "
         "is added"),
        delete=["hidden_layer_size"],
        order=["hidden", "dependency", "x_fields", "y_fields",
               "dense_layers", "output_units", "ordinal_classification",
               "ordinal_classification_threshold",
               "bin_size",
               "activation", "batch_normalization", "dropout", "l2",
               "tparams",
               "load", "store", "learner_formats", "out", "domain_properties",
               "count_samples", "test_similarity",
               "graphdef", "variables", "id"]
        )

    def __init__(self, hidden, dependency, x_fields, y_fields, dense_layers=1,
                 output_units=-1, ordinal_classification=False,
                 ordinal_classification_threshold=0.01, bin_size=1,
                 activation="sigmoid", batch_normalization=None,
                 dropout=None, l2=None, tparams=None,
                 load=None, store=None, learner_formats=None, out=".",
                 domain_properties=None,
                 count_samples=False, test_similarity=None, graphdef=None,
                 variables=None, id=None):

        KerasMLP.__init__(
            self, x_fields, y_fields, -1,
            output_units, ordinal_classification,
            ordinal_classification_threshold, bin_size,
            activation, batch_normalization, dropout, l2, tparams, load,
            store, learner_formats,
            out, domain_properties, count_samples, test_similarity, graphdef,
            variables, id)

        self._hidden = hidden
        assert dependency is not None
        self._dependency = dependency.lower()
        assert self._dependency in ["pre_post"]
        self._dense_layers = dense_layers

    @staticmethod
    def __canonicalize_atom_arg(p, parameters, counters, known):
        assert p in parameters, "%s not found in %s" % (p, str(parameters))
        type_name = parameters[p]
        canon = TypedObject(type_name + str(counters[type_name]), type_name)
        counters[type_name] += 1
        known[p] = canon
        return canon

    @staticmethod
    def __canonicalize_atoms(parameters, *atoms):
        known = {}
        type_counter = collections.defaultdict(int)
        return [Atom(atom.predicate,
                     [known[p] if p in known else
                      KerasDependencyMLP.__canonicalize_atom_arg(
                          p, parameters, type_counter, known)
                      for p in atom.args])
                for atom in atoms]

    @staticmethod
    def __get_atom_dependency_relation(domain):
        def get_atoms_from_condition(condition):
            templates = set()
            todo = [condition]
            while len(todo) > 0:
                for part in todo.pop().parts:
                    if isinstance(part, Atom):
                        templates.add(part)
                    else:
                        todo.append(part)
            return templates

        relations = collections.defaultdict(set)
        for action in domain.actions:
            precondition = get_atoms_from_condition(action.precondition)
            base_parameters = {to.name: to.type_name for to in
                               action.parameters}

            for effect in action.effects:
                if len(effect.parameters) == 0:
                    parameters = base_parameters
                else:
                    parameters = {to.name: to.type_name for to in
                                  effect.parameters}
                    parameters.update(base_parameters)

                eff_atom = (
                    effect.literal.negate() if effect.literal.negated else
                    effect.literal)

                for pre_atom in precondition | get_atoms_from_condition(
                        effect.condition):
                    c_pre_atom, c_eff_atom = \
                        KerasDependencyMLP.__canonicalize_atoms(
                            parameters, pre_atom, eff_atom)
                    relations[c_pre_atom].add(c_eff_atom)
        return relations

    @staticmethod
    def __get_matrix_from_atom_dependencies(domain_properties, atoms,
                                            dependencies):
        assert domain_properties.fixed_world, "Otherwise not implemented"

        def atom_grounding_generator(_pre, _posts, _relevant_predicates):
            if _pre.predicate not in _relevant_predicates:
                return

            for post in _posts:
                if post.predicate not in _relevant_predicates:
                    continue

                all_vars = set(x for x in post.args) | set(x for x in _pre.args)
                all_vars = sorted(all_vars, key=lambda x: str(x))

                choices = [domain_properties.fixed_objects_by_type[v.type_name]
                           for v in all_vars]
                choice_idx = [0 for _ in choices]
                var_idx = {var: idx for idx, var in enumerate(all_vars)}
                pre_idx = [var_idx[v] for v in _pre.args]
                post_idx = [var_idx[v] for v in post.args]

                exhausted = False
                while not exhausted:
                    yield Atom(_pre.predicate, tuple(
                        choices[idx][choice_idx[idx]].name for idx in
                        pre_idx)), Atom(post.predicate, tuple(
                            choices[idx][choice_idx[idx]].name for idx in
                            post_idx))
                    # Get Next Grounding
                    idx_inc = len(choice_idx) - 1
                    do_increment = True
                    while do_increment:
                        if idx_inc < 0:
                            exhausted = True
                            break
                        if choice_idx[idx_inc] == len(choices[idx_inc]) - 1:
                            choice_idx[idx_inc] = 0
                            idx_inc -= 1
                        else:
                            choice_idx[idx_inc] += 1
                            do_increment = False

        atoms = sorted(atoms)
        atoms_indices = {atom: no for no, atom in enumerate(atoms)}
        available_predicates = set(atom.predicate for atom in atoms)
        matrix = np.zeros(shape=(len(atoms), len(atoms)))

        for pre, posts in dependencies.items():
            groundings = atom_grounding_generator(pre, posts,
                                                  available_predicates)
            for atom_pre, atom_post in groundings:
                if atom_pre not in atoms or atom_post not in atoms:
                    continue
                matrix[atoms_indices[atom_pre], atoms_indices[atom_post]] = 1
        return matrix

    def _construct_model(
            self, input_units, output_units, regression):
        assert self._domain_properties is not None
        total_input_units = input_units * len(self._x_field_names)

        if self._dependency == "pre_post":
            assert (input_units ==
                    self._domain_properties.get_flexible_state_size()
                    or input_units ==
                    self._domain_properties.get_full_state_size())
            non_static_case = (
                    input_units ==
                    self._domain_properties.get_flexible_state_size())
            assert non_static_case, "Do not support full state case"
            dependencies = self.__get_atom_dependency_relation(
                self._domain_properties.domain)
            atoms = (self._domain_properties.get_flexible_state_atoms()
                     if non_static_case else
                     self._domain_properties.get_full_state_atoms())

            dependency_matrix = self.__get_matrix_from_atom_dependencies(
                self._domain_properties, atoms, dependencies)
            assert dependency_matrix.shape == (input_units, input_units)
            state_goal_matrix = np.ones(shape=(len(self._x_field_names),
                                               len(self._x_field_names)))

            # Create the model
            curr = input_layers = [keras.layers.Input(shape=(input_units,))
                                   for _ in range(len(self._x_field_names))]
            curr = Stack(axis=0)(curr)
            for idx_hidden in range(self._hidden):
                curr = AdjacencyLayer(
                    dependency_matrix,
                    adjacency_axis=-1,
                    activation=self._activation,
                    kernel_regularizer=self._l2)(curr)
                if len(self._x_field_names) > 1:
                    curr = AdjacencyLayer(
                        state_goal_matrix,
                        adjacency_axis=0,
                        activation=self._activation,
                        kernel_regularizer=self._l2)(curr)
        else:
            assert False, ("Internal error: Unknown dependency should have "
                           "been caught in constructor")

        curr = keras.layers.Flatten()(curr)
        for i in range(self._dense_layers):
            is_final_layer = i == self._dense_layers - 1
            units = (output_units if is_final_layer else total_input_units)
            curr = KerasNetwork.next_dense(
                curr, units, self._get_final_activation(), None,
                kernel_regularizer=self._l2,
                batch_normalization=(None if is_final_layer
                                     else self._batch_normalization))
        return keras.Model(inputs=input_layers, outputs=curr)

    @staticmethod
    def parse(tree, item_cache):
        return parser.try_whole_obj_parse_process(tree, item_cache,
                                                  KerasDependencyMLP)


main_register.append_register(KerasDependencyMLP, "keras_dep_mlp")
