
from keras import backend as K
from keras.engine import InputSpec
from keras.layers import Dense
# from keras.layers.merge import _Merge

import numpy as np

CUSTOM_LAYERS = {}


# class Stack(_Merge):
#     """Layer that stacks a list of inputs.
#     It takes as input a list of tensors,
#     all of the same shape, and returns a
#     single tensor, stacked at the given
#     axis.
#     # Arguments
#         axis: Axis along which to stack.
#         **kwargs: standard layer keyword arguments.
#     """
#
#     def __init__(self, axis=-1, **kwargs):
#         super(Stack, self).__init__(**kwargs)
#         self.axis = axis
#         self.adapted_axis = (self.axis + 1 if self.axis >= 0 else self.axis)
#         self.supports_masking = True
#         self._reshape_required = False
#
#     def build(self, input_shape):
#         # Used purely for shape validation.
#         if not isinstance(input_shape, list) or len(input_shape) < 2:
#             raise ValueError('A `Stack` layer should be called '
#                              'on a list of at least 2 inputs')
#         if all([shape is None for shape in input_shape]):
#             return
#         shape_set = set([tuple(shape) for shape in input_shape])
#         if len(shape_set) > 1:
#             raise ValueError('A `Stack` layer requires inputs with matching '
#                              'shapes. Got inputs shapes: %s' % input_shape)
#         super(Stack, self).build(input_shape)
#
#     def _merge_function(self, inputs):
#         return K.stack(inputs, axis=self.adapted_axis)
#
#     def compute_output_shape(self, input_shape):
#         if not isinstance(input_shape, list):
#             raise ValueError('A `Stack` layer should be called '
#                              'on a list of inputs.')
#         output_shape = list(input_shape[0])
#         output_shape.insert(self.adapted_axis, len(input_shape))
#         return tuple(output_shape)
#
#     def compute_mask(self, inputs, mask=None):
#         if mask is None:
#             return None
#         if not isinstance(mask, list):
#             raise ValueError('`mask` should be a list.')
#         if not isinstance(inputs, list):
#             raise ValueError('`inputs` should be a list.')
#         if len(mask) != len(inputs):
#             raise ValueError('The lists `inputs` and `mask` '
#                              'should have the same length.')
#         if all([m is None for m in mask]):
#             return None
#         # Make a list of masks while making sure
#         # the dimensionality of each mask
#         # is the same as the corresponding input.
#         masks = []
#         for input_i, mask_i in zip(inputs, mask):
#             if mask_i is None:
#                 # Input is unmasked. Append all 1s to masks,
#                 masks.append(K.ones_like(input_i, dtype='bool'))
#             elif K.ndim(mask_i) < K.ndim(input_i):
#                 # Mask is smaller than the input, expand it
#                 masks.append(K.expand_dims(mask_i))
#             else:
#                 masks.append(mask_i)
#         stacked = K.stack(masks, axis=self.adapted_axis)
#         return K.all(stacked, axis=-1, keepdims=False)
#
#     def get_config(self):
#         config = {'axis': self.axis}
#         config.update(super(Stack, self).get_config())
#         return config
#
#
# CUSTOM_LAYERS[Stack.__name__] = Stack


class AdjacencyLayer(Dense):
    def __init__(self, adjacency_matrix, adjacency_axis, **kwargs):
        self.adjacency_matrix = adjacency_matrix
        if not isinstance(self.adjacency_matrix, np.ndarray):
            self.adjacency_matrix = np.array(self.adjacency_matrix)
        self.adjacency_axis = adjacency_axis

        self.k_adjacency_matrix = K.constant(self.adjacency_matrix)
        # If axis >= 0, then +1 for first dimension is batch
        self.adapted_adjacency_axis = (self.adjacency_axis + 1
                                       if self.adjacency_axis >= 0 else
                                       self.adjacency_axis)

        super(AdjacencyLayer, self).__init__(self.adjacency_matrix.shape[1],
                                             **kwargs)

    def build(self, input_shape):
        assert len(input_shape) >= 2
        if self.adapted_adjacency_axis < 0:
            self.adapted_adjacency_axis += len(input_shape)
        assert 0 <= self.adapted_adjacency_axis < len(input_shape)
        input_dim = input_shape[self.adapted_adjacency_axis]
        assert input_dim == self.adjacency_matrix.shape[0]
        assert self.units == self.adjacency_matrix.shape[1]

        self.kernel = self.add_weight(shape=(input_dim, self.units),
                                      initializer=self.kernel_initializer,
                                      name='kernel',
                                      regularizer=self.kernel_regularizer,
                                      constraint=self.kernel_constraint)
        if self.use_bias:
            self.bias = self.add_weight(shape=(self.units,),
                                        initializer=self.bias_initializer,
                                        name='bias',
                                        regularizer=self.bias_regularizer,
                                        constraint=self.bias_constraint)
        else:
            self.bias = None
        self.input_spec = InputSpec(min_ndim=2, axes={
            self.adapted_adjacency_axis: input_dim})
        self.built = True

    def call(self, inputs):
        swaps = None
        if self.adapted_adjacency_axis != len(inputs.shape.as_list()) - 1:
            swaps = [i for i in range(len(inputs.shape.as_list()))]
            swaps[self.adapted_adjacency_axis] = len(inputs.shape.as_list()) - 1
            swaps[-1] = self.adapted_adjacency_axis

        inputs = (inputs if swaps is None else
                  K.permute_dimensions(inputs, swaps))
        output = K.dot(inputs, self.kernel * self.k_adjacency_matrix)
        if self.use_bias:
            output = K.bias_add(output, self.bias)
        if self.activation is not None:
            output = self.activation(output)
        output = (output if swaps is None else
                  K.permute_dimensions(output, swaps))
        return output

    def compute_output_shape(self, input_shape):
        assert input_shape and len(input_shape) >= 2
        assert input_shape[self.adapted_adjacency_axis]
        output_shape = list(input_shape)
        output_shape[self.adapted_adjacency_axis] = self.units
        return tuple(output_shape)

    def get_config(self):
        config = {
            "adjacency_matrix": self.adjacency_matrix.tolist(),
            "adjacency_axis": self.adjacency_axis
        }
        config.update(super(AdjacencyLayer, self).get_config())
        del config["units"]
        return config


CUSTOM_LAYERS[AdjacencyLayer.__name__] = AdjacencyLayer
