import pickle
from getdata import load_regression_results
import os
import torch

def load_DVICres(basefolder, dset_name, alg, seed, notion, sim_type, reuse,
                 rng_split):

    fname = "{0}_{1}_{4}_{5}_{2}_seed{3}".format(dset_name, alg, reuse, seed, sim_type, notion)

    folder = os.path.join(basefolder, fname)

    # folder = basefolder + '/' + dset_name + '_' + alg + '_' + sim_type + '_seed{0}'.format(seed)
    pathQ = os.path.join(folder, dset_name + '_{1}_seed{0}.pickle'.format(seed, sim_type))
    Q_res = load_regression_results(pathQ, rng_split)

    return Q_res
    

def load_DVICres_old(basefolder, dset_name, alg, seed, sim_type, reuse):

    fname = "{0}_{1}_{4}_{2}_seed{3}".format(dset_name, alg, reuse, seed, sim_type)

    folder = os.path.join(basefolder, fname)

    # folder = basefolder + '/' + dset_name + '_' + alg + '_' + sim_type + '_seed{0}'.format(seed)
    Q_res = load_regression_results(os.path.join(folder, dset_name + '_{1}_seed{0}.pickle'.format(seed, sim_type)))

    # Q_pars = folder + '/' + dset_name + '_QNet_seed{0}_{1}.pt'.format(seed, sim_type)
    # Q_pars = os.path.join(folder, dset_name + '_DVIC_seed{0}_{1}.pt'.format(seed, sim_type))
    # folder + '/' + dset_name + '_{1}_seed{0}.pickle'.format(seed, sim_type)
    return Q_res
    

def save_DVICres(data_dict, basefolder, dset_name, alg, seed, notion, sim_type, figs, figsonly=False):
    """Saves the result of the DVIC simulations

    Parameters
    ----------
    data_dict: dictionary with results of the training
    basefolder: base folder where all the Q networks are stored
    dset_name: name of the dataset where the Q network was trained
    alg: algorithm used to estimate the quantiles, for example SQR
    seed: to generate all the data
    sim_type: over what the quantiles where computed, for example, the input Y or the absolute error |y-f(x)|
    figs: a dictionary of figures to store

    Returns
    -------
    Nothing but stores in a pickle file the results of the simulations and in a .pt file the parameters of the network.
    This information can be used to estimate Pg afterwards."""

    #folder name where to store the results.
    fname = "{0}_{1}_{4}_{5}_{2}_seed{3}".format(dset_name, alg, data_dict[0]['conf']['DVIC_dset'], seed, sim_type,
                                                 notion)
    folder = os.path.join(basefolder, fname)
    # folder = basefolder + '/' + dset_name + '_' + alg + '_' + sim_type + '_seed{0}'.format(seed)

    os.makedirs(basefolder, exist_ok=True)
    os.makedirs(folder, exist_ok=True)

    # Pytorch model
    # model_fname = dset_name + '_CNet_seed{0}_{1}.pt'.format(seed, sim_type)  # file name for the pytorch model
    # torch.save(model.state_dict(), os.path.join(folder, model_fname))

    # Simulation results
    if not figsonly:
        # with open(os.path.join(folder, dset_name + '_{1}_seed{0}.pickle'.format(seed, sim_type)), "wb") as f:
        #     pickle.dump(data_dict, f, protocol=pickle.HIGHEST_PROTOCOL)
        outname = os.path.join(folder, dset_name + 
                               '_{1}_seed{0}.pickle'.format(seed, sim_type))
        torch.save(data_dict, f=outname, pickle_protocol=pickle.HIGHEST_PROTOCOL)  
    #Figures
    for key in figs.keys():
        figs[key].savefig(os.path.join(folder, dset_name + '_'+key+'.pdf'), format='pdf')
        
    # if split_rng:
        # rng_state = data_dict['Random']
        # rng_fname = folder + '/' + dset_name + '_reg_seed{0}_rngstate.pickle'.format(seed)
        # del data_dict['Random']
        
    # torch.save(data_dict, f=rng_fname, 
    #            pickle_protocol=pickle.HIGHEST_PROTOCOL)       

         
