import constants
from utils import file_utils
from utils.shell_args import SHELL_ARGS


def _update_dict_by_seqs(_dict, _seqs):
    for seq in _seqs:
        for word in seq.split(' '):
            if word not in _dict:
                _dict[word] = 1
            else:
                _dict[word] += 1


def _filter_rate_vocab(_dict, vocab_rate):
    freq_count = 0
    for _, count in _dict.items():
        freq_count += count

    items = sorted(_dict.items(), key=lambda item: -item[1])

    words = list()
    cur_freq_count = 0
    for idx, (word, count) in enumerate(items):
        cur_freq_count += count
        if cur_freq_count / freq_count > vocab_rate:
            print('words_total', idx)
            break
        words.append(word)
    return words


def get_vocab():
    prefix = SHELL_ARGS.prefix

    train_source_seqs = file_utils.read_file_to_string(prefix + constants.SUFFIX_TRAIN_SOURCE).split('\n')
    train_target_seqs = file_utils.read_file_to_string(prefix + constants.SUFFIX_TRAIN_TARGET).split('\n')

    eval_source_seqs = file_utils.read_file_to_string(prefix + constants.SUFFIX_EVAL_SOURCE).split('\n')
    eval_target_seqs = file_utils.read_file_to_string(prefix + constants.SUFFIX_EVAL_TARGET).split('\n')

    source_word_dict = dict()
    _update_dict_by_seqs(source_word_dict, train_source_seqs)
    _update_dict_by_seqs(source_word_dict, eval_source_seqs)

    target_word_dict = dict()
    _update_dict_by_seqs(target_word_dict, train_target_seqs)
    _update_dict_by_seqs(target_word_dict, eval_target_seqs)

    source_words = _filter_rate_vocab(source_word_dict, SHELL_ARGS.source_vocab_rate)
    target_words = _filter_rate_vocab(target_word_dict, SHELL_ARGS.target_vocab_rate)
    file_utils.write_string_to_file(prefix + constants.SUFFIX_VOCAB_SOURCE, '\n'.join(source_words))
    file_utils.write_string_to_file(prefix + constants.SUFFIX_VOCAB_TARGET, '\n'.join(target_words))

    # word json
    # source_words_count = len(source_words)
    source_word_json = {word: idx + 2 for idx, word in enumerate(source_words)}
    source_word_json['eos'] = 0
    source_word_json['UNK'] = 1
    file_utils.write_json_to_file(prefix + constants.SUFFIX_VOCAB_SOURCE + '.json', source_word_json)

    # target_words_count = len(target_words)
    target_word_json = {word: idx + 2 for idx, word in enumerate(target_words)}
    target_word_json['eos'] = 0
    target_word_json['UNK'] = 1
    file_utils.write_json_to_file(prefix + constants.SUFFIX_VOCAB_TARGET + '.json', target_word_json)


if __name__ == '__main__':
    get_vocab()
