#! /usr/bin/env python3
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import glob
import geopandas as gpd
import os
from joblib import Parallel, delayed


def apply_regression2(in_file, par_list_gen, nuts2 ):
    print(in_file)
    out_file = in_file.replace('Rsds', 'BNI').replace('.csv', '_v3.1.csv')
    if os.path.isfile(out_file):
        return 1
    
    par_list = sorted(glob.glob(par_list_gen))
    
    if not os.path.isdir(os.path.dirname(out_file)):
        print("making dir ", os.path.dirname(out_file))
        try:
            os.makedirs(os.path.dirname(out_file))
        except Exception as ex:
            print(ex)
        
    rad_df =  pd.read_csv(in_file, index_col = 0)
    rad_df.index = pd.to_datetime(rad_df.index)
    
    rad_nuts = set(rad_df.columns)
    
    par_nuts = set([pp.split('_')[-1][:4] for pp in par_list])
    
    if 'NUTS2_Europe' in in_file:
        nearest_nuts = get_nearest(nuts2, par_nuts, rad_nuts)
        nearest_nuts = nearest_nuts.loc[rad_df.columns]
        # param = join_param(par_list, nearest_nuts)
    else:
        nearest_nuts = pd.DataFrame(rad_df.columns.str[:4], index = rad_df.columns, columns = ['nearest'])
        # param = join_param(par_list, nearest_nuts)
    
    num_cores = 32
    parallel_input = [(rad_df[nuts], par_list_gen.replace('*', nearest_nuts.loc[nuts].item())) for nuts in nearest_nuts.index]
    bni_df_list = Parallel(n_jobs=num_cores, verbose=10)(delayed(percentile_correction)(*args) for args in parallel_input)
    
    bni = pd.concat(bni_df_list, axis=1)
    bni[bni<0] = 0
        
    solar_const = 1361
    bni[bni>solar_const] = solar_const
    
    bni.index = bni.index.strftime('%Y-%m-%d-%H')
    bni.to_csv(out_file, float_format='%.2f')
    
    return 0

def percentile_correction(ghi_mod, bni_pkl):
    nuts = bni_pkl.split('_')[-1][:-4]
    # try:
        
    print(nuts)
    bni_obs = read_bni_pickle(bni_pkl)
    period_idx = pd.MultiIndex.from_tuples(list(zip(*[ghi_mod.index.dayofyear//8, ghi_mod.index.hour])))
    
    bni_obs_sort = bni_obs.groupby(bni_obs.index).apply(np.sort)
    bni_bin_size = (bni_obs_sort.apply(len)-1).reindex(period_idx).fillna(0)
    
    pct = ghi_mod.groupby(period_idx).rank(pct=True)
    pct_pos = (pct*(bni_bin_size.values)).round().astype(int)
    
    
    bni_out = pd.Series(index = ghi_mod.index, dtype=float, name=ghi_mod.name)
    low_threshold = 10
    for wh in sorted(set(bni_obs_sort.index)):
        # print(wh)
        mask = (period_idx == wh) & (ghi_mod.values > low_threshold) & (bni_bin_size.values > 0)
        bni_out.loc[mask] = bni_obs_sort[wh][pct_pos.values[mask]]
        
    bni_out.fillna(ghi_mod, inplace=True)

    return bni_out
    # except Exception as ex:
    #     print(ex)
    #     print('ERROR in' +  nuts)
    

def get_nearest(nuts2, par_nuts, rad_nuts):
    nuts_cent = nuts2.centroid.to_frame()
    nuts_cent.columns = ['geometry']
    nuts_missing = nuts_cent.loc[rad_nuts - par_nuts]
    nuts_avail = nuts_cent.loc[rad_nuts - (rad_nuts - par_nuts)]

    nearest_nuts = pd.DataFrame(columns = ['nearest'])
    for nm in nuts_missing.index:
        nearest_nuts.loc[nm,'nearest'] = nuts_avail.distance(nuts_missing.loc[nm, 'geometry']).sort_values().index[0]
    for na in nuts_avail.index:
        nearest_nuts.loc[na,'nearest'] = na  
    
    return nearest_nuts
        

def read_bni_pickle(fname):
    # Wh/m**2/0.25h -> W/m**2
    df = pd.read_pickle(fname)['BNI']*4
    return df 
    
#     
# #NUTS2    
in_dir_par = "../Rad_Regression2/"
par_list_gen = in_dir_par + "Radiation_Regression_*.pkl"

in_dir_rad = "/path/to/Radiation/CSV_Data/NUTS2_Europe/"
rad_files = sorted(glob.glob(in_dir_rad  + '*/Rsds*csv'))

nuts2 = gpd.read_file("../GIS_SECURES/NUTS2_clipped.shp")
nuts2_p = nuts2.set_index('NUTS').to_crs('epsg:3035')

for ff in rad_files:
    apply_regression2(ff, par_list_gen, nuts2_p)

#NUTS3
in_dir_par = "../Rad_Regression2/"
par_list_gen = in_dir_par + "Radiation_Regression_*.pkl"
# par_list = glob.glob(in_dir_par + "Radiation_Regression_*.pkl")

in_dir_rad = "/path/to/Radiation/CSV_Data/NUTS3_Austria/"
rad_files = sorted(glob.glob(in_dir_rad  + '*/Rsds*csv'))

nuts2 = gpd.read_file("../GIS_SECURES/NUTS2_clipped.shp")
nuts2_p = nuts2.set_index('NUTS').to_crs('epsg:3035')

for ff in rad_files:
    apply_regression2(ff, par_list_gen, nuts2_p)






