import constants
from utils import file_utils
from utils.shell_args import SHELL_ARGS


class Data(object):
    def __init__(self, filename_prefix, only_vocab_set=False):
        self.filename_prefix = filename_prefix

        self.source_id2word, self.source_word2id = self._get_vocab_set_from_file(
            self.filename_prefix + constants.SUFFIX_VOCAB_SOURCE)
        self.target_id2word, self.target_word2id = self._get_vocab_set_from_file(
            self.filename_prefix + constants.SUFFIX_VOCAB_TARGET)

        if only_vocab_set:
            return

        self.train_source_seqs, \
            self.train_target_seqs, \
            self.eval_source_seqs, \
            self.eval_target_seqs = self._get_data_from_file()

        # self.train_source_seqs_ids = [[self.source_word2id.get(word, source_unk_id) for word in seq.split(' ')]
        #                               for seq in self.train_source_seqs]
        self.train_source_seqs_ids = []
        self.train_source_extend_tokens = []
        self.train_source_oov_words = []
        for seq in self.train_source_seqs:
            ids, extend_tokens, oov_words = self._map_source_words_to_ids(seq.split(' '))
            self.train_source_seqs_ids.append(ids)
            self.train_source_extend_tokens.append(extend_tokens)
            self.train_source_oov_words.append(oov_words)

        self.train_target_seqs_ids = [self._map_target_words_to_ids(seq.split(' '), self.train_source_oov_words[idx])
                                      for idx, seq in enumerate(self.train_target_seqs)]

        # self.eval_source_seqs_ids = [[self.source_word2id.get(word, source_unk_id) for word in seq.split(' ')]
        #                              for seq in self.eval_source_seqs]
        self.eval_source_seqs_ids = []
        self.eval_source_extend_tokens = []
        self.eval_source_oov_words = []
        for seq in self.eval_source_seqs:
            ids, extend_tokens, oov_words = self._map_source_words_to_ids(seq.split(' '))
            self.eval_source_seqs_ids.append(ids)
            self.eval_source_extend_tokens.append(extend_tokens)
            self.eval_source_oov_words.append(oov_words)

        self.eval_target_seqs_ids = [self._map_target_words_to_ids(seq.split(' '), self.eval_source_oov_words[idx])
                                     for idx, seq in enumerate(self.eval_target_seqs)]

        print('train_source_seqs', self.train_source_seqs[: 10])
        print('train_source_seqs_ids', self.train_source_seqs_ids[: 10])
        print('train_source_extend_tokens', self.train_source_extend_tokens[: 10])
        print('train_source_oov_words', self.train_source_oov_words[: 10])
        print('=====================')
        print('train_target_seqs_ids', self.train_target_seqs_ids[: 10])
        print('train_target_seqs', self.train_target_seqs[: 10])
        print('REVERSE', SHELL_ARGS.reverse)

    def _map_source_words_to_ids(self, words):
        if SHELL_ARGS.reverse:
            words.reverse()

        source_unk_id = self.source_word2id[constants.SpecialWord.UNK]
        ids = [self.source_word2id.get(word, source_unk_id) for word in words]

        target_vocab_size = len(self.target_word2id)
        extend_tokens = []
        oov_words = []

        for word in words:
            if word not in self.target_word2id:
                if word not in oov_words:
                    oov_words.append(word)
                extend_tokens.append(target_vocab_size + oov_words.index(word))
            else:
                extend_tokens.append(self.target_word2id[word])

        return ids, extend_tokens, oov_words

    def _map_target_words_to_ids(self, words, source_oov_words):
        target_unk_id = self.target_word2id[constants.SpecialWord.UNK]
        target_eos_id = self.target_word2id[constants.SpecialWord.EOS]

        if SHELL_ARGS.train_type != 'pointer':
            return [self.target_word2id.get(word, target_unk_id) for word in words] + [target_eos_id]

        target_vocab_size = len(self.target_word2id)

        ids = []

        for word in words:
            if word not in self.target_word2id:
                if word not in source_oov_words:
                    ids.append(target_unk_id)
                else:
                    ids.append(target_vocab_size + source_oov_words.index(word))
            else:
                ids.append(self.target_word2id[word])

        ids.append(target_eos_id)

        return ids

    def _get_data_from_file(self):
        train_source_data = file_utils.read_file_to_string(self.filename_prefix + constants.SUFFIX_TRAIN_SOURCE)
        train_target_data = file_utils.read_file_to_string(self.filename_prefix + constants.SUFFIX_TRAIN_TARGET)

        eval_source_data = file_utils.read_file_to_string(self.filename_prefix + constants.SUFFIX_EVAL_SOURCE)
        eval_target_data = file_utils.read_file_to_string(self.filename_prefix + constants.SUFFIX_EVAL_TARGET)

        train_source_data = train_source_data.split('\n')
        train_target_data = train_target_data.split('\n')
        eval_source_data = eval_source_data.split('\n')
        eval_target_data = eval_target_data.split('\n')

        return train_source_data, train_target_data, eval_source_data, eval_target_data

    def _get_vocab_set_from_file(self, filename):
        words = file_utils.read_file_to_string(filename).split('\n')
        special_words = [constants.SpecialWord.START,
                         constants.SpecialWord.EOS,
                         constants.SpecialWord.PAD,
                         constants.SpecialWord.UNK]

        id2word = dict()
        word2id = dict()
        for idx, word in enumerate(words):
            id2word[idx] = word
            word2id[word] = idx

        words_sum = len(words)
        for idx, special_word in enumerate(special_words):
            real_idx = idx + words_sum
            id2word[real_idx] = special_word
            word2id[special_word] = real_idx

        return id2word, word2id

    def map_source_words2ids(self, words):
        sequence_length = 7
        source_unk_id = self.source_word2id[constants.SpecialWord.UNK]
        source_pad_id = self.source_word2id[constants.SpecialWord.PAD]

        id_list = [self.source_word2id.get(word, source_unk_id) for word in words]
        pad_list = [source_pad_id] * (sequence_length - len(words))

        return id_list + pad_list

    def source_vocab_size(self):
        return len(self.source_id2word)

    def target_vocab_size(self):
        return len(self.target_id2word)
