
import json
import time
import os
import math
import argparse
import logging
import pickle
import numpy as np
import matplotlib.pyplot as plt
import sklearn.metrics as metrics
from sklearn.preprocessing import StandardScaler
from sklearn.neural_network import MLPRegressor
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import cross_validate
from sklearn.linear_model import LinearRegression
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.svm import SVR


def process_line_arguments():
    """ Read de command line arguments from console"""

    parser = argparse.ArgumentParser(description='Nanowire predictor.')
    parser.add_argument('-r','--regressor', help='MLP (0), LinearRegression (1), DTree (2), RForest (3) and SVM (4)', default=0, required=False)
    parser.add_argument('-f','--finetuning', help='fine-tuning an existent model', action='store_true', required=False)
    parser.add_argument('-i','--train_folder', help='folder with train json files',required=True)
    parser.add_argument('-t', '--test_folder', help='folder with test json files', required=False)
    parser.add_argument('-ts', '--test_size', help='percentage of test examples (from 0 to 1)', type=float, default=0.2, required=False)
    parser.add_argument('-trs', '--train_size', help='percentage of train examples (from 0 to 1)', type=float, required=False)
    parser.add_argument('-y','--response', help='name of the response variable in json files',
                        default='ioff_VG', required=False)
    parser.add_argument('-nit', '--iterations', help='max. number of iterations', default=2000, required=False)
    parser.add_argument('-a','--activation', help='activation function', default='tanh', required=False)
    parser.add_argument('-s','--seed', help='random seed', default=905, required=False)
    parser.add_argument('-p','--plot', help='plot true-predicted graph', default=0, required=False)
    args = parser.parse_args()

    return(args)


#####################################################

def read_train_files(args):

    X = []
    Y = []
    response = args.response
    print('----------------------------------')
    for subdir, dirs, files in os.walk(args.train_folder):
        for file in files:
            if file.endswith(".json"):
                print('Processing input train file:', file)
                #print(os.path.join(subdir, file))
                # Opening JSON file
                f = open(os.path.join(subdir, file))
                # Returns JSON object as an array
                data = json.load(f)
                # Iterating array
                for i in data:
                    #if not math.isnan(float(i[response])):
                    if i[response] is not None:
                        X.append(i['ler_profile'])
                        if 'VG' in response:
                            if i[response] <= 0:
                                print('Negative value!!! ', i[response])
                            Y.append(math.log(i[response]))
                        else:
                            Y.append(i[response])

                # Closing file
                f.close()
    return X, Y

#####################################################

def read_test_files(args):

    X_test = []
    Y_test = []
    all_files_test = []
    response = args.response

    print('----------------------------------')
    for subdir, dirs, files in os.walk(args.test_folder):
        for file in files:
            if file.endswith(".json"):
                print('Processing input test file:', file)
                #print(os.path.join(subdir, file))
                # Opening JSON file
                f = open(os.path.join(subdir, file))
                # Returns JSON object as an array
                file_test = json.load(f)
                all_files_test.extend(file_test)
                # Closing file
                f.close()

    # Iterating array
    for i in all_files_test:
        #if not math.isnan(float(i[response])):
        if i[response] is not None:
            X_test.append(i['ler_profile'])
            if 'VG' in response:
                if i[response] <= 0:
                    print('Negative value!!! ', i[response])
                Y_test.append(math.log(i[response]))
            else:
                Y_test.append(i[response])
        else:
            logging.warning('None value detected!')

    return X_test, Y_test

#####################################################

def train_model(X_train, X_test, Y_train, Y_test, args):

    print('----------------------------------')
    print ('Number of train examples:', len(X_train))
    print ('Number of test examples:', len(X_test))
    reg_type = int(args.regressor)

    sc_X = StandardScaler()
    X_train_fit = sc_X.fit(X_train)
    X_trainscaled = X_train_fit.transform(X_train)
    X_testscaled = X_train_fit.transform(X_test)

    if reg_type == 1:
        print('*********** LINEAR REGRESSION ************')
        reg = LinearRegression()
    elif reg_type == 2:
        print('*********** DECISSION TREE ***********')
        reg = DecisionTreeRegressor(random_state = int(args.seed))
    elif reg_type == 3:
        print('*********** RANDOM FOREST ***********')
        reg = RandomForestRegressor(n_estimators = 100 ,  random_state = int(args.seed))
    elif reg_type == 4:
        print('*********** SVM ***********')
        reg = SVR(kernel="rbf", epsilon=0.05)
    else:
        print('*********** MLP REGRESSOR ***********')
        hl_sizes = (80, 80, 80)
        iters = int(args.iterations)
        print('Layers: ', hl_sizes)
        print('Max. number of iterations: ', iters)
        reg = MLPRegressor(hidden_layer_sizes = hl_sizes, solver = 'lbfgs', random_state=int(args.seed), alpha = 0.1, tol = 1e-10, activation = args.activation, max_iter = iters, verbose=False)

    start = time.time()
    reg.fit(X_trainscaled, Y_train)
    t = time.time() - start
    if hasattr(reg, 'n_iter_'):
        print('Iterations executed:', reg.n_iter_)
    print('Training time (seconds):', t)
    Y_pred = reg.predict(X_testscaled)

    # #### Print prediction train
    Y_train_pred = reg.predict(X_trainscaled)
    Y_train_pred = np.exp(Y_train_pred)
    Y_train_exp = np.exp(Y_train)
    #print_JSON_output(args, Y_train_pred, args.train_folder)
    r2 = metrics.r2_score(Y_train_exp, Y_train_pred)
    variance = metrics.explained_variance_score(Y_train_exp, Y_train_pred)
    MAE = metrics.mean_absolute_error(Y_train_exp, Y_train_pred)
    MSE = metrics.mean_squared_error(Y_train_exp, Y_train_pred)
    RMSE = np.sqrt(metrics.mean_squared_error(Y_train_exp, Y_train_pred))
    MAPE = metrics.mean_absolute_percentage_error(Y_train_exp, Y_train_pred)
    print('----------------------------------')
    print('Metrics Training -----------------')
    print('R2 score:', r2)
    print('Explained variance score:', variance)
    print('Mean Absolute Error (MAE):', MAE)
    print('Mean Squared Error (MSE):', MSE)
    print('Root Mean Squared Error (RMSE):', RMSE)
    print('Mean Absolute Percentage Error (MAPE):', MAPE)

    # save the model to disk
    filename = 'trained_model.sav'
    pickle.dump(reg, open(filename, 'wb'))

    return t, Y_pred

#####################################################

def finetune_model(X_train, X_test, Y_train, Y_test, args):

    print('----------------------------------')
    # load the model from disk
    print('Loading existent model')
    filename = 'trained_model.sav'
    loaded_model = pickle.load(open(filename, 'rb'))
    #result = loaded_model.score(X_testscaled, Y_test)
    #print(result)

    print ('Number of finetuning train examples:', len(X_train))
    print ('Number of test examples:', len(X_test))

    sc_X = StandardScaler()
    X_train_fit = sc_X.fit(X_train)
    X_trainscaled = X_train_fit.transform(X_train)
    X_testscaled = X_train_fit.transform(X_test)

    loaded_model.max_iter = int(args.iterations)
    loaded_model.activations = args.activation
    loaded_model.warm_start = True
    print('Warm start: ', loaded_model.warm_start)
    start = time.time()
    loaded_model.fit(X_trainscaled, Y_train)
    t = time.time() - start
    print('Training time (seconds):', t)
    Y_pred = loaded_model.predict(X_testscaled)

    return t, Y_pred

#####################################################

def print_JSON_output(args, Y_pred, folder):

    response = args.response
    a = iter(Y_pred)
    print('----------------------------------')
    print('Number of predicted examples:', len(Y_pred))
    sum = 0
    for subdir, dirs, files in os.walk(folder):
        for file in files:
            if file.endswith(".json"):
                # Opening JSON file
                f = open(os.path.join(subdir, file))
                # Returns JSON object as an array
                file_test = json.load(f)

                output = []
                for i in file_test:
                    if i[response] is not None:
                        j = next(a)
                        sum = sum +1
                        #if 'VG' in response:
                        #    j = math.exp(j)
                        elements = {'device': i['device'], 'id': i['id'], response : i[response], response + '_pred' : j, 'ler_profile':i['ler_profile']}
                        output.append(elements)
                    else:
                        logging.warning('None value detected!')

                name_file = os.path.splitext(file)[0] + '_pred.json'
                print('Generating output test file:', name_file)
                with open(name_file, 'w', encoding='utf-8') as json_fp:
                     json.dump(output, json_fp, ensure_ascii=False, indent=4)
                json_fp.close()

                # Closing file
                f.close()
    print('Number of written predicted examples (both should be equal):', sum)


def calculate_metrics(plot, Y_test, Y_pred, args, tr_time):

    if Y_test.any:
        #for i,j in zip(Y_test, Y_pred):
        #    print(i, j)
        r2 = metrics.r2_score(Y_test, Y_pred)
        variance = metrics.explained_variance_score(Y_test, Y_pred)
        MAE = metrics.mean_absolute_error(Y_test, Y_pred)
        MSE = metrics.mean_squared_error(Y_test, Y_pred)
        RMSE = np.sqrt(metrics.mean_squared_error(Y_test, Y_pred))
        MAPE = metrics.mean_absolute_percentage_error(Y_test, Y_pred)

        fp = open('performance_metrics.txt', 'a')
        line = str(args.train_size) + ' ' + str(args.response) + ' ' + str(args.iterations) + ' ' + str(round(tr_time,3)) + ' ' + str(r2) + ' ' + str(variance) + ' ' + str(MAE) + ' ' + str(MSE) + ' ' + str(RMSE) + ' ' + str(MAPE) + '\n'
        fp.write(line)
        fp.close()

        print('----------------------------------')
        print('Metrics --------------------------')
        print('R2 score:', r2)
        print('Explained variance score:', variance)
        print('Mean Absolute Error (MAE):', MAE)
        print('Mean Squared Error (MSE):', MSE)
        print('Root Mean Squared Error (RMSE):', RMSE)
        print('Mean Absolute Percentage Error (MAPE):', MAPE)

        dict_var = {'ioff_VG': 'I$_{off}$ [A]', 'ion_mc_VG': 'I$_{on}$ [A]', 'ss_VGI': '$SS$ [mV/dec]', 'vth_LE': 'V$_{th}$ [V]'}
        latex_rep = dict_var[args.response]

        RMSE_f = '{:.3e}'.format(RMSE)

        # Set the font size for x tick labels
        plt.rc('xtick', labelsize=20)
        # Set the font size for y tick labels
        plt.rc('ytick', labelsize=20)
        plt.rcParams.update({'font.sans-serif':'Arial'})
        plt.style.use('bmh')
        plt.figure(figsize=(8,8))
        plt.scatter(Y_test, Y_pred, c='crimson', alpha=0.4)
        plt.grid(linestyle = '--', linewidth = 0.5)
        p1 = max(Y_test)
        p2 = min(Y_test)
        plt.plot([p1, p2], [p1, p2], 'b-')
        plt.title(latex_rep, fontweight='bold', fontsize=28, pad=20)
        plt.xlabel('True Values ', fontweight='bold', fontsize=24)
        plt.ylabel('Predictions ', fontweight='bold', fontsize=24)
        plt.text(0.55, 0.1, 'R$^{2}$ = ' + str(round(r2,4)), fontfamily='Georgia', fontweight='bold', fontsize=20, transform=plt.gca().transAxes)
        plt.axis('equal')
        plt.text(0.55, 0.05, 'RMSE = ' + str(RMSE_f), fontfamily='Georgia', fontweight='bold', fontsize=20, transform=plt.gca().transAxes)
        plt.axis('equal')
        plt.savefig('output_prediction.pdf')

        if plot:
            plt.show()

#------------------------------------------------------------------------------
# main()
#------------------------------------------------------------------------------
if __name__ == '__main__':
    # process command line options
    args = process_line_arguments()
    X, Y = read_train_files(args)
    print('----------------------------------')
    if(args.test_folder) is not None:
        if(args.train_size) is None or args.train_size >= 1:
            print('Using the complete training dataset (all the files).')
            X_train = X
            Y_train = Y
        else:
            print('Using %.2f percent of the training dataset.' % (args.train_size * 100))
            X_train, X_val, Y_train, Y_val = train_test_split(X, Y, random_state=int(args.seed), train_size=args.train_size, test_size=1-args.train_size)
        print('Using all the input files as test dataset.')
        X_test, Y_test = read_test_files(args)
    else:
        print('WARNING: training size parameter not considered.')
        print('Using %.2f percent of the input files as training dataset.' % ((1-args.test_size) * 100))
        print('Using %.2f percent of the input files as test dataset.' % (args.test_size * 100))
        X_train, X_test, Y_train, Y_test = train_test_split(X, Y, random_state=int(args.seed), train_size=1-args.test_size, test_size=args.test_size)

    if not args.finetuning:
        tr_time, Y_pred = train_model(X_train, X_test, Y_train, Y_test, args)
    else:
        tr_time, Y_pred = finetune_model(X_train, X_test, Y_train, Y_test, args)

    if 'VG' in args.response:
        Y_test = np.exp(Y_test)
        Y_pred = np.exp(Y_pred)
    else: # convert list to array for checking Y_test.any
        Y_test = np.array(Y_test)
        Y_pred = np.array(Y_pred)

    if(args.test_folder) is not None:
        print_JSON_output(args, Y_pred, args.test_folder)

    calculate_metrics(args.plot, Y_test, Y_pred, args, tr_time)
