# -*- coding: utf-8 -*-
"""

Using a pretrained regressor, trains different quantile regressors to estimate the probability of being epsilon good.

"""
import argparse
import numpy as np
from matplotlib import pyplot as plt
from getdata import Dset, split_dset, Kfold_shuffler
from getdata import set_seeds, get_rng_dict
from networks import get_sim_device, load_reg_model
import torch
from Quantile_loss import trainQuantileNet
from getdata import load_reg_result, save_Qnet
from sklearn.model_selection import ParameterGrid
from sklearn.preprocessing import StandardScaler
from knife_samples_generation import train_KNIFE
from conditional_gaussian import train_conditional_gaussian


def train_single(train_x, train_y, val_x, val_y, sim_conf, reg_res, reg_par,
                 device, get_best=False):
    """Performs the training of a Q network for one cycle with predefined 
    parameters.

    Parameters:
    ----------
    train_x, train_y: (np.ndarray) training data set
    val_x, val_y: (np.ndarray) validation set
    sim_conf: configuration for the quantile network to be trained
    reg_res: results of the regressor, used for the absE simulation
    reg_par: path to the regressor, need for the absE simulation
    device: "cuda" or "cpu"
    get_best: (bool) if False, only the train and validation loss are returned. 
    If True returns
    train, val_loss, Q_network, scaler_qtrain_y, best_ind
    """

    musig = (reg_res['scale']['mu_y_train'] / reg_res['scale']['scale_y_train'])[0]
    sim_conf['musig'] = musig
    
    if sim_conf['type'] == 'absE':
        '''If we need to estimate the quantiles or CI of the error variable
        we perform a rescaling of the error'''

        reg_network = load_reg_model(n_in=train_x.shape[1],
                                     n_out=train_y.shape[1], reg_res=reg_res,
                                     reg_par=reg_par, device=device)
        reg_network.eval()

        # Compute the estimated outputs
        with torch.no_grad():  # Get the model predictions
            # Training
            train_y_hat = reg_network(torch.tensor(train_x, dtype=torch.float32,
                                                   device=device)).cpu().numpy()

            # new output is absolute error |Y-Yhat|
            train_y = np.abs(train_y - train_y_hat)

            # Validation
            y_val_hat = reg_network(torch.tensor(val_x, dtype=torch.float32,
                                                 device=device)).cpu().numpy()
            y_val = np.abs(val_y - y_val_hat)

        # rescale again the error data set
        scaler_qtrain_y = StandardScaler().fit(train_y)
        train_y = scaler_qtrain_y.transform(train_y)

        # Generate the scaled validation fold for this split
        val_y = scaler_qtrain_y.transform(y_val)
    
        # To avoid incompatibilities between versions
        scale_dict = {}
        scale_dict['mu_err_train'] = scaler_qtrain_y.mean_
        scale_dict['scale_err_train'] = scaler_qtrain_y.scale_
        sim_conf['err_dict'] = scale_dict
        
    elif sim_conf['type'] == 'absErel':
        reg_network = load_reg_model(n_in=train_x.shape[1],
                                     n_out=train_y.shape[1], reg_res=reg_res,
                                     reg_par=reg_par, device=device)
        reg_network.eval()

        # Compute the estimated outputs
        with torch.no_grad():  # Get the model predictions

                
            # Training
            train_y_hat = reg_network(torch.tensor(train_x, dtype=torch.float32,
                                                   device=device)).cpu().numpy()

            # new output is absolute error |Y-Yhat|
            train_y = np.abs(train_y - train_y_hat)/abs(train_y_hat+musig)


            # Validation
            y_val_hat = reg_network(torch.tensor(val_x, dtype=torch.float32,
                                                 device=device)).cpu().numpy()
            y_val = np.abs(val_y - y_val_hat)/abs(y_val_hat+musig)
            
            # rescale again the error data set
            scaler_qtrain_y = StandardScaler().fit(train_y)
            train_y = scaler_qtrain_y.transform(train_y)

            # Generate the scaled validation fold for this split
            val_y = scaler_qtrain_y.transform(y_val)
            
            # To avoid incompatibilities between versions
            scale_dict = {}
            scale_dict['mu_err_train'] = scaler_qtrain_y.mean_
            scale_dict['scale_err_train'] = scaler_qtrain_y.scale_
            sim_conf['err_dict'] = scale_dict
            
    trainDS = Dset(train_x, train_y)
    valDS = Dset(val_x, val_y)

    if sim_conf['qtype'] in ['Inteval', 'SQR']:
        # now train the quantiles with the corresponding scale
        (Q_network, train_loss_Q, val_loss_Q, best_ind) = trainQuantileNet(trainDS,
                                                                           valDS,
                                                                           device,
                                                                           sim_conf,
                                                                           get_best=get_best)
    elif sim_conf['qtype'] == 'KNIFE':
        (Q_network, train_loss_Q, val_loss_Q, best_ind) = train_KNIFE(trainDS,
                                                                      valDS,
                                                                      device,
                                                                      sim_conf,
                                                                      get_best=get_best)
    elif sim_conf['qtype'] == 'conditional_gaussian':
        (Q_network, train_loss_Q, val_loss_Q, best_ind) = train_conditional_gaussian(trainDS,
                                                                                     valDS,
                                                                                     device,
                                                                                     sim_conf,
                                                                                     get_best=get_best)

    if get_best:
        return train_loss_Q, val_loss_Q, Q_network, best_ind
    else:
        return train_loss_Q, val_loss_Q


def main_loop(reg_res, sim_conf, device, reg_par, Kfold):
    """Runs a training for a network, either through cross-validation to tune the hyperparameters or the
    final run once K-fold validation has been done.

    Parameters:
    -----------
    reg_res: dictionary with the results from the regression
    sim_conf: dictionary with the configuration used for the quantile simulation
    device: 'cuda' or 'cpu'
    reg_par: string with the path of the file with the regressor DNN parameters of the pytorch model.
    Kfold (bool): if True performs K-fold validation (K is in sim_conf['fold_splits'] parameter). If False, performs a
    single run to train the final model with the passed hyperparameters.

    Returns:
    --------
    If Kfold=True: train_loss, test_loss (np.ndarrays) with the average train and validation losses of K fold.
    If Kfold=False: train_loss, val_loss, Qnet, best_ind, (qx_train, qx_val, qy_train, qy_val, qx_test, qy_test)
    """

    nfold_splits = sim_conf['fold_splits']
    nepochs = sim_conf['nepochs']
    stop_frac = sim_conf['stop_split']/100

    # Recover the data set that was used for the seed
    x_train = reg_res['dataset']['train x']
    y_train = reg_res['dataset']['train y']
    x_test = reg_res['dataset']['test x']
    y_test = reg_res['dataset']['test y']
    x_val = reg_res['dataset']['stop x']
    y_val = reg_res['dataset']['stop y']
    x_sup = reg_res['dataset']['sup x']
    y_sup = reg_res['dataset']['sup y']

    # we now load the regressor and the random number states that existed after the regressor was trained
    rng = set_seeds(seed)  # rng is a numpy random generator
    rng.bit_generator.state = reg_res['Random']['rng_state']
    torch.set_rng_state(reg_res['Random']['torch_state'])
    if torch.cuda.is_available():
        torch.cuda.set_rng_state(reg_res['Random']['cuda_state'])

    # Setup the training set for the quantile algorithm
    if sim_conf['split_use'] == 'regtrain':
        # All the training set of regressor including train+stop set
        qx_train = np.vstack((x_train, x_val))
        qy_train = np.vstack((y_train, y_val))

    elif sim_conf['split_use'] == 'regtrainsup':
        # Training set+supplementary data
        qx_train = np.vstack((x_train, x_val, x_sup))
        qy_train = np.vstack((y_train, y_val, y_sup))

    elif sim_conf['split_use'] == 'regsup':
        # Only the supplementary data
        qx_train = x_sup
        qy_train = y_sup

    # Perform cross-validation over the selected hyperparameters
    if Kfold:
        kf = Kfold_shuffler(n_splits=nfold_splits, dsize=qx_train.shape[0],
                            rng=rng)
        # train the model by K-fold validation
        train_loss_matrix = np.zeros((nfold_splits, nepochs))
        val_loss_matrix = np.zeros((nfold_splits, nepochs))
        for fold_cnt in range(nfold_splits):
            (train_index, val_index) = kf.get_split_inds(
                fold_cnt)  # get the splits

            # Separate the train for this split (data is already normalized with the training set of the regressor)
            qx_train_fold = qx_train[train_index]
            qy_train_fold = qy_train[train_index]

            qx_val_fold = qx_train[val_index]
            qy_val_fold = qy_train[val_index]

            (train_loss_matrix[fold_cnt],
             val_loss_matrix[fold_cnt]) = train_single(qx_train_fold,
                                                       qy_train_fold,
                                                       qx_val_fold,
                                                       qy_val_fold,
                                                       sim_conf,
                                                       reg_res,
                                                       reg_par,
                                                       device)

        avg_train_loss = np.mean(train_loss_matrix, axis=0)
        avg_val_loss = np.mean(val_loss_matrix, axis=0)
        return avg_train_loss, avg_val_loss

    else:
        # For a single run we separate the training set to have a stopping set
        (qx_train, qx_val, qy_train, qy_val) = split_dset(qx_train, qy_train,
                                                          test_frac=stop_frac,
                                                          rng=rng,
                                                          ndraw=nfold_splits)

        (tr_loss_best, val_loss_best,
         Qnetwork, best_ind) = train_single(qx_train,
                                            qy_train,
                                            qx_val,
                                            qy_val,
                                            sim_conf,
                                            reg_res,
                                            reg_par,
                                            device,
                                            get_best=True)

        # get state after finishing training of classifier.
        rng_dict = get_rng_dict(seed, rng)

        # Return the training set and stop set used for the training of the
        # quantile network.
        dset = {'train x': qx_train, 'stop x': qx_val, 'train y': qy_train,
                'stop y': qy_val}

        return tr_loss_best, val_loss_best, Qnetwork, best_ind, dset, rng_dict


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description="Computes the quantile " +
                                     "functions or data pdfs using different" +
                                     "algorithms",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    parser.add_argument("--seed", type=int, default=1,
                        help='Seed to be used for simulation')

    parser.add_argument("--type", type=str, default='absErel',
                        help='Type of simulation "outY" or "absE" or "absErel"')

    parser.add_argument("--dset", type=str, default='boston',
                        help='Name of the dataset or "all" for all of them ' +
                        'or small (<1000) or large (>1000)')

    parser.add_argument("--save", type=bool, default=True,
                        help="Generate dump files to disk")

    parser.add_argument("--dir", type=str, default="tmp_Test",
                        help='Base directory where to store results')

    parser.add_argument("--nepochs", type=int, default=100,
                        help='Number of epochs to run the training')

    parser.add_argument("--split_use", type=str, default='regtrain',
                        help="If regtrain: all the training set of regressor" +
                        "If regtrainsup: training set+supplementary data" +
                        "If regsup: only the supplementary data")

    parser.add_argument("--device", type=str, default='auto',
                        help='Possibilities: auto, cpu, cuda, cuda:0, cuda:1,etc')

    parser.add_argument("--show_plots", type=bool, default=True)

    parser.add_argument("--qtype", type=str, default='KNIFE',
                        help="SQR: will use pinball loss," +
                        "Interval will use the interval loss to estimate only a CI" +
                        "KNIFE: will estimate the CDF as gaussian mixture,"+
                        "conditional_gaussian, is the ensembles method")

    parser.add_argument("--signif_level", type=float, default=0.05,
                        help='Only for qtype=Interval, significance level')

    parser.add_argument("--regdir", type=str, default='runs_reg_900',
                        help="Dir where the regressor results are kept")

    parser.add_argument("--stopfrac", type=int, default=20,
                        help='Fraction of the training set used a stopping' +
                        'set after K-fold validation')

    parser.add_argument("--nfolds", type=int, default=5, help='Number of ' +
                        'folds used for training.')

    parser.add_argument('--batchsize', type=int, default=64, help='Batch size' +
                        'for training')

    parser.add_argument('--modes_number', type=int, default=64,
                        help='Number of terms of the Gaussian mixture')

    # conditional_gaussian
    parser.add_argument('--n_ens', type=int, default=10,
                        help='Number of ensemble members')
    # conditional_gaussian
    parser.add_argument('--alpha', type=float, default=0.05)
    # conditional_gaussian
    parser.add_argument('--n_hidden_layers', type=int, default=1)
    # conditional_gaussian
    parser.add_argument('--n_hidden_units', type=int, default=64)

    parser.add_argument("--lr", type=float, nargs='*', default=[1e-3, 5e-4, 1e-4])
    parser.add_argument("--wd", type=float, nargs='*', default=[0],
                        help="Weight decay")

    args = parser.parse_args()

    # General parameters
    seed = args.seed
    sim_type = args.type  # 'absE'  # or 'outY'
    enable_save = args.save
    nepochs = args.nepochs
    split_use = args.split_use
    reg_dir = args.regdir
    device_t = args.device
    bsize = args.batchsize
    nfolds = args.nfolds
    stop_split = args.stopfrac
    lr = args.lr
    wd = args.wd

    alg = args.qtype  # chooses the type of algorithm to be implemented
    '''Algorithm specific parameters'''
    alpha = args.signif_level  # only for qtype=Interval
    modes_number = args.modes_number  # only for KNIFE
    n_ens = args.n_ens  # only for conditional_gaussian

    if args.dset == 'all':
        dsets = ['yacht', 'boston', 'toy', 'energy', 'concrete', 'wine',
                 'kin8nm', 'power', 'naval']
    elif args.dset == 'small':
        dsets = ['yacht', 'boston', 'toy', 'energy', 'concrete']
    elif args.dset == 'large':
        dsets = ['wine', 'kin8nm', 'power', 'naval']
    else:
        dsets = [args.dset]
    base_folder = args.dir  # base directory where things are stored

    # [1e-3, 0.5e-3, 1e-4, 0.5e-4, 1e-5]
    param_for_grid = {'l_rate': lr, 'wd': wd}

    if alg in ['SQR', 'Interval']:
        param_for_grid['n_hlayers'] = [3]
        param_for_grid['n_inner_neurons'] = [64]

    param_grid = ParameterGrid(param_for_grid)
    device = get_sim_device(device_t)

    print('Using seed {0}'.format(seed))
    print('Simulation type: {0}'.format(sim_type))
    print('Q training set: {0}'.format(split_use))
    print('Using device: {0}'.format(device))
    for dset_name in dsets:
        print('Data set is: ' + dset_name)

        sim_conf = {'batch_size': bsize, 'nepochs': nepochs, 'device': device_t,
                    'qtype': alg, 'type': sim_type,
                    'fold_splits': nfolds, 'stop_split': stop_split,
                    'split_use': split_use}

        if alg == 'Interval':
            sim_conf['alpha'] = alpha  # confidence level of the algorithm
        elif alg == 'KNIFE':
            sim_conf['modes_number'] = modes_number
        elif alg == 'conditional_gaussian':
            sim_conf['n_ens'] = n_ens
            sim_conf['alpha'] = args.alpha
            sim_conf['n_hidden_layers'] = args.n_hidden_layers
            sim_conf['n_hidden_units'] = args.n_hidden_units

        # Setup variable to store the results
        data_dict = {'conf': sim_conf, 'dset_name': dset_name,
                     'type': 'Qestimate'}  # stores the output result to be dumped to the hard drive

        '''Hyperparameter training of the network - The results is a quantile network of Y or Error '''
        val_losses = []
        train_losses = []
        results = {}  # holds the training hyperparameter configurations
        for grid_cnt, item in enumerate(param_grid):
            for key in item.keys():
                sim_conf[key] = item[key]

            # restore data just in case
            (reg_res, reg_par) = load_reg_result(dset_name, seed, reg_type='DNN',
                                                 basedir=reg_dir, rng_split=True)

            # give a round of training
            (train_loss, val_loss) = main_loop(reg_res, sim_conf, device,
                                               reg_par, Kfold=True)

            val_losses.append(val_loss)
            train_losses.append(train_loss)
            results[grid_cnt] = item

        # Store Results of the Kfold training:
        data_dict['Kfold'] = {}
        data_dict['desc'] = 'Results of the K-fold training'
        data_dict['Kfold']['train_loss'] = train_losses
        data_dict['Kfold']['val_loss'] = val_losses
        data_dict['Kfold']['params'] = results  # each of the parameters

        '''Train the best model'''
        # Find the best parameters for testing
        val_loss_mat = np.array(val_losses)
        # index of the best loss of each model
        best_losses = np.argmin(val_loss_mat, axis=1)
        # index of the model with the smallest loss
        best_model_index = np.argmin(val_loss_mat[np.arange(val_loss_mat.shape[0]),
                                                  best_losses])
        # number of epochs of the best model
        nepochs_optimal = best_losses[best_model_index]

        # Now train the model again with the best hyperparameters
        print('Best parameters are: ')
        print(results[best_model_index])

        for key in results[best_model_index].keys():
            sim_conf[key] = results[best_model_index][key]

        # Retrain the best model:
        (reg_res, reg_par) = load_reg_result(dset_name, seed, reg_type='DNN',
                                             basedir=reg_dir, rng_split=True)

        (train_loss, val_loss, Qnet, ind_best, data_splits,
         rng_dict) = main_loop(reg_res, sim_conf, device, reg_par, Kfold=False)

        print('Best Avg Training loss {0:.4f}, '.format(train_loss[ind_best]) +
              'Best Avg. Val Loss {0:.4f}, '.format(val_loss[ind_best]) +
              'Best epoch {0}'.format(ind_best))

        # Results of the training with the best parameters
        data_dict['Train'] = {}
        data_dict['Train']['desc'] = "Result of training with the best hyperpars"
        data_dict['Train']['best pars'] = results[best_model_index]
        data_dict['Train']['train_loss'] = train_loss
        data_dict['Train']['stop_loss'] = val_loss
        data_dict['Train']['best_index'] = best_model_index
        data_dict['Train']['best_epoch Kfold'] = nepochs_optimal
        data_dict['Train']['best_epoch train'] = ind_best
        data_dict['Train']['splits'] = data_splits
        data_dict['Random'] = rng_dict

        
        data_dict['reg'] = {'res' : reg_res, 'par':  reg_par}


        plt.close('all')
        fig1 = plt.figure(1)
        ax = plt.subplot(2, 1, 1)

        for item in train_losses:
            ax.plot(item)
        ax.grid(True)
        # ax.legend()
        ax.set_xlabel('Epochs')
        ax.set_ylabel('Avg. Training loss')
        ax.set_title('Average training loss in cross-validation')

        ax2 = plt.subplot(2, 1, 2)
        for cnt, item in enumerate(val_losses):
            ax2.plot(item)
        ax2.plot(nepochs_optimal, val_loss_mat[best_model_index,
                                               nepochs_optimal], 'd')
        ax2.grid(True)
        ax2.set_xlabel('Epochs')
        ax2.set_ylabel('Avg. Val. loss')
        ax2.set_title('Average validation loss in cross-validation')
        plt.tight_layout()

        fig2 = plt.figure(2)
        ax = plt.subplot(1, 1, 1)
        ax.plot(train_loss, label='Training loss')
        ax.plot(val_loss, label='Stopping set loss')
        ax.set_xlabel('Epochs')
        ax.legend()
        ax.grid(True)
        plt.tight_layout()

        # Pytorch model
        if enable_save:
            fig_dict = {'Kfold_training': fig1, 'training': fig2}
            save_Qnet(model=Qnet, data_dict=data_dict, basefolder=base_folder,
                      dset_name=dset_name, alg=alg, seed=seed,
                      sim_type=sim_type, figs=fig_dict, split_rng=True)
        if args.show_plots:
            plt.show()

        '''END ORIGNAL'''
