import constants
from predict import Predict
from utils import file_utils
from utils.shell_args import SHELL_ARGS


def get_predicted_targets():
    """根据prefix.test.source，通过model生成prefix.test.predicted.target.

    shell_args:
        --prefix
        --global_step

        load model needs:
        --train_type
        --optimizer
        --batch_size
    """
    global_step = SHELL_ARGS.global_step
    predict = Predict('%s-%s' % (constants.MODEL_FILENAME, global_step))

    prefix = SHELL_ARGS.prefix
    test_source_seqs = file_utils.read_file_to_string(prefix + constants.SUFFIX_TEST_SOURCE).split('\n')

    predict_targets = list()

    source_seqs_count = len(test_source_seqs)

    for idx, source_seq in enumerate(test_source_seqs):
        source_words = source_seq.split(' ')
        source_words = [word for word in source_words if word != '']
        if len(source_words) == 0:
            continue

        predict_words = predict.predict(source_words)
        predict_words = [word if word != constants.SpecialWord.UNK else '<UNK>' for word in predict_words
                         if word != constants.SpecialWord.EOS]

        predict_targets.append(' '.join(predict_words))

        if idx % 1000 == 0:
            print('%d/%d' % (idx, source_seqs_count))

    file_utils.write_string_to_file(prefix + constants.SUFFIX_TEST_PREDICTED_TARGET, '\n'.join(predict_targets))


if __name__ == '__main__':
    get_predicted_targets()
