import tensorflow as tf
from tensorflow.contrib.rnn.python.ops.core_rnn_cell import _linear
from tensorflow.python.framework import ops
from tensorflow.python.util import nest


class PointerGeneratorDecoder(tf.contrib.seq2seq.BasicDecoder):
    """Pointer Generator sampling decoder."""

    def __init__(self, source_extend_tokens, source_oov_words, cell, helper, initial_state,
                 output_layer=None):
        self.source_oov_words = source_oov_words
        self.source_extend_tokens = source_extend_tokens
        super(PointerGeneratorDecoder, self).__init__(cell, helper, initial_state, output_layer)

    @property
    def output_size(self):
        # Return the cell output and the id
        return tf.contrib.seq2seq.BasicDecoderOutput(
            rnn_output=self._rnn_output_size() + self.source_oov_words,
            sample_id=self._helper.sample_ids_shape)

    @property
    def output_dtype(self):
        # Assume the dtype of the cell is the output_size structure
        # containing the input_state's first component's dtype.
        # Return that structure and the sample_ids_dtype from the helper.
        dtype = nest.flatten(self._initial_state)[0].dtype
        return tf.contrib.seq2seq.BasicDecoderOutput(
            nest.map_structure(lambda _: dtype, self._rnn_output_size() + self.source_oov_words),
            self._helper.sample_ids_dtype)

    def step(self, time, inputs, state, name=None):
        """Perform a decoding step.
        Args:
        time: scalar `int32` tensor.
        inputs: A (structure of) input tensors.
        state: A (structure of) state tensors and TensorArrays.
        name: Name scope for any created operations.
        Returns:
        `(outputs, next_state, next_inputs, finished)`.
        """
        with ops.name_scope(name, "PGDecoderStep", (time, inputs, state)):
            cell_outputs, cell_state = self._cell(inputs, state)
            # the last cell state contains attention, which is context
            last_cell_state = cell_state
            attention = last_cell_state.attention
            att_cell_state = last_cell_state.cell_state
            alignments = last_cell_state.alignments

            with tf.variable_scope('calculate_pgen'):
                p_gen = _linear([attention, inputs, att_cell_state[-1]], 1, True)
                p_gen = tf.sigmoid(p_gen)

            if self._output_layer is not None:
                cell_outputs = self._output_layer(cell_outputs)

            vocab_dist = tf.nn.softmax(cell_outputs) * p_gen

            alignments = alignments * (1 - p_gen)

            # since we have OOV words, we need expand the vocab dist
            vocab_size = tf.shape(vocab_dist)[-1]
            extended_vsize = vocab_size + self.source_oov_words
            batch_size = tf.shape(vocab_dist)[0]
            extra_zeros = tf.zeros((batch_size, self.source_oov_words))
            # batch * extend vocab size
            vocab_dists_extended = tf.concat(axis=-1, values=[vocab_dist, extra_zeros])

            batch_nums = tf.range(0, limit=batch_size)  # shape (batch_size)
            batch_nums = tf.expand_dims(batch_nums, 1)  # shape (batch_size, 1)
            attn_len = tf.shape(self.source_extend_tokens)[1]  # number of states we attend over
            batch_nums = tf.tile(batch_nums, [1, attn_len])  # shape (batch_size, attn_len)
            indices = tf.stack((batch_nums, self.source_extend_tokens), axis=2)  # shape (batch_size, enc_t, 2)
            shape = [batch_size, extended_vsize]

            attn_dists_projected = tf.scatter_nd(indices, alignments, shape)

            final_dists = attn_dists_projected + vocab_dists_extended

            # note: sample_ids will contains OOV words
            sample_ids = self._helper.sample(
                time=time, outputs=final_dists, state=cell_state)

            finished, next_inputs, next_state = self._helper.next_inputs(
                time=time,
                outputs=cell_outputs,
                state=cell_state,
                sample_ids=sample_ids)

            outputs = tf.contrib.seq2seq.BasicDecoderOutput(final_dists, sample_ids)
            return outputs, next_state, next_inputs, finished


def _pg_bahdanau_score(processed_query, keys):
    """Implements Bahdanau-style (additive) scoring function.
    Args:
        processed_query: Tensor, shape `[batch_size, num_units]` to compare to keys.
        keys: Processed memory, shape `[batch_size, max_time, num_units]`.
    Returns:
        A `[batch_size, max_time]` tensor of unnormalized score values.
    """
    dtype = processed_query.dtype
    # Get the number of hidden units from the trailing dimension of keys
    num_units = keys.shape[2].value or tf.shape(keys)[2]
    # Reshape from [batch_size, ...] to [batch_size, 1, ...] for broadcasting.
    processed_query = tf.expand_dims(processed_query, 1)
    v = tf.get_variable(
        "attention_v", [num_units], dtype=dtype)
    b = tf.get_variable(
        "attention_b", [num_units], dtype=dtype,
        initializer=tf.zeros_initializer())

    return tf.reduce_sum(v * tf.tanh(keys + processed_query + b), [2])


class PointerGeneratorBahdanauAttention(tf.contrib.seq2seq.BahdanauAttention):
    def __init__(self,
                 num_units,
                 memory,
                 memory_sequence_length=None,
                 normalize=False,
                 probability_fn=None,
                 score_mask_value=float("-inf"),
                 name="PointerGeneratorBahdanauAttention",
                 coverage=False):
        """Construct the Attention mechanism.
        Args:
            num_units: The depth of the query mechanism.
            memory: The memory to query; usually the output of an RNN encoder.  This
            tensor should be shaped `[batch_size, max_time, ...]`.
            memory_sequence_length (optional): Sequence lengths for the batch entries
            in memory.  If provided, the memory tensor rows are masked with zeros
            for values past the respective sequence lengths.
            normalize: Python boolean.  Whether to normalize the energy term.
            probability_fn: (optional) A `callable`.  Converts the score to
            probabilities.  The default is @{tf.nn.softmax}. Other options include
            @{tf.contrib.seq2seq.hardmax} and @{tf.contrib.sparsemax.sparsemax}.
            Its signature should be: `probabilities = probability_fn(score)`.
            score_mask_value: (optional): The mask value for score before passing into
            `probability_fn`. The default is -inf. Only used if
            `memory_sequence_length` is not None.
            name: Name to use when creating ops.
            coverage: whether use coverage mode
        """
        super(PointerGeneratorBahdanauAttention, self).__init__(
            num_units=num_units,
            memory=memory,
            memory_sequence_length=memory_sequence_length,
            normalize=normalize,
            probability_fn=probability_fn,
            score_mask_value=score_mask_value,
            name=name)
        self.coverage = coverage

    def __call__(self, query, state):
        """Score the query based on the keys and values.
        Args:
            query: Tensor of dtype matching `self.values` and shape
            `[batch_size, query_depth]`.
            previous_alignments: Tensor of dtype matching `self.values` and shape
            `[batch_size, alignments_size]`
            (`alignments_size` is memory's `max_time`).
        Returns:
            alignments: Tensor of dtype matching `self.values` and shape
            `[batch_size, alignments_size]` (`alignments_size` is memory's
            `max_time`).
        """
        with tf.variable_scope(None, "pointer_generator_bahdanau_attention", [query]):
            processed_query = self.query_layer(query) if self.query_layer else query
            score = _pg_bahdanau_score(processed_query, self._keys)

        alignments = self._probability_fn(score, state)
        next_state = alignments

        return alignments, next_state


def compute_loss(logits, targets, masks, batch_size, max_target_sequence_length):
    i1, i2 = tf.meshgrid(tf.range(batch_size),
                         tf.range(max_target_sequence_length), indexing="ij")
    indices = tf.stack((i1, i2, targets), axis=2)
    probs = tf.gather_nd(logits, indices)

    probs = tf.where(tf.less_equal(probs, 0), tf.ones_like(probs) * 1e-10, probs)
    crossent = -tf.log(probs)
    cost = tf.reduce_sum(crossent * masks) / tf.to_float(batch_size)

    return cost
