# -*- coding: utf-8 -*-


import numpy as np
from matplotlib import pyplot as plt
import torch
from tqdm import tqdm
from getdata import Dset, split_dset, Kfold_shuffler, set_seeds, load_reg_result, load_Qnet
from Quantile_loss import augment
from sklearn.metrics import auc, roc_curve
import argparse
from sklearn.model_selection import ParameterGrid
from networks import get_sim_device,load_reg_model,DeepNetBN
from networks import DVIC_net as DVIC_DNNnet
import copy
from DVIC_misc import save_DVICres
from exp_trans import train_exptrans
from Quantile_loss import find_Pg
from sklearn.preprocessing import StandardScaler
from knife_samples_generation import restore_KNIFEnet, sampleKNIFE
from conditional_gaussian import samplecond_gaussian,restore_cond_gaussian

def fpr_at_fixed_tpr(fprs, tprs, thresholds, tpr_level: float = 0.95):
    if all(tprs < tpr_level):
        raise ValueError(f"No threshold allows for TPR at least {tpr_level}.")
    idxs = [i for i, x in enumerate(tprs) if x >= tpr_level]
    idx = min(idxs)
    return fprs[idx], tprs[idx], thresholds[idx]

def get_eps_good(y_true, y_hat, thres, notion, musig=0):
    if notion == 'abs':
        pos_samps = np.abs(y_true - y_hat) > thres  # epsilon bad samples
    elif notion == 'rel':
        pos_samps = np.abs(y_true - y_hat) > \
               thres * abs(y_hat + musig) # epsilon bad samples
    else:
        raise Exception('Notion has to be "abs" or "rel",' +
                        'not {0}'.format(sim_conf['notion']))
    return pos_samps


# (qx_train[val_index], 
                               # qy_train[val_index],
                               # model_data)

def check_good_bad_frac(train_x, train_y, model_data, sim_conf, label):
    reg_network = load_reg_model(n_in=train_x.shape[1], n_out=train_y.shape[1],
                                 reg_res=model_data['reg']['res'],
                                 reg_par=model_data['reg']['par'],
                                 device=device)
    reg_network.eval()
    
    musig = (model_data['reg']['res']['scale']['mu_y_train'] / \
        model_data['reg']['res']['scale']['scale_y_train'])[0]
    
    with torch.no_grad():
        # output of the regressor
        train_y_hat = reg_network(torch.tensor(train_x, dtype=torch.float32,
                                               device=device)).cpu().numpy()
        
    # epsilon bad samples from training set
    train_pos_samps = get_eps_good(train_y, train_y_hat, sim_conf['eps'],
                                   notion=sim_conf['notion'], musig=musig)
    
    npos = np.count_nonzero(train_pos_samps) # eps bad samples
    
    frac_bad = npos/train_y.size # fraction of bad samples
    print('Fraction of bad in the {1}: {0:.3f}'.format(frac_bad, label))
    if (frac_bad > 0) and (frac_bad < 1):
        return True
    else:
        return False
    

def chooseParameters(loss_mat, opt_type='max'):
    '''Given a matrix loss_mat, such that each row is a training run with
    different hyperparameters, the method returns the row and column with the
    largest (if opt_type=='max') or smallest value (opt_type=='min')
    '''

    if opt_type == 'max':
        best_losses = np.nanargmax(loss_mat, axis=1)  # index of epoch
        best_model_index = np.nanargmax(
        loss_mat[np.arange(loss_mat.shape[0]), best_losses])  # index of the model with the smallest loss
    elif opt_type == 'min':
        best_losses = np.nanargmin(loss_mat, axis=1)  # best epoch for each row
        best_model_index = np.nanargmin(
        loss_mat[np.arange(loss_mat.shape[0]), best_losses])

    nepochs_optimal = best_losses[best_model_index]  # number of epochs of the best model

    return best_model_index, nepochs_optimal



# def chooseParameters(loss_mat, opt_type='max', train_losses=0):
#     '''Given a matrix loss_mat, such that each row is a training run with
#     different hyperparameters, the method returns the row and column with the
#     largest (if opt_type=='max') or smallest value (opt_type=='min')
#     '''

#     if opt_type == 'max':
#         best_losses = np.nanargmax(loss_mat, axis=1)  # index of epoch
#         # check_max = train_losses[np.arange(loss_mat.shape[0]), best_losses]
#         maxval = np.max(loss_mat[np.arange(loss_mat.shape[0]), best_losses])
        
       
#         if np.any(np.isclose(maxval, loss_mat[np.arange(loss_mat.shape[0]), best_losses])):
            
#             # get the index of the one that have the same AUROC 
#             close_inds = np.isclose(maxval, loss_mat[np.arange(loss_mat.shape[0]), best_losses])
#             inds = np.nonzero(close_inds)[0] # absolute index of potential best models
            
#             # index of epoch for the close 
#             best_epochs = np.argmin(train_losses[close_inds], axis=1)  
            
#             # best model relative to close_inds
#             best_model_rel = np.argmin(train_losses[close_inds, best_epochs])
#             # best model relative to all the models
#             best_model_index = inds[best_model_rel]
            
#             nepochs_optimal = best_epochs[best_model_rel] # number of epochs of the best model

#         else:
#             best_model_index = np.argmax(
#                 loss_mat[np.arange(loss_mat.shape[0]), best_losses])  # index of the model with the smallest loss
            
#             nepochs_optimal = best_losses[best_model_index]  # number of epochs of the best model
            
#     elif opt_type == 'min':
#         best_losses = np.nanargmin(loss_mat, axis=1)  # best epoch for each row
#         minval = np.min(loss_mat[np.arange(loss_mat.shape[0]), best_losses])
        
#         if np.any(np.isclose(minval, loss_mat[np.arange(loss_mat.shape[0]), best_losses])):
            
#             close_inds = np.isclose(minval, loss_mat[np.arange(loss_mat.shape[0]), best_losses])
#             inds = np.nonzero(close_inds)[0] # absolute index of potential best models
            
#             best_epochs = np.nanargmin(train_losses[close_inds], axis=1)  # index of epoch
#             # best model relative to close_inds
#             best_model_rel = np.argmin(train_losses[close_inds, best_epochs])
#             # best model relative to all the models
#             best_model_index = inds[best_model_rel]
#             nepochs_optimal = best_epochs[best_model_rel] # number of epochs of the best model
            
#         else:
#             best_model_index = np.nanargmin(
#                 loss_mat[np.arange(loss_mat.shape[0]), best_losses])

#             nepochs_optimal = best_losses[best_model_index]  # number of epochs of the best model

#     return best_model_index, nepochs_optimal


def getGamma(Qnet, reg_network, data, device, sim_conf, gen_type='cat'):

    nunif = sim_conf['nunif_gamma']
    samps = sim_conf['samps_d']  # Y or Error
    musig = sim_conf['musig']

    data = data.float()
    data_stack = data.repeat_interleave(repeats=nunif, dim=0)# stack vertically interleaved copies of the data
    Qnet.eval()
    Qnet.to('cpu')
    with torch.no_grad():
        Y = Qnet(augment(data_stack, torch.rand(data_stack.shape[0],
                                                      1, device='cpu',
                                                      dtype=torch.float32))).flatten()

    if samps == 'Yonly':
        return Y
    elif samps == 'Err':
        with torch.no_grad():
            reg_network.eval()
            reg_network.to('cpu')
            yhat = reg_network(data)  # compute the estimation of each x
        yhat_stack = yhat.repeat_interleave(repeats=nunif, dim=0).flatten() #interleave as the Y samples

        if sim_conf['notion'] == 'rel': # absolute error
            out = (Y-yhat_stack).abs()/(yhat_stack+musig).abs()
        elif sim_conf['notion'] == 'abs': # relative error
            out = (Y-yhat_stack).abs()
        else:
            raise Exception("Unknown output samples simulation")

    out = torch.reshape(out, (-1, nunif)) # each row becomes one data of E|X=x
    
    (model, train_loss, gamma_vals, conv) = train_exptrans(out, lr=sim_conf['lr_gamma'],
                                                           verb=True,
                                                           nepochs=500,
                                                           thres=0.005)
    sim_conf['gamma'] = model.gamma.item()
    sim_conf['gamma_conv'] = conv
    
    if sim_conf['enable plots']:
        plt.figure()
        ax= plt.subplot(2,1,1)
        ax.plot(train_loss)
        ax2 = plt.subplot(2,1,2)
        ax2.plot(gamma_vals)
    reg_network.to(device)
    Qnet.to(device)
    

def getYsamps(Qnet, reg_network, data, device, sim_conf, gen_type='cat', **kwargs):
    '''Given a quantile network that generates samples of Y|X=x and a data 2D
    tensor which has samples of x (data), one per row, the algorithm generates
    nunif pairs of independent samples of Y|X=x for each x in the data.

    If samps=='Yonly' it returns a tensor of size nunif*data.shape[0] x 2 with
    samples (Y1, Y2)|X=x.
    If samps=='Err' it will return (Y1-f(x), Y2-f(x)), that is, it will return
    samples of the error.

    if gen_type=='for' it generates the data sequentially for each x while if
    samps='cat' it will  generate the data all at once, resulting in a single
    call to the Qnet (faster but consumes more memory).
    '''

    print(f"Calling the function getYsamps with gen_type={gen_type}, qtype={sim_conf['qtype']}, type={sim_conf['type']}")

    nunif = sim_conf['nunif']
    # samps = sim_conf['samps_d']  # Y or Error
    musig = sim_conf['musig']


    if gen_type == 'for':
        # Compute the sampls of Y|X
        U1 = torch.rand(data.shape[0], nunif, device=device, dtype=torch.float32)
        U2 = torch.rand(data.shape[0], nunif, device=device, dtype=torch.float32)

        # Each row of Y1,Y2 contains samples of the Y|X=x for the same x
        Y1 = torch.zeros(data.shape[0], nunif, device=device)
        Y2 = torch.zeros(data.shape[0], nunif, device=device)
        Qnet.eval()
        if "requires_grad" in kwargs and kwargs["requires_grad"]:
            for cnt in range(U1.shape[1]):  # calculate the samples
                Y1[:, cnt] = Qnet(
                    augment(data, U1[:, cnt].view(-1, 1), device)).flatten()
                Y2[:, cnt] = Qnet(
                    augment(data, U2[:, cnt].view(-1, 1), device)).flatten()
        else:
            with torch.no_grad():
                for cnt in range(U1.shape[1]): # calculate the samples
                    Y1[:, cnt] = Qnet(augment(data, U1[:, cnt].view(-1, 1), device)).flatten()
                    Y2[:, cnt] = Qnet(augment(data, U2[:, cnt].view(-1, 1), device)).flatten()

        Y = torch.zeros(Y1.numel(),2, device=device, dtype=torch.float32)
        # flatten works over rows first so samples for each x are stacked vertically
        Y[:, 0] = Y1.flatten()
        Y[:, 1] = Y2.flatten()
        raise Exception('This is outdated and does not work for all algorithms')
    elif gen_type == 'cat':
        # stack vertically interleaved copies of the data
        data_stack = data.repeat_interleave(repeats=nunif, dim=0) 
        
        if sim_conf['qtype'] == 'SQR':
            Y = torch.zeros(data_stack.shape[0], 2, device=device, dtype=torch.float32)  # output samples
            Qnet.eval()
            if "requires_grad" in kwargs and kwargs["requires_grad"]:
                Y[:, 0] = Qnet(augment(data_stack, torch.rand(data_stack.shape[0],
                                                              1, device=device,
                                                              dtype=torch.float32))).flatten()
                Y[:, 1] = Qnet(augment(data_stack, torch.rand(data_stack.shape[0],
                                                              1, device=device,
                                                              dtype=torch.float32))).flatten()
            else:
                with torch.no_grad():
                    Y[:, 0] = Qnet(augment(data_stack, torch.rand(data_stack.shape[0],
                                                                  1, device=device,
                                                                  dtype=torch.float32))).flatten()
                    Y[:, 1] = Qnet(augment(data_stack, torch.rand(data_stack.shape[0],
                                                                  1, device=device,
                                                                  dtype=torch.float32))).flatten()
        elif sim_conf['qtype'] == 'KNIFE':
            Y = sampleKNIFE(Qnet, data_stack, nsamps=2, **kwargs)
        elif sim_conf['qtype'] == 'conditional_gaussian':
            Y = samplecond_gaussian(Qnet, data_stack, 2, **kwargs)
    else:
        raise Exception('Unknown generation mode {0}'.format(sim_type))



    if sim_conf['type'] == 'absE':
        if sim_conf['notion'] == 'abs':
            ## The Qnet in this case is directly the error samples (scaled to 0,1)
            return Y  
        elif sim_conf['notion'] == 'rel': 
            ## The Qnet in this case is directly the error samples (scaled to 0,1)
            # but we normalize by the regressor output
            with torch.no_grad():
                reg_network.eval()
                yhat = reg_network(data)  # compute the estimation of each x
            yhat_stack = yhat.repeat_interleave(repeats=nunif, dim=0) #interleave as the Y samples
            out = (Y-yhat_stack).abs()/(yhat_stack+musig).abs()
            return out
    elif sim_conf['type'] == 'absErel':
        if sim_conf['notion'] == 'rel':
            ## The Qnet in this case is directly the relative error samples 
            return Y  ## The Qnet in this case is directly the error samples
            
        elif sim_conf['notion'] == 'abs': # we normalize
            raise Exception('Invalid combination absErel/abs')

    
    
    elif sim_conf['type'] == 'outY':
        if "requires_grad" in kwargs and kwargs["requires_grad"]:
            reg_network.eval()
            yhat = reg_network(data)  # compute the estimation of each x
        else:
            with torch.no_grad():
                reg_network.eval()
                yhat = reg_network(data)  # compute the estimation of each x
        yhat_stack = yhat.repeat_interleave(repeats=nunif, dim=0) #interleave as the Y samples

        if sim_conf['notion'] == 'rel': # absolute error
            out = (Y-yhat_stack).abs()/(yhat_stack+musig).abs()
        elif sim_conf['notion'] == 'abs': # relative error
            out = (Y-yhat_stack).abs()
        else:
            raise Exception("Unknown output samples simulation")

        if sim_conf['zoom_func']=='id':
            return out
        elif sim_conf['zoom_func'] == 'exp':
            gamma = sim_conf['gamma']
            return (torch.exp(gamma * out) - 1) / gamma
    else:
        raise Exception('Unknown samples generation type for the d function')


def Hfunction(DVICnet, Qnet, reg_network, data, sim_conf, device, gen_type='cat', **kwargs):
    '''Computes the diversity metric H which is (16) in the current paper
    version'''
    nunif = sim_conf['nunif']

    nmax = 5000
    nits = int(np.floor(nunif/nmax)) 
    nres = nunif - nmax*nits
    
    
    for cnt in range(nits+1):
        if cnt == nits:
            ns = nres
        else:
            ns = nmax
        
        if ns > 0:
            sim_conf['nunif'] = ns
            Y = getYsamps(Qnet, reg_network, data, device, sim_conf, gen_type=gen_type, **kwargs)
            
            dy1y2 = DVICnet(Y, device)
            dy1y2 = torch.reshape(dy1y2, (-1, ns)).sum(dim=1)
            
            if cnt == 0:
                out = dy1y2
            else:
                out +=dy1y2
    sim_conf['nunif'] = nunif
    return out/nunif
    
    # Generate the samples of d(y, f(x))
    # Y = getYsamps(Qnet, reg_network, data, device, sim_conf, gen_type=gen_type)

    # Compute  h(d(y1, f(x)), d(y2, f(x)))
    # dy1y2 = DVICnet(Y, device)

    # Compute H through the average
    # dy1y2 = torch.reshape(dy1y2, (-1, nunif)).mean(dim=1)
    # return dy1y2

def compute_DVICloss(dy1y2, labels, lbd, device):
    nbad = labels.sum()
    
    if nbad > 0:
        loss_neg = dy1y2[labels == 1].mean()
    else:
        loss_neg = torch.tensor(0)# torch.zeros(1, device=device)
    if nbad < labels.numel():
        loss_pos = dy1y2[labels == 0].mean()
    else:
        loss_pos = torch.tensor(0) # torch.zeros(1, device=device)

    # loss_pos is the loss of the epsilon good samples, labels=0
    # loss_neg is the loss of the epsilon bad samples (should be maximized)
    loss = lbd * loss_pos - (1-lbd) * loss_neg

    return loss, loss_pos, loss_neg


def train_DVIC(trainDS, valDS, device, sim_conf, model_data, testDS=None, get_best=False):
    """Train the diversity coefficient algorithm. """
    epochs = sim_conf['nepochs']
    batch_size = sim_conf['batch_size']
    lbd = sim_conf['lambda']
    lr = sim_conf['l_rate']
    wd = sim_conf['wd']

    if sim_conf['qtype'] == 'SQR':
        # Load the quantile network to generate the samples
        Qnet = DeepNetBN(n_in=trainDS.dset.shape[1] + 1,
                         n_out=1,
                         n_hlayers=model_data['qmod']['res']['conf']['n_hlayers'],
                         n_inner_neurons=model_data['qmod']['res']['conf']['n_inner_neurons'])
    
        Qnet.load_state_dict(torch.load(model_data['qmod']['par'], map_location=device))
    elif sim_conf['qtype'] == 'KNIFE':
        Qnet = restore_KNIFEnet(device, model_data['qmod']['res'], 
                                model_data['qmod']['par'], 
                                in_dim=trainDS.dset.shape[1]) 
        
    elif sim_conf['qtype'] == 'conditional_gaussian':
        Qnet = restore_cond_gaussian(device, model_data['qmod']['res'],
                                     model_data['qmod']['par'], 
                                     in_dim=trainDS.dset.shape[1])
        
    Qnet.to(device)

    # Load the regressor
    reg_network = load_reg_model(n_in=trainDS.dset.shape[1],
                                 n_out=1,
                                 reg_res=model_data['reg']['res'],
                                 reg_par=model_data['reg']['par'],
                                 device=device)
    reg_network.eval()
    reg_network.to(device)

    if sim_conf['zoom_func'] == 'exp':
        getGamma(Qnet, reg_network, trainDS.dset, device, sim_conf, gen_type='cat')
    else:
        sim_conf['gamma'] = -1
    
    # Create the DVIC network
    DVICnet = DVIC_DNNnet(n_in=2,
                      n_out=1,
                      n_hlayers=sim_conf['n_hlayers'],
                      n_inner_neurons=sim_conf['n_inner_neurons'])

    DVICnet.to(device)
    optimizer = torch.optim.Adam(params=DVICnet.parameters(),
                                 lr=lr, weight_decay=wd)
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    
    # Pg = torch.tensor(sim_conf['Pg']) # load Pg values
    # train_pos_samps = torch.bernoulli(1-Pg)
    # trainDS = Dset(trainDS.dset, train_pos_samps)
    
    train_loader = torch.utils.data.DataLoader(trainDS, batch_size=batch_size,
                                            shuffle=True, num_workers=0,
                                            pin_memory=True)
        
    # Setup optimizer and data loader
    loss_history = [] # losses during training
    loss_history_pos = []
    loss_history_neg = []
    true_test_auc = [] # auc for testing
    true_test_fpr = [] # fpr at tpr 95% for testing
    auc_history = [] # to store the auc of the validation loss
    val_loss_history = [] # to store the validation loss
    fpr_history = [] # to store the fpr at tpr 0.95 for validation
    best_auc = -0.1 # initialized at the worst possible value
    best_fpr = 1.1
    test_roc = [] # to store the test ROC
    pbar = tqdm(range(epochs))
    for epoch in pbar:

        
        loss_pos = torch.zeros(1, device=device)
        loss_neg = torch.zeros(1, device=device)
        total_pos = 0
        total_neg = 0
        total_loss = 0
        
        DVICnet.train()
        Qnet.eval()
        


        for data, labels in train_loader:
            optimizer.zero_grad()

            data = data.to(device)  # x samples
            labels = labels.flatten() # =1 eps bad

            # Compute the diversity metric
            dy1y2 = Hfunction(DVICnet, Qnet, reg_network, data, sim_conf, device)

            # compute the loss of the good and bad terms or 0 if there are no
            # elements in one of the sets
            (loss, loss_pos, loss_neg) = compute_DVICloss(dy1y2, labels, lbd, device)
            
            # loss += loss_tmp
            
            total_loss += loss.item()
            total_pos += loss_pos.item() #pos samples are epsilon good ones.
            total_neg += loss_neg.item()

            loss.backward()
            optimizer.step()

        # scheduler.step()
        loss_history.append(total_loss)
        loss_history_pos.append(total_pos)
        loss_history_neg.append(total_neg)

        # del batch_x, batch_y
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        # Validation
        DVICnet.eval()
        Qnet.eval()
        with torch.no_grad():
            if sim_conf['nunif']<= 5000:              
                # get score of the validation set
                scores = Hfunction(DVICnet, Qnet, reg_network,
                                   valDS.dset.to(device).float(), sim_conf,
                                   device).cpu().numpy()
            else:
                scores = []
                val_loader = torch.utils.data.DataLoader(valDS, batch_size=batch_size,
                                                        shuffle=False, num_workers=0,
                                                        pin_memory=True)
                
                for data, labels in val_loader:
                    scores.append(Hfunction(DVICnet, Qnet, reg_network,
                                       data.to(device), sim_conf,
                                       device).cpu().numpy())
                scores = np.concatenate(scores)


            tmp_nc = torch.count_nonzero(valDS.out)  # number of bad in val dataset

            #  check that there are both negative and positive samples
            if (tmp_nc.item() > 0) and (tmp_nc.item() < valDS.dset.shape[0]):

                labs = valDS.out.flatten().cpu().numpy()
                (fprs, tprs, thrs) = roc_curve(labs, scores)
                roc_auc = auc(fprs, tprs)

                # get the best fpr for a certain tpr
                (fpr, _, _) = fpr_at_fixed_tpr(fprs, tprs, thrs, sim_conf['tpr_opt'])
            else:
                roc_auc = 0
                fpr = 1

            # Compute the validation loss for reference
            (loss_val, _, _) = compute_DVICloss(scores, valDS.out.flatten(),
                                                lbd, device)

        auc_history.append(roc_auc)
        val_loss_history.append(loss_val.item())
        fpr_history.append(fpr)

        # Train at the same time for the best AUC
        if best_auc < roc_auc:
            best_ind = epoch
            best_auc = roc_auc
            best_params = copy.deepcopy(DVICnet.state_dict())
            best_loss = loss

        if fpr < best_fpr:
            best_ind_fpr = epoch
            best_fpr = fpr
            best_params_fpr = copy.deepcopy(DVICnet.state_dict())
            best_loss_fpr = loss

        # # If get_best=True, the AUROC of the test set is also computed
        if get_best:
            DVICnet.eval()
            Qnet.eval()

            with torch.no_grad():
                if sim_conf['nunif']<= 5000:              
                    # get score of the validation set
                    scores = Hfunction(DVICnet, Qnet, reg_network,
                                       testDS.dset.to(device).float(),
                                       sim_conf, device).cpu().numpy()
                else:
                    scores = []
                    test_loader = torch.utils.data.DataLoader(testDS, batch_size=batch_size,
                                                            shuffle=False, num_workers=0,
                                                            pin_memory=True)
                    
                    for data, labels in test_loader:
                        scores.append(Hfunction(DVICnet, Qnet, reg_network,
                                           data.to(device), sim_conf,
                                           device).cpu().numpy())
                    scores = np.concatenate(scores)


                tmp_nc = testDS.out.sum()
                labels = testDS.out.flatten()

                if (tmp_nc>0) and (tmp_nc < labels.size()[0]):
                    (fprs, tprs, thrs) = roc_curve(labels, scores)
                    roc_auc = auc(fprs, tprs)
                    true_test_auc.append(roc_auc)
                    test_roc.append((fprs, tprs))
                    (fpr, _, _) = fpr_at_fixed_tpr(fprs, tprs, thrs, sim_conf['tpr_opt'])
                    true_test_fpr.append(fpr)
                else:
                    true_test_auc.append(0)
                    true_test_fpr.append(1)
                    test_roc.append((0, 0))

        if get_best:
            pbar.set_description('Total Loss  {0:.3f} - '.format(total_loss) +
                                 'Pos Loss {0:.3f} - '.format(total_pos) +
                                 'Neg Loss {0:.3f} - '.format(total_neg) +
                                 'Val AUC: {0:.3f} '.format(auc_history[-1])+
                                 'Val FPR: {0:.3f} '.format(fpr)+
                                 'Test AUC: {0:.3f}'.format(true_test_auc[-1]))
        else:
            pbar.set_description('Total Loss  {0:.3f} - '.format(total_loss) +
                             'Pos Loss {0:.3f} - '.format(total_pos) +
                             'Neg Loss {0:.3f} - '.format(total_neg) +
                             'Val AUC: {0:.3f} '.format(auc_history[-1])+
                             'Val FPR: {0:.3f} '.format(fpr))
                             # 'Test AUC: {0:.3f}'.format(true_test_auc[-1]))

    data_out = {'loss history': loss_history,  # training losses
                'loss history pos': loss_history_pos,
                'loss history neg': loss_history_neg,
                'val loss history': val_loss_history,  # validation loss
                'auc history': auc_history, # auc and related auc metrics
                'best params auroc': best_params,
                'best loss auroc': best_loss,
                'epoch best auc': best_ind,
                'best params fpr': best_params_fpr, # validation fpr and metrics
                'best loss fpr': best_loss_fpr,
                'fpr history': fpr_history,
                'epoch best fpr': best_ind_fpr,
                'test rocs': test_roc
                }

    if get_best:
        #DVICnet.load_state_dict(best_params)
        data_out['test auc'] = true_test_auc
        data_out['test fpr'] = true_test_fpr
        return data_out, Qnet
    return data_out


def train_single(train_x, train_y, val_x, val_y, test_x, test_y, sim_conf,
                 model_data, device, get_best=False):
    """Preprocesses the sets for the actual training routine.

    Parameters:
    ----------
    train_x, train_y: (np.ndarray) training data set
    val_x, val_y: (np.ndarray) validation set
    sim_conf: configuration for the quantile network to be trained
    model_data: dictionary which contains subdictionaries with the models
    required for training, in this case 'reg' and 'qmod'
    device: "cuda", "cuda:0", "cpu", etc
    get_best: (bool) if False, only the train and validation loss are returned. If True returns
    train, val_loss, Q_network, scaler_qtrain_y, best_ind
    """

    # Load the regressor
    reg_network = load_reg_model(n_in=train_x.shape[1], n_out=train_y.shape[1],
                                 reg_res=model_data['reg']['res'],
                                 reg_par=model_data['reg']['par'],
                                 device=device)
    reg_network.eval()
    
    # process the train set
    with torch.no_grad():
        
        musig = (model_data['reg']['res']['scale']['mu_y_train'] / \
            model_data['reg']['res']['scale']['scale_y_train'])[0]
            
        # output of the regressor
        train_y_hat = reg_network(torch.tensor(train_x, dtype=torch.float32,
                                               device=device)).cpu().numpy()
        
        
        # probs = np.linspace(0.0001, 1-0.0001, 400)  # probabilities used to compute the CDFs to compute Pg
        
        # (Pg, Ftest) = find_Pg(torch.tensor(train_x), train_y_hat.flatten(), 
        #                       Qnet, eps=sim_conf['eps'],
        #                       device=device, probs=probs, notion=sim_conf['notion'], 
        #                       const=musig)
        # Get the shift factor for the relative notion, not used for abs
        # sim_conf['Pg'] = Pg
        sim_conf['musig'] = musig


        # epsilon bad samples from training set
        train_pos_samps = get_eps_good(train_y, train_y_hat, sim_conf['eps'],
                                       notion=sim_conf['notion'], musig=musig)
        
        # plt.figure()
        # ind = np.argsort(Pg)
        # plt.plot(Pg[ind])
        # plt.plot(train_pos_samps[ind])
        # train_pos_samps[Pg > 1-1e-3] = False
        # train_pos_samps[Pg < 1e-3] = True
        # plt.plot(train_pos_samps[ind])
        

        # Compute the relative weight of each term in the loss
        sim_conf['lambda'] =  0.5#np.count_nonzero(train_pos_samps) / train_y.size


        # get the validation set
        val_y_hat = reg_network(torch.tensor(val_x, dtype=torch.float32,
                                               device=device)).cpu().numpy()

        val_pos_samps = get_eps_good(val_y, val_y_hat, sim_conf['eps'],
                            notion=sim_conf['notion'], musig=musig)

    # if sim_conf['notion'] == 'rel':
    #     errs = np.abs(train_y-train_y_hat)/np.abs(train_y_hat+musig)
    #     (model, train_loss, gamma_vals) = train_exptrans(errs, lr=1, verb=True,
    #                                                      nepochs=1000,
    #                                                      thres=1e-3)
    #     sim_conf['gamma'] = model.gamma.item()
    #     plt.figure()
    #     ax= plt.subplot(2,1,1)
    #     ax.plot(train_loss)
    #     ax2 = plt.subplot(2,1,2)
    #     ax2.plot(gamma_vals)
    # else:
    #      raise Exception('Work to be done here')

    trainDS = Dset(train_x, train_pos_samps)
    valDS = Dset(val_x, val_pos_samps)

    if get_best:  # return the test set if it is the final simulation
        with torch.no_grad():
            test_y_hat = reg_network(torch.tensor(test_x, dtype=torch.float32,
                                                   device=device)).cpu().numpy()

            test_pos_samps = get_eps_good(test_y, test_y_hat, sim_conf['eps'],
                                          notion=notion, musig=musig)
        testDS = Dset(test_x, test_pos_samps)

    # Call the training loop
    if get_best:
        # return dict_losses, class_network, best_ind
        (data_out, Qnet) = train_DVIC(trainDS, valDS, device, sim_conf, model_data,
                              testDS, get_best)
        return data_out, testDS, Qnet
    else:
        data_out = train_DVIC(trainDS, valDS, device, sim_conf, model_data)
        return data_out, 0, 0
        # return data_out['val loss history'], data_out['auc history'], data_out['fpr history']


def main_loop(model_data, sim_conf, device, Kfold):
    """This function generates the data splits, i.e. training sets and passes
    them to the next function, in the training procedure.

    Parameters:
    -----------
    model_data: dictionary which contains subdictionaries with the models
    required for training, in this case 'reg' and 'qmod'
    sim_conf: dictionary with the configuration used for the simulation
    device: 'cuda', 'cpu', 'cuda:0', 'cuda:1' etc
    Kfold (bool): if True performs K-fold validation
    (K is in sim_conf['fold_splits'] parameter). If False, performs a
    single run to train the final model with the passed hyperparameters.

    Returns:
    --------
    If Kfold=True: avg_train_loss, avg_val_auc(np.ndarrays) with the average
    train loss and validation AUROC of K fold.
    If Kfold=False: data_out, dset, testDS
    """

    nfold_splits = sim_conf['fold_splits']
    nepochs = sim_conf['nepochs']
    stop_frac = sim_conf['stop_split']/100

    '''We now reload the training set used to train the regressor. We are only
    interested in the separating the test part from the rest. The test part is
    the same in the quantile estimation algorithm'''
    x_train = model_data['reg']['res']['dataset']['train x']
    y_train = model_data['reg']['res']['dataset']['train y']
    x_test = model_data['reg']['res']['dataset']['test x']
    y_test = model_data['reg']['res']['dataset']['test y']
    x_val = model_data['reg']['res']['dataset']['stop x']
    y_val = model_data['reg']['res']['dataset']['stop y']
    x_sup = model_data['reg']['res']['dataset']['sup x']
    y_sup = model_data['reg']['res']['dataset']['sup y']

    # we now load the random number states that existed after quantiles were trained
    rng = set_seeds(seed)  # rng is a numpy random generator
    rng.bit_generator.state = model_data['qmod']['res']['Random']['rng_state']
    torch.set_rng_state(model_data['qmod']['res']['Random']['torch_state'].to('cpu'))
    if torch.cuda.is_available():
        torch.cuda.set_rng_state(model_data['qmod']['res']['Random']['cuda_state'].to('cpu'))

    # Now define the tranining/test sets for the DVIC coefficients.
    if sim_conf['DVIC_dset'] == "sup":
        # use only the supplementary data to train DVIC
        qx_train = x_sup
        qy_train = y_sup

    elif sim_conf['DVIC_dset'] == "sup_stop":
        # use the supplementary data and the stopping set
        qx_train = np.vstack((x_sup, x_val))
        qy_train = np.vstack((y_sup, y_val))
    elif sim_conf['DVIC_dset'] == "all":
        qx_train = np.vstack((x_train, x_sup, x_val))
        qy_train = np.vstack((y_train, y_sup, y_val))
    elif sim_conf['DVIC_dset'] == 'stop':
        qx_train = x_val
        qy_train = y_val
    elif sim_conf['DVIC_dset'] == 'regtrain':
        qx_train = np.vstack((x_train, x_val))
        qy_train = np.vstack((y_train, y_val))
    else:
        raise Exception('Unknown dataset use')

    # if sim_conf['notion'] == 'rel' and sim_conf['type'] == 'absErel':
    #     # In this case we need to rescale the data properly with the scaler 
    #     reg_network = load_reg_model(n_in=qx_train.shape[1], n_out=qy_train.shape[1],
    #                                   reg_res=model_data['reg']['res'],
    #                                   reg_par=model_data['reg']['par'],
    #                                   device=device)
    #     reg_network.eval()
        
    #     with torch.no_grad():
    #     # Recover the scaler of the training set for the error quantiles
    #         train_y_hat = reg_network(torch.tensor(qx_train, 
    #                                   dtype=torch.float32, 
    #                                   device=device)).cpu().numpy()
    #         train_err = np.abs(qy_train - train_y_hat)/np.abs(train_y_hat+model_data['qmod']['res']['conf']['musig']) # absolute error |Y-Yhat|

    #     scaler_err = StandardScaler().fit(train_err)  # scaler used before training the quantile network
    
    if sim_conf['type'] == 'absErel':
        scaler_err = StandardScaler()
        scaler_err.mean_ = model_data['qmod']['res']['conf']['err_dict']['mu_err_train']
        scaler_err.scale_ = model_data['qmod']['res']['conf']['err_dict']['scale_err_train']
        sim_conf['scaler_err'] = scaler_err
        sim_conf['musig'] = model_data['qmod']['res']['conf']['musig']
                            
    # The test set is always the test set of the regressor
    qx_test = x_test
    qy_test = y_test

    check_good_bad_frac(qx_train, qy_train, model_data, sim_conf, 'train set')
    # Perform cross-validation over the selected hyperparameters
    if Kfold:
        if nfold_splits == 1:
            
            good_val = False
            while not good_val:
                kf = Kfold_shuffler(n_splits=nfold_splits, dsize=qx_train.shape[0], 
                                    rng=rng, val_frac=stop_frac)
            
                # If only one fold is drawn we check that there are good and bad
                # points in the validation set
                # train_idx  = kf.splits[0]
                val_index = kf.splits[1]
                good_val = check_good_bad_frac(qx_train[val_index], 
                                               qy_train[val_index],
                                               model_data, sim_conf, 'val set')
            
            get_best = True
        else:
            kf = Kfold_shuffler(n_splits=nfold_splits, dsize=qx_train.shape[0], 
                                rng=rng)
            get_best = False
            
        # train the model by K-fold validation
        val_loss_matrix = np.zeros((nfold_splits, nepochs))
        val_auc_matrix = np.zeros((nfold_splits, nepochs))
        val_fpr_matrix = np.zeros((nfold_splits, nepochs))

        for fold_cnt in range(nfold_splits):
            (train_index, val_index) = kf.get_split_inds(fold_cnt) # get the splits

            # Separate the train for this split (data is already normalized with the training set of the regressor)
            qx_train_fold = qx_train[train_index]
            qy_train_fold = qy_train[train_index]

            qx_val_fold = qx_train[val_index]
            qy_val_fold = qy_train[val_index]

            if fold_cnt == 0:
                print('Train size: {0}, '.format(qx_train_fold.shape[0]) +
                      'Val size {0},'.format(qx_val_fold.shape[0]) +
                      'Test size {0}\n'.format(qx_test.shape[0]))

            (data_out, testDS, Qnet) = train_single(qx_train_fold,
                                                    qy_train_fold,
                                                    qx_val_fold,
                                                    qy_val_fold,
                                                    qx_test, qy_test,
                                                    sim_conf, model_data,
                                                    device, get_best)
            
            if nfold_splits > 1:
                val_loss_matrix[fold_cnt] = data_out['val loss history']
                val_auc_matrix[fold_cnt] = data_out['auc history']
                val_fpr_matrix[fold_cnt] = data_out['fpr history']

                
            # (train_loss_matrix[fold_cnt],
            #  val_auc_matrix[fold_cnt],
            #  val_fpr_matrix[fold_cnt]) = train_single(qx_train_fold,
            #                                           qy_train_fold,
            #                                           qx_val_fold,
            #                                           qy_val_fold,
            #                                           qx_test, qy_test,
            #                                           sim_conf, model_data,
            #                                           device, get_best)
            

        if nfold_splits>1:
            avg_val_loss = np.mean(val_loss_matrix, axis=0)
            avg_val_auc = np.mean(val_auc_matrix, axis=0)
            avg_val_fpr = np.mean(val_fpr_matrix, axis=0)
            return avg_val_loss, avg_val_auc, avg_val_fpr
        else:
            dset = {'train x': qx_train_fold, 'val x': qx_val_fold, 'train y': qy_train_fold,
                    'val y': qy_val_fold, 'test x': qx_test,
                    'test y': qy_test}

            return data_out, dset, testDS, Qnet
            

    else:
        
        
        # For a single run we separate the training set to have a stopping set
        good_val = False
        while not good_val:
            (qx_train_tmp, qx_val, 
             qy_train_tmp, qy_val) = split_dset(qx_train, qy_train,
                                                    test_frac=stop_frac,
                                                    rng=rng, ndraw=nfold_splits)

            good_val = check_good_bad_frac(qx_val, 
                                           qy_val,
                                           model_data, sim_conf, 'val set')
                
        qx_train = qx_train_tmp
        qy_train = qy_train_tmp

        
        print('Train size: {0}, Val size {2}, Test size {1},'.format(qx_train.shape[0],
                                                                     qx_test.shape[0],
                                                                    qx_val.shape[0]))

        (data_out, testDS, Qnet) = train_single(qx_train, qy_train,
                                                         qx_val, qy_val,
                                                         qx_test, qy_test,
                                                         sim_conf, model_data,
                                                         device, get_best=True)

        dset = {'train x': qx_train, 'val x': qx_val, 'train y': qy_train,
                'val y': qy_val, 'test x': qx_test,
                'test y': qy_test}

        return data_out, dset, testDS, Qnet


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description="",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--seed", type=int, nargs='*', default=[1], #np.arange(1,16),
                        help='Seed to be used for simulation')

    parser.add_argument("--dset", type=str, default='boston',
                        help='Name of the dataset or "all" for all of them '+
                        'or small (<1000) or large (>1000)')

    parser.add_argument("--save", type=bool, default=True,
                        help="Generate dump files to disk")

    parser.add_argument("--dir", type=str, default="tmp_DVIC5",
                        help='Base directory where to store results')

    parser.add_argument("--nepochs", type=int, default=20,
                        help='Number of epochs to run the quantile algorithm')

    parser.add_argument("--device", type=str, default='cuda',
                        help='Possibilities: auto, cpu, cuda, cuda:0, cuda:1,etc')

    parser.add_argument("--notion", type=str, default='rel',
                        help='abs or rel')

    parser.add_argument("--DVIC_dset", type=str, default="regtrain",
                        help='sup: train DVIC with the supplementary set'+
                        'sup_stop: use the supplementary set and the stopping set')

    parser.add_argument("--stopfrac", type=int, nargs='*', default=[20],
                        help='Fraction of the training set used a stopping'+
                        'set after K-fold validation')

    parser.add_argument("--nfolds", type=int, default=1, help='Number of '+
                        'folds used for training.')

    parser.add_argument('--batchsize', type=int, default=64, help='Batch size'+
                        'for training')

    # parser.add_argument("--sampsD", type=str, default='Err',
    #                     help="Yonly means that the d function will be trained with samples of Y|X"+\
    #                         'Err means that it will be trained with samples of Y-f(x)|X=x')
    
    parser.add_argument("--type", type=str, default='absErel',
                        help='Type of simulation "outY" or "absE" or "absErel"')

    parser.add_argument("--zoom_func", type=str, default='id',
                        help="If zoom_func='id' then the algorithm of the paper"+
                        "is implemented. If zoom_func='exp' then before learning"+
                        "the h function in the paper, we apply 1-exp{-alpha d}"+
                        "where alpha is an hyperparameter")
    
    parser.add_argument("--nunifs", type=int, default=10000, 
                        help="Number of uniforms used for the Monte Carlo"+
                        "estimation of the H function")
    
    parser.add_argument("--wd", type=float, nargs='*', default=[0],
                        help="Weight decay")
    parser.add_argument("--lr", type=float, nargs='*', default=[1e-3])#0.0005, 0.001, 0.005 ])

    parser.add_argument("--nlayers", type=int, default=4,
                        help="Number of internal layers for the DVIC network")


    # Regressor parameters
    parser.add_argument("--regdir", type=str, default='runs_reg_900',
                        help="Directory where the regressor results are kept")
    # Qfunctions
    parser.add_argument("--qdir", type=str, default='tmp_Test')#'runs_ensembles_900_regtrain3') #, runs_KNIFE_900_regtrain4_256")
                        #"runs_SQR_900_regtrain")
    
    parser.add_argument("--lr_gamma", type=float, default=0.5)

    parser.add_argument('--tpr_opt', type=float, default=0.9,
                        help="Used for the system which optimized the FPR at"+
                        " tpr_opt value of TPR")
    
    parser.add_argument("--qtype", type=str, default='KNIFE',
                        help="SQR: will use pinball loss," +
                        "Interval will use the interval loss to estimate only a CI" +
                        "KNIFE: will estimate the CDF as gaussian mixture"+
                        "conditional_gaussian, is the ensembles method")
    
    dset_reuseQ = "regtrain"  # data set reuse for the Qnetwork
    # sim_type = 'outY'

    

    args = parser.parse_args()
    DVIC_dset = args.DVIC_dset  # "regtest_val"  # data set use for the estimation of DVIC
    sim_type = args.type
    enable_save = args.save
    nepochs = args.nepochs
    reg_dir = args.regdir
    qfolder = args.qdir
    notion = args.notion
    device_t = args.device
    bsize = args.batchsize
    nfolds = args.nfolds
    stop_split = args.stopfrac
    alg = args.qtype
    zoom_func = args.zoom_func
    lr = args.lr
    wd = args.wd
    lr_gamma = args.lr_gamma
    seeds = args.seed
    nunif = args.nunifs 
    nlayers = args.nlayers
    
    if args.dset == 'all':
        dsets = ['yacht', 'boston', 'energy', 'concrete','wine',
                 'kin8nm', 'power', 'naval']
    elif args.dset == 'small':
        dsets = ['yacht', 'boston', 'energy', 'concrete']
    elif args.dset == 'large':
        dsets = ['wine', 'kin8nm', 'power', 'naval']
    else: # single dataset
        dsets = [args.dset]

    base_folder = args.dir  # base directory where things are stored

    '''Define parameter grid for hyperparameter tuning'''
    param_for_grid = {'l_rate': lr, 
                      'n_hlayers': [nlayers],
                      'n_inner_neurons': [64],
                      'stop_split': stop_split,
                      'wd': wd}
    
    if zoom_func == 'exp':
        param_for_grid['zoom_func'] = ['exp']
    elif zoom_func == 'id':
        param_for_grid['zoom_func'] = ['id']
    elif zoom_func == 'both':
        param_for_grid['zoom_func']= ['exp', 'id']

    param_grid = ParameterGrid(param_for_grid)

    device = get_sim_device(device_t) # get device cuda or cpu for sim.

    for seed in seeds:
        print('Using seed {0}'.format(seed))
        print('Q training set: {0}'.format(dset_reuseQ))
        print('DVIC training set: {0}'.format(DVIC_dset))
        print('Using device: {0}'.format(device))
        for dset_name in dsets:
            if notion == 'rel':
                if dset_name=='energy':
                    eps_vals = np.array([0.02,0.035, 0.05])
                elif dset_name =='naval':
                    eps_vals = np.array([0.0015,0.00175, 0.002])
                elif dset_name == "wine":
                    eps_vals = np.array([0.05, 0.075, 0.1])
                elif dset_name == 'power':
                    eps_vals = np.array([0.01,  0.0125, 0.015])
                elif dset_name == 'yacht':
                    eps_vals = np.array([0.1, 0.15, 0.2])
                elif dset_name == 'kin8nm':
                    eps_vals = np.array([0.1,  0.2,  0.3])
                elif dset_name == 'boston':
                    eps_vals = np.array([0.1, 0.15, 0.2])
                elif dset_name == 'concrete':
                    eps_vals = np.array([.1, .15, 0.2])
                else:
                    eps_vals = np.array([0.1, .15, 0.2, 0.25])
            elif notion =='abs':
                if dset_name=='yacht':
                    eps_vals = np.array([0.0250, 0.05, 0.075])
                elif dset_name =='boston':
                    eps_vals = np.array([0.25, 0.3, 0.35])
                elif dset_name == 'concrete':
                    eps_vals = np.array([0.2, 0.5, 0.6])
                elif dset_name == 'energy':
                    eps_vals = np.array([0.025, 0.0625, 0.1])
                elif dset_name == 'wine':
                    eps_vals = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6])
                elif dset_name == 'kin8nm':
                    eps_vals = np.array([0.6, 0.7, 0.8])
                elif dset_name == 'power':
                    eps_vals = np.array([.3, .35, 0.4])
                elif dset_name =='naval':
                    eps_vals = np.array([0.175, 0.2, 0.225])


            print('Data set is: ' + dset_name)
    
            full_dict_list = []
    
            for eps in eps_vals:
                print('\n\nUsing eps={0:.4f}'.format(eps))
                sim_conf = {'batch_size': bsize, 'nepochs': nepochs,
                            'device': device_t, 'type': sim_type,
                            'fold_splits': nfolds, 'dset_reuseQ': dset_reuseQ,
                            'eps': eps, 'DVIC_dset': DVIC_dset, 'notion': notion,
                            'qtype': alg, #'stop_split': stop_split,
                            #'samps_d': sampsD,
                            'nunif_gamma': 2000,
                            'enable plots': False,
                            'nunif': nunif,
                            'lr_gamma': lr_gamma,
                            'tpr_opt': args.tpr_opt}
    
                # Setup variable to store the results    
                data_dict = {'conf': sim_conf, 'dset_name': dset_name,
                             'type': ''}  # stores the output result to be dumped to the hard drive
    
    
                '''Hyperparameter training of the network'''
                val_aucs = []
                val_fprs = []
                train_losses = []
                results = {}  # holds the training hyperparameter configurations
                data_out_list = [] # only used if sim_conf['fold_splits']=1
                if len(param_grid) > 1:
                    for grid_cnt, item in enumerate(param_grid):
                        for key in item.keys():
                            sim_conf[key] = item[key]
        
                        # restore data just in case
                        (reg_res, reg_par) = load_reg_result(dset_name, seed,
                                                             reg_type='DNN',
                                                             basedir=reg_dir,
                                                             rng_split=True)
        
                        # restore the quantile functions
                        (Q_res, Q_pars) = load_Qnet(basefolder=qfolder,
                                                    dset_name=dset_name,
                                                    alg=alg, seed=seed,
                                                    sim_type=sim_type,
                                                    reuse=dset_reuseQ,
                                                    split_rng=True)
        
                        # get the models required for training DVIC
                        model_data = {}
                        model_data['reg'] = {'res': reg_res, 'par': reg_par}
                        model_data['qmod'] = {'res': Q_res, 'par': Q_pars}
        
                        # give a round of training
                        if sim_conf['fold_splits'] > 1:
                            (train_loss, val_auc, val_fpr) = main_loop(model_data=model_data,
                                                              sim_conf=sim_conf,
                                                              device=device,
                                                              Kfold=True)
            
                            val_aucs.append(val_auc)
                            val_fprs.append(val_fpr)
                            train_losses.append(train_loss)
                            results[grid_cnt] = item
                        elif sim_conf['fold_splits'] == 1:
                            (data_out, dset, testDS, Qnet) = main_loop(model_data=model_data,
                                                              sim_conf=sim_conf,
                                                              device=device,
                                                              Kfold=True)
                            
                            val_aucs.append(data_out['auc history'])
                            val_fprs.append(data_out['fpr history'])
                            train_losses.append(data_out['val loss history'])
                            results[grid_cnt] = item.copy()
                            data_out_list.append(data_out.copy()) # backup the model
                            # data_out = {'loss history': loss_history,  # training losses
                            #             'loss history pos': loss_history_pos,
                            #             'loss history neg': loss_history_neg,
                            #             'val loss history': val_loss_history,  # validation loss
                            #             'auc history': auc_history, # auc and related auc metrics
                            #             'best params auroc': best_params,
                            #             'best loss auroc': best_loss,
                            #             'epoch best auc': best_ind,
                            #             'best params fpr': best_params_fpr, # validation fpr and metrics
                            #             'best loss fpr': best_loss_fpr,
                            #             'fpr history': fpr_history,
                            #             'epoch best fpr': best_ind_fpr,
                            #             'test rocs': test_roc
                            #             }
                            # data_out['test auc'] = true_test_auc
                            # data_out['test fpr'] = true_test_fpr
                            
                    # Store the results of training:
                    data_dict['Train'] = {}
                    data_dict['Train']['train_loss'] = train_losses
                    data_dict['Train']['val_aucs'] = val_aucs
                    data_dict['Train']['val_fprs'] = val_fprs
                    data_dict['Train']['params'] = results  # each of the parameters
    
    
                    '''Train the best model for the AUC and for FPR '''
        
                    ''' First test the AUC '''
                    # Find the best parameters for testing the AUC
                    (best_par_ind_auc, nepochs_opt_auc) = chooseParameters(np.array(val_aucs),
                                                                           opt_type='max')
                                                                           # train_losses=np.array(train_losses))
        
                    # Find the best parameters for testing the FPR
                    (best_par_ind_fpr, nepochs_opt_fpr) = chooseParameters(np.array(val_fprs),
                                                                           opt_type='min')
                                                                           # train_losses=np.array(train_losses))
                        
                    # Find the best parameters for the validation loss
                    (best_par_ind_loss, nepochs_opt_loss) = chooseParameters(np.array(train_losses),
                                                                           opt_type='min')
                                                                           # train_losses=np.array(train_losses))
                    
                    # Now train the model again with the best hyperparameters
                    print('Best parameters for AUROC are: ')
                    print(results[best_par_ind_auc])
        
                    print('Best parameters for FPR are:')
                    print(results[best_par_ind_fpr])
                    
                    print('Best parameters for validation loss are:')
                    print(results[best_par_ind_loss])
        
                    
                    # Train the best model for AUC
                    for key in results[best_par_ind_auc].keys():
                        sim_conf[key] = results[best_par_ind_auc][key]


                    # Store the retraining results.
                    data_dict['Test AUC'] = {}
                    data_dict['Test AUC']['sim_conf'] = sim_conf.copy()
                    
                    if sim_conf['fold_splits'] > 1: # need to retrain
                        (data_out_AUC, dset, testDS, Qnet) = main_loop(model_data = model_data,
                                                                   sim_conf=sim_conf,
                                                                   device=device,
                                                                   Kfold=False)
                    elif sim_conf['fold_splits'] == 1: # no need to retrain
                        data_out_AUC = data_out_list[best_par_ind_auc].copy()
        
                    # Train the best model for FPR
                    for key in results[best_par_ind_fpr].keys():
                        sim_conf[key] = results[best_par_ind_fpr][key]
        
                    data_dict['Test FPR'] = {}
                    data_dict['Test FPR']['sim_conf'] = sim_conf.copy()
                    
                    if sim_conf['fold_splits'] > 1: # need to retrain
                        (data_out_fpr, _, _, _) = main_loop(model_data = model_data,
                                                                       sim_conf=sim_conf,
                                                                       device=device,
                                                                       Kfold=False)
                    elif sim_conf['fold_splits'] == 1: # no need to retrain
                        data_out_fpr = data_out_list[best_par_ind_fpr].copy()

                        
                else:
                    
                
                    for key in param_grid[0].keys():
                        sim_conf[key] = param_grid[0][key]
                    
                    # restore data just in case
                    (reg_res, reg_par) = load_reg_result(dset_name, seed,
                                                         reg_type='DNN',
                                                         basedir=reg_dir,
                                                         rng_split=True)
    
                    # restore the quantile functions
                    (Q_res, Q_pars) = load_Qnet(basefolder=qfolder,
                                                dset_name=dset_name,
                                                alg=alg, seed=seed,
                                                sim_type=sim_type,
                                                reuse=dset_reuseQ,
                                                split_rng=True)
    
                    # get the models required for training DVIC
                    model_data = {}
                    model_data['reg'] = {'res': reg_res, 'par': reg_par}
                    model_data['qmod'] = {'res': Q_res, 'par': Q_pars}
                    
                    data_dict['Test AUC'] = {}
                    data_dict['Test AUC']['sim_conf'] = sim_conf.copy()
                    
                    data_dict['Test FPR'] = {}
                    data_dict['Test FPR']['sim_conf'] = sim_conf.copy()
                    
                    (data_out_AUC, dset, testDS, Qnet) = main_loop(model_data = model_data,
                                                                   sim_conf=sim_conf,
                                                                   device=device,
                                                                   Kfold=False)

                    data_out_fpr = data_out_AUC
                    best_par_ind_fpr = 0
                    best_par_ind_auc = 0
                    nepochs_opt_auc = -1
                    nepochs_opt_fpr = -1
                    results = [param_grid[0]]
                            
                print('\n')
                print('Best epoch por AUROC {0} -  '.format(data_out_AUC['epoch best auc'])+
                      'Best validation AUC {0:.4f} - '.format(data_out_AUC['auc history'][data_out_AUC['epoch best auc']])+
                      'Best Test AUC {0:.4f}'.format(data_out_AUC['test auc'][data_out_AUC['epoch best auc']]))
                      # 'Best Test AUC val loss {0:.4f}'.format(data_out_AUC['test auc'][np.argmin(data_out_AUC['val loss history'])]))
                # print('Best epoch val loss', np.argmin(data_out_AUC['val loss history']))
                print('Best epoch auc', data_out_AUC['epoch best auc'])
                
                print('Best epoch por FPR {0} - '.format(data_out_fpr['epoch best fpr'])+
                      'Best validation FPR {0:.4f} - '.format(data_out_fpr['fpr history'][data_out_fpr['epoch best fpr']])+
                      'Best Test FPR {0:.4f}'.format(data_out_fpr['test fpr'][data_out_fpr['epoch best fpr']]))
    
                # Save the results for the test AUC:
                data_dict['Test AUC']['best'] = results[best_par_ind_auc].copy()
                for key in data_out_AUC.keys():
                    data_dict['Test AUC'][key] = data_out_AUC[key]
                data_dict['Test AUC']['best_index par'] = best_par_ind_auc
                data_dict['Test AUC']['best_epoch train'] = nepochs_opt_auc
                data_dict['Test AUC']['splits'] = dset
    
               
                fig4 = plt.figure(figsize=(5, 5))
                ax = plt.subplot(1, 1, 1)
                best_ep = data_dict['Test AUC']['epoch best auc']
                (fprs, tprs) = data_dict['Test AUC']['test rocs'][best_ep]
                ax.plot(fprs, tprs)
                ax.grid(True)
                ax.set_xlabel('FPR')
                ax.set_ylabel('TPR')
                if type(fprs)!=int:
                    ax.set_title('AUC = {0:.4f}'.format(auc(fprs, tprs)))
    
    
                data_dict['Test FPR']['best'] = results[best_par_ind_fpr]
                for key in data_out_fpr.keys():
                    data_dict['Test FPR'][key] = data_out_fpr[key]
    
                data_dict['Test FPR']['best_index par'] = best_par_ind_fpr
                data_dict['Test FPR']['best_epoch'] = nepochs_opt_fpr
                data_dict['Test FPR']['splits'] = dset
    
                plt.close('all')
                
                
                fig1 = plt.figure(1)
                ax = plt.subplot(3, 1, 1)
                
                if len(param_grid)>1:
                    for item in train_losses:
                        ax.plot(item)
                ax.grid(True)
                # ax.legend()
                ax.set_xlabel('Epochs')
                ax.set_ylabel('Avg. Training Losses')
    
                ax2 = plt.subplot(3, 1, 2)
                if len(param_grid)>1:
                    for cnt, item in enumerate(val_aucs):
                        ax2.plot(item)  # , label=''.format(results[cnt]['l_rate']))
                    ax2.plot(nepochs_opt_auc, val_aucs[best_par_ind_auc][nepochs_opt_auc], 'd')
                ax2.grid(True)
                ax2.set_xlabel('Epochs')
                ax2.set_ylabel('Avg. Validation AUROC')
    
                ax3 = plt.subplot(3,1,3)
                for cnt, item in enumerate(val_fprs):
                    ax3.plot(item)  # , label=''.format(results[cnt]['l_rate']))
                ax3.grid(True)
                ax3.set_xlabel('Epochs')
                ax3.set_ylabel('Avg. Validation FPR at 95% TPR')
                plt.tight_layout()
    
                # Test for AUROC
                fig2 = plt.figure(2)
                ax = plt.subplot(2, 2, 1)
                ax.plot(data_out_AUC['loss history'], label='Train Loss total')
                ax.plot(data_out_AUC['val loss history'], label='Val Loss total')
                # ax.plot(data_out_AUC['loss history pos'], label='Train Loss pos')
                # ax.plot(data_out_AUC['loss history neg'], label='Train Loss neg')
                ax.set_xlabel('Epochs')
                ax.legend()
                ax.grid(True)
                
                ax = plt.subplot(2, 2, 2)
                # ax.plot(data_out_AUC['loss history'], label='Train Loss total')
                ax.plot(data_out_AUC['loss history pos'], label='Train Loss pos')
                ax.plot(data_out_AUC['loss history neg'], label='Train Loss neg')
                ax.set_xlabel('Epochs')
                ax.legend()
                ax.grid(True)
                
    
                ax2 = plt.subplot(2, 2, 3)
                ax2.plot(data_out_AUC['auc history'], label='AUROC for validation')
                ax2.plot(data_out_AUC['test auc'], label='AUROC for test')
                ax2.set_ylim([0,1])
                ax2.set_xlabel('Epochs')
                ax2.legend()
                ax2.grid(True)
    
                ax3 = plt.subplot(2, 2, 4)
                ax3.plot(data_out_AUC['fpr history'], label='FPR at 95% TPR for validation')
                ax3.plot(data_out_AUC['test fpr'], label='FPR at 95% TPR for test')
                ax3.set_xlabel('Epochs')
                ax3.legend()
                ax3.grid(True)
                plt.suptitle('Testing for model trained for best AUROC')
                plt.tight_layout()
    
                # Test for FPR
                fig3 = plt.figure(3)
                ax = plt.subplot(3, 1, 1)
                ax.plot(data_out_fpr['loss history'], label='Train Loss total')
                ax.plot(data_out_fpr['loss history pos'], label='Train Loss pos')
                ax.plot(data_out_fpr['loss history neg'], label='Train Loss neg')
                ax.set_xlabel('Epochs')
                ax.legend()
                ax.grid(True)
    
                ax2 = plt.subplot(3, 1, 2)
                ax2.plot(data_out_fpr['auc history'], label='AUROC for validation')
                ax2.plot(data_out_fpr['test auc'], label='AUROC for test')
                ax2.set_xlabel('Epochs')
                ax2.legend()
                ax2.grid(True)
    
                ax3 = plt.subplot(3, 1, 3)
                ax3.plot(data_out_fpr['fpr history'], label='FPR at 90% TPR for validation')
                ax3.plot(data_out_fpr['test fpr'], label='FPR at 90% TPR for test')
                ax3.set_xlabel('Epochs')
                ax3.legend()
                ax3.grid(True)
                plt.suptitle('Testing for model trained for best FPR')
                plt.tight_layout()
    
                full_dict_list.append(data_dict.copy())
    
                if enable_save: # save the figures only at this stage
                    fig_dict = {'eps{0:.4f}_training'.format(eps): fig1,
                                'eps{0:.4f}_val_auc'.format(eps): fig2,
                                'eps{0:.4f}_val_fpr'.format(eps): fig3,
                                'eps{0:.4f}_test_roc'.format(eps): fig4}
                    save_DVICres(data_dict=full_dict_list, basefolder=base_folder,
                                 dset_name=dset_name,
                                 alg='DVICmat',#nbins_str+'_'+str(args.nbins_exp),
                                 seed=seed, notion=notion,
                                 sim_type=sim_type, figs=fig_dict, figsonly=False)

    plt.figure(5)
    for cnt in range(len(data_dict['Test AUC']['test rocs'])):
        plt.plot(data_dict['Test AUC']['test rocs'][cnt][0],
                 data_dict['Test AUC']['test rocs'][cnt][1])
