from mikatools import *

def read_to_sentences(num_words=1, path_suffix="_char_character_level/"):
	data = json_load("data/nmt_splits/test.json")
	source = open_read("data/words_" + str(num_words) + path_suffix + "test_source.txt" )
	target = open_read("data/words_" + str(num_words) + path_suffix + "test_target.txt" )
	pred = open_read("data/words_" + str(num_words) + path_suffix + "brnn_pred.txt" )
	source_s = open_write("data/words_" + str(num_words) + path_suffix + "source_sent.txt" )
	target_s = open_write("data/words_" + str(num_words) + path_suffix + "target_sent.txt" )
	pred_s = open_write("data/words_" + str(num_words) + path_suffix + "pred_sent.txt" )
	for sentence in data:
		sent = []
		write_source = []
		for word in sentence:
			if len(word["error"]) + len(write_source) > num_words:
				sent.append(write_source)
				write_source = []
			write_source.extend(word["error"])
			if len(write_source) >= num_words:
				sent.append(write_source)
				write_source = []
		sent.append(write_source)
		sent = [x for x in sent if len(x)!=0]
		for s in sent:
			sour = source.readline().replace("\n","").replace("\r","").replace(" ","").replace("_"," ")
			tar = target.readline().replace("\n","").replace("\r","").replace(" ","").replace("_"," ")
			pre = pred.readline().replace("\n","").replace("\r","").replace(" ","").replace("_"," ")
			source_s.write(sour +"!!")
			target_s.write(tar + "!!")
			pred_s.write(pre + "!!")
		source_s.write("\n")
		target_s.write("\n")
		pred_s.write("\n")
	source_s.close()
	target_s.close()
	pred_s.close()


def read_to_sentences_extended_test(num_words=1, path_suffix="_char_character_level/"):
	data = json_load("data/nmt_splits/test_extended.json")
	source = open_read("data/words_" + str(num_words) + path_suffix + "test_extended_source.txt" )
	#target = open_read("data/words_" + str(num_words) + path_suffix + "test_target.txt" )
	pred = open_read("data/words_" + str(num_words) + path_suffix + "test_extended_pred.txt" )
	source_s = open_write("data/words_" + str(num_words) + path_suffix + "source_sent_extended_pred.txt" )
	#target_s = open_write("data/words_" + str(num_words) + path_suffix + "target_sent.txt" )
	pred_s = open_write("data/words_" + str(num_words) + path_suffix + "pred_sent_extended_pred.txt" )
	for sentence in data:
		sent = []
		write_source = []
		for word in sentence:
			if 1 + len(write_source) > num_words:
				sent.append(write_source)
				write_source = []
			write_source.extend([word["correct"]])
			if len(write_source) >= num_words:
				sent.append(write_source)
				write_source = []
		sent.append(write_source)
		sent = [x for x in sent if len(x)!=0]
		for s in sent:
			sour = source.readline().replace("\n","").replace("\r","").replace(" ","").replace("_"," ")
			#tar = target.readline().replace("\n","").replace("\r","").replace(" ","").replace("_"," ")
			pre = pred.readline().replace("\n","").replace("\r","").replace(" ","").replace("_"," ")
			source_s.write(sour +" ")
			#target_s.write(tar + "!!")
			pred_s.write(pre + " ")
		source_s.write("\n")
		#target_s.write("\n")
		pred_s.write("\n")
	source_s.close()
	#target_s.close()
	pred_s.close()

def _process_s(s):
	s = s.replace("!!", " ")
	s = s.strip()
	return ' '.join(s.split())

def fix_lines(num_words=1, path_suffix="_char_character_level/"):
	pred_s = open_read("data/words_" + str(num_words) + path_suffix + "pred_sent_extended_pred.txt" ).read()
	pred_s = pred_s.replace("\n", " ").replace("¶ ", "\n")
	w = open_write("data/words_" + str(num_words) + path_suffix + "pred_sent_extended_pred.txt")
	w.write(pred_s)
	w.close()

def sentence_accuracy(num_words=1, path_suffix="_char_character_level/"):
	target_s = open_read("data/words_" + str(num_words) + path_suffix + "/target_sent.txt" ).readlines()
	pred_s = open_read("data/words_" + str(num_words) + path_suffix +"pred_sent.txt" ).readlines()
	total = 0.0
	correct =0.0
	print(num_words)
	for t, p in zip(target_s, pred_s):
		t = _process_s(t)
		p = _process_s(p)
		if t == p:
			correct +=1
		total += 1
	print(correct, total, correct/total)

def sentence_accuracy_extended(num_words=1, path_suffix="_char_character_level/"):
	target_s = open_read("data/words_" + str(num_words) + path_suffix + "/target_sent.txt" ).readlines()
	pred_s = open_read("data/words_" + str(num_words) + path_suffix +"pred_sent.txt" ).readlines()
	total = 0.0
	correct =0.0
	print(num_words)
	for t, p in zip(target_s, pred_s):
		t = _process_s(t)
		p = _process_s(p)
		if t == p:
			correct +=1
		total += 1
	print(correct, total, correct/total)




for x in range(4):
	x += 2
	#path_suffix = "_char_True_pos_True/"
	path_suffix = "_char_True_pos_False/"
	#read_to_sentences(x, path_suffix)
	#read_to_sentences_extended_test(x, path_suffix)
	#fix_lines(x, path_suffix)
	#sentence_accuracy(x, path_suffix)
