#!/usr/bin/env python3
# -*- coding: utf-8 -*-



from DVIC_misc import load_DVICres
import torch
from networks import DeepNetBN
import numpy as np
from matplotlib import pyplot as plt
from matplotlib import cm
from networks import DVIC_net as DVIC_DNNnet
from networks import load_reg_model
from getdata import load_reg_result, load_Qnet
import pandas as pd
import pickle 
import os


save_results = False
reg_folder = 'runs_reg_900'

reuse= 'regtrain'


dsets = [ 'yacht']#, 'energy', 'concrete']
# dsets = ['yacht', 'boston', 'concrete','energy']#, 'energy', 'concrete', 'wine', 'kin8nm', 'power', 'naval']
# dsets = ['yacht']
basefolder1 = 'DVIC_regtrain90_5fold_r5_fpr_r3' 
basefolder2 = 'DVIC_regtrain90_3fold_r5_fpr_r2expr2'

# dsets = ['yacht', 'boston', 'toy', 'energy', 'concrete']#,#'wine', 'kin8nm', 
          # 'naval','power']
# dsets = ['yacht', 'boston', 'toy', 'energy', 'concrete', 'wine', 'kin8nm']
dsets = ['wine', 'kin8nm', 'power','naval']
# dsets=['yacht', 'boston', 'concrete', 'energy']

alg = 'DVICmat'
out_folder = 'paper'
notion = 'rel'
sim_type = 'outY'

rng_split = False
device = 'cpu'


seeds = [1,2,3,4,5,6,7,8,9,10]
# seeds = [4]

out = []
totout = []


out_dict = {}
out_dict2 = {}


for dset_name in dsets:
    # print('\n')
    print(dset_name)
    eps_found=[]
    
    
    
    # Allowed values of epsilon
    if dset_name=='energy':
        # eps_vals = np.array([0.025, 0.0500, 0.075, 0.1])
        eps_vals = np.array([0.01, 0.02, 0.03, 0.04, 0.05])# get_eps_vals(dset_name, notion, sim_type)
    elif dset_name =='naval':
        eps_vals = np.array([0.0005, 0.001, 0.0015])
    elif dset_name == 'power':
        eps_vals= np.array([0.005, 0.01, 0.025, 0.05])
    elif dset_name == 'yacht':
        eps_vals = np.array([0.05, 0.075, 0.1, 0.15, 0.2, 0.25, 0.3])
    else:
        eps_vals = np.array([0.05, 0.075, 0.1, .15, .2])
    
    
    df = pd.DataFrame(columns=[str(eps) for eps in eps_vals])
    df2 = pd.DataFrame(columns=[str(eps) for eps in eps_vals])
    data_dict = {} 
    for seed in seeds:
        try:
            Q_res = load_DVICres(basefolder1, dset_name, alg, seed, notion, sim_type, reuse, rng_split)
        except:
            Q_res = load_DVICres(basefolder2, dset_name, alg, seed, notion, sim_type, reuse, rng_split)
        ind = ['seed']
        out = [int(seed)]
        out2 = [int(seed)]
        for item in Q_res: # a value of Epsilon 
            
        
            if str(item['conf']['eps']) in df.columns:
                
                ind.append(str(item['conf']['eps']))
                out.append(item['Test AUC']['Test roc auc'])
                
                ind_opt = np.argmin(item['Test AUC']['val loss history'])
                out2.append(item['Test AUC']['test auc'][ind_opt])
                
                print(item['Test AUC']['best']['exp_error'], item['Test AUC']['best']['wd'])
        
        df = pd.concat((df, pd.DataFrame([out], columns=ind)))
            # print(item['Test']['best']['exp_error'])
        df['seed'] = df['seed'].apply(lambda x : int(x))
        
        df2 = pd.concat((df2, pd.DataFrame([out2], columns=ind)))
        df2['seed'] = df2['seed'].apply(lambda x: int(x))
        
        # print('Seed={0}'.format(seed))
        # print(eps_found)
    df = df.set_index('seed')   
    df2 = df2.set_index('seed')   
    out_dict[dset_name] = {
                'conf': Q_res[0]['conf'],
                'AUROC': df, # AUROC matrix obtained
                }

    out_dict2[dset_name] = {
                'conf': Q_res[0]['conf'],
                'AUROC': df2, # AUROC matrix obtained
                } 

if save_results:
    with open(os.path.join(out_folder, 'out_resultsDVIC.pickle'), "wb") as output_file:
        pickle.dump(out_dict, output_file)


for item in out_dict.keys():
    print(item)
    print(out_dict[item]['AUROC'].mean().values)
    print(out_dict2[item]['AUROC'].mean().values)
    # print(out_dict[item]['AUROC'].std().values)

# mat = np.array(totout)
# print(np.mean(mat, axis=0))

# # Check with waiting 10 epochs
# seeds = [1,2,3,4,5,6,7,8,9,10]
# out = []
# totout = []
# for dset_name in dsets:
#     # print('\n')
#     print(dset_name)
#     eps=''
#     for seed in seeds:
#         Q_res = load_DVICres(basefolder, dset_name, alg, seed, notion, sim_type, reuse, rng_split)
        
#         for item in Q_res:
#             # Q_res = Q_res[0]
            
#             val_auc = item['Test']['auc history']
#             test_auc = item['Test']['test auc']
            
#             ind = np.argmax(val_auc[10:])
            
            
#             out.append(test_auc[10:][ind])
#             # out.append(item['Test']['Test roc auc'])
#             eps += '{0} '.format(item['conf']['eps'])
#             # print(item['Test']['best']['exp_error'])
            
#         # print('Esp={0}'.format(item['conf']['eps']))
#         print('Seed={0}'.format(seed))
#         print(eps)
#         for x in out:
#             print(x)
            
#         totout.append(out)
#         out = []
#         eps=''
        
# mat = np.array(totout)

# print(np.mean(mat, axis=0))
        
        
        


