import numpy as np
import tensorflow as tf

import constants
from data import Data, BatchData
from model import HyperParameters, MultiSeq2seqModel
from postprocess import bleu
from utils.shell_args import SHELL_ARGS


def pad_seqs(seqs, pad_id):
    max_seq_len = max([len(seq) for seq in seqs])
    return [seq + [pad_id] * (max_seq_len - len(seq)) for seq in seqs]


class Train(object):
    def __init__(self, data: Data, hyper_params: HyperParameters, shell_args):
        self.hyper_params = hyper_params
        self.data = data
        self.shell_args = shell_args

        self.seq2seq_model = MultiSeq2seqModel(hyper_params, data)

    def get_batches(self, targets, sources, source_extend_tokens, source_oov_words):
        """定义生成器，用来获取batch."""

        # 打乱数据的顺序
        print('Reshuffle train data')
        index_list = np.arange(targets.shape[0])
        np.random.shuffle(index_list)
        targets = targets[index_list]
        sources = sources[index_list]
        source_extend_tokens = source_extend_tokens[index_list]
        source_oov_words = source_oov_words[index_list]

        for batch_i in range(0, len(sources) // self.hyper_params.batch_size):
            start_id = batch_i * self.hyper_params.batch_size
            end_i = start_id + self.hyper_params.batch_size
            sources_batch = sources[start_id: end_i]
            targets_batch = targets[start_id: end_i]

            source_extend_tokens_batch = source_extend_tokens[start_id: end_i]
            source_oov_words_batch = source_oov_words[start_id: end_i]

            # 补全序列
            source_pad_id = self.data.source_word2id[constants.SpecialWord.PAD]
            target_pad_id = self.data.target_word2id[constants.SpecialWord.PAD]
            pad_sources_batch = np.array(pad_seqs(sources_batch, source_pad_id))
            pad_targets_batch = np.array(pad_seqs(targets_batch, target_pad_id))

            pad_source_extend_tokens_batch = np.array(pad_seqs(source_extend_tokens_batch, source_pad_id))

            # 记录每条记录的长度
            targets_lengths = []
            for target in targets_batch:
                targets_lengths.append(len(target))

            source_lengths = []
            for source in sources_batch:
                source_lengths.append(len(source))

            yield BatchData(source_batch=pad_sources_batch,
                            target_batch=pad_targets_batch,
                            source_lengths=source_lengths,
                            target_lengths=targets_lengths,
                            source_extend_tokens=pad_source_extend_tokens_batch,
                            source_oov_words_length=max(len(oov_words) for oov_words in source_oov_words_batch),
                            source_oov_words=source_oov_words_batch)

    def ids2sentence(self, ids, source_oov_words):
        sen = []
        for idx in ids:
            if idx not in self.data.target_id2word:
                sen.append(source_oov_words[idx - self.data.target_vocab_size()])
            else:
                sen.append(self.data.target_id2word[idx])
        return ' '.join([word if word != constants.SpecialWord.UNK else '<UNK>' for word in sen
                         if word not in [constants.SpecialWord.EOS, constants.SpecialWord.PAD,
                                         constants.SpecialWord.START]])

    def train(self):
        train_source = np.asarray(self.data.train_source_seqs_ids)
        train_target = np.asarray(self.data.train_target_seqs_ids)
        train_source_extend_tokens = np.asarray(self.data.train_source_extend_tokens)
        train_source_oov_words = np.asarray(self.data.train_source_oov_words)

        eval_source = np.asarray(self.data.eval_source_seqs_ids)
        eval_target = np.asarray(self.data.eval_target_seqs_ids)
        eval_source_extend_tokens = np.asarray(self.data.eval_source_extend_tokens)
        eval_source_oov_words = np.asarray(self.data.eval_source_oov_words)

        display_step = 50

        train_batch_list_len = len(train_source) // self.hyper_params.batch_size
        eval_batch_list_len = len(eval_source) // self.hyper_params.batch_size

        best_eval_value = SHELL_ARGS.early_stop_init
        eval_value_increase_times = 0

        gpu_options = tf.GPUOptions(allow_growth=True)
        with tf.Session(graph=self.seq2seq_model.graph, config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
            sess.run(tf.global_variables_initializer())

            if self.shell_args.restore and self.shell_args.global_step:
                self.seq2seq_model.saver.restore(
                    sess, "%s-%d" % (constants.MODEL_FILENAME, self.shell_args.global_step))
            sess.graph.finalize()

            lr = self.seq2seq_model.hyper_params.learning_rate

            cur_step = 0
            last_step = 0
            epoch_i = 0 + self.shell_args.epoch_i
            max_epoch = self.hyper_params.epochs_sum + self.shell_args.epoch_i

            while epoch_i < max_epoch:
                print('Epoch %d, learning rate: %f' % (epoch_i, lr))
                self.seq2seq_model.summary_learning_rate(sess, lr, epoch_i)

                loss_sum = 0
                for batch_i, batch_data in enumerate(self.get_batches(train_target,
                                                                      train_source,
                                                                      train_source_extend_tokens,
                                                                      train_source_oov_words)):
                    cur_step = tf.train.global_step(sess, self.seq2seq_model.global_step)

                    loss = self.seq2seq_model.train(sess, batch_data, lr)
                    loss_sum += loss

                    if batch_i % display_step == 0:
                        print('Epoch {:>3}/{} Batch {:>4}/{} - Training Loss: {:>6.3f}'.format(
                            epoch_i,
                            max_epoch,
                            batch_i,
                            len(train_source) // self.hyper_params.batch_size,
                            loss))
                self.seq2seq_model.train_average(sess, loss_sum / train_batch_list_len, epoch_i)

                loss_sum = 0
                target_sens = []
                predict_sens = []
                for batch_i, batch_data in enumerate(self.get_batches(eval_target,
                                                                      eval_source,
                                                                      eval_source_extend_tokens,
                                                                      eval_source_oov_words)):
                    loss = self.seq2seq_model.eval(sess, batch_data)
                    loss_sum += loss

                    predictions = self.seq2seq_model.infer_batch(sess, batch_data)
                    for idx, predict_ids in enumerate(predictions):
                        predict_sens.append(self.ids2sentence(predict_ids, batch_data.source_oov_words[idx]))
                        target_sens.append(
                            self.ids2sentence(batch_data.target_batch[idx], batch_data.source_oov_words[idx]))

                    if batch_i % display_step == 0:
                        print('Epoch {:>3}/{} Batch {:>4}/{} - Evaluation Loss: {:>6.3f}, target: {}, predict: {}.'.format(
                            epoch_i,
                            max_epoch,
                            batch_i,
                            len(eval_source) // self.hyper_params.batch_size,
                            loss,
                            target_sens[batch_i * self.hyper_params.batch_size],
                            predict_sens[batch_i * self.hyper_params.batch_size]))

                avg_eval_loss = loss_sum / eval_batch_list_len
                self.seq2seq_model.eval_average(sess, avg_eval_loss, epoch_i)

                eval_bleu = bleu.bleu('\n'.join(target_sens), '\n'.join(predict_sens))
                self.seq2seq_model.summary_eval_bleu(sess, eval_bleu, epoch_i)

                epoch_i += 1
                if (SHELL_ARGS.early_stop_option == 'loss' and eval_bleu - best_eval_value > 1e-6) \
                        or (SHELL_ARGS.early_stop_option != 'loss' and best_eval_value - eval_bleu > 1e-6):
                    eval_value_increase_times += 1
                    print('Not best {}: {}/{}, global_step: {}, increasing time: {}'.format(SHELL_ARGS.early_stop_option,
                                                                                            eval_bleu,
                                                                                            best_eval_value,
                                                                                            cur_step,
                                                                                            eval_value_increase_times))
                    if eval_value_increase_times >= 10:
                        print('Not best {} in latest 10 epoch! Early stop!'.format(SHELL_ARGS.early_stop_option))
                        break
                    if self.hyper_params.optimizer == 'sgd':
                        lr *= 0.5
                        epoch_i -= 1
                        self.seq2seq_model.saver.restore(sess, "%s-%d" % (constants.MODEL_FILENAME, last_step))
                else:
                    print('Save model global_step: %d' % cur_step)
                    self.seq2seq_model.save(sess, constants.MODEL_FILENAME,
                                            global_step=self.seq2seq_model.global_step)
                    best_eval_value = eval_bleu
                    eval_value_increase_times = 0

                    if self.hyper_params.optimizer == 'sgd':
                        lr *= 1.05
                        last_step = cur_step + 1

            print('Save last model, global_step: %d' % cur_step)
            self.seq2seq_model.save(sess, constants.MODEL_FILENAME, global_step=self.seq2seq_model.global_step)


def main():
    data = Data(SHELL_ARGS.prefix)
    hyper_params = HyperParameters(embedding_size=constants.EMBEDDING_SIZE,
                                   rnn_size=constants.RNN_SIZE,
                                   num_layers=constants.NUM_LAYERS,
                                   batch_size=constants.BATCH_SIZE,
                                   learning_rate=constants.LEARNING_RATE,
                                   clip_norm=constants.CLIP_NORM,
                                   epochs_sum=constants.EPOCHS_SUM,
                                   drop_out=constants.DROP_OUT,
                                   optimizer=SHELL_ARGS.optimizer)
    train = Train(data, hyper_params, SHELL_ARGS)
    train.train()


if __name__ == '__main__':
    main()
