#!/usr/bin/env python3
'''split compounds in sme raw texts.'''

from uralicNLP import uralicApi
from sys import argv
import json

import os, sys, re
from corpustools import argparse_version, ccat, corpusxmlfile, modes, util
import multiprocessing


def levenshteinDistance(s1, s2):
    '''from net.

    https://stackoverflow.com/questions/2460177/edit-distance-in-python'''
    if len(s1) > len(s2):
        s1, s2 = s2, s1

    distances = range(len(s1) + 1)
    for i2, c2 in enumerate(s2):
        distances_ = [i2+1]
        for i1, c1 in enumerate(s1):
            if c1 == c2:
                distances_.append(distances[i1])
            else:
                distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
        distances = distances_
    return distances[-1]


def _clean_morphology(m):
    if "Prop" in m:
        return m
    m = m.replace("+Cmp", "")
    m = m.replace("/SgGen", "+Sg+Nom")
    m = m.replace("/Sg", "+Sg+")
    return m


def _cutup_morph(parts, original_word):
    if len(parts) != 2:
        return parts
    new_parts = [original_word[0:len(original_word) - len(parts[1])], parts[1]]
    return new_parts


def generate(morph, original_word):
    parts = morph.split("#")
    res = []
    cutup = False
    for i, p in enumerate(parts):
        p = _clean_morphology(p)
        r = uralicApi.generate(p, "sme")
        if len(r) == 0:
            res.append(p.split("+")[0])
            if i != len(parts)-1:
                cutup = True
            else:
                cutup = False
        else:
            res.append(r[0][0])
    if cutup:
        res = _cutup_morph(res, original_word)
    return res


def split(word):
    analysis = uralicApi.analyze(word, "sme")
    bestsplit = [word]
    mostsplat = 0
    leastdist = 10000
    for l in analysis:
        if "#" in l[0]:
            maby = generate(l[0], word)
            if len(maby) > mostsplat:
                bestsplit = maby
                mostsplat = len(maby)
            if len(maby) == mostsplat:
                dist = levenshteinDistance(''.join(bestsplit),
                                           ''.join(maby))
                if dist < leastdist:
                    bestsplit = maby
                    leastdist = dist
    return bestsplit


def analyses(modename, lang, input):
    pipeline = modes.Pipeline(modename, lang)
    pipeline.sanity_check()
    res = pipeline.run(input.encode('utf8'))
    return res


def split_compounds(data_file):
    out_name = data_file.split("/")[-1]
    f = open(data_file)
    f2 = open("data/multi_all_" + argv[1] + "/" + out_name + ".split", "w")
    f3 = open("data/multi_all_" + argv[1] + "/" + out_name + ".json", "w")
    print("Processing file: ", data_file)
    print("...")
    tot_data = []
    outFST = ""
    lemma_an = []
    pos =[]
    for sentence in f:
        sen_ar = []
        sen_err_ar = []
        res_sent = []
        sentence = sentence.replace(". ", " . ") # XXX
        sentence = sentence.replace(".\n", " .\n") # XXX
        # "Truecasing"
        if len(sentence) > 1 and \
                sentence[0].isupper() and sentence[1].islower():
            sentence = sentence[0].lower() + sentence[1:]
        sentence_err = sentence
        for word in sentence.split():
            parts = split(word)
            res_sent.append(parts)
            sen_ar.append([word, len(parts)])
            if len(parts) > 1:
                sentence_err = sentence_err.replace(word, " ".join(parts))
        for r in res_sent:
            print(" ".join(r), file=f2)
            print(" ".join(r).replace("- ", " "),  end=' ', file=f2)
        outFST = analyses("hfst", "sme", sentence_err)
        lemma_an = outFST.split('"<')
        for an in lemma_an:
            token = (an.split('>"')[0])
            pos = []
            try:
                analysis = re.split('\t', an)[1]
                no_lemma = analysis.split('" ')[1]
                no_lemma_ar = no_lemma.split(' ')
                for el in no_lemma_ar:
                    reg = re.search("Der\/*|Sem\/*|Ex\/*|Gram\/*|Err\/*|TV|IV", el)
                    if not reg:
                        pos.append(el)
                        break
            except:
                pos.append("")
            if token:
                sen_err_ar.append([token, pos])
        shift = 0
        for i in range(len(sen_ar)):
            data = {}
            err = []
            err_pos = []
            if sen_ar[i][1] == 1:
                data["correct"] = sen_ar[i][0]
                data["error"] = sen_ar[i][0]
                if shift>1:
                    try:
                        data["pos"] = sen_err_ar[i+shift][1]
                    except:
                        data["pos"] = "err_in_pos"
                else:
                    try:
                        data["pos"] = sen_err_ar[i][1]
                    except:
                        data["pos"] = "err_in_pos"
                tot_data.append(data)
            else:
                data["correct"] = sen_ar[i][0]
                for k in range(sen_ar[i][1]):
                    try:
                        err.append(sen_err_ar[k+i+shift][0])
                        err_pos.append(sen_err_ar[k+i+shift][1][0])
                    except:
                        err.append("err_in_sent")
                        err.append("err_in_pos")
                shift += k
                data["error"] = err
                data["pos"] = err_pos
                tot_data.append(data)

        print(file=f2)

    json.dump(tot_data, f3, ensure_ascii=False, indent=4)
    print("Done processing file: ", data_file)
    mv_cmd = "mv " + data_file  + " data/multi_all_" + argv[1] + "/done/"
    pCcat = Popen(mv_cmd, shell=True, stdout=PIPE, stderr=PIPE)
    outCcat, errCcat = pCcat.communicate()
    print()
    f.close()
    f2.close()
    f3.close()


def append_files(folder_path):
    for root, dirs, files in os.walk(folder_path):
        for file in files:
            if file.endswith(".xml"):
                files_list.append(os.path.join(root, file))


def split_in_parallel():
      """Analyse file in parallel."""
      pool_size = multiprocessing.cpu_count() * 2
      pool = multiprocessing.Pool(processes=pool_size, )
      pool.map(split_compounds, files_list)
      pool.close()  # no more tasks
      pool.join()  # wrap up current tasks
      return


home = os.getenv("HOME")
#modify path_to_ct according to your path to corpustools
path_to_ct = home + "/main/tools/CorpusTools/corpustools/"
sys.path.insert(1, path_to_ct)

files_list = []
print("Collecting files for ", argv[1])
#modify path accroding to git repository
append_files(home + "/compound-errors/data/multi_all_" + argv[1] + "/files")
print("Done collecting files")

split_in_parallel()
