import dataset_mapper
import svm
import statistics
import Perturbation
import matplotlib.pyplot as plt
import numpy as np
import exec
import os
import subprocess
import crime.crime_adversarial_region
import adult.adult_adversarial_region
import compas.compas_adversarial_region
import german.german_adversarial_region
import health.health_adversarial_region
import csv
import shutil
import math
import time



#kernel_name = 'poly'
#reg_param = 0.01
#gamma = 0.01
#degree = 6 
#coef0 = 3


# Their
#featRank2 = {'credit_history=A31': 5, 'other_debtors=A101': 5, 'other_debtors=A102': 5, 'housing=A151': 5, 'housing=A152': 5, 'skill_level=A172': 5, 'people_liable_for': 6, 'foreign_worker_A202': 6, 'status=A12': 6, 'status=A13': 6, 'credit_history=A30': 6, 'credit_history=A34': 6, 'purpose=A410': 6, 'purpose=A42': 6, 'purpose=A44': 6, 'purpose=A45': 6, 'purpose=A46': 6, 'purpose=A48': 6, 'purpose=A49': 6, 'savings=A63': 6, 'savings=A64': 6, 'employment=A72': 6, 'employment=A75': 6, 'other_debtors=A103': 6, 'property=A122': 6, 'property=A123': 6, 'installment_plans=A141': 6, 'installment_plans=A142': 6, 'installment_plans=A143': 6, 'housing=A153': 6, 'skill_level=A171': 6, 'skill_level=A174': 6, 'residence_since': 7, 'age': 7, 'number_of_credits': 7, 'sex_male': 7, 'credit_history=A33': 7, 'purpose=A41': 7, 'purpose=A43': 7, 'savings=A61': 7, 'savings=A62': 7, 'savings=A65': 7, 'employment=A71': 7, 'employment=A73': 7, 'employment=A74': 7, 'property=A121': 7, 'property=A124': 7, 'credit_amount': 8, 'investment_as_income_percentage': 8, 'telephone_A192': 8, 'status=A14': 8, 'credit_history=A32': 8, 'purpose=A40': 8, 'skill_level=A173': 8, 'status=A11': 9, 'months': 10} 

# Our
#featRank1 = {'people_liable_for': 6, 'status=A11': 6, 'credit_history=A30': 6, 'purpose=A40': 6, 'savings=A61': 6, 'employment=A71': 6, 'other_debtors=A101': 6, 'property=A121': 6, 'installment_plans=A141': 6, 'housing=A151': 6, 'skill_level=A171': 6, 'residence_since': 7, 'telephone_A192': 7, 'sex_male': 7, 'number_of_credits': 8, 'foreign_worker_A202': 8, 'investment_as_income_percentage': 9, 'months': 10, 'credit_amount': 10, 'age': 10} 

featColor = {'residence_since': 'b', 'people_liable_for': 'g', 'telephone_A192': 'y', 'sex_male': 'c', 'investment_as_income_percentage': 'm', 'number_of_credits': 'r', 'foreign_worker_A202': 'orange', 'months': 'cyan', 'age': 'pink', 'credit_amount': 'peru',}

#kernel_name = 'rbf'
#reg_param = 10
#gamma = 0.01
#degree = 6
#coef0 = 3

#data_folder = "german"	
training_name = "dataset/training-set.csv"
test_name = "dataset/test-set.csv"

def LIMEtrend(svm, data_folder, data_point, columns):
	from lime.lime_tabular import LimeTabularExplainer

	dataset_path = f"./{data_folder}/{test_name}"
	dataset_mapper1 = dataset_mapper.DatasetMapper()
	x, y = dataset_mapper1.read(dataset_path)
	
	#with open(f"./{data_folder}/dataset/columns.csv", 'r') as f:
	#	columns = [line for line in csv.reader(f)][0]
	columns = columns[1:]
	x = [list(map(float, xi)) for xi in x]
	explainer = LimeTabularExplainer(np.array(x), feature_names = columns, class_names=['0', '1'], discretize_continuous=False)
	
	
	exp = explainer.explain_instance(np.array(x[data_point]), svm.predict_proba, num_features = len(x[0])).as_map()[1]
	feature_imp = dict()
	for key, score in exp:
		feature_imp[columns[key]] = abs(score)
	feature_grade,feature_score = exec.score_to_grade(feature_imp, only_num = True)
	return feature_grade, x[data_point]

def AFItrend(data_folder, columns):
	fileR = open(f"./{data_folder}/{data_folder}-feature_score_raw.txt","r+")
	#with open(f"./{data_folder}/dataset/columns.csv", 'r') as f:
	#	columns = [line for line in csv.reader(f)][0]
	rawdata = fileR.readlines()
	weights = rawdata[0].split()
	feature_score = dict()
	for col_i in range(1,len(columns)):
		feature_score[columns[col_i]] = abs(float(weights[col_i]))
	feature_grade,feature_score = exec.score_to_grade(feature_score, only_num = True)
	return feature_grade

def test_SVM(model):
	from sklearn import metrics
	dataset_path = f"./{data_folder}/{test_name}"
	dataset_mapper1 = dataset_mapper.DatasetMapper()
	x, y = dataset_mapper1.read(dataset_path)
	y_pred = model.predict(x)
	print("Accuracy:",metrics.accuracy_score(y, y_pred))
	print("Balanced Accuracy:",metrics.balanced_accuracy_score(y, y_pred))


def outcomeCurve(model,feat,input_mid, data_folder):
	Fid = Perturbation.readColumns(f'./{data_folder}/dataset/columns.csv').index(feat)
	outcomes = dict()
	slope = 0
	store = input_mid[Fid]
	for Fval in range(-5,6):
		input_mid[Fid] = store - Fval/500	
		outcomes[Fval/500] = list(model.decision_function([input_mid]))[0]
	input_mid[Fid] = store
	mid = outcomes[0.0]
	for key in outcomes.keys():
		slope += abs((outcomes[key] - mid) / (key + 0.00001))
		outcomes[key] = abs(outcomes[key] - mid)
	slope = slope/len(outcomes.items())

	return outcomes, slope

def baslineTrend(model, input_mid, data_folder):
	feature_score = dict()
	scores = dict()
	cols = Perturbation.readColumns(f'./{data_folder}/dataset/columns.csv')
	for feat in cols:
		if '=' in feat:
			continue
		_ , feature_score[feat] = outcomeCurve(model,feat,input_mid, data_folder)
	feature_grade,feature_score = exec.score_to_grade(feature_score, only_num = True)
	return feature_grade

def euclid_dist_of_trends(trend1, trend2):
	dist = 0
	for key, val in trend1.items():
		dist += (val - trend2[key])**2
	dist = math.sqrt(dist)
	return dist


def alloutcomeCurve(model, limeRank, AFIRank, data_folder, kernel_name, input_mid = None ):
	allOutcomes = dict()
	scores = dict()
	cols = Perturbation.readColumns(f'./{data_folder}/dataset/columns.csv')
	if (input_mid == None):
		input_mid = [0.0]*(len(cols))
		for cid in range(len(cols)):
			if '=' in cols[cid]:
				if cols[cid] in featRank1.keys():
					input_mid[cid] = 1.0
				else:
					input_mid[cid] = 0.0
			else:
				input_mid[cid] = 0.5
	for feat in AFIRank.keys():
		if '=' in feat or feat == 'sex_male':
			continue
		allOutcomes[feat], _ = outcomeCurve(model,feat,input_mid, data_folder)
	
	f = plt.figure()
	f.set_figwidth(2)
	f.set_figheight(2)
	i = 0
	for legend,data in allOutcomes.items():
		x = list(data.keys())
		y = list(data.values())
		#print(f"{legend} --> {y}")
		plt.plot(x, y,'--bo', color = featColor[legend])
		
		pos11,pos12 = (x[-1],y[-1])
		pos21,pos22 = (x[0],y[0])
		plt.text(pos11,pos12, f'{limeRank[legend]}',fontsize = 30.0)
		plt.text(pos21,pos22, f'{AFIRank[legend]}',fontsize = 30.0)
		i += 1
	plt.text(-0.0037,0.0035, f'AFI',fontsize = 30.0)
	plt.text(0.0027,0.0035, f'LIME',fontsize = 30.0)
	plt.xlabel('Perturbation of feature Input')
	plt.ylabel('Absolute change in outcome')
	plt.title(f'{data_folder}-{kernel_name}')
	plt.show()


def compare_trends(model,svm_addr, kernel_name,reg_param,gamma,degree, coef0,data_folder, get_CE, data_point, draw = False, epsilon = 0.3):
	exec.createDir(data_folder)
	os.system('rm ../saver/result1.txt')
	os.system('rm ../saver/feature_score_raw.txt')
	os.system('rm ../saver/result_raw.txt')
	os.system('touch ../saver/result1.txt')
	os.system('touch ../saver/result_raw.txt')
	os.system('touch ../saver/feature_score_raw.txt')
	os.chdir(f"./{data_folder}")
	s = subprocess.check_call(f"python3 {data_folder}-get.py", shell = True)
	os.chdir("..")

	if(data_folder == "adult"):
		adult.adult_adversarial_region.execute("neighbour", [], epsilon, data_point)
		adult.adult_adversarial_region.execute("top", [], epsilon, data_point)
	if(data_folder == "compas"):
		compas.compas_adversarial_region.execute("neighbour", [], epsilon, data_point)
		compas.compas_adversarial_region.execute("top", [], epsilon, data_point)
	if(data_folder == "crime"):
		crime.crime_adversarial_region.execute("neighbour", [], epsilon, data_point)
		crime.crime_adversarial_region.execute("top", [], epsilon, data_point)
	if(data_folder == "german"):
		german.german_adversarial_region.execute("neighbour", [], epsilon, data_point)
		german.german_adversarial_region.execute("top", [], epsilon, data_point)
	if(data_folder == "health"):
		health.health_adversarial_region.execute2("neighbour", [], epsilon, data_point)
		health.health_adversarial_region.execute2("top", [], epsilon, data_point)

	with open(f"./{data_folder}/dataset/columns.csv", 'r') as f:
		columns = [line for line in csv.reader(f)][0]

	t1 = time.time()
	exec.run_saver(svm_addr,"raf","neighbour", data_folder, is_OH = 1, get_CE = get_CE, if_part = 0)
	dest = shutil.move("../saver/feature_score_raw.txt", f"./{data_folder}/{data_folder}-feature_score_raw.txt")
	afitrend_L = AFItrend(data_folder, columns)
	t2 = time.time()

	os.system('rm ../saver/result1.txt')
	os.system('rm ../saver/feature_score_raw.txt')
	os.system('rm ../saver/result_raw.txt')
	os.system('touch ../saver/result1.txt')
	os.system('touch ../saver/result_raw.txt')
	os.system('touch ../saver/feature_score_raw.txt')

	t3 = time.time()
	exec.run_saver(svm_addr,"raf","top", data_folder, is_OH = 1, get_CE = get_CE, if_part = 0)
	dest = shutil.move("../saver/feature_score_raw.txt", f"./{data_folder}/{data_folder}-feature_score_raw.txt")
	afitrend_G = AFItrend(data_folder, columns)

	t4 = time.time()
	limetrend, original_point = LIMEtrend(model, data_folder, data_point, columns)
	
	t5 = time.time()
	basline_trend = baslineTrend(model, original_point, data_folder)
	t6 = time.time()
	no_of_feat = len(limetrend)

	dist_lime_base = euclid_dist_of_trends(limetrend, basline_trend) #/ no_of_feat
	dist_afiL_base = euclid_dist_of_trends(afitrend_L, basline_trend) #/ no_of_feat
	dist_afiG_base = euclid_dist_of_trends(afitrend_G, basline_trend) #/ no_of_feat
	dist_afiL_lime = euclid_dist_of_trends(afitrend_L, limetrend) #/ no_of_feat
	dist_afiG_lime = euclid_dist_of_trends(afitrend_G, limetrend) #/ no_of_feat
	if(draw):
		alloutcomeCurve(model,limetrend,afitrend_L, data_folder, kernel_name, original_point)

	return dist_lime_base, dist_afiL_base, dist_afiG_base, dist_afiL_lime, dist_afiG_lime, (t2 - t1), (t4 - t3), (t5 - t4), (t6 - t5)

def execute_setting(choice, choice2, choice3, get_CE, fileW):
	kernel_name = "linear" if choice2 == 1 else "rbf" if choice2 == 2 else "poly"
	reg_param, gamma, deg, coef0, data_folder = 0, 0.1, 0, 0, ""
	if(choice == 1):
		data_folder = "adult"
	if(choice == 2):
		data_folder = "compas"
	if(choice == 3):
		data_folder = "german"

	os.chdir(f"./{data_folder}")
	s = subprocess.check_call(f"python3 {data_folder}-get.py", shell = True)
	os.chdir("..")

	if(choice == 1 and choice2 == 1):
		reg_param = 1
	if(choice == 2 and choice2 == 1):
		reg_param = 1
	if(choice == 3 and choice2 == 1):
		reg_param = 1


	if(choice == 1 and choice2 == 2):
		reg_param, gamma = 0.05, 0.01
	if(choice == 2 and choice2 == 2):
		reg_param, gamma = 0.05, 0.01
	if(choice == 3 and choice2 == 2):
		reg_param, gamma = 10, 0.05


	if(choice == 1 and choice2 == 3):
		reg_param, deg, coef0 = 0.01, 3, 3
	if(choice == 2 and choice2 == 3):
		reg_param, deg, coef0 = 0.01, 3, 3      
	if(choice == 3 and choice2 == 3):
		reg_param, deg, coef0 = 0.01, 6, 6

	
	model,svm_addr = exec.create_model(kernel_name,reg_param,gamma,deg, coef0,data_folder,[], ifmlx = False, get_model = True)
	data_point = None
	
	if (choice3 == 3):
		data_point = int(input("Enter Data Point Index: "))
		epsilon = float(input("Enter epsilon value: "))
		compare_trends(model,svm_addr,kernel_name,reg_param,gamma,deg, coef0,data_folder, get_CE, data_point, True, epsilon = epsilon)
		exit(0)
	
	for epsilon in [0.1,0.2,0.3]:
		dist_lime_base_s, dist_afiL_base_s, dist_afiG_base_s, dist_afiL_lime_s, dist_afiG_lime_s, time_lime, time_afi_L, time_afi_G, time_base = [], [], [], [], [], [], [], [], []
		dataset_path = f"./{data_folder}/{test_name}"
		dataset_mapper1 = dataset_mapper.DatasetMapper()
		x, y = dataset_mapper1.read(dataset_path)
		size = 200 #len(x)
		fileW.write(f"\n\n\n \t\t\tDataset: {data_folder} ; Kernel: {kernel_name} Epsilon: {epsilon}\n \t\treg_param({reg_param}), gamma({gamma}), degree({deg}), coef0({coef0})\n\n")
		print(f"\n\n\n \t\t\tDataset: {data_folder} Kernel: {kernel_name} Epsilon: {epsilon}\n \t\treg_param({reg_param}), gamma({gamma}), degree({deg}), coef0({coef0})\n")
		for i in range(size):
			l_b, aL_b, aG_b, aL_l, aG_l, taL, taG, tl, tb = compare_trends(model,svm_addr,kernel_name,reg_param,gamma,deg, coef0,data_folder, get_CE, i, False, epsilon = epsilon)
			dist_lime_base_s.append(l_b)
			dist_afiL_base_s.append(aL_b)
			dist_afiG_base_s.append(aG_b)
			dist_afiL_lime_s.append(aL_l)
			dist_afiG_lime_s.append(aG_l)
			time_afi_L.append(taL)
			time_afi_G.append(taG)
			time_lime.append(tl)
			time_base.append(tb)
			#str1 = f"[{i}/{size} Lime-Baseline: {l_b} [L]AFI-Baseline: {aL_b} [G]AFI-Baseline: {aG_b} [L]AFI-lime: {aL_l} [G]AFI-lime: {aG_l} Time: [L]AFI: {taL} [G]AFI: {taG}  LIME: {tl} Base: {tb}"
			str1 = f"[{i}/{size} [e = {epsilon}]AFI-lime: {aL_l} [G]AFI-lime: {aG_l} Time: [L]AFI: {taL} [G]AFI: {taG}  LIME: {tl}"
			print(str1+"\n\n")
			#fileW.write(str1+"\n")
		print(f"\n\n\nFinal --> {data_folder}-{kernel_name} -> [e = {epsilon}]AFI-lime: {sum(dist_afiL_lime_s)/size}  [Global]AFI-lime: {sum(dist_afiG_lime_s)/size} Time: [L]AFI: {sum(time_afi_L)/size} [G]AFI: {sum(time_afi_G)/size} LIME: {sum(time_lime)/size}")
		fileW.write(f"\nFINAL --> {data_folder}-{kernel_name} -> [e = {epsilon}]AFI-lime: {sum(dist_afiL_lime_s)/size} [Global]AFI-lime: {sum(dist_afiG_lime_s)/size} Time: [L]AFI: {sum(time_afi_L)/size} [G]AFI: {sum(time_afi_G)/size} LIME: {sum(time_lime)/size}\n")

		#print(f"\n\n\nFinal --> {data_folder}-{kernel_name}-e={epsilon} -> Lime-Baseline: {sum(dist_lime_base_s)/size} [L]AFI-Baseline: {sum(dist_afiL_base_s)/size} [G]AFI-Baseline: {sum(dist_afiG_base_s)/size} [L]AFI-lime: {sum(dist_afiL_lime_s)/size} [G]AFI-lime: {sum(dist_afiG_lime_s)/size} Time: [L]AFI: {sum(time_afi_L)/size} [G]AFI: {sum(time_afi_G)/size} LIME: {sum(time_lime)/size} Base: {sum(time_base)/size}")
		#fileW.write(f"\nFINAL --> {data_folder}-{kernel_name}-e={epsilon} -> Lime-Baseline: {sum(dist_lime_base_s)/size} [L]AFI-Baseline: {sum(dist_afiL_base_s)/size} [G]AFI-Baseline: {sum(dist_afiG_base_s)/size} [L]AFI-lime: {sum(dist_afiL_lime_s)/size} [G]AFI-lime: {sum(dist_afiG_lime_s)/size} Time: [L]AFI: {sum(time_afi_L)/size} [G]AFI: {sum(time_afi_G)/size} LIME: {sum(time_lime)/size} Base: {sum(time_base)/size}\n")

if __name__ == '__main__':
	fileW = open(f"./local_compare_trends.txt","w+")
	get_CE = 1

	print("1) All combinations  2) One Dataset and Kernel  3) One Neighbourhood")
	choice3 = int(input("Enter choice: "))

	if(choice3 == 1):
		for choice1 in [1,2,4]:
			for choice2 in range(1,4):
				execute_setting(choice1, choice2, choice3, get_CE, fileW)
		exit(0)
	print(" 1) Adult \n 2) Compas \n 3) German")
	choice1 = int(input("Enter choice: "))

	print(" 1) Linear \n 2) RBF \n 3) Poly")
	choice2 = int(input("Enter choice: "))
	
	execute_setting(choice1, choice2, choice3, get_CE, fileW)
	

#adult-linear -> Lime-Baseline: 1.3582921541974997 [L]AFI-Baseline: 0.0 [G]AFI-Baseline: 0.0 
#	Time: [L]AFI: 0.4769662165641785 [G]AFI: 0.4624751996994019 LIME: 6.279986448287964 Base: 0.14116583824157714
#adult-rbf -> Lime-Baseline: 0.35692130429902463 [L]AFI-Baseline: 0.3341421356237309 [G]AFI-Baseline: 0.3324264068711929 
#	Time: [L]AFI: 0.9156425428390503 [G]AFI: 0.813996229171753 LIME: 18.325791442394255 Base: 0.34733598947525024
#adult-poly -> Lime-Baseline: 0.7662424407670203 [L]AFI-Baseline: 2.4296379342324634 [G]AFI-Baseline: 1.3033441156718468 
#	Time: [L]AFI: 0.6721958041191101 [G]AFI: 0.6106457996368408 LIME: 7.644788334369659 Base: 0.16184167861938475
#compas-linear -> Lime-Baseline: 1.6098711427676775 [L]AFI-Baseline: 0.0 [G]AFI-Baseline: 0.0 
#	Time: [L]AFI: 0.295796160697937 [G]AFI: 0.28781139373779296 LIME: 5.016060974597931 Base: 0.17614589452743531
#compas-rbf -> Lime-Baseline: 0.18414213562373097 [L]AFI-Baseline: 0.07 [G]AFI-Baseline: 0.09 
#	Time: [L]AFI: 0.4910068345069885 [G]AFI: 0.4656415629386902 LIME: 12.433901107311248 Base: 0.34461902141571044
#compas-poly -> Lime-Baseline: 1.3179607616892266 [L]AFI-Baseline: 2.5539815014695613 [G]AFI-Baseline: 3.55420461438044 
#	Time: [L]AFI: 0.36212203741073606 [G]AFI: 0.36239205360412596 LIME: 5.3497550630569455 Base: 0.18888885021209717
#crime-linear -> Lime-Baseline: 3.417461726351986 [L]AFI-Baseline: 0.0 [G]AFI-Baseline: 0.0 
#	Time: [L]AFI: 0.05586108446121216 [G]AFI: 0.055214536190032956 LIME: 0.589758837223053 Base: 0.5151235270500183
#crime-rbf -> Lime-Baseline: 3.229918740001143 [L]AFI-Baseline: 0.09 [G]AFI-Baseline: 2.731645589694046 
#	Time: [L]AFI: 0.07820971250534058 [G]AFI: 0.07671674013137818 LIME: 1.5693612837791442 Base: 0.7863988137245178
#crime-poly -> Lime-Baseline: 1.0037776597079087 [L]AFI-Baseline: 5.496927299335007 [G]AFI-Baseline: 13.344533506513738 
#	Time: [L]AFI: 0.1297285032272339 [G]AFI: 0.12927788972854615 LIME: 0.8044147324562073 Base: 0.5876604056358338
#german-linear -> Lime-Baseline: 1.9366807221176114 [L]AFI-Baseline: 0.0 [G]AFI-Baseline: 0.0 
#	Time: [L]AFI: 0.01484534740447998 [G]AFI: 0.013740637302398682 LIME: 0.28549861431121826 Base: 0.0951015830039978
#german-rbf -> Lime-Baseline: 2.5498907904436647 [L]AFI-Baseline: 2.133044661211691 [G]AFI-Baseline: 2.2267794748375436 
#	Time: [L]AFI: 0.017930498123168947 [G]AFI: 0.017587919235229493 LIME: 0.4924344801902771 Base: 0.10769207715988159
#german-poly -> Lime-Baseline: 3.69523373849353 [L]AFI-Baseline: 4.608042418256227 [G]AFI-Baseline: 4.672275812249791 
#	Time: [L]AFI: 0.01861497402191162 [G]AFI: 0.018133800029754638 LIME: 0.3039626836776733 Base: 0.09699675798416138
