from mikatools import *
from glob import glob
import random, unicodedata
import nltk

random.seed()

def _check_sentence(sentence):
	compounds = False
	cor = "".join([x["correct"] for x in sentence])
	err = "".join(["".join(x["error"]) for x in sentence])
	if cor!= err:
		"""
		print(cor)
		print(err)
		for c_c, c_e in zip(cor, err):
			if c_c is not c_e:
				print(c_c, c_e)
		"""
		#quit()
		return False
	for word in sentence:
		if "?" in word["pos"] or "err_in_pos" in word["pos"] or "err_in_sent" in word["pos"]:
			return False
		if len(word["error"]) > 1:
			compounds = True
	return compounds

def make_splits():
	sentences = []
	for file in glob("data/pos_tagged/*.json"):
		d = json_load(file)
		sentence = []
		for word in d:
			#word["pos"] = "_".join(word["pos"])
			if type(word["error"]) is str:
				word["error"] = [word["error"]]
			sentence.append(word)
			if "CLB" in word["pos"] or "PUNCT" in word["pos"]:
				if _check_sentence(sentence):
					sentences.append(sentence)
				sentence = []
	for sentence in json_load("data/ud_generated.json"):
		if _check_sentence(sentence):
			sentences.append(sentence)
	train_end = int(len(sentences)*0.7)
	random.shuffle(sentences)
	train = sentences[:train_end]
	valid = sentences[train_end:train_end + int(len(sentences)*0.15)]
	test  = sentences[train_end + int(len(sentences)*0.15):]
	print(len(sentences))
	json_dump(train, "data/nmt_splits/train.json")
	json_dump(valid, "data/nmt_splits/valid.json")
	json_dump(test, "data/nmt_splits/test.json")

def make_splits_extended_test():
	sentences = []
	for file in glob("data/restricted/*.json"):
		d = json_load(file)
		sentence = []
		for word in d:
			#word["pos"] = "_".join(word["pos"])
			sentence.append(word)
			if "CLB" in word["pos"] or "PUNCT" in word["pos"]:
				#if _check_sentence(sentence):
				sentences.append(sentence)
				sentence = []
	print(len(sentences))
	json_dump(sentences, "data/nmt_splits/test_extended.json")

def _write(stream, words, character_level, pos_tag):
	#print(words)
	if len(words) == 0:
		return
	if character_level:
		words_out = [" ".join(x[0]) for x in words]
		if pos_tag:
			for i, w in enumerate(words):
				if w[1] is None:
					continue
				words_out[i] = w[1].replace("\n","").replace("\r","") + "> " + words_out[i] + " <" + w[1].replace("\n","").replace("\r","")
		stream.write(" _ ".join(words_out) + "\n")
	else:
		stream.write(" ".join([x[0] for x in words]) + "\n")

def make_nmt_files(num_words=2, character_level=True, pos_tag=False):
	folder = "data/words_" +str(num_words)+ "_char_" +str(character_level) + "_pos_" + str(pos_tag)
	try:
		os.mkdir(folder)
	except:
		pass
	data_types = ["train", "valid", "test"]
	for data_type in data_types:
		data = json_load("data/nmt_splits/"+data_type+".json")
		source = open_write(folder + "/" + data_type + "_source.txt")
		target = open_write(folder + "/" + data_type + "_target.txt")
		for sentence in data:
			write_source = []
			write_target = []
			for word in sentence:
				if len(word["error"]) + len(write_source) > num_words:
					_write(source, write_source, character_level, pos_tag)
					_write(target, write_target, character_level, pos_tag)
					write_source = []
					write_target = []
				write_source.extend(zip(word["error"], word["pos"]))
				write_target.append([word["correct"], None])
				if len(write_source) >= num_words:
					_write(source, write_source, character_level, pos_tag)
					_write(target, write_target, character_level, pos_tag)
					write_source = []
					write_target = []
			_write(source, write_source, character_level, pos_tag)
			_write(target, write_target, character_level, pos_tag)
		source.close()
		target.close()
	slurm = open_read("slurm_template.sh").read()
	slurm = slurm.replace("DATAFOLDER", folder).replace("MODELTYPE", "brnn")
	f = open_write(folder +"/slurm.sh")
	f.write(slurm)
	f.close()

def make_nmt_extended_test(num_words=2, character_level=True, pos_tag=False):
	folder = "data/words_" +str(num_words)+ "_char_" +str(character_level) + "_pos_" + str(pos_tag)
	try:
		os.mkdir(folder)
	except:
		pass
	data_types = ["test_extended"]
	for data_type in data_types:
		data = json_load("data/nmt_splits/"+data_type+".json")
		source = open_write(folder + "/" + data_type + "_source.txt")
		#target = open_write(folder + "/" + data_type + "_target.txt")
		for sentence in data:
			write_source = []
			#write_target = []
			for word in sentence:
				if 1+ len(write_source) > num_words:
					_write(source, write_source, character_level, pos_tag)
					#_write(target, write_target, character_level, pos_tag)
					write_source = []
					#write_target = []
				write_source.extend(zip([word["correct"]], word["pos"]))
				#Swrite_target.append([word["correct"], None])
				if len(write_source) >= num_words:
					_write(source, write_source, character_level, pos_tag)
					#_write(target, write_target, character_level, pos_tag)
					write_source = []
					#write_target = []
			_write(source, write_source, character_level, pos_tag)
			#_write(target, write_target, character_level, pos_tag)
		source.close()
		#target.close()

def tokenize_extended_test_target():
	source = json_load("data/restricted/src-boundcorpus-errcmp.txt.json")
	puncts = []
	for item in source:
		if "CLB" in item["pos"] or "PUNCT" in item["pos"]:
			puncts.append(item["correct"])
	puncts = set(puncts)
	target = open_read("data/restricted/tgt-boundcorpus-errcmp-only.text").read().replace("\n"," ").replace("\r","")
	tokens = nltk.word_tokenize(target)
	sentence = []
	w = open_write("data/restricted/tgt-boundcorpus-errcmp-tokenized.txt")
	for t in tokens:
		sentence.append(t)
		if t in puncts:
			w.write(" ".join(sentence) + "\n")
			sentence = []
	if len(sentence) != 0:
		w.write(" ".join(sentence) + "\n")
	w.close()

tokenize_extended_test_target()

#make_splits()
#make_splits_extended_test()
"""
for x in [True, False]:
	make_nmt_extended_test(2, pos_tag=x)
	make_nmt_extended_test(3, pos_tag=x)
	make_nmt_extended_test(4, pos_tag=x)
	make_nmt_extended_test(5, pos_tag=x)
"""

"""
make_nmt_files(2, pos_tag=True)
make_nmt_files(3, pos_tag=True)
make_nmt_files(4, pos_tag=True)
make_nmt_files(5, pos_tag=True)
"""

