import tensorflow as tf

import constants
from data import Data
from model import HyperParameters, MultiSeq2seqModel
from utils import word_utils
from utils.shell_args import SHELL_ARGS


class Predict(object):
    def __init__(self, model_filename):
        self.model_filename = model_filename

        self.data = Data(SHELL_ARGS.prefix, only_vocab_set=True)
        self.hyper_params = HyperParameters(embedding_size=constants.EMBEDDING_SIZE,
                                            rnn_size=constants.RNN_SIZE,
                                            num_layers=constants.NUM_LAYERS,
                                            batch_size=constants.BATCH_SIZE,
                                            learning_rate=constants.LEARNING_RATE,
                                            clip_norm=constants.CLIP_NORM,
                                            epochs_sum=constants.EPOCHS_SUM,
                                            drop_out=constants.DROP_OUT,
                                            optimizer=SHELL_ARGS.optimizer)

        self.seq2seq_model = MultiSeq2seqModel(self.hyper_params, self.data, predict=True)

        self.sess = tf.Session(graph=self.seq2seq_model.graph)
        with self.seq2seq_model.graph.as_default():
            self.seq2seq_model.load(self.sess, self.model_filename)

    def predict(self, input_word_list):
        if SHELL_ARGS.reverse:
            input_word_list.reverse()

        source_unk_id = self.data.source_word2id[constants.SpecialWord.UNK]
        input_seq = [self.data.source_word2id.get(word, source_unk_id) for word in input_word_list]

        target_vocab_size = self.data.target_vocab_size()
        source_oov_words = []
        source_extend_tokens = []

        for word in input_word_list:
            if word not in self.data.target_word2id:
                if word not in source_oov_words:
                    source_oov_words.append(word)
                source_extend_tokens.append(target_vocab_size + source_oov_words.index(word))
            else:
                idx = self.data.target_word2id[word]
                source_extend_tokens.append(idx)

        # print('source_extend_tokens', source_extend_tokens)
        # print('input_seq', input_seq)
        # print('source_oov_words', source_oov_words)

        logits = self.seq2seq_model.infer(self.sess, input_seq, source_extend_tokens, len(source_oov_words))

        predict_seq = []
        for idx in logits:
            if idx not in self.data.target_id2word:
                predict_seq.append(source_oov_words[idx - target_vocab_size])
            else:
                predict_seq.append(self.data.target_id2word[idx])

        return predict_seq


if __name__ == '__main__':
    global_step = SHELL_ARGS.global_step
    predict = Predict('%s-%s' % (constants.MODEL_FILENAME, global_step))

    source_seq = input()
    while source_seq != 'exit':
        words = source_seq.split(' ')
        if SHELL_ARGS.lem_words:
            lower_words = [word.lower() for word in words if word != '']
            words = word_utils.lemmatize_word_list(lower_words)
        else:
            words = [word for word in words if word != '']
        target_seq = predict.predict(words)
        print(target_seq)
        source_seq = input()
