import random

import constants
from utils import file_utils
from utils.shell_args import SHELL_ARGS


def split_train_eval_test(prefix):
    eval_source_seqs = file_utils.read_file_to_string(prefix + constants.SUFFIX_EVAL_SOURCE).split('\n')
    train_source_seqs = file_utils.read_file_to_string(prefix + constants.SUFFIX_TRAIN_SOURCE).split('\n')
    eval_target_seqs = file_utils.read_file_to_string(prefix + constants.SUFFIX_EVAL_TARGET).split('\n')
    train_target_seqs = file_utils.read_file_to_string(prefix + constants.SUFFIX_TRAIN_TARGET).split('\n')

    source_seqs = train_source_seqs + eval_source_seqs
    target_seqs = train_target_seqs + eval_target_seqs

    train_source_seqs = list()
    train_target_seqs = list()
    eval_source_seqs = list()
    eval_target_seqs = list()
    test_source_seqs = list()
    test_target_seqs = list()

    for idx, source_seq in enumerate(source_seqs):
        target_seq = target_seqs[idx]
        random_prob = random.random()
        if random_prob < constants.TEST_RATE:
            test_source_seqs.append(source_seq)
            test_target_seqs.append(target_seq)
        elif random_prob < constants.EVAL_RATE:
            eval_source_seqs.append(source_seq)
            eval_target_seqs.append(target_seq)
        else:
            train_source_seqs.append(source_seq)
            train_target_seqs.append(target_seq)

    file_utils.write_string_to_file(prefix + constants.SUFFIX_TEST_SOURCE, '\n'.join(test_source_seqs))
    file_utils.write_string_to_file(prefix + constants.SUFFIX_TEST_TARGET, '\n'.join(test_target_seqs))
    file_utils.write_string_to_file(prefix + constants.SUFFIX_EVAL_SOURCE, '\n'.join(eval_source_seqs))
    file_utils.write_string_to_file(prefix + constants.SUFFIX_EVAL_TARGET, '\n'.join(eval_target_seqs))
    file_utils.write_string_to_file(prefix + constants.SUFFIX_TRAIN_SOURCE, '\n'.join(train_source_seqs))
    file_utils.write_string_to_file(prefix + constants.SUFFIX_TRAIN_TARGET, '\n'.join(train_target_seqs))


if __name__ == '__main__':
    split_train_eval_test(SHELL_ARGS.prefix)
