#!/usr/bin/env python3

import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
import glob
import datetime as dt
import percentile_nc as pnc
from joblib import Parallel, delayed
import wind_powercurves as wpc
import os

def score_at_percentile2(frequency, percentile, bins, time_axis = None, bin_axis = 0, split_time = 1, n_workers = 8, parallel = True):
    #TODO
    # check dimension match
    f_shp_tmp = list(frequency.shape)
    f_shp_tmp.pop(bin_axis)
    p_shp_tmp = list(percentile.shape)
    if time_axis != None:
        p_shp_tmp.pop(time_axis)
    print(f_shp_tmp, p_shp_tmp)
    assert f_shp_tmp == p_shp_tmp, "Spatial Dimensions, exept time or bin, of frequency and percentile respectively do not match"
    
    assert bins.ndim == 1, "bins are not 1-D"
    
    assert frequency.shape[bin_axis] == len(bins) - 1, "len(bins) not equal to bin-dimension of frequency"
    
    # get bin limits
    bins_ll = bins[:-1]
    bins_ul = bins[1:]
    #bins_center = (bins_ll + bins_ul)/2.
    
    # make sure first axis is along bins
    fqcy = frequency.view().swapaxes(bin_axis,0)
    
    # make sure first axis is along bins
    if time_axis != None:
        pctl = percentile.view().swapaxes(time_axis,0)
    
    # calc cumsum
    cum = np.cumsum(fqcy, axis = 0)
    
    # get percentage of bins (lower and upper limit)
    perc_ul = cum/cum[-1,...]*100.
    perc_ll = (cum - fqcy)/cum[-1,...]*100.
    
    # set bins with 0 only to nan
    not_all_null = ~((perc_ll==0) & (perc_ul==0))
    perc_ul = np.where(not_all_null, perc_ul, np.nan)
    perc_ll = np.where(not_all_null, perc_ll, np.nan)
    perc_ll[perc_ll==0] = -1
    
    del not_all_null
    
    
    def bin_at_percentile(bins_ll, bins_ul, perc_ll, perc_ul, pctl, cnt=None):
        if cnt != None:
            print(cnt)
        idx2 = np.meshgrid(*[np.arange(s) for s in pctl.shape], indexing ="ij")[1:]
        
        # get last value, where pctl > perc_ll, thus argmax, reversed array and calc original index
        idx_l_tmp = perc_ll.shape[0] - np.argmax(np.flip(np.expand_dims(perc_ll,1) < np.expand_dims(pctl,0),0),axis=0) - 1
        # get first value, where pctl > perc_ll, in case of multiple occurances
        idx_l = np.argmax(np.expand_dims(perc_ll,1) == np.expand_dims(perc_ll[(np.squeeze(idx_l_tmp), *idx2,)],0),axis=0)
        # get first value, where pctl <= perc_ul
        idx_u = np.argmax(np.expand_dims(perc_ul,1) >= np.expand_dims(pctl,0),axis=0)
        
        perc_ll2 = perc_ll.copy()
        perc_ll2[perc_ll2 < 0] = 0
        
        perc_ul2 = perc_ul.copy()
        perc_ul2[perc_ul2 < 0] = 0

        
        score_out = pnc.linear_interpolation(perc_ll2[(np.squeeze(idx_l), *idx2,)],
                                        perc_ul2[(np.squeeze(idx_u), *idx2,)],bins_ll[idx_l],
                                        bins_ul[idx_u],
                                        pctl)
        
        return score_out
    
    
    if time_axis != None:
        if split_time == None:
            split_time = 1
        
        if parallel:
            parallel_input = [(bins_ll, bins_ul, perc_ll, perc_ul, pctl[ind_time, ...], cnt) for cnt, ind_time in enumerate(np.array_split(np.arange(pctl.shape[0]), split_time, axis = 0))]
            score_list = Parallel(n_jobs=n_workers, verbose=1)(delayed(bin_at_percentile)(*args) for args in parallel_input)
            score_out = np.concatenate(score_list, axis=0)
        else:
            # initialize output
            score_out = np.zeros(pctl.shape)
            for ind_time in np.array_split(np.arange(pctl.shape[0]), split_time, axis = 0):
                print(ind_time)
                score_out[ind_time, ...] = bin_at_percentile(bins_ll, bins_ul, perc_ll, perc_ul, pctl[ind_time, ...])
                
        
        score_out = score_out.view().swapaxes(time_axis, 0)
        
        return score_out
    
    elif time_axis == None:
        score_out = bin_at_percentile(bins_ll, bins_ul, perc_ll, perc_ul, pctl)
    
        return score_out


def score_at_percentile_nc2(freq_ds, freq_var, perc_ds, perc_var, score_var = "score", fname_score = None, fname_powercurve = None, bindim = "bin", time_axis="time", split_time = 1, n_workers = 8, parallel = True):
    # TODO Doku
    # open infile
    if (fname_score != None) & (os.path.isfile(fname_score)):
        print("Outfile exists", fname_score)
        return 1
    
    if type(freq_ds) != xr.Dataset:
        freq_ds = xr.open_dataset(freq_ds)
        close_freq = True
    else:
        close_freq = False
        
    if type(perc_ds) != xr.Dataset:
        perc_ds = xr.open_dataset(perc_ds)
        close_perc = True
    else:
        close_freq = False
        
    print("get frequency")
    freq, bins, pos_bindim = pnc.frequency_bins_from_netcdf(freq_ds, bindim = bindim)
    
    # make sure percentiles are an array
    print("read percentiles")
    percentile = perc_ds[perc_var].data
    
    time_pos = np.argmax(np.array(perc_ds[perc_var].dims) == time_axis)
    
    bin_pos = np.argmax(np.array(freq_ds[freq_var].dims) == bindim)
    
    # calculate scores percentile
    print("calc score")
    score_out = score_at_percentile2(freq, percentile, bins, time_axis = time_pos, bin_axis = bin_pos, split_time = split_time, n_workers = n_workers, parallel = parallel)
    
    score_arr = xr.DataArray(data = score_out, coords = perc_ds[perc_var].coords)
    score_ds = xr.Dataset()
    score_ds[score_var] = score_arr
    
    if close_freq:
        freq_ds.close()
        
    if close_perc:
        perc_ds.close()
    
    if fname_score is not None:
        # write outfile    
        score_ds.to_netcdf(fname_score)
        score_ds.close()
    else:    
        return score_ds
    
if __name__ == "__main__":
    
    
    freq_ds = xr.open_dataset("/path/to/Frequency/COSMO_REA6/windspeed/COSMO_REA6_WS150m_frequency_all.nc4")
    perc_files = sorted(glob.glob("/path/to/percentiles/ERA5/ERA5_Land_*_Wspeed_hourly_percentiles.nc"))
    
    out_dir = "/path/to/ERA5/ERA5_wspd_150m/"
    
    
    split_time = 8784/8
    
    parallel_perc_input = [((freq_ds, "wind_speed_freq", perc_ds, "percentile"), dict(fname_score = out_dir + "ERA5_wspd_150m_" + perc_ds.split("_")[-4] + ".nc" , split_time = split_time, parallel = True, n_workers = 16),) for perc_ds in perc_files]
    
    Parallel(n_jobs=1, verbose=5)(delayed(score_at_percentile_nc2)(*args[0], **args[1]) for args in parallel_perc_input)


