from sklearn.ensemble import GradientBoostingClassifier # Gradient Boosting Machine (GBM)
import pandas as pd
import numpy as np
import time
import os
import psutil
from joblib import load
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

feature_names = ["patient_freewill","patient_age","patient_health","doctor_x","doctor_y","robot_speed","robot_charge"]

freewill_mapping = {
    "foc" : 0,
    "distr" : 1,
    "free" : 2
}

age_mapping = {
    "y" : 0,
    "e" : 1
}

health_mapping = {
    "h" : 0,
    "s" : 1,
    "u" : 2
}

def process_memory():
    process = psutil.Process(os.getpid())
    mem_info = process.memory_info()
    return mem_info.rss


print("Loading best classifier...")

dataset2 = pd.read_csv("data/formalise2023_dataset500_2.csv")
X_test = dataset2[feature_names].replace({'patient_freewill': freewill_mapping}).replace({'patient_age': age_mapping}).replace({'patient_health': health_mapping}).to_numpy()
y_test = dataset2[["success"]].to_numpy()
classifier = load('data/best_classifier.joblib') 

print("Computing cost...")

n_instances = X_test.shape[0]
runtimes = np.zeros(n_instances, dtype=float)
mems = np.zeros(n_instances, dtype=float)
for i in range(n_instances):
    instance = X_test[[i], :]
    mem_before = process_memory()
    start = time.time()
    label = classifier.predict(instance)[0]
    #if label != y_test[i]:
    #    prange = classifier.predict_proba(instance)
    #    print("{} {} {} {}".format(str(instance).replace('\n', ' '), str(y_test[[i], 0]).replace('\n', ' '), label, prange))
    runtimes[i] = time.time() - start
    mem_after = process_memory()
    mems[i] = mem_after

pred_cost = pd.DataFrame(data = {'time': runtimes, 'mem': mems})
pred_cost['mem'] = pred_cost['mem'].map(lambda x: x / 1000)
pred_cost['mode'] = 'prediction'
#pred_cost.to_csv('data/formalise2023_prediction_cost_020123.csv')  

smc_cost = pd.read_csv("data/formalise2023_smc_cost_020123.csv", skiprows=range(1,500))
smc_cost['TIME_MS'] = smc_cost['TIME_MS'].map(lambda x: x / 1000)

cost_mem = pd.DataFrame(data = {
    'SMC': smc_cost.loc[:,'RESIDENT_MEMORY_KiB'],
    'GBM': pred_cost.loc[:,'mem']}
    )
cost_time = pd.DataFrame(data = {
    'SMC': smc_cost.loc[:,'TIME_MS'],
    'GBM': pred_cost.loc[:,'time']}
    )

print('Plotting...')

matplotlib.rcParams.update({'font.size': 14})

ax = cost_time.plot.box(ylabel = 'time (seconds)', showmeans=True)
ax.set_aspect(0.25)
plt.yscale('log')
plt.savefig('plots/cost_time.pdf')

ax = cost_mem.plot.box(ylim = (10000, 1000000), ylabel = 'memory (KBytes)', showmeans=True)
#plt.ticklabel_format(axis='y', style='sci', scilimits=(3,3))
ax.set_aspect(1.0)
plt.yscale('log')
#sns.boxplot(data=cost_dataset, x="mode", y="time")
plt.savefig('plots/cost_mem.pdf')

#plt.show()

print('Done.')