
import torch
import argparse
import numpy as np
from torch.autograd import Variable
from matplotlib import pyplot as plt
from sklearn.metrics import auc, roc_curve
from networks import DVIC_net as DVIC_DNNnet
from torch.utils.data import DataLoader as dl
from DVIC_misc import load_DVICres, save_DVICres
from main_DVIC_Qestimates import Hfunction, get_eps_good, fpr_at_fixed_tpr
import pickle
from torch.utils.data import TensorDataset as tds
from networks import get_sim_device, load_reg_model, DeepNetBN
from getdata import Dset, split_dset, load_reg_result, load_Qnet
import os
from tqdm import tqdm
from gen_DVIC_results_and_tables import get_allowed_eps
from knife_samples_generation import restore_KNIFEnet
from conditional_gaussian import restore_cond_gaussian


def create_noisy_data(data, magnitude, test_pos_samps, device, DVICnet, Qnet, reg_network, sim_conf, **kwargs):
    tensor_ds = tds(data, torch.tensor(test_pos_samps, dtype=torch.float32))
    tensor_dl = dl(tensor_ds, batch_size=1, shuffle=False)

    new_data = []
    for idx, (x, _) in enumerate(tensor_dl):
        data = Variable(x, requires_grad=True).to(device)
        if magnitude > 0:
            if "verbose" in kwargs and kwargs["verbose"]:
                print(f"Processing batch {idx} with magnitude: {magnitude}")
            H = Hfunction(DVICnet, Qnet, reg_network, data,
                          sim_conf, device, gen_type='cat', requires_grad=True)

            # assert H has only one dimension
            assert len(H.shape) == 1

            if H.shape[0] > 1:
                H.mean().backward()
            else:
                H.backward()

            assert torch.equal(data.grad, data.grad.data)   
            # assert data.grad is not None
            assert data.grad is not None

            new_data.append(data - magnitude * torch.sign(-data.grad))
        else:
            new_data.append(data)

    new_data = torch.vstack(new_data)

    if "verbose" in kwargs and kwargs["verbose"]:
        print(f"new_data: {new_data}")

    return new_data.detach().cpu().numpy()


# def main():
if __name__ == '__main__':
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--dset_name', type=str, default='yacht')
    parser.add_argument('--device', type=str, default='cpu')
    parser.add_argument('--reg_dir', type=str, default='runs_reg_900')
    parser.add_argument('--algQ', type=str, default='conditional_gaussian') #KNIFE conditional_gaussian
    parser.add_argument('--conf_type', type=str, default='absE_abs',
                        help='Possibilities are outY_abs, outY_rel, absErel_rel, absE_abs')
    # parser.add_argument('--ind', type=int, default=2)
    parser.add_argument('--magnitudes', type=float, nargs='+',
                        default=np.array(np.linspace(0,4e-3,31)),
                        help='Possible noise magnitudes to add')

    seeds = np.array([1,2,3])
        
        
    # basic setup do not change
    dset_reuseQ = 'regtrain'
    algD = 'DVICmat'
    save_results = True

    args = parser.parse_args()
    device_t = args.device
    device = get_sim_device(device_t)  # torch.device('cpu')
    conf_type = args.conf_type

    if args.dset_name == 'all':
        dsets = ['yacht', 'boston', 'energy', 'concrete','wine',
                 'kin8nm', 'power', 'naval']
    else:
        dsets = [args.dset_name]
        
    reg_dir = args.reg_dir
    out_folder = 'paper3' # where the baselines are stored
    outfile = 'noise_results' # where we save the output of noise files
    DVIC_results = 'paper_results'
    out_dict = {}
    rng_split = True
    
    # Folders where the models to generate 
    if args.algQ == 'SQR':
        qfolder = 'runs_SQR_900_regtrain'
    elif args.algQ == 'KNIFE':
        qfolder = 'runs_KNIFE_900_regtrain4_256'
    elif args.algQ == 'conditional_gaussian':
        qfolder = 'runs_ensembles_900_regtrain3'

    nunif = 20000
    batch_size = 5000
    
    
    # Load the current results in the paper.
    with open(os.path.join(out_folder, DVIC_results+'.pickle'), "rb") as f:
        data_dict = pickle.load(f) 

    if args.algQ == 'SQR':
        if conf_type in ['absE_abs', 'outY_abs']:
            DVICfolder = 'DVIC_regtrain90_SQRabs'
        elif conf_type in ['absErel_rel', 'outY_rel']:
            DVICfolder = 'DVIC_regtrain90_SQRrel'
    
    elif args.algQ == 'KNIFE':
        if conf_type in ['absE_abs', 'outY_abs']:
            DVICfolder = 'DVIC_regtrain90_KNIFEabs'
        elif conf_type in ['absErel_rel', 'outY_rel']:
            DVICfolder = 'DVIC_regtrain90_KNIFErel'
            
    elif args.algQ == 'conditional_gaussian':
        if conf_type in ['absE_abs', 'outY_abs']:
            DVICfolder = 'DVIC_regtrain90_ENSabs'
        elif conf_type in ['absErel_rel', 'outY_rel']:
            DVICfolder = 'DVIC_regtrain90_ENSrel'
    else:
        raise Exception('Unknown algorithm or sim type')
        
    
    (sim_type, notion) = conf_type.split('_')
    
    for dset_name in dsets:
        eps_vals = get_allowed_eps(dset_name, notion=notion)
           
        # seeds = np.array(data_dict[dset_name][DVICfolder]['AUROC']['good seeds'][:10])+1
        # eps_vals = data_dict[dset_name][DVICfolder]['dfauc'].columns.values

        ##### SETUP #####
        plt.close('all')
        
        fig, axs1 = plt.subplots(2, 1) # to plot averages
        fig2, axs2 = plt.subplots(3, 1) # to plot for each eps all the seeds

        for cnt_eps, eps in enumerate(eps_vals):
            FPRs= np.zeros((seeds.size, args.magnitudes.size))
            AUROCs = np.zeros((seeds.size, args.magnitudes.size))
            

            for cnt,seed in enumerate(seeds):
        
                (reg_res, reg_par) = load_reg_result(dset_name, seed,
                                                     reg_type='DNN',
                                                     basedir=reg_dir,
                                                     rng_split=True)

            
                # print(reg_res.keys())
                # print(reg_par)
            
                train_x = reg_res['dataset']['train x']
                train_y = reg_res['dataset']['train y']
                x_test = reg_res['dataset']['test x']
                y_test = reg_res['dataset']['test y']
                x_val = reg_res['dataset']['stop x']
                y_val = reg_res['dataset']['stop y']
                x_sup = reg_res['dataset']['sup x']
                y_sup = reg_res['dataset']['sup y']
            
                reg_network = load_reg_model(n_in=train_x.shape[1], n_out=train_y.shape[1],
                                             reg_res=reg_res,
                                             reg_par=reg_par,
                                             device=device)  # instantiate model, and architecture, load parameters internally
            
                reg_network.to(device)
            
                ##### GET QNET: THE NETOWRK TRAINED WITH THE PNBALL LOSS TO ESTIMATE THE CDF #####
            
                # qfolder = args.qfolder
                algQ = args.algQ
                # sim_type = args.sim_type
                
                (Q_res, Q_pars) = load_Qnet(basefolder=qfolder,
                                            dset_name=dset_name,
                                            alg=algQ, seed=seed,
                                            sim_type=sim_type,
                                            reuse=dset_reuseQ,
                                            split_rng=True)
                
                if algQ == 'SQR':
                    # Load the quantile network to generate the samples

               
                   Qnet = DeepNetBN(n_in=x_test.shape[1] + 1,
                                    n_out=1,
                                    n_hlayers=Q_res['conf']['n_hlayers'],
                                    n_inner_neurons=Q_res['conf']['n_inner_neurons'])
               
                   Qnet.to(device)
               
                   Qnet.load_state_dict(torch.load(Q_pars, map_location=device))
                elif algQ == 'KNIFE':
                    Qnet = restore_KNIFEnet(device, Q_res, 
                                            Q_pars, 
                                            in_dim=x_test.shape[1]) 
                    
                elif algQ == 'conditional_gaussian':
                    Qnet = restore_cond_gaussian(device, Q_res,
                                                 Q_pars,
                                                 in_dim=x_test.shape[1]) 
                    
                Qnet.to(device)
                
                
                

            
                load_conf = {'alg': algD, 'notion': notion,
                             'sim_type': sim_type, 'reuse': dset_reuseQ, 'rng_split': False,
                             'folder': DVICfolder}
            
                DVres = load_DVICres(load_conf['folder'], dset_name, load_conf['alg'],
                                     seed, load_conf['notion'], load_conf['sim_type'],
                                     load_conf['reuse'], load_conf['rng_split'])
            
                ind = 0
                while DVres[ind]['conf']['eps'] != eps:
                    ind+=1
                
                # ind = args.ind  # index of the the epsilon value
                # eps = DVres[ind]['conf']['eps']
                
                DVICnet = DVIC_DNNnet(n_in=2,
                                      n_out=1,
                                      n_hlayers=DVres[ind]['conf']['n_hlayers'],
                                      n_inner_neurons=DVres[ind]['conf']['n_inner_neurons'])
            
                DVICnet.load_state_dict(DVres[ind]['Test AUC']['best params auroc'])
            
                DVICnet.to(device)
            
                
                musig = (reg_res['scale']['mu_y_train'] /
                         reg_res['scale']['scale_y_train'])[0]
            
                with torch.no_grad():
                    reg_network.eval()
                    test_y_hat = reg_network(torch.tensor(x_test, dtype=torch.float32,
                                                          device=device)).cpu().numpy()
            
                    test_pos_samps = get_eps_good(y_test, test_y_hat, DVres[ind]['conf']['eps'],
                                                  notion=notion, musig=musig)  # =1 Eps bad, =0 Eps good
            
                data = torch.tensor(x_test, dtype=torch.float32).to(device)
                sim_conf = {'nunif': nunif,
                            'zoom_func': DVres[ind]['conf']['zoom_func'],
                            'notion': DVres[ind]['conf']['notion'],
                            'gamma': DVres[ind]['conf']['gamma'],
                            'musig': musig,'qtype': args.algQ,
                            'type': sim_type}
            
                pbar = tqdm(args.magnitudes)
                for cnt2, magnitude in enumerate(pbar):
                    noisy_x = create_noisy_data(data, magnitude, test_pos_samps, device,
                                      DVICnet, Qnet, reg_network, sim_conf, verbose=False)
            
                    DVICnet.eval()
                    Qnet.eval()
                    testDS = Dset(noisy_x, test_pos_samps)
                    with torch.no_grad():
                        if sim_conf['nunif']<= 10000:              
                            # get score of the validation set
                            scores = Hfunction(DVICnet, Qnet, reg_network,
                                               torch.tensor(noisy_x,
                                                            dtype=torch.float32,
                                                            device=device),                                                       
                                               sim_conf, device,requires_grad=False).cpu().numpy()
                        else:
                            scores = []
                            test_loader = torch.utils.data.DataLoader(testDS, batch_size=batch_size,
                                                                    shuffle=False, num_workers=0,
                                                                    pin_memory=True)
                            
                            for data_b, labels in test_loader:
                                scores.append(Hfunction(DVICnet, Qnet, reg_network,
                                                   data_b.to(device), sim_conf,
                                                   device, requires_grad=False).cpu().numpy())
                            scores = np.concatenate(scores)
                
                
                        tmp_nc = testDS.out.sum()
                        labels = testDS.out.flatten()
                        
                        if (tmp_nc>0) and (tmp_nc < labels.size()[0]):
                            (fprs, tprs, thrs) = roc_curve(labels, scores)
                            roc_auc = auc(fprs, tprs)
                            AUROCs[cnt, cnt2] = roc_auc
                            # test_roc.append((fprs, tprs))
                            (fpr, _, _) = fpr_at_fixed_tpr(fprs, tprs, thrs, 0.9)
                            FPRs[cnt, cnt2] = fpr
                            pbar.set_description('Mag: {3:.2f} - eps={2} - AUROC = {0:.1f} - FPR at TPR 90%= {1:.2f}'.format(roc_auc*100,
                                                                                               fpr,
                                                                                               eps,
                                                                                               magnitude))
                            # plt.plot(fprs, tprs, label='Mag: {0:.2f}'.format(magnitude))
                        else:
                            raise Exception('No positive or negative labels')
                        # else:
                        #     true_test_auc.append(0)
                        #     true_test_fpr.append(1)
                        #     test_roc.append((0, 0))
                axs2[cnt_eps].plot(args.magnitudes, AUROCs[cnt])         
            mu_auc = AUROCs.mean(axis=0)
            mu_fpr = FPRs.mean(axis=0)
            std_auc = AUROCs.std(axis=0)
            std_fpr = FPRs.std(axis=0)
            
            out_dict[dset_name] = {}
            out_dict[dset_name][eps] = {}
            out_dict[dset_name][eps]['AUROCs'] = AUROCs
            out_dict[dset_name][eps]['FPRs'] = FPRs
            out_dict[dset_name][eps]['mu AUC'] = mu_auc
            out_dict[dset_name][eps]['mu FPR'] = mu_fpr
            out_dict[dset_name][eps]['std AUC'] = std_auc
            out_dict[dset_name][eps]['std FPR'] = std_fpr
            out_dict[dset_name][eps]['seeds'] = seeds
            out_dict[dset_name][eps]['mag'] = args.magnitudes
            
            
            axs1[0].plot(args.magnitudes, mu_auc, label='eps = {0}%'.format(eps))
            axs1[1].plot(args.magnitudes, mu_fpr, label='eps = {0}%'.format(eps))
            
            axs1[0].set_xlabel('Noise magnitude')
            axs1[0].set_ylabel('AUROC')
            axs1[1].set_xlabel('Noise magnitude')
            axs1[1].set_ylabel('FPR')
            axs1[0].fill_between(args.magnitudes, mu_auc-std_auc, mu_auc+std_auc, alpha=0.3)
            axs1[1].fill_between(args.magnitudes, mu_fpr-std_fpr, mu_fpr+std_fpr, alpha=0.3)
            axs1[0].grid(True)
            axs1[1].grid(True)
            
            axs1[0].plot(args.magnitudes, mu_auc, label='eps = {0}%'.format(eps))
            axs1[1].plot(args.magnitudes, mu_fpr, label='eps = {0}%'.format(eps))
            
            

            
            # ax.legend()
            # ax2.legend()
            plt.tight_layout()
            fig.savefig(os.path.join(out_folder, dset_name+'.pdf'), format='pdf')
            
            if save_results:
                with open(os.path.join(out_folder, outfile+'.pickle'), "wb") as f:
                    pickle.dump(out_dict, f, protocol=pickle.HIGHEST_PROTOCOL)
            
        # if __name__ == '__main__':
            # main()
