# -*- coding: utf-8 -*-
"""


Trains different regressors to then evaluate their reliability with the epsilon good notion.

"""
import argparse
import numpy as np
from matplotlib import pyplot as plt
from getdata import get_uci_data, Dset, get_toy_data, save_reg_result
from getdata import split_dset, set_seeds, get_rng_dict
from networks import train_regressor, get_sim_device
import torch
from torch import nn
from sklearn.model_selection import KFold, ParameterGrid
from sklearn.preprocessing import StandardScaler


def gen_dataset_splits(dset_name, sim_conf, rng, verb=True, get_stop=False,
                       scale_data=False):
    train_split = sim_conf['train_split']/100  # fraction of data for training
    sup_split = sim_conf['sup_split']/100  # fraction of data left out
    stop_split = sim_conf['stop_split']/100
    test_split = 1-train_split-sup_split  # fraction of data used for testing   
    
    if (test_split <=0) or (test_split>=1):
        raise Exception('Test split has to be in the range (0, 1)')
    
    # load the datasets        
    if dset_name == 'toy':
        Nsamps = sim_conf['toy_size']
        (x,y) = get_toy_data(Nsamps, mode='data', rng=rng)
    else:
        (x,y) = get_uci_data(dset_name) # load the data as numpy arrays
    
    # separate the train and test data sets randomly using the numpy rng   
    (x_train, x_test, y_train, y_test) = split_dset(x, y, test_frac=test_split,
                                                    rng=rng)
        
    # Separate the supplementary data set
    if sup_split>0:
        frac = sup_split/(1-test_split) # relative fraction to size of x_train
        (x_train, x_sup, y_train, y_sup) = split_dset(x_train, y_train,
                                                            test_frac = frac,
                                                            rng=rng)
    else:
        x_sup = np.array([])
        y_sup = np.array([])
        
    # generate a stopping set or validation set for the training if necessary
    if get_stop:
        frac = stop_split/train_split
        (x_train, x_stop, y_train, y_stop) = split_dset(x_train,
                                                        y_train,
                                                        test_frac=frac, 
                                                        rng=rng)
    if scale_data:
        scaler_train_x = StandardScaler().fit(x_train)
        scaler_train_y = StandardScaler().fit(y_train)
        
        # train data
        x_train = scaler_train_x.transform(x_train)
        y_train = scaler_train_y.transform(y_train)
        
        # test data
        x_test = scaler_train_x.transform(x_test)
        y_test = scaler_train_y.transform(y_test)

        if sup_split > 0:
            x_sup = scaler_train_x.transform(x_sup)
            y_sup = scaler_train_y.transform(y_sup)
            
        if get_stop:
            x_stop = scaler_train_x.transform(x_stop)
            y_stop = scaler_train_y.transform(y_stop)
        
        # To avoid incompatibilities between versions
        scale_dict = {}
        scale_dict['mu_x_train'] = scaler_train_x.mean_
        scale_dict['scale_x_train'] = scaler_train_x.scale_
        scale_dict['mu_y_train'] = scaler_train_y.mean_
        scale_dict['scale_y_train'] = scaler_train_y.scale_

            
    if verb:
        print('Training size: {0} - Test size: {1} - '.format(x_train.shape[0],
                                                       x_test.shape[0])+
              'Supplementary data set size: {0}'.format(x_sup.shape[0]))

        if get_stop:
            print('Stopping/validation set size: {0}'.format(x_stop.shape[0]))
        
        print('Fractions are: {0:.2f}, {1:.2f}, {2:.2f}'.format(x_train.shape[0]/x.shape[0],
                                                               x_test.shape[0]/x.shape[0],
                                                               x_sup.shape[0]/x.shape[0]))
        
        if get_stop:
            print('Fraction of stopping/validation set: {:2f}'.format(x_stop.shape[0]/x.shape[0]))
            
        if scale_data:
            print('Data have been normalized with the standard scaler')
            
    data_splits = {'train x': x_train, 'train y': y_train,
                   'test x': x_test, 'test y': y_test,
                   'sup x': x_sup, 'sup y': y_sup}
    if get_stop:
        data_splits['stop x'] = x_stop
        data_splits['stop y'] = y_stop
    
    
    if scale_data:
        return data_splits, scale_dict
    else:
        return data_splits




def main_loop(dset_name, seed, sim_conf):

    nfold_splits = sim_conf['fold_splits']  # splits for K-fold validation
    nepochs = sim_conf['nepochs']
    
    rng = set_seeds(seed)  # rng is a numpy random generator
    device = get_sim_device(sim_conf['device'], verb=True)

    data_splits = gen_dataset_splits(dset_name, sim_conf, rng)

    # the data are already shuffled in the split
    kf = KFold(n_splits=nfold_splits, shuffle=False)

    # train the model by K-fold validation
    train_loss_mat = np.zeros((nfold_splits, nepochs))
    val_loss_mat = np.zeros((nfold_splits, nepochs))
    
    x_train = data_splits['train x']
    y_train = data_splits['train y']
    
    for fold_cnt, (train_index, val_index) in enumerate(kf.split(x_train)):

        # Separate the train and test sets for this split
        x_train_fold = x_train[train_index]
        y_train_fold = y_train[train_index]

        # Rescale the train set
        scaler_train_x = StandardScaler().fit(x_train_fold)
        scaler_train_y = StandardScaler().fit(y_train_fold)
        x_train_fold = scaler_train_x.transform(x_train_fold)
        y_train_fold = scaler_train_y.transform(y_train_fold)

        # Generate the scaled validation fold for this split
        x_val_fold = scaler_train_x.transform(x_train[val_index])
        y_val_fold = scaler_train_y.transform(y_train[val_index])

        # Obtain the datasets to train the regressor
        trainDS = Dset(x_train_fold, y_train_fold)
        valDS = Dset(x_val_fold, y_val_fold)

        (reg_network, train_loss, val_loss) = train_regressor(trainDS, valDS, 
                                                              device, 
                                                              netconf=sim_conf)
        train_loss_mat[fold_cnt] = train_loss.cpu().numpy()
        val_loss_mat[fold_cnt] = val_loss.cpu().numpy()

    avg_train_loss = np.mean(train_loss_mat, axis=0)
    avg_val_loss = np.mean(val_loss_mat, axis=0)
    return avg_train_loss, avg_val_loss


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description="Main script to run regressor",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument("--seed", type=int, default=11, 
                        help='Seed to be used for simulation')
    
    parser.add_argument("--dset", type=str, default='all', 
                        help='Name of the dataset or "all" for all')
    
    parser.add_argument("--save", type=bool, default=True,
                        help="Save results to disk",)
    
    parser.add_argument("--dir", type=str, default='tmp_REG',
                        help='base directory where to store results')
    
    parser.add_argument("--nepochs", type=int, default=250,
                        help='Number of epochs to run the training')
    
    parser.add_argument("--device", type=str, default='cuda', 
                        help='Possibilities: auto, cpu, cuda, cuda:0, cuda:1')
    
    parser.add_argument("--splits", type=int, nargs=2, default=[90, 0],
                        help='Two int (x1, x2) where x1 is the % of the'+
                        'dataset used for training and x2 as supplementary'+
                        '1-x1-x2 is used as test set')
    
    parser.add_argument("--stopfrac", type=int, default=20,
                        help='Fraction of the training set used a stopping'+
                        'set after K-fold validation')
    
    parser.add_argument("--nfolds", type=int, default=5, help='Number of '+
                        'folds used for training.')
    
    parser.add_argument('--batchsize', type=int, default=64, help='Batch size'+
                        'for training')
    
    # Read the arguments
    show_plots = False
    args = parser.parse_args()
    seed = args.seed
    enable_save = args.save
    nepochs = args.nepochs
    (train_split, sup_split) = args.splits
    bsize = args.batchsize
    nfolds = args.nfolds
    stop_split = args.stopfrac

    if args.dset == 'all':
        dsets = ['yacht', 'boston', 'toy', 'energy', 'concrete','wine', 
                 'kin8nm', 'power', 'naval']
    elif args.dset == 'small':
        dsets = ['yacht', 'boston', 'toy', 'energy', 'concrete']
    elif args.dset == 'large':
        dsets = ['wine', 'kin8nm', 'power', 'naval']
    else:
        dsets = [args.dset]
        
    folder = args.dir # base directory where things are stored
    device_t = args.device 
    device = get_sim_device(device_t)
    
    
    '''Basic parameter grid for the regressor'''
    param_grid = ParameterGrid({'l_rate': [1e-3, 1e-4],
                                'wd': [0, 0.025, 0.05, 0.075, 0.1],
                                'n_hlayers': [3],
                                'n_inner_neurons': [64]})

    print('Using seed {0}'.format(seed))
    '''General simulation setup'''
    for dset_name in dsets:
        print('Data set is: ' + dset_name)

        sim_conf = {'batch_size': bsize, 'nepochs': nepochs, 
                    'fold_splits': nfolds, 'device': device_t, 
                    'train_split': train_split, 'sup_split':sup_split,
                    'stop_split': stop_split,
                    'toy_size': 560}

        '''Grid search over all parameters '''
        test_losses = []
        train_losses = []
        results = {}
        for grid_cnt, item in enumerate(param_grid):
            for key in item.keys():
                sim_conf[key] = item[key]

            (train_loss, test_loss) = main_loop(dset_name, seed, sim_conf)

            test_losses.append(test_loss)
            train_losses.append(train_loss)
            results[grid_cnt] = item

        '''Find the best parameters'''
        test_loss_mat = np.array(test_losses)
        # index of the best loss of each model: 
        best_losses = np.argmin(test_loss_mat, axis=1) 
        # index of the model with the smallest loss:
        best_model_index = np.argmin(test_loss_mat[np.arange(test_loss_mat.shape[0]), 
                                                   best_losses])
        nepochs_optimal = best_losses[best_model_index] # number of epochs of the best model


        ''' Train the chosen model using the whole training dataset with the 
        best hyperparameters that were chosen'''
        
        print('Best parameters are: ')
        print(results[best_model_index])

        # Reload the best parameters
        for key in results[best_model_index].keys():
            sim_conf[key] = results[best_model_index][key]

        rng = set_seeds(seed)  # reset the random number generators
        # generate the data splits used for training before
        (data_splits, scale_dict) = gen_dataset_splits(dset_name, sim_conf, rng,
                                         get_stop=True, scale_data=True)
        
        # Re generate the split
        trainDS = Dset(data_splits['train x'], data_splits['train y'])
        stopDS = Dset(data_splits['stop x'], data_splits['stop y'])
        testDS = Dset(data_splits['test x'], data_splits['test y'])

        # Get the final split that will be used for the Q analysis:
        # Q_dset = {'x train': x_train, 'x val': x_val, 'x test': x_test,
                  # 'y train': y_train, 'y val': y_val, 'y test': y_test}


        (best_net, train_loss, stop_loss, ind_best) = train_regressor(trainDS, 
                                                                      stopDS, 
                                                                      device, 
                                                                      netconf=sim_conf,
                                                                      get_best=True)

        print('Best Avg Training loss {0:.4f} '.format(train_loss[ind_best]) +
              'Best Avg. Val Loss {0:.4f}, '.format(stop_loss[ind_best]) +
              'Best epoch {0}'.format(ind_best))


        plt.close('all')
        # Plot of the losses during training
        fig1 = plt.figure(1, figsize=(7,6))
        ax = plt.subplot(2, 1, 1)
        for item in train_losses:
            ax.plot(item)
        ax.grid(True)
        # ax.legend()
        ax.set_xlabel('Epochs')
        ax.set_ylabel('Avg. Training loss')
        ax.set_title('Average training loss in cross-validation')
        
        ax2 = plt.subplot(2,1,2)
        for cnt, item in enumerate(test_losses):
            ax2.plot(item)#, label=''.format(results[cnt]['l_rate']))
        ax2.plot(nepochs_optimal, test_loss_mat[best_model_index, nepochs_optimal], 'd')
        ax2.grid(True)
        ax2.set_xlabel('Epochs')
        ax2.set_ylabel('Avg. Val. loss')
        plt.tight_layout()  
        ax2.set_title('Average validation loss in cross-validation')


        fig2 = plt.figure(2, figsize=(7,3))
        ax = plt.subplot(1, 1, 1)
        ax.plot(train_loss.cpu().numpy(), label='Training loss')
        ax.plot(stop_loss.cpu().numpy(), label='Stopping set loss')
        ax.set_xlabel('Epochs')
        ax.legend()
        ax.grid(True)
        plt.tight_layout()


        # Compute the MAE over the test set
        with torch.no_grad():
            best_net.eval()
            mae = nn.L1Loss()
            testMAE = mae(testDS.out.to(device), best_net(testDS.dset.float().to(device)))
            error = abs((testDS.out.to(device)-best_net(testDS.dset.float().to(device))).cpu().numpy())
            fig4 = plt.figure(4)
            plt.plot(error)
            plt.xlabel('Sample index')
            plt.ylabel('L1 error')
            plt.grid(True)
            plt.title('L1 loss of the test set normalized to '+
                      '0 mean and unit variance')

            if dset_name == 'toy':
                Xplot = torch.linspace(-2, 2, 300).reshape(-1, 1).to(device)
                Yhatplot = best_net(Xplot)
                # (Xtrue, Ytrue) = get_toy_data(0, mode='teo', rng=None, scale=(scaler_train_x, scaler_train_y))

        print("MAE loss: {0}".format(testMAE))

        if dset_name == 'toy':
            fig3 = plt.figure(3)
            plt.plot(data_splits['train x'], data_splits['train y'], 'o', 
                     label='Train', color='tab:blue')
            # plt.plot(x_test, y_test, 'o', label='Test', color='tab:red')
            plt.plot(Xplot.cpu(), Yhatplot.cpu(), label='Regressor', color='tab:orange')
            # plt.plot(, Ytrue, label='True output', color='tab:green')
            plt.xlabel('x')
            plt.ylabel('y')
            plt.grid(True)
            plt.legend()

        ''' Saving results to hard disk '''
        # # Tensorboard
        # if enableWriter:
        #     writer.add_figure(tag='Crossval', figure=fig1)
        #     writer.add_figure(tag='Test', figure=fig2)
        #     if dset_name == 'toy':
        #         writer.add_figure(tag='Fit', figure=fig3)

        '''Numerical'''
        data_dict = {}
        # General:
        data_dict['conf'] = sim_conf
        data_dict['dset_name'] = dset_name
        data_dict['type'] = 'regressor'

        # Results of the Kfold training:
        data_dict['Kfold'] = {}
        data_dict['Kfold']['desc'] = 'Results of the K-fold training'
        data_dict['Kfold']['train_loss'] = train_losses  # for K-fold
        data_dict['Kfold']['val_loss'] = test_losses  # for Kfold
        data_dict['Kfold']['params'] = results  # each of the parameters
        data_dict['dataset'] = data_splits
        data_dict['scale'] =  scale_dict 

        # Results of the training with the best parameters
        data_dict['Train'] = {}
        data_dict['Train']['desc'] = "Result of training with the best hyperpars"
        data_dict['Train']['best pars'] = results[best_model_index]
        data_dict['Train']['train_loss'] = train_loss
        data_dict['Train']['stop_loss'] = stop_loss
        data_dict['Train']['best_index'] = best_model_index
        data_dict['Train']['best_epoch'] = nepochs_optimal
        
        data_dict['Test MAE'] = testMAE

        # Random number state
        data_dict['Random'] = get_rng_dict(seed, rng)


        if enable_save:
            
            fig_dict = {'crossvalLoss': fig1, 'trainLoss': fig2, 'MAE': fig4}
            if dset_name== 'toy':
                fig_dict['fit'] = fig3
            
            
            basefolder = save_reg_result(model=best_net, basefolder=folder, 
                                         dset_name=dset_name, seed=seed,
                                         data_dict=data_dict, reg_type='DNN',
                                         figs=fig_dict, split_rng=True)

        if show_plots:
            plt.show()
        ''' END '''

            
