from sklearn import metrics
import os
import subprocess
import csv
import dataset_mapper
import classifier_mapper
import svm
import exec
import shutil
import numpy as np

test_name = "dataset/test-set.csv"
training_name = "dataset/training-set.csv"

def LIMEtrend(svm,data_folder,columns):
	from lime.lime_tabular import LimeTabularExplainer

	dataset_path = f"./{data_folder}/{test_name}"
	dataset_mapper1 = dataset_mapper.DatasetMapper()
	x, y = dataset_mapper1.read(dataset_path)
	
	columns = columns[1:]
	x = [list(map(float, xi)) for xi in x]
	explainer = LimeTabularExplainer(np.array(x), feature_names = columns, class_names=['0', '1'], discretize_continuous=False)
	
	explanations = []
	for i in range(len(x)):
		exp = explainer.explain_instance(np.array(x[i]), svm.predict_proba, num_features = len(x[0]))
		explanations.append(exp.as_map()[1])
		print(i)
	
	explanations = np.abs(np.array(explanations))
	result = []
	for key in range(len(x[0])):
	   	values = [sublist[i][1] for sublist in explanations for i in range(len(sublist)) if sublist[i][0] == key]
	   	avg = sum(values) / len(values)
	   	result.append(avg)
	importances = np.argsort(result)
	feature_imp = []
	for i in importances:
		if ('=' not in columns[i]):
			feature_imp.append((columns[i],result[i]))
	print("Lime Feature Rankings:", feature_imp)

	return feature_imp

def PFItrend(svm, data_folder, columns):
	from sklearn.inspection import permutation_importance
	dataset_path = f"./{data_folder}/{test_name}"
	dataset_mapper1 = dataset_mapper.DatasetMapper()
	x, y = dataset_mapper1.read(dataset_path)
	#with open(f"./{data_folder}/dataset/columns.csv", 'r') as f:
	#		columns = [line for line in csv.reader(f)][0]
	result = permutation_importance(svm, x, y, n_repeats = 10, random_state = 0)

	Score = dict()
	for col_id in range(1,len(columns)):
		if ('=' not in columns[col_id]):
			Score[columns[col_id]] = result.importances_mean[col_id-1]
	#feature_grade, feature_score = exec.score_to_grade(Score, canBeZero = True, only_num = True)
	return Score

def test_SVM(model,data_folder, exclude = None):
	from sklearn import metrics
	dataset_path = f"./{data_folder}/{test_name}"
	dataset_mapper1 = dataset_mapper.DatasetMapper()
	x, y = dataset_mapper1.read(dataset_path, exclude = exclude)
	y_pred = model.predict(x)
	ba = metrics.balanced_accuracy_score(y, y_pred)
	print("\tAccuracy:",metrics.accuracy_score(y, y_pred))
	print("\tBalanced Accuracy:", ba)

	return ba

def AFItrend(data_folder, columns):

	fileR = open(f"./{data_folder}/{data_folder}-feature_score_raw.txt","r+")
	#with open(f"./{data_folder}/dataset/columns.csv", 'r') as f:
	#	columns = [line for line in csv.reader(f)][0]
	rawdata = fileR.readlines()
	weights = rawdata[0].split()
	feature_score = dict()
	for col_i in range(1,len(columns)):
		if ('=' not in columns[col_i]): 
			feature_score[columns[col_i]] = abs(float(weights[col_i]))
	#feature_grade,feature_score = exec.score_to_grade(feature_score, only_num = True)
	return feature_score

def create_retrain_model(original_ba, kernel_name, col, feature, reg_param = 1, gamma = 1, degree = 1, coef0 = 0, data_folder = ""):
	svm_name = f"{data_folder}_without-{col[feature]}-svm_{kernel_name}_g{gamma}_d{degree}_c{coef0}_C{reg_param}"

	dataset_path = f"./{data_folder}/{training_name}"
	print(f"Creating SVM without {col[feature]}: {svm_name}")
	# Trains model
	dataset_mapper1 = dataset_mapper.DatasetMapper()
	x, y = dataset_mapper1.read(dataset_path, exclude = feature)

	trainer = svm.SVM(kernel_name, gamma, degree, coef0, reg_param)
	model = trainer.train(x, y)

	ba = test_SVM(model, data_folder, exclude = feature)
	return abs(ba - original_ba)

def execute_setting(choice, choice2, fileW):
	kernel_name = "linear" if choice2 == 1 else "rbf" if choice2 == 2 else "poly"
	reg_param, gamma, deg, coef0, data_folder = 0, 0.1, 0, 0, ""
	if(choice == 1):
		data_folder = "adult"
	if(choice == 2):
		data_folder = "compas"
	if(choice == 3):
		data_folder = "crime"
	if(choice == 4):
		data_folder = "german"
	if(choice == 5):
		data_folder = "health"

	os.chdir(f"./{data_folder}")
	s = subprocess.check_call(f"python3 {data_folder}-get.py", shell = True)
	os.chdir("..")

	if(choice == 1 and choice2 == 1):
		reg_param = 1
	if(choice == 2 and choice2 == 1):
		reg_param = 1
	if(choice == 3 and choice2 == 1):
		reg_param = 1
	if(choice == 4 and choice2 == 1):
		reg_param = 1
	if(choice == 5 and choice2 == 1):
		reg_param = 0.01

	if(choice == 1 and choice2 == 2):
		reg_param, gamma = 0.05, 0.01
	if(choice == 2 and choice2 == 2):
		reg_param, gamma = 0.05, 0.01
	if(choice == 3 and choice2 == 2):
		reg_param, gamma = 1, 0.001
	if(choice == 4 and choice2 == 2):
		reg_param, gamma = 10, 0.05
	if(choice == 5 and choice2 == 2):
		reg_param, gamma = 0.1, 0.01

	if(choice == 1 and choice2 == 3):
		reg_param, deg, coef0 = 0.01, 3, 3
	if(choice == 2 and choice2 == 3):
		reg_param, deg, coef0 = 0.01, 3, 3 
	if(choice == 3 and choice2 == 3):
		reg_param, deg, coef0 = 1, 9, 0      
	if(choice == 4 and choice2 == 3):
		reg_param, deg, coef0 = 0.01, 6, 6
	if(choice == 5 and choice2 == 3):
		reg_param, deg, coef0 = 0.01, 3, 0.01

	
	model,svm_addr = exec.create_model(kernel_name,reg_param,gamma,deg, coef0,data_folder,[], ifmlx = False, get_model = True)
	original_ba = test_SVM(model,data_folder)
	with open(f"./{data_folder}/dataset/columns.csv", 'r') as f:
		columns = [line for line in csv.reader(f)][0]

	retrain_ba = dict()
	for i in range(1, len(columns)):
		if ('=' not in columns[i]):
			retrain_ba[columns[i]] = create_retrain_model(original_ba, kernel_name, columns, i, reg_param, gamma, deg, coef0, data_folder)

	exec.run_saver(svm_addr,"raf","top", data_folder, is_OH = 1, get_CE = 1, if_part = 0)
	dest = shutil.move("../saver/feature_score_raw.txt", f"./{data_folder}/{data_folder}-feature_score_raw.txt")
	afi_trend = AFItrend(data_folder, columns)

	pfi_trend = PFItrend(model, data_folder, columns)

	lime_trend = LIMEtrend(model,data_folder,columns)

	print(f"\n\n\nAFI: {sorted(afi_trend.items(), key = lambda x:x[1])} \n\n\n Retrain: {sorted(retrain_ba.items(), key = lambda x:x[1])} \n\n\n PFI: {sorted(pfi_trend.items(), key = lambda x:abs(x[1]))} \n\n\n LIME: {lime_trend}")

	#print(f"\n\n\n\n\nAFI: {afi_trend} \n\n\n Retrain: {retrain_ba} \n\n\n PFI: {pfi_trend}")


if __name__ == '__main__':
	execute_setting(4, 2, "")
	