# -*- coding: utf-8 -*-


import numpy as np
from matplotlib import pyplot as plt
import torch
from sklearn.preprocessing import scale, minmax_scale
from sklearn.preprocessing import StandardScaler
import pickle
import os
import torch
          
            
def get_rng_dict(seed, rng):
    rng_dict = {'seed': seed, 'rng_state': rng.bit_generator.state,
                'torch_state': torch.get_rng_state()}
    
    if torch.cuda.is_available():
        rng_dict['cuda_state'] = torch.cuda.get_rng_state()
        
    return rng_dict


def set_seeds(seed):
    """ Set the global seed to the specified value 'seed' and returns a Numpy
    random generator with the indicated seed to be used in simulations. """

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    rng = np.random.default_rng(seed)
    return rng


def get_uci_data(dset_name):
    '''Returns the UCI datasets for the test. The input is 'dset_name' (str)'''
    
    data = np.loadtxt("UCI_Datasets" + "/{}.txt".format(dset_name))
    x_al = data[:, :-1]
    y_al = data[:, -1].reshape(-1, 1)
    return x_al, y_al

def get_toy_data(Nsamps, mode='data', rng=None, scale=None):
    mu0 = 0
    mu1 = 1.2
    sigma0 = 0.25
    sigma1 = 0.25
    p1 = 0.6
    mun = 0
    sigman = 0.05

    rng = np.random.default_rng(rng)
    ''' True regression function '''
    phiteo = lambda x: np.sin(5*x)*0.5+0.5

    if mode == 'data':
        ''' Noisy regression '''
        phi = lambda x: phiteo(x) + rng.normal(mun, sigman, x.size)

        M = (rng.uniform(0, 1, Nsamps) < p1)
        X = rng.normal(mu0, sigma0, Nsamps) * (1 - M) + rng.normal(mu1, sigma1, Nsamps) * M

        ''' y Training values'''
        Y = phi(X)

        return X.reshape(-1,1), Y.reshape(-1,1)

    elif mode == 'teo':
        x = np.linspace(-0.75, 2, 100)
        y = phiteo(x)
        x = scale[0].transform(x.reshape(-1,1))
        y = scale[1].transform(y.reshape(-1,1))
        return x, y
    elif mode == 'phiteo':
        return (phiteo, sigman)


def split_dset(x, y, test_frac, rng, ret_inds=False, ndraw=1):
    
    if ndraw > 1:
        for cnt in range(ndraw):
            ind_perm = rng.permutation(x.shape[0]) # permute the data
        ind_perm = rng.permutation(x.shape[0]) # permute the data
    else:
        ind_perm = rng.permutation(x.shape[0]) # permute the data
    indcut = int(np.round(x.shape[0]*test_frac)) #the first test_frac fraction go to the test set
    x_test = x[ind_perm[:indcut]]
    y_test = y[ind_perm[:indcut]]
    x_train = x[ind_perm[indcut:]]
    y_train = y[ind_perm[indcut:]]

    if ret_inds:
        return x_train, x_test, y_train, y_test, (ind_perm[indcut:], ind_perm[:indcut])
    else:
        return x_train, x_test, y_train, y_test


class Kfold_shuffler():
    def __init__(self, n_splits, dsize, rng, val_frac=0):

        ind_perm = rng.permutation(dsize) # generate a permutation for the splits
        
        if n_splits > 1:
            step = int(dsize // n_splits) # generate the size of the splits
            remain = int(dsize - step * n_splits) #remaining samples to distribute
            sizes = np.ones(n_splits)*step
            sizes[:remain] += 1 #size of each partition
    
            ind_start = np.cumsum(sizes).astype(int) #
            splits =[]
            for cnt in range(n_splits):
                if cnt == 0:
                    splits.append(ind_perm[:ind_start[cnt]])
                else:
                    splits.append(ind_perm[ind_start[cnt-1]:ind_start[cnt]])
            self.splits = splits
            self.n_splits = n_splits
        elif n_splits == 1:
            if val_frac >0:
                cut_ind = int(dsize * val_frac) 
                self.n_splits = n_splits 
                # first are training indexes, second are the validation.
                self.splits= [ind_perm[cut_ind:], ind_perm[:cut_ind]]
                
            else:
                raise Exception('If n_splits=1, val_frac > 0')

    def get_split_inds(self, n_split):
        
        if self.n_splits > 1:
            test_idx = self.splits[n_split]
            # tmp = [self.splits[cnt] for cnt in range(self.n_splits) if cnt != n_split]
            train_idx = np.concatenate([self.splits[cnt] for cnt in range(self.n_splits) if cnt != n_split])
        elif self.n_splits == 1:
            train_idx  = self.splits[0]
            test_idx = self.splits[1]
            
        return train_idx, test_idx


class Dset():
    """Basic dataset class to train a model with"""
    def __init__(self, XY, z):
        """Returns a dataset used to train a model. If data is 1D it will be 
        reshaped to have 2 dimensions
        
        Parameters
        ----------
        XY (np.ndarray): input data of shape (n_samples, n_features)
        z (np.ndarray): output data of shape (n_samples, n_features)
        
        Attributes
        ----------
        self.dset (torch.tensor): input data of shape (n_samples, n_features)
        self.out (torch.tensor): output data of shape (n_samples, n_features).
        """
        if not (type(XY)==torch.Tensor):
            self.dset = torch.tensor(XY)
        else:
            self.dset = XY
        if not (type(z) == torch.Tensor):
            self.out = torch.tensor(z)
        else:
            self.out = z
        
        if XY.ndim == 1:
            self.dset = self.dset.reshape((-1, 1))
        elif XY.ndim > 2:
            self.dset = self.dset.reshape((self.dset.shape[0], -1))
        
        if z.ndim == 1:
            self.out = self.out.reshape((-1, 1))
        
    def __getitem__(self, ind):
        data = self.dset[ind].float()
        lab = self.out[ind].float()
        return (data, lab)

    def __len__(self):
        return len(self.out)  
    

def load_regression_results(fname, rng_split=False):
    
    (path, ext) = os.path.splitext(fname)
    
    if ext == '':
        fname+='.pickle'
        
    
    if torch.cuda.is_available():
        if torch.cuda.device_count() == 1:
            res_dict = torch.load(fname, map_location='cuda:0')
        else:
            res_dict = torch.load(fname)
    else:
        res_dict = torch.load(fname, map_location='cpu')
        
    if rng_split:
        # if torch.cuda.is_available():
        #     if torch.cuda.device_count() == 1:
        #         rng_dict = torch.load(path+'_rngstate.pickle', map_location='cuda:0')
        #     else:
        #         rng_dict = torch.load(path+'_rngstate.pickle')
        # else:
        rng_dict = torch.load(path+'_rngstate.pickle', map_location='cpu')
                
        return res_dict, rng_dict
    else:
        return res_dict


def load_reg_result(dset_name, seed, reg_type='DNN', basedir='runs', 
                    rng_split=True):

    # Directory to find the regression models
    model_folders = basedir+'/'+dset_name+'_regressor_seed{0}_{1}'.format(seed,reg_type)
    # regressor configuration
    fname = model_folders+'/'+dset_name+'_reg_seed{0}'.format(seed)

    tmp_load = load_regression_results(fname, rng_split)
    
    res_dict = tmp_load[0]
    res_dict['Random'] = tmp_load[1]

    # with open(fname, "rb") as f:
    #     res_dict = pickle.load(f)

    # regressor best model parameters
    if reg_type == 'DNN':
        reg_par = model_folders + '/' + dset_name + '_regNet_seed{0}.pt'.format(seed)
        return res_dict, reg_par
    else:
        return res_dict

def save_reg_result(model, basefolder, dset_name, seed, data_dict, figs,
                    reg_type='DNN', split_rng=True):

    folder = basefolder + '/' + dset_name + '_regressor_seed{0}_{1}'.format(seed,reg_type)
    
    os.makedirs(basefolder, exist_ok=True)
    os.makedirs(folder, exist_ok=True)
    
    torch.save(model.state_dict(), folder + '/' + dset_name + '_regNet_seed{0}.pt'.format(seed))

    fname = folder + '/' + dset_name + '_reg_seed{0}.pickle'.format(seed)

    # Store the cuda random number generator separately
    if split_rng:
        rng_state = data_dict['Random']
        rng_fname = folder + '/' + dset_name + '_reg_seed{0}_rngstate.pickle'.format(seed)
        del data_dict['Random']
        
        torch.save(rng_state, f=rng_fname, 
                   pickle_protocol=pickle.HIGHEST_PROTOCOL)       

    torch.save(data_dict, f=fname, pickle_protocol=pickle.HIGHEST_PROTOCOL)       
    
    for key in figs.keys():
        figs[key].savefig(os.path.join(folder, dset_name + '_'+key+'.pdf'),
                          format='pdf')

    txt_dump = "Model type: {0}\nData set: {1}\n".format(reg_type, dset_name)
    txt_dump += "Seed: {0} - Epochs: {1}\n".format(seed, 
                                                 data_dict['conf']['nepochs'])
    txt_dump += 'Splits: Train: {0}, Stop: {1}, Sup {2}\n'.format(
        data_dict['conf']['train_split'], data_dict['conf']['stop_split'],
        data_dict['conf']['sup_split'])
    txt_dump += "Model parameters: {0}\n".format(data_dict['Train']['best pars'])
    txt_dump += "Test set MAE: {0:4f}".format(data_dict['Test MAE'])
    
    f = open(os.path.join(folder,"output_log.txt"), "w")
    f.write(txt_dump)
    f.close()
    

    return folder

def load_Qnet(basefolder, dset_name, alg, seed, sim_type, reuse, split_rng):

    fname = "{0}_{1}_{4}_{2}_seed{3}".format(dset_name, alg, reuse, seed, sim_type)
    folder = os.path.join(basefolder, fname)

    qres_folder = os.path.join(folder, 
                               dset_name + '_{1}_seed{0}'.format(seed, sim_type))
    
    temp = load_regression_results(qres_folder, split_rng)
    Q_res = temp[0]
    Q_res['Random'] = temp[1]
    Q_pars = os.path.join(folder, dset_name + '_QNet_seed{0}_{1}.pt'.format(seed, sim_type))

    return Q_res, Q_pars




def save_Qnet(model, data_dict, basefolder, dset_name, alg, seed, sim_type, 
              figs, split_rng=True):
    """Saves the result of a quantile network which is used to estimate the probability of being epsilon good.

    Parameters
    ----------
    model: the Quantile network estimated (a NN)
    data_dict: dictionary with results of the training
    basefolder: base folder where all the Q networks are stored
    dset_name: name of the dataset where the Q network was trained
    alg: algorithm used to estimte the quantiles, for example SQR
    seed: to generate all the data
    sim_type: over what the quantiles where computed, for example, the input Y or the absolute error |y-f(x)|
    figs: a dictionary of figures to store

    Returns
    -------
    Nothing but stores in a pickle file the results of the simulations and in a .pt file the parameters of the network.
    This information can be used to estimate Pg afterwards."""

    #folder name where to store the results.
    fname = "{0}_{1}_{4}_{2}_seed{3}".format(dset_name, alg, 
                                             data_dict['conf']['split_use'], 
                                             seed, sim_type)
    folder = os.path.join(basefolder, fname)
    # folder = basefolder + '/' + dset_name + '_' + alg + '_' + sim_type + '_seed{0}'.format(seed)

    os.makedirs(basefolder, exist_ok=True)
    os.makedirs(folder, exist_ok=True)

    # Pytorch model
    model_fname = dset_name + '_QNet_seed{0}_{1}.pt'.format(seed, sim_type)  # file name for the pytorch model
    
    torch.save(model.state_dict(), os.path.join(folder, model_fname))

    # Simulation results
    # Store the cuda random number generator separately
    if split_rng:
        rng_state = data_dict['Random']
        
        rng_fname = os.path.join(folder, 
                                  dset_name+'_{0}_seed{1}'.format(sim_type, 
                                                                  seed)+
                                  '_rngstate.pickle')
        
        # cuda_fname = folder + '/' + dset_name + '_seed{0}_cudastate.pickle'.format(seed)
        del data_dict['Random']
        torch.save(rng_state, f=rng_fname, pickle_protocol=pickle.HIGHEST_PROTOCOL)
            
    # with open( 
    #           "wb") as f:
    #     pickle.dump(data_dict, f, protocol=pickle.HIGHEST_PROTOCOL)
    
    tmp_name = os.path.join(folder, dset_name + '_{1}_seed{0}.pickle'.format(seed,
                                                                             sim_type))
    torch.save(data_dict, tmp_name, pickle_protocol=pickle.HIGHEST_PROTOCOL)

    #Figures
    for key in figs.keys():
        figs[key].savefig(os.path.join(folder, dset_name + '_'+key+'.pdf'), format='pdf')

    txt_dump = "Model type: {0}\nData set: {1}\n".format(alg, dset_name)
    txt_dump += "Seed: {0} - Epochs: {1}\n".format(seed, 
                                                 data_dict['conf']['nepochs'])
    txt_dump += "Model parameters: {0}\n".format(data_dict['Train']['best pars'])   
    
    txt_dump += '\n\n---- Regressor parameters ----\n'
    txt_dump += 'Splits: Train: {0}, Stop: {1}, Sup {2}\n'.format(
        data_dict['reg']['res']['conf']['train_split'],
        data_dict['reg']['res']['conf']['stop_split'],
        data_dict['reg']['res']['conf']['sup_split'])
    txt_dump += "Regressor parameters: {0}\n".format(
        data_dict['reg']['res']['Train']['best pars'])
    txt_dump += "Epochs: {0}\n".format(data_dict['reg']['res']['conf']['nepochs'])
    txt_dump += "Test set MAE: {0:4f}".format(data_dict['reg']['res']['Test MAE'])
    

    f = open(os.path.join(folder,"output_log.txt"), "w")
    f.write(txt_dump)
    f.close()

def split_rngstate(fname):
    """Loads a dictionary and extracts the cuda random data and stores it in
    a different file.
    
    Upon loading the data_dict from the file, the key to separate is
    data_dict['Random']['cuda_state']
    
    The original file is backed up.
    
    Parameters:
    ---------------
    fname: full path of the pickle file with or without .pickle as extension.
        
    """
    (path, ext) = os.path.splitext(fname)
    
    # Load the original file    
    with open(path+'.pickle', "rb") as f:
        data_dict = pickle.load(f) 
        
    # Rewrite it as backup
    with open(path+'_bkp.pickle', "wb") as f:
        pickle.dump(data_dict, f, protocol=pickle.HIGHEST_PROTOCOL)
        
    # Extract cuda random state data:
    rng_state = data_dict['Random']
    rng_fname = path +'_rngstate.pickle'   
    del data_dict['Random']
    
    # with open(path+'_rngstate.pickle', "wb") as f:
    torch.save(rng_state, f=rng_fname, pickle_protocol=pickle.HIGHEST_PROTOCOL)
        
    # # save data dictionary again without the cuda random state
    # with open(path+'.pickle', "wb") as f:
    #     pickle.dump(data_dict, f, protocol=pickle.HIGHEST_PROTOCOL)
    torch.save(data_dict, f=path+'.pickle', pickle_protocol=pickle.HIGHEST_PROTOCOL)
    
