import tensorflow as tf
from tensorflow.python.layers.core import Dense

import constants
from data import BatchData, Data
from pointer_generator_helper import PointerGeneratorBahdanauAttention, PointerGeneratorDecoder, compute_loss
from utils.shell_args import SHELL_ARGS


class HyperParameters(object):
    def __init__(self,
                 embedding_size,
                 rnn_size,
                 num_layers,
                 batch_size,
                 clip_norm,
                 learning_rate,
                 drop_out,
                 epochs_sum,
                 optimizer):
        self.embedding_size = embedding_size
        self.rnn_size = rnn_size
        self.num_layers = num_layers
        self.batch_size = batch_size
        self.clip_norm = clip_norm
        self.drop_out = drop_out

        self.learning_rate = learning_rate
        self.epochs_sum = epochs_sum

        self.optimizer = optimizer


class MultiSeq2seqModel(object):
    def __init__(self, hyper_params: HyperParameters, data: Data, predict=False):
        self.hyper_params = hyper_params
        self.data = data

        # #input tf.variable for general
        # [batch_size, seq_length]
        self.source_seqs = None
        self.target_seqs = None
        self.learning_rate = None
        self.drop_out = None
        self.source_sequence_length = None
        self.target_sequence_length = None
        self.max_target_sequence_length = None
        self.source_extend_tokens = None
        self.source_oov_words = None

        # tf.variable for train
        self.cost = None
        self.global_step = None
        self.train_op = None

        # tf.variable for infer
        self.logits = None
        self.predictions = None

        # summary
        self.train_summary_op = None
        self.train_summary_writer = None

        self.eval_summary_op = None
        self.eval_summary_writer = None

        self.train_average_summary_op = None
        self.train_average_summary_writer = None

        self.eval_average_summary_op = None
        self.eval_average_summary_writer = None

        self.eval_bleu_summary_op = None
        self.eval_bleu_summary_writer = None

        self.lr_summary_op = None
        self.lr_summary_writer = None

        self.average_loss = None
        self.eval_bleu = None

        # graph
        self.graph = tf.Graph()

        if not predict:
            self.build_graph()

            with self.graph.as_default():
                self.saver = tf.train.Saver()

    def _set_inputs(self):
        self.source_seqs = tf.placeholder(tf.int32, [None, None], name='source_seqs')
        self.target_seqs = tf.placeholder(tf.int32, [None, None], name='target_seqs')

        self.source_sequence_length = tf.placeholder(
            tf.int32, (None,), name='source_sequence_length')

        self.target_sequence_length = tf.placeholder(
            tf.int32, (None,), name='target_sequence_length')
        self.max_target_sequence_length = tf.reduce_max(
            self.target_sequence_length, name='max_target_sequence_length')

        self.learning_rate = tf.placeholder(tf.float32, name='learning_rate')
        self.drop_out = tf.placeholder(tf.float32, name='drop_out')

        self.source_oov_words = tf.placeholder(tf.int32, shape=[], name='source_oov_words')
        self.source_extend_tokens = tf.placeholder(tf.int32, shape=[None, None], name='source_extend_tokens')

    def _get_attention_mechanism(self, rnn_size, encoder_output, memory_sequence_length):
        if SHELL_ARGS.train_type == 'pointer':
            return PointerGeneratorBahdanauAttention(rnn_size, encoder_output, memory_sequence_length)
        else:
            if SHELL_ARGS.attention_option == 'luong':
                return tf.contrib.seq2seq.LuongAttention(
                    rnn_size, encoder_output, memory_sequence_length)
            elif SHELL_ARGS.attention_option == 'scaled_luong':
                return tf.contrib.seq2seq.LuongAttention(
                    rnn_size, encoder_output, memory_sequence_length, scale=True)
            elif SHELL_ARGS.attention_option == 'bahdanau':
                return tf.contrib.seq2seq.BahdanauAttention(
                    rnn_size, encoder_output, memory_sequence_length)
            elif SHELL_ARGS.attention_option == 'normed_bahdanau':
                return tf.contrib.seq2seq.BahdanauAttention(
                    rnn_size, encoder_output, memory_sequence_length, normalize=True)
            else:
                raise ValueError('Unkown attention option %s' % SHELL_ARGS.attention_option)

    def _get_decoder_net(self, cell, helper, initial_state, output_layer):
        if SHELL_ARGS.train_type == 'pointer':
            return PointerGeneratorDecoder(
                cell=cell,
                helper=helper,
                initial_state=initial_state,
                output_layer=output_layer,
                source_extend_tokens=self.source_extend_tokens,
                source_oov_words=self.source_oov_words)
        else:
            return tf.contrib.seq2seq.BasicDecoder(
                cell=cell,
                helper=helper,
                initial_state=initial_state,
                output_layer=output_layer)

    def _get_cell(self, rnn_size, is_last=False):
        if SHELL_ARGS.train_type == 'pointer':
            cell = tf.contrib.rnn.GRUCell(
                rnn_size,
                kernel_initializer=tf.contrib.layers.xavier_initializer(),
                bias_initializer=tf.contrib.layers.xavier_initializer())
        else:
            cell = tf.contrib.rnn.LSTMCell(
                rnn_size,
                initializer=tf.contrib.layers.xavier_initializer(),
                forget_bias=0.2)
        if is_last:
            return tf.contrib.rnn.DropoutWrapper(
                cell=cell, input_keep_prob=(1.0 - self.drop_out))
        return cell

    def _get_optimizer(self):
        if self.hyper_params.optimizer == 'adam':
            return tf.train.AdamOptimizer(self.learning_rate)
        elif self.hyper_params.optimizer == 'sgd':
            return tf.train.GradientDescentOptimizer(self.learning_rate)
        elif self.hyper_params.optimizer == 'adadelta':
            return tf.train.AdadeltaOptimizer(self.learning_rate)
        raise ValueError('--optimizer must be <adam>, <sgd>')

    def _get_loss(self, logits, target_seqs, masks):
        if SHELL_ARGS.train_type == 'pointer':
            targets = tf.slice(target_seqs, [0, 0], [-1, self.max_target_sequence_length], 'targets')
            return compute_loss(logits=logits,
                                targets=targets,
                                masks=masks,
                                batch_size=self.hyper_params.batch_size,
                                max_target_sequence_length=self.max_target_sequence_length)
        else:
            return tf.contrib.seq2seq.sequence_loss(logits,
                                                    target_seqs,
                                                    masks,
                                                    softmax_loss_function=tf.nn.sparse_softmax_cross_entropy_with_logits)

    def _get_initializer(self, seed=3, init_weight=0.1):
        """Create an initializer. init_weight is only for uniform."""
        if SHELL_ARGS.init_op == "uniform":
            assert init_weight
            return tf.random_uniform_initializer(
                -init_weight, init_weight, seed=seed)
        elif SHELL_ARGS.init_op == "glorot_normal":
            return tf.keras.initializers.glorot_normal(
                seed=seed)
        elif SHELL_ARGS.init_op == "glorot_uniform":
            return tf.keras.initializers.glorot_uniform(
                seed=seed)
        else:
            raise ValueError("Unknown init_op %s" % SHELL_ARGS.init_op)

    def _get_encoder(self, encoder_input_seqs):
        encoder_embedding_input = tf.contrib.layers.embed_sequence(
            encoder_input_seqs, self.data.source_vocab_size(), self.hyper_params.embedding_size)

        multi_cell = tf.contrib.rnn.MultiRNNCell(
            [self._get_cell(self.hyper_params.rnn_size, True)
             for _ in range(self.hyper_params.num_layers)])

        encoder_output, encoder_state = tf.nn.dynamic_rnn(
            multi_cell, encoder_embedding_input, sequence_length=self.source_sequence_length, dtype=tf.float32)

        return encoder_output, encoder_state

    def _get_decoder(self, encoder_output, encoder_state, decoder_input_seqs):
        if SHELL_ARGS.train_type == 'pointer':
            condition = tf.less(decoder_input_seqs, self.data.target_vocab_size())
            target_unk_id = self.data.target_word2id[constants.SpecialWord.UNK]
            decoder_inputs = tf.where(condition, decoder_input_seqs, tf.ones_like(decoder_input_seqs) * target_unk_id)
        else:
            decoder_inputs = decoder_input_seqs

        decoder_embeddings = tf.Variable(
            tf.random_uniform([self.data.target_vocab_size(), self.hyper_params.embedding_size]))
        decoder_embedding_input = tf.nn.embedding_lookup(decoder_embeddings, decoder_inputs)

        output_layer = Dense(self.data.target_vocab_size(), use_bias=False)
        cell_list = [self._get_cell(self.hyper_params.rnn_size, True)
                     for _ in range(self.hyper_params.num_layers)]
        multi_cell = tf.contrib.rnn.MultiRNNCell(cell_list)

        attention_mechanism = self._get_attention_mechanism(
            self.hyper_params.rnn_size, encoder_output, memory_sequence_length=self.source_sequence_length)
        attention_multi_cell = tf.contrib.seq2seq.AttentionWrapper(
            multi_cell,
            attention_mechanism,
            attention_layer_size=self.hyper_params.rnn_size / 2,
            alignment_history=True)

        initial_state = attention_multi_cell.zero_state(
            batch_size=self.hyper_params.batch_size, dtype=tf.float32).clone(cell_state=encoder_state)

        with tf.variable_scope('decode'):
            training_helper = tf.contrib.seq2seq.TrainingHelper(
                inputs=decoder_embedding_input,
                sequence_length=self.target_sequence_length,
                time_major=False)

            training_decoder = self._get_decoder_net(
                cell=attention_multi_cell,
                helper=training_helper,
                initial_state=initial_state,
                output_layer=output_layer)
            training_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
                training_decoder, impute_finished=True, maximum_iterations=self.max_target_sequence_length)

        with tf.variable_scope('decode', reuse=True):
            start_tokens = tf.tile(
                tf.constant([self.data.target_word2id[constants.SpecialWord.START]], dtype=tf.int32),
                [self.hyper_params.batch_size],
                name='start_tokens')
            predicting_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
                decoder_embeddings,
                start_tokens,
                self.data.target_word2id[constants.SpecialWord.EOS])

            predicting_decoder = self._get_decoder_net(
                cell=attention_multi_cell,
                helper=predicting_helper,
                initial_state=initial_state,
                output_layer=output_layer)
            predicting_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
                predicting_decoder, impute_finished=True, maximum_iterations=self.max_target_sequence_length)

        return training_decoder_output, predicting_decoder_output

    def _get_processed_target_seqs(self):
        """Insert <START> to the beginning of each sentence."""
        ending = tf.strided_slice(self.target_seqs, [0, 0], [self.hyper_params.batch_size, -1], [1, 1])
        target_seqs = tf.concat(
            [tf.fill([self.hyper_params.batch_size, 1], self.data.target_word2id[constants.SpecialWord.START]),
             ending],
            1)
        return target_seqs

    def _get_model(self):
        encoder_output, encoder_state = self._get_encoder(self.source_seqs)
        target_seqs = self._get_processed_target_seqs()
        training_decoder_output, predicting_decoder_output = self._get_decoder(
            encoder_output, encoder_state, target_seqs)

        return training_decoder_output, predicting_decoder_output

    def _build_summary(self):
        # loss_summary
        loss_summary = tf.summary.scalar('loss', self.cost)

        self.train_summary_op = tf.summary.merge([loss_summary])
        self.train_summary_writer = tf.summary.FileWriter(constants.SUMMARY_TRAIN_FILENAME, self.graph)

        self.eval_summary_op = tf.summary.merge([loss_summary])
        self.eval_summary_writer = tf.summary.FileWriter(constants.SUMMARY_EVAL_FILENAME, self.graph)

        # average_loss_summary
        self.average_loss = tf.placeholder(tf.float32)
        average_loss_summary = tf.summary.scalar('average_loss', self.average_loss)

        self.train_average_summary_op = tf.summary.merge([average_loss_summary])
        self.train_average_summary_writer = tf.summary.FileWriter(
            constants.SUMMARY_AVERAGE_TRAIN_FILENAME, self.graph)

        self.eval_average_summary_op = tf.summary.merge([average_loss_summary])
        self.eval_average_summary_writer = tf.summary.FileWriter(
            constants.SUMMARY_AVERAGE_EVAL_FILENAME, self.graph)

        # learning_rate_summary
        learning_rate_summary = tf.summary.scalar('learning_rate', self.learning_rate)

        self.lr_summary_op = tf.summary.merge([learning_rate_summary])
        self.lr_summary_writer = tf.summary.FileWriter(constants.SUMMARY_LR_FILENAME, self.graph)

        # eval bleu summary
        self.eval_bleu = tf.placeholder(tf.float32)
        eval_bleu_summary = tf.summary.scalar('eval_bleu', self.eval_bleu)

        self.eval_bleu_summary_op = tf.summary.merge([eval_bleu_summary])
        self.eval_bleu_summary_writer = tf.summary.FileWriter(constants.SUMMARY_EVAL_BLEU_FILENAME, self.graph)

    def build_graph(self):
        with self.graph.as_default():
            self._set_inputs()

            initializer = self._get_initializer()
            tf.get_variable_scope().set_initializer(initializer)

            training_decoder_output, predicting_decoder_output = self._get_model()

            self.logits = tf.identity(training_decoder_output.rnn_output, name='logits')
            self.predictions = tf.identity(predicting_decoder_output.sample_id, name='predictions')

            with tf.name_scope("optimization"):
                masks = tf.sequence_mask(self.target_sequence_length,
                                         self.max_target_sequence_length,
                                         dtype=tf.float32,
                                         name='masks')

                self.cost = self._get_loss(logits=self.logits,
                                           target_seqs=self.target_seqs,
                                           masks=masks)

                self.global_step = tf.Variable(0, name='global_step', trainable=False)

                optimizer = self._get_optimizer()
                params = tf.trainable_variables()
                gradients = tf.gradients(self.cost, params)
                capped_gradients, _ = tf.clip_by_global_norm(gradients, clip_norm=self.hyper_params.clip_norm)
                self.train_op = optimizer.apply_gradients(zip(capped_gradients, params), global_step=self.global_step)

                self._build_summary()

        return self.cost, self.train_op

    def train(self, sess, batch_data: BatchData, learning_rate):
        if not learning_rate:
            learning_rate = self.hyper_params.learning_rate
        _, step, summaries, loss = sess.run(
            [self.train_op, self.global_step, self.train_summary_op, self.cost],
            {self.source_seqs: batch_data.source_batch,
             self.target_seqs: batch_data.target_batch,
             self.learning_rate: learning_rate,
             self.target_sequence_length: batch_data.target_lengths,
             self.source_sequence_length: batch_data.source_lengths,
             self.source_extend_tokens: batch_data.source_extend_tokens,
             self.source_oov_words: batch_data.source_oov_words_length,
             self.drop_out: self.hyper_params.drop_out})
        self.train_summary_writer.add_summary(summaries, step)
        return loss

    def eval(self, sess, batch_data: BatchData):
        step, summaries, loss = sess.run(
            [self.global_step, self.eval_summary_op, self.cost],
            {self.source_seqs: batch_data.source_batch,
             self.target_seqs: batch_data.target_batch,
             self.learning_rate: self.hyper_params.learning_rate,
             self.target_sequence_length: batch_data.target_lengths,
             self.source_sequence_length: batch_data.source_lengths,
             self.source_extend_tokens: batch_data.source_extend_tokens,
             self.source_oov_words: batch_data.source_oov_words_length,
             self.drop_out: 0})
        self.eval_summary_writer.add_summary(summaries, step)
        return loss

    def infer(self, sess, input_seq, source_extend_tokens, source_oov_words_size):
        words_total = len(input_seq)

        predictions = sess.run(
            self.predictions,
            {self.source_seqs: [input_seq] * self.hyper_params.batch_size,
             self.target_sequence_length: [words_total] * self.hyper_params.batch_size,
             self.source_sequence_length: [words_total] * self.hyper_params.batch_size,
             self.source_extend_tokens: [source_extend_tokens] * self.hyper_params.batch_size,
             self.source_oov_words: source_oov_words_size,
             self.drop_out: 0})
        return predictions[0]

    def infer_batch(self, sess, batch_data: BatchData):
        predictions = sess.run(
            self.predictions,
            {self.source_seqs: batch_data.source_batch,
             self.target_sequence_length: batch_data.target_lengths,
             self.source_sequence_length: batch_data.source_lengths,
             self.source_extend_tokens: batch_data.source_extend_tokens,
             self.source_oov_words: batch_data.source_oov_words_length,
             self.drop_out: 0})
        return predictions

    def train_average(self, sess, average_loss, step):
        summary = sess.run(self.train_average_summary_op, {self.average_loss: average_loss})
        self.train_average_summary_writer.add_summary(summary, step)

    def eval_average(self, sess, average_loss, step):
        summary = sess.run(self.eval_average_summary_op, {self.average_loss: average_loss})
        self.eval_average_summary_writer.add_summary(summary, step)

    def summary_learning_rate(self, sess, learning_rate, step):
        summary = sess.run(self.lr_summary_op, {self.learning_rate: learning_rate})
        self.lr_summary_writer.add_summary(summary, step)

    def summary_eval_bleu(self, sess, eval_bleu, step):
        summary = sess.run(self.eval_bleu_summary_op, {self.eval_bleu: eval_bleu})
        self.eval_bleu_summary_writer.add_summary(summary, step)

    def save(self, sess, model_filename, global_step=None):
        self.saver.save(sess, model_filename, global_step=global_step)

    def load(self, sess, model_filename):
        loader = tf.train.import_meta_graph(model_filename + '.meta')
        loader.restore(sess, model_filename)

        self.predictions = self.graph.get_tensor_by_name('predictions:0')

        self.source_seqs = self.graph.get_tensor_by_name('source_seqs:0')
        self.source_sequence_length = self.graph.get_tensor_by_name('source_sequence_length:0')
        self.target_sequence_length = self.graph.get_tensor_by_name('target_sequence_length:0')

        self.source_extend_tokens = self.graph.get_tensor_by_name('source_extend_tokens:0')
        self.source_oov_words = self.graph.get_tensor_by_name('source_oov_words:0')
        self.drop_out = self.graph.get_tensor_by_name('drop_out:0')
