import random

import constants
from utils import file_utils


def write_data_to_files(prefix,
                        train_source_seqs,
                        train_target_seqs,
                        eval_source_seqs,
                        eval_target_seqs,
                        is_append=False):
    file_utils.write_string_to_file(prefix + constants.SUFFIX_TRAIN_SOURCE, train_source_seqs, is_append)
    file_utils.write_string_to_file(prefix + constants.SUFFIX_TRAIN_TARGET, train_target_seqs, is_append)
    file_utils.write_string_to_file(prefix + constants.SUFFIX_EVAL_SOURCE, eval_source_seqs, is_append)
    file_utils.write_string_to_file(prefix + constants.SUFFIX_EVAL_TARGET, eval_target_seqs, is_append)


def limit_commits_length():
    prefix = 'static/data/text_commits/text_commits'
    eval_source_seqs = file_utils.read_file_to_string(prefix + constants.SUFFIX_EVAL_SOURCE).split('\n')
    train_source_seqs = file_utils.read_file_to_string(prefix + constants.SUFFIX_TRAIN_SOURCE).split('\n')
    eval_target_seqs = file_utils.read_file_to_string(prefix + constants.SUFFIX_EVAL_TARGET).split('\n')
    train_target_seqs = file_utils.read_file_to_string(prefix + constants.SUFFIX_TRAIN_TARGET).split('\n')

    source_seqs = train_source_seqs + eval_source_seqs
    target_seqs = train_target_seqs + eval_target_seqs

    train_limit_len_source_seqs = list()
    eval_limit_len_source_seqs = list()
    train_limit_len_target_seqs = list()
    eval_limit_len_target_seqs = list()

    source_vocab_set = dict()
    target_vocab_set = dict()

    source_seqs_with_java_count = 0

    for idx, source_seq in enumerate(source_seqs):
        source_words = source_seq.split(' ')
        target_seq = target_seqs[idx]
        target_words = target_seq.split(' ')
        if len(source_words) <= constants.SOURCE_SEQ_LEN_MAX and len(target_words) <= constants.TARGET_SEQ_LEN_MAX:
            if source_seq.find(constants.SPECIAL_WORD_JAVA_FILE_START) != -1:
                source_seqs_with_java_count += 1

            if random.random() < constants.EVAL_RATE:
                eval_limit_len_source_seqs.append(source_seq)
                eval_limit_len_target_seqs.append(target_seq)
            else:
                train_limit_len_source_seqs.append(source_seq)
                train_limit_len_target_seqs.append(target_seq)

            for source_word in source_words:
                if source_word not in source_vocab_set:
                    source_vocab_set[source_word] = 1
                else:
                    source_vocab_set[source_word] += 1
            for target_word in target_words:
                if target_word not in target_vocab_set:
                    target_vocab_set[target_word] = 1
                else:
                    target_vocab_set[target_word] += 1

    print('len(train_limit_len_source_seqs)', len(train_limit_len_source_seqs))
    print('source_seqs_with_java_count', source_seqs_with_java_count)

    prefix = 'static/data/text_limit_len_commits/text_limit_len_commits'
    write_data_to_files(prefix,
                        '\n'.join(train_limit_len_source_seqs),
                        '\n'.join(train_limit_len_target_seqs),
                        '\n'.join(eval_limit_len_source_seqs),
                        '\n'.join(eval_limit_len_target_seqs))


if __name__ == '__main__':
    limit_commits_length()
