from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
from sklearn.inspection import PartialDependenceDisplay, partial_dependence
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

SEED = 123
FEATURE_NAMES = ["patient_freewill","patient_age","patient_health","doctor_x","doctor_y","robot_speed","robot_charge"]
OUTCOME_NAME = "success"

freewill_mapping = {
    "foc" : 0,
    "free" : 1,
    "distr" : 2
    }

age_mapping = {
    "y" : 0,
    "e" : 1
}

health_mapping = {
    "h" : 0,
    "s" : 1,
    "u" : 2
}

def adjustFigAspect(fig, aspect=1):
    '''
    Adjust the subplot parameters so that the figure has the correct
    aspect ratio.
    '''
    xsize,ysize = fig.get_size_inches()
    minsize = min(xsize,ysize)
    xlim = .4*minsize/xsize
    ylim = .4*minsize/ysize
    if aspect < 1:
        xlim *= aspect
    else:
        ylim /= aspect
    fig.subplots_adjust(left=.5-xlim,
                        right=.5+xlim,
                        bottom=.5-ylim,
                        top=.5+ylim)


def generate_pdpExp(log_file: str) -> None:

    print("Reading dataset...")
    dataset = pd.read_csv(log_file)

    X = dataset[FEATURE_NAMES].replace({'patient_freewill': freewill_mapping}).replace({'patient_age': age_mapping}).replace({'patient_health': health_mapping}).to_numpy()
    y = dataset[[OUTCOME_NAME]].to_numpy()

    print("Loading classifier...")
    classifier = RandomForestClassifier(random_state=SEED, n_estimators=10).fit(X, np.ravel(y))

    hmt_features = ['patient_freewill', 'patient_health', 'robot_charge', 'robot_speed']


    # # == free will + health

    print("Building PDP explanations (free will)...")

    matplotlib.rcParams.update({'font.size': 16})
    #fig, ax = plt.subplots(figsize=(12, 6))
    _, ax = plt.subplots(ncols=1)
    #ax.set_title("Partial Dependence Display")
    pdp = PartialDependenceDisplay.from_estimator(
        classifier,
        X,
        [FEATURE_NAMES.index(hmt_features[0])],
        kind="average",
        feature_names=FEATURE_NAMES,
        categorical_features=[True,True,True,False,False,False,False],
        target=0,
        ax=ax,
        #ice_lines_kw={"color": "tab:blue", "alpha": 0.2, "linewidth": 0.5},
        #pd_line_kw={"color": "tab:orange", "linestyle": "--"}
    )
    pdp.figure_.subplots_adjust(wspace=0.8, hspace=0.3)
    #ocs, labels = plt.xticks()
    plt.xticks([0, 1, 2], ['foc', 'nom', 'inatt'], rotation=0)
    yticks = np.arange(0.0, 0.7, 0.1)
    #yrange = (yticks[0], yticks[-1]) 
    plt.yticks(yticks)
    #plt.grid(True)
    #plt.ylim(yrange)
    plt.gca().axes.set_aspect(5)
    plt.savefig('plots/pdp_freewill.pdf', bbox_inches='tight', pad_inches=0.1)
    #plt.show()

    print("Building PDP explanations (health)...")

    matplotlib.rcParams.update({'font.size': 16})
    #fig, ax = plt.subplots(figsize=(12, 6))
    _, ax = plt.subplots(ncols=1)
    #ax.set_title("Partial Dependence Display")
    pdp = PartialDependenceDisplay.from_estimator(
        classifier,
        X,
        [FEATURE_NAMES.index(hmt_features[1])],
        kind="average",
        feature_names=FEATURE_NAMES,
        categorical_features=[True,True,True,False,False,False,False],
        target=0,
        ax=ax,
        #ice_lines_kw={"color": "tab:blue", "alpha": 0.2, "linewidth": 0.5},
        #pd_line_kw={"color": "tab:orange", "linestyle": "--"}
    )
    pdp.figure_.subplots_adjust(wspace=0.8, hspace=0.3)
    #ocs, labels = plt.xticks()
    plt.xticks([0, 1, 2], ['heal', 'sick', 'unst'], rotation=0)
    yticks = np.arange(0.0, 0.7, 0.1)
    #yrange = (yticks[0], yticks[-1]) 
    plt.yticks(yticks)
    #plt.grid(True)
    #plt.ylim(yrange)
    plt.gca().axes.set_aspect(5)
    plt.savefig('plots/pdp_health.pdf', bbox_inches='tight', pad_inches=0.1)
    #plt.show()

    print("Building PDP explanations (joint free_will-health)...")
    
    matplotlib.rcParams.update({'font.size': 16})
    #fig, ax = plt.subplots(figsize=(12, 6))
    fig, ax = plt.subplots(ncols=1)
    #ax.set_title("Partial Dependence Display")
    pdp = PartialDependenceDisplay.from_estimator(
        classifier,
        X,
        [(FEATURE_NAMES.index(hmt_features[0]), FEATURE_NAMES.index(hmt_features[1]))],
        kind="average",
        feature_names=FEATURE_NAMES,
        categorical_features=[True,True,True,False,False,False,False],
        target=0,
        ax=ax,
        #ice_lines_kw={"color": "tab:blue", "alpha": 0.2, "linewidth": 0.5},
        #pd_line_kw={"color": "tab:orange", "linestyle": "--"}
    )
    pdp.figure_.subplots_adjust(wspace=0.8, hspace=0.3)
    #ocs, labels = plt.xticks()
    plt.xticks([0, 1, 2], ['heal', 'sick', 'unst'], rotation=0)
    plt.yticks([0, 1, 2], ['foc', 'nom', 'inatt'], rotation=0)
    #fig.colorbar(shrink=0.6)
    #yticks = np.arange(0.0, 0.7, 0.1)
    #yrange = (yticks[0], yticks[-1]) 
    #plt.yticks(yticks)
    #plt.grid(True)
    #plt.ylim(yrange)
    plt.gca().axes.set_aspect(1.0)
    plt.savefig('plots/pdp_freewill+health.pdf', bbox_inches='tight', pad_inches=0.1)
    #plt.show()


    # == battery + speed

    print("Building PDP explanations (charge)...")

    matplotlib.rcParams.update({'font.size': 16})
    #fig, ax = plt.subplots(figsize=(12, 6))
    _, ax = plt.subplots(ncols=1)
    #ax.set_title("Partial Dependence Display")
    pdp = PartialDependenceDisplay.from_estimator(
        classifier,
        X,
        [FEATURE_NAMES.index(hmt_features[2])],
        kind="average",
        feature_names=FEATURE_NAMES,
        categorical_features=[True,True,True,False,False,False,False],
        target=0,
        ax=ax,
        #ice_lines_kw={"color": "tab:blue", "alpha": 0.2, "linewidth": 0.5},
        #pd_line_kw={"color": "tab:orange", "linestyle": "--"}
    )
    pdp.figure_.subplots_adjust(wspace=0.8, hspace=0.3)
    #ocs, labels = plt.xticks()
    #plt.xticks([0, 1, 2], ['foc', 'nom', 'inatt'], rotation=0)
    #yticks = np.arange(0.0, 0.8, 0.1)
    #yrange = (yticks[0], yticks[-1]) 
    plt.yticks(np.arange(0.0, 0.8, 0.1))
    #plt.grid(True)
    #plt.ylim(yrange)
    plt.gca().axes.set_aspect(2)
    plt.savefig('plots/pdp_charge.pdf', bbox_inches='tight', pad_inches=0.1)
    #plt.show()

    print("Building PDP explanations (speed)...")

    matplotlib.rcParams.update({'font.size': 16})
    #fig, ax = plt.subplots(figsize=(12, 6))
    _, ax = plt.subplots(ncols=1)
    #ax.set_title("Partial Dependence Display")
    pdp = PartialDependenceDisplay.from_estimator(
        classifier,
        X,
        [FEATURE_NAMES.index(hmt_features[3])],
        kind="average",
        feature_names=FEATURE_NAMES,
        categorical_features=[True,True,True,False,False,False,False],
        target=0,
        ax=ax,
        #ice_lines_kw={"color": "tab:blue", "alpha": 0.2, "linewidth": 0.5},
        #pd_line_kw={"color": "tab:orange", "linestyle": "--"}
    )
    pdp.figure_.subplots_adjust(wspace=0.8, hspace=0.3)
    #ocs, labels = plt.xticks()
    #plt.xticks([0, 1, 2], ['heal', 'sick', 'unst'], rotation=0)
    #yrange = (yticks[0], yticks[-1]) 
    plt.yticks(np.arange(0.0, 0.8, 0.1))
    #plt.grid(True)
    #plt.ylim(yrange)
    plt.gca().axes.set_aspect(100)
    plt.savefig('plots/pdp_speed.pdf', bbox_inches='tight', pad_inches=0.1)
    #plt.show()
    
    # matplotlib.rcParams.update({'font.size': 16})
    # #fig, ax = plt.subplots(figsize=(12, 6))
    # fig, ax = plt.subplots(ncols=1)
    # #ax.set_title("Partial Dependence Display")
    # pdp = PartialDependenceDisplay.from_estimator(
    #     classifier,
    #     X,
    #     [(FEATURE_NAMES.index(hmt_features[2]), FEATURE_NAMES.index(hmt_features[3]))],
    #     kind="average",
    #     feature_names=FEATURE_NAMES,
    #     categorical_features=[True,True,True,False,False,False,False],
    #     target=0,
    #     ax=ax,
    #     contour_kw={'cmap': 'Blues', 'fontsize': 'x-large'}
    #     #ice_lines_kw={"color": "tab:blue", "alpha": 0.2, "linewidth": 0.5},
    #     #pd_line_kw={"color": "tab:orange", "linestyle": "--"}
    # )
    # pdp.figure_.subplots_adjust(wspace=0.8, hspace=0.3)
    # #ocs, labels = plt.xticks()
    # #plt.xticks([0, 1, 2], ['heal', 'sick', 'unst'], rotation=0)
    # #plt.yticks([0, 1, 2], ['foc', 'nom', 'inatt'], rotation=0)
    # #yticks = np.arange(0.0, 0.7, 0.1)
    # #yrange = (yticks[0], yticks[-1]) 
    # #plt.yticks(yticks)
    # #plt.grid(True)
    # #plt.ylim(yrange)
    # plt.gca().axes.set_aspect(0.015)
    # plt.colorbar(pdp)
    # #plt.savefig('plots/pdp_charge+speed.pdf', bbox_inches='tight', pad_inches=0.1)
    # plt.show()

    print("Building PDP explanations (joint charge-speed)...")

    matplotlib.rcParams.update({'font.size': 12})
    fig = plt.figure()
    features = (FEATURE_NAMES.index(hmt_features[2]), FEATURE_NAMES.index(hmt_features[3]))
    pdp = partial_dependence(classifier, X, features=features, kind="average", grid_resolution=15)
    XX, YY = np.meshgrid(pdp["values"][0], pdp["values"][1])
    Z = pdp.average[0].T 
    ax = fig.add_subplot(projection="3d")
    fig.add_axes(ax)

    surf = ax.plot_surface(XX, YY, Z, rstride=1, cstride=1, cmap='viridis', edgecolor="k")
    ax.set_xlabel(hmt_features[2])
    ax.set_ylabel(hmt_features[3])
    #fig.suptitle("PD of number of bike rentals on\nthe temperature and humidity GBDT model", fontsize=16)
    # pretty init view
    ax.view_init(elev=22, azim=122)
    clb = plt.colorbar(surf, pad=0.08, shrink=0.6, aspect=10)
    #clb.ax.set_title("PD")
    ax.view_init(30, 220)
    #plt.show()
    plt.savefig('plots/pdp_charge+speed_3d.pdf', bbox_inches='tight', pad_inches=0.1)
    


def main():
    generate_pdpExp("data/formalise2023_dataset1000.csv")

if __name__ == "__main__":
    main()