import os
import random
import re

from pymongo import MongoClient

import constants
from preprocess.get_commit_diff_seq import get_commit_diff_seq
from preprocess.get_commit_msg import get_commit_msg
from utils.shell_args import SHELL_ARGS


def _get_repo_name(repo_item):
    items = repo_item.split('.')
    return "%s/%s" % (items[1], items[2])


def write_data_to_files(prefix,
                        train_source_seqs,
                        train_target_seqs,
                        eval_source_seqs,
                        eval_target_seqs,
                        is_append=False):
    file_utils.write_string_to_file(prefix + constants.SUFFIX_TRAIN_SOURCE, train_source_seqs, is_append)
    file_utils.write_string_to_file(prefix + constants.SUFFIX_TRAIN_TARGET, train_target_seqs, is_append)
    file_utils.write_string_to_file(prefix + constants.SUFFIX_EVAL_SOURCE, eval_source_seqs, is_append)
    file_utils.write_string_to_file(prefix + constants.SUFFIX_EVAL_TARGET, eval_target_seqs, is_append)


def preprocess_repo_list_with_commits(root_path):
    train_source_seqs = list()
    train_target_seqs = list()
    eval_source_seqs = list()
    eval_target_seqs = list()
    source_vocab_set = dict()
    target_vocab_set = dict()

    repos = os.listdir(root_path)
    repos_count = len(repos)

    conn = MongoClient(constants.DATABASE['HOST'], constants.DATABASE['PORT'])[constants.DATABASE['DATABASE']]
    db_col = conn[constants.DB_COLLECTIONS['BRIEF_REPOS']]

    for idx, repo_item in enumerate(repos):
        repo_item_path = os.path.join(root_path, repo_item)

        if not os.path.isdir(repo_item_path):
            continue

        repo_name = _get_repo_name(repo_item)
        db_repo_item = db_col.find_one({'repo_name': repo_name})
        if db_repo_item is None \
                or 'status' not in db_repo_item \
                or db_repo_item['status'] != constants.REPO_STATUS['SUCCESS'] or \
                ('preprocess_status' in db_repo_item
                 and db_repo_item['preprocess_status'] == constants.REPO_STATUS['SUCCESS']):
            print('SKIP: %s' % repo_item)
            continue

        print_repo_item = 'pos: %d/%d repo: %s' % (idx, repos_count, repo_item)
        print(print_repo_item)

        for commit_item in os.listdir(repo_item_path):
            commit_item_path = os.path.join(repo_item_path, commit_item)
            if not os.path.isdir(commit_item_path):
                continue
            print('%s, commit: %s' % (print_repo_item, commit_item))

            commit_diff_seq = get_commit_diff_seq(commit_item_path)
            if commit_diff_seq is None:
                continue
            except_no_english = r'[^\sa-zA-Z0-9.!"#$%&\'()*+,-./:;<=>?@[\]^_`{|}~]'
            commit_diff_seq = re.sub(except_no_english, '', commit_diff_seq)
            commit_diff_seq = re.sub(r'\s+', ' ', commit_diff_seq)
            if commit_diff_seq == '':
                continue

            commit_msg = get_commit_msg(commit_item_path)
            if commit_msg is None:
                continue
            commit_msg = re.sub(except_no_english, '', commit_msg)
            if commit_msg == '':
                continue

            print('commit_msg: %s' % commit_msg)

            if random.random() <= constants.EVAL_RATE:
                eval_source_seqs.append(commit_diff_seq)
                eval_target_seqs.append(commit_msg)
            else:
                train_source_seqs.append(commit_diff_seq)
                train_target_seqs.append(commit_msg)

            for source_word in commit_diff_seq.split(' '):
                if source_word not in source_vocab_set:
                    source_vocab_set[source_word] = 1
                else:
                    source_vocab_set[source_word] += 1
            for target_word in commit_msg.split(' '):
                if target_word not in target_vocab_set:
                    target_vocab_set[target_word] = 1
                else:
                    target_vocab_set[target_word] += 1
        db_col.update({'_id': db_repo_item['_id']}, {'$set': {
            'preprocess_status': constants.REPO_STATUS['SUCCESS']
        }})

    return train_source_seqs, \
           train_target_seqs, \
           eval_source_seqs, \
           eval_target_seqs


def write_data():
    train_source_seqs, \
    train_target_seqs, \
    eval_source_seqs, \
    eval_target_seqs = preprocess_repo_list_with_commits(SHELL_ARGS.raw_data_path)

    prefix = 'static/data/text_commits/text_commits'
    write_data_to_files(prefix,
                        '\n'.join(train_source_seqs),
                        '\n'.join(train_target_seqs),
                        '\n'.join(eval_source_seqs),
                        '\n'.join(eval_target_seqs),
                        is_append=True)


if __name__ == '__main__':
    write_data()
