#!/usr/bin/env python

__authors__ = ('Fabio Cumbo (fabio.cumbo@unitn.it)')
__version__ = '0.1.0'
__date__ = 'Jan 17, 2023'

import os
import re
import time
import itertools
import statistics
import numpy as np
import pandas as pd

from numpy import sort
from sklearn.model_selection import KFold
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from xgboost import XGBClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.feature_selection import SelectFromModel

# Datasets folder path
basepath = "./datasets/"
if not os.path.exists(basepath):
    raise Exception("No \"datasets\" folder found!")

# List of datasets
datasets = [
    "BIN__ThomasAM__species.csv",
    "BIN_adult__ThomasAM__species.csv",
    "BIN_senior__ThomasAM__species.csv",
    "BIN_male__ThomasAM__species.csv",
    "BIN_female__ThomasAM__species.csv",
    "RA__ThomasAM__species.csv",
    "RA_adult__ThomasAM__species.csv",
    "RA_senior__ThomasAM__species.csv",
    "RA_male__ThomasAM__species.csv",
    "RA_female__ThomasAM__species.csv"
]

# List of classifiers
classifiers = {
    "XGBoost": XGBClassifier,
    "DecisionTree": DecisionTreeClassifier,
    "SVM": SVC,
    "LogisticRegression": LogisticRegression,
    "RandomForest": RandomForestClassifier,
}

# Cross validation folds
kfolds = 5
# Selection threshold on folds
selection_thresh = 4

# Iterate over classifiers
for classifier in classifiers:
    print("Classifier: {}".format(classifier))
    # Iterate over datasets
    for dataset in datasets:
        t0 = time.time()

        outfolder = os.path.join(basepath, classifier)
        if not os.path.exists(outfolder):
            os.makedirs(outfolder)

        print("\t{}".format(dataset))
        # Dataset file path
        filepath = os.path.join(basepath, dataset)

        # Load the dataset into a pandas DataFrame
        X = pd.read_csv(filepath, sep=",", header=0, index_col=0)
        X[list(X.columns)[-1]] = (X[list(X.columns)[-1]] == "CRC").astype(int)
        # Extract classes from the last column
        y = X.iloc[:,-1]

        # Remove the last column "Class" from the dataset
        X.drop("Class", axis=1, inplace=True)
        # Extract the feature names from the dataset
        feature_names = [re.sub('\W+','', s) for s in list(X.columns)]
        X.columns = feature_names

        # Init folds
        kf = KFold(n_splits=kfolds)
        # Init the classifier
        if classifier == "SVM":
            cfr = classifiers[classifier](kernel="linear", random_state=0)
        elif classifier == "LogisticRegression":
            cfr = classifiers[classifier](solver='liblinear', random_state=0)
        else:
            cfr = classifiers[classifier](random_state=0)

        # Iterate over folds
        fold_count = 1
        accuracies = list()
        selected_features = list()
        for train_index, test_index in kf.split(X, y):
            # Get train dataset
            X_train, y_train = X.iloc[train_index,:], y.iloc[train_index]
            # Get test dataset
            X_test, y_test = X.iloc[test_index,:], y.iloc[test_index]

            # Fit
            cfr.fit(X_train, y_train)

            """
            Feature importance based on mean decrease in impurity
            """

            # Get feature importances
            if classifier == "LogisticRegression" or classifier == "SVM":
                importances = pd.DataFrame([abs(i) for i in cfr.coef_[0]], index=feature_names)
            else:
                importances = pd.DataFrame(cfr.feature_importances_, index=feature_names)
            importances.sort_values(importances.columns[0], axis=0, ascending=False, inplace=True)

            # Define output file path
            out_filepath = os.path.join(outfolder, "{}__{}__fold_{}.tsv".format(os.path.splitext(os.path.basename(filepath))[0], classifier, fold_count))

            # Dump result
            importances.to_csv(out_filepath, sep="\t", header=False, index=True)

            # Fit model using each importance as a threshold
            thresholds = sorted(set([abs(i) for i in cfr.coef_[0]])) if classifier == "LogisticRegression" or classifier == "SVM" else sorted(set(cfr.feature_importances_))
            best_in_fold_threshold = None
            best_in_fold_features = None
            best_in_fold_accuracy = 0.0
            for thresh in thresholds:
                # select features using threshold
                selection = SelectFromModel(cfr, threshold=thresh, prefit=True)
                select_X_train = selection.transform(X_train)
                # train model
                if classifier == "SVM":
                    selection_model = classifiers[classifier](kernel="linear", random_state=0)
                elif classifier == "LogisticRegression":
                    selection_model = classifiers[classifier](solver='liblinear', random_state=0)
                else:
                    selection_model = classifiers[classifier](random_state=0)
                selection_model.fit(select_X_train, y_train)
                # eval model
                select_X_test = selection.transform(X_test)
                predictions = selection_model.predict(select_X_test)
                accuracy = accuracy_score(y_test, predictions)
                #print("\t\t\tThreshold=%f, n=%d, Accuracy: %.2f%%" % (thresh, select_X_train.shape[1], accuracy*100.0))
                if best_in_fold_features == None:
                    best_in_fold_threshold = thresh
                    best_in_fold_features = select_X_train.shape[1]
                    best_in_fold_accuracy = accuracy*100.0
                else:
                    if accuracy*100.0 >= best_in_fold_accuracy:
                        if select_X_train.shape[1] < best_in_fold_features:
                            best_in_fold_threshold = thresh
                            best_in_fold_features = select_X_train.shape[1]
                            best_in_fold_accuracy = accuracy*100.0

            # Report best result in current fold
            print("\t\tFold=%d, Threshold=%f, n=%d, Accuracy: %.2f%%" % (fold_count, best_in_fold_threshold, best_in_fold_features, best_in_fold_accuracy))
            with open(out_filepath, "a+") as out:
                out.write("\n# thresh: {}\n# features: {}\n# accuracy: {}\n".format(best_in_fold_threshold, best_in_fold_features, best_in_fold_accuracy))
        
            accuracies.append(best_in_fold_accuracy)

            selected_features.append(list(importances[importances[importances.columns[0]].ge(best_in_fold_threshold)].index))

            # Increment fold count
            fold_count += 1

        # Get features selected in at least 4 out of 5 folds
        selected_features = list(itertools.chain.from_iterable(selected_features))
        hist = dict()
        for s in selected_features:
            if s not in hist:
                hist[s] = 0
            hist[s] += 1
        in4 = [s for s in hist if hist[s]>=selection_thresh]

        t1 = time.time()
        print("\tTotal elapsed time {}s, Accuracy {}, Selected features {}".format(int(t1 - t0), round(statistics.mean(accuracies), 4), len(in4)))
