from .keras_network import KerasNetwork
from .keras_network import BN_OFF, BN_PRE_ACTIVATION, BN_POST_ACTIVATION

from ... import parser_tools as parset
from ... import parser
from ... import main_register

from keras import layers


class KResidualBlock(object):
    arguments = parset.ClassArguments(
        'KResidualBlock',
        None,
        ("hidden_layer_count", False, None, int,
         "Number of hidden layers within the residual block."),
        ("hidden_layer_size", True, None, float,
         "Optional. If provided, then the given value takes precedence over"
         "a possible value provided by a NN. Size of the hidden layers within "
         "this residual block. Positive values represent fix number of neurons,"
         "negative values will be converted as abs(value)*#input_units."),
        ("batch_normalization", True, None, int,
         "Optional. If provided, then the given value takes precedence over"
         "a possible value provided by a NN. 0 = no batch normalization, "
         "%i batch normalization pre activation,"
         "%i batch normalization post activation" %
         (BN_PRE_ACTIVATION, BN_POST_ACTIVATION)),
        ("activation", True, None, str,
         "Optional. If provided, then the given value takes precedence over"
         "a possible value provided by a NN. Activation function for the "
         "hidden layers"),
        ("add_skip_projection", True, None, parser.convert_bool,
         "if not set, then a linear projection for the skip block to match the"
         "residual block is only added it the residual block has change the"
         "tensor dimensions. If True: Always add a skip projection, if False:"
         "Never add a skip projection (if dimensions were changed this causes "
         "an error"
         ),

        order=["hidden_layer_count",
               "hidden_layer_size",
               "activation",
               "batch_normalization",
               "add_skip_projection"]
    )

    def __init__(self, hidden_layer_count,
                 hidden_layer_size=None,
                 activation=None,
                 batch_normalization=None,
                 add_skip_projection=None):
        assert hidden_layer_count > 0, hidden_layer_count
        self._hidden_layer_count = hidden_layer_count
        assert hidden_layer_count is None or hidden_layer_size != 0, \
            hidden_layer_size
        self._hidden_layer_size = hidden_layer_size
        self._activation = activation
        assert batch_normalization in [None, BN_OFF, BN_PRE_ACTIVATION,
                                       BN_POST_ACTIVATION]
        self._batch_normalization = batch_normalization
        self._add_skip_projection = add_skip_projection

    @staticmethod
    def __get_parameter_value(from_self, from_nn, permit_none=False):
        v = from_nn if from_self is None else from_self
        assert permit_none or v is not None
        return v

    def __call__(self,
                 hidden_layer_size=None,
                 activation=None,
                 batch_normalization=None,
                 dropout=None,
                 kernel_regularizer=None,
                 add_skip_projection=None,
                 input_size=None):

        hidden_layer_size = self.__get_parameter_value(
            self._hidden_layer_size, hidden_layer_size)
        hidden_layer_size = KerasNetwork.calculate_hidden_layer_size(
            hidden_layer_size, input_size)

        activation = self.__get_parameter_value(
            self._activation, activation)

        batch_normalization = self.__get_parameter_value(
            self._batch_normalization, batch_normalization, permit_none=True)

        dropout = self.__get_parameter_value(
            None, dropout, permit_none=True)

        kernel_regularizer = self.__get_parameter_value(
            None, kernel_regularizer, permit_none=True)

        add_skip_projection = self.__get_parameter_value(
            self._add_skip_projection, add_skip_projection, permit_none=True
        )

        def func(previous_layer):
            # Residual Block
            residual_block = previous_layer
            for _ in range(self._hidden_layer_count):
                residual_block = KerasNetwork.next_dense(
                    prev=residual_block,
                    neurons=hidden_layer_size,
                    activation=activation,
                    dropout=dropout,
                    kernel_regularizer=kernel_regularizer,
                    batch_normalization=batch_normalization,
                )
            # Skip block
            skip_block = previous_layer
            if ((add_skip_projection is None and
                 not previous_layer.shape.is_compatible_with(
                     residual_block.shape)) or
                    add_skip_projection):
                skip_block = layers.Dense(hidden_layer_size)(skip_block)
            # Merge
            return layers.add([residual_block, skip_block])

        return func

    @staticmethod
    def parse(tree, item_cache):
        return parser.try_whole_obj_parse_process(tree, item_cache,
                                                  KResidualBlock)


main_register.append_register(KResidualBlock, "keras_residual_block")
