#!/usr/bin/env python3

import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
import glob
import datetime as dt

def linear_interpolation(x1, x2, y1, y2, xe):
    '''
    Interplolate linearly to get percentile or value between bins.
    Parameters:
        x1: Array of lower x-Values (predictor variable)
        x2: Array of upper x-Values (predictor variable)
        y1: Array of lower y-Values (target variable)
        y2: Array of upper y-Values (target variable)
        xe: Array or scalar of x-value to estimate y-value from
    Returns
        ye: Array or  estimated y-values as function of xe
    '''
    ye = y1 + (y2 - y1)/(x2 - x1)*(xe - x1)
    return ye


def get_frequency(bins, data, frequency = None, axis = 0, verbose = False):
    '''
    Get frequency of occurances of values in certain bins along an axis (usually time)
    
    Parameters:
        bins: list or array of bin boundaries (strictly increasinig)
        data: array of data to be binned
        frequency: array of frequency of occurances of values in other array. This is the starting value.
            Default is zero. Useful to read in multiple datasets serially.
        axis: along wich axis the frequency is counted (default: 0)
        verbose: verbose mode (default: False)
        
    Returns
        frequency_out: Array with frequency of occurances in bin. The shape is like data, but the dimension of axis along frequency is counted is len(bins) - 1.
    '''
    
    # convert bins to array
    bins = np.array(bins)
    
    # check bin dimension
    if bins.ndim != 1:
        raise ValueError("Error: bins must be of dimension 1, but are of dimension {}!".format(bins.ndim))
    
    # check if bins are strictly increasing
    if np.any(np.diff(bins)<=0):
        raise ValueError("Error: bins must be stricly increasing!")
    
    # get dimensions of frequency array
    dims_out = list(data.shape)
    dims_out[axis] = len(bins) - 1
    dims_out = tuple(dims_out)
    
    # initialize frequency array
    if frequency is None:
       frequency = np.zeros(dims_out, dtype=int)
    
    # check dimensions of frequency array
    if frequency.shape != dims_out:
        raise ValueError("Error: frequency requires shape {}, but is of shape {}!".format(dims_out, frequency.shape))
    
    # in case the axes along the frequency is counted is not the first swap axes
    frequency = frequency.swapaxes(0, axis)
    
    # get frequency for each bin
    for ib,bb in enumerate(bins[:-1]):
        if verbose:
            print("Current bin:", bins[ib], bins[ib+1] , end="\r")
            
        nb = np.sum((data.swapaxes(0, axis) >= bins[ib]) & (data.swapaxes(0, axis) < bins[ib+1]), axis = 0)
        frequency[ib,...] = frequency[ib,...] + nb
    
    # return frequency
    return frequency.swapaxes(0, axis)


def get_frequency_multi_netcdf(flist, bins, variable, fname_out = None, aggdim = "time", verbose = True):
    '''
    Get frequency of occurances of values in certain bins for multiple netCDF files.
    
    Parameters:
        flist: list of input netCDF files
        bins: list or array of bin boundaries (strictly increasinig)
        variable: variable in netCDF files to be binned
        fname_out: name of output netCDF file (default: no file is saved, but xarray dataset is returned)
        aggdim: dimension along wich the frequency is counted (default: "time")
        verbose: verbose Mode (default: True)
        
    Returns
        xarray dataset, which containes all dimensions and variables, except those related to aggdim. The data variable variable contains the number of .  New dimension bin, and coordiante variables bin_ll and bin_ul are
            included. 
        
    '''
    bins = np.array(bins)
    
    if bins.ndim != 1:
        raise ValueError("Error: bins must be of dimension 1, but are of dimension {}!".format(bins.ndim))
    
    if np.any(np.diff(bins)<=0):
        raise ValueError("Error: bins must be stricly increasing!")
    
    bins_ll = bins[:-1]
    bins_ul = bins[1:]
    
    t0 = dt.datetime.now()
    
    for ii,ff in enumerate(flist):
        if verbose:
            print("Processing:", ff, "Time elapsed:", dt.datetime.now() - t0)
            
        with xr.open_dataset(ff) as dat:
            # on first file create outfile
            if ii == 0:
                if verbose:
                    print("Creating Outfile {} ...".format(fname_out))
                    
                pos_aggdim = np.nonzero((np.array(dat[variable].dims) == aggdim))[0][0]
                
                dim_out_names = list(dat[variable].dims)
                dim_out_names[pos_aggdim] = "bin"
                
                dims_out = list(dat[variable].shape)
                dims_out[pos_aggdim] = len(bins_ll)
                dims_out = tuple(dims_out)
                
                data_vars_out = {variable + "_freq": (dim_out_names, np.zeros(dims_out, dtype=int), dat[variable].attrs)}
                
                data_vars_out.update({vv: (dat.variables[vv].dims, dat.variables[vv].values, dat.variables[vv].attrs) for vv in dat.data_vars if aggdim not in dat.variables[vv].dims})
                
                coords_out = {"bin":("bin",np.arange(len(bins_ll))),
                              "bin_ll":("bin", bins_ll, {"long_name": "lower limit of bins (inclusive)"}),
                              "bin_ul":("bin", bins_ul, {"long_name": "upper limit of bins (not inclusive)"})
                              }
                
                coords_out.update({dd:dat[dd] for dd in dat.dims if dd != aggdim})

                
                data_out = xr.Dataset(data_vars = data_vars_out,
                                      coords = coords_out)
                
            # get frequency for bins    
            freq_tmp = get_frequency(bins, dat[variable].values, data_out[variable + "_freq"].values, axis = pos_aggdim, verbose = verbose)
            data_out[variable + "_freq"].values = freq_tmp
    
    if fname_out is not None:
        # write outfile    
        data_out.to_netcdf(fname_out)
    
    return data_out


def percentile_of_score(frequency, bins, score, axis = 0):    
    # check bin dimension
    if bins.ndim != 1:
        raise ValueError("Error: bins must be of dimension 1, but are of dimension {}!".format(bins.ndim))
    
    # check if bins are strictly increasing
    if np.any(np.diff(bins)<=0):
        raise ValueError("Error: bins must be stricly increasing!")
        
    # check values of scores
    if np.any(score < bins[0]) | np.any(score > bins[-1]):
        raise ValueError("Error: Scores must be in the interval [{},{}]!".format(bins[0], bins[-1]))
    
    score = np.array([score]).ravel()
    
    bins_ll = bins[:-1]
    bins_ul = bins[1:]
#    bins_center = (bins_ll + bins_ul)/2.
    
    fqcy = frequency.view().swapaxes(axis,0)
    
    cum = np.cumsum(fqcy, axis = 0)
    
    perc_ul = cum/cum[-1,...]*100.
    perc_ll = (cum - fqcy)/cum[-1,...]*100.
    
    
    dims_out = list(fqcy.shape)
    dims_out[0] = len(score)
    dims_out = tuple(dims_out)
    
    perc_out = np.zeros(dims_out)
    
    for si,sc in enumerate(score):
        ind_score = np.nonzero((bins_ll <= sc) & (bins_ul >= sc))[0][0]
        
        perc_out[si, ...] = linear_interpolation(bins_ll[ind_score],
                                                 bins_ul[ind_score],
                                                 perc_ll[ind_score,...], 
                                                 perc_ul[ind_score,...],
                                                 sc)
    
    return perc_out.swapaxes(axis,0)


def score_at_percentile(frequency, bins, percentile, axis = 0):
    # check bin dimension
    if bins.ndim != 1:
        raise ValueError("Error: bins must be of dimension 1, but are of dimension {}!".format(bins.ndim))
    
    # check if bins are strictly increasing
    if np.any(np.diff(bins)<=0):
        raise ValueError("Error: bins must be stricly increasing!")
    
    # make sure percentiles are an array
    percentile = np.array([percentile]).ravel()
    
    # check values of percentiles
    if np.any(percentile < 0.0) | np.any(percentile > 100.0):
        raise ValueError("Error: Percentile must be in the interval [0.0,100.0]!")
    
    # get bin limits
    bins_ll = bins[:-1]
    bins_ul = bins[1:]
    #bins_center = (bins_ll + bins_ul)/2.
    
    # make sure first axis is along bins
    fqcy = frequency.view().swapaxes(axis,0)
    
    # calc cumsum
    cum = np.cumsum(fqcy, axis = 0)
    
    # get percentage of bins (lower and upper limit)
    perc_ul = cum/cum[-1,...]*100.
    perc_ll = (cum - fqcy)/cum[-1,...]*100.
    
    # reshape percentage
    perc_ul_r = perc_ul.view().reshape(perc_ul.shape[0],-1)
    perc_ll_r = perc_ll.view().reshape(perc_ll.shape[0],-1)
    
    # calc dimesions of output array
    dims_out = list(fqcy.shape)
    dims_out[0] = len(percentile)
    dims_out = tuple(dims_out)
    
    # initialize output
    perc_out = np.zeros(dims_out)
    
    # iterate over pecentile values
    for ip, pc in enumerate(percentile):
        # get index of relevant bin
        # ignores empty bins in the beginning
        if pc == 0.0:
            ind_perc = ((perc_ll == pc) & (perc_ul >= pc)).sum(axis = 0) - 1
        # ignores empty bins at the end
        elif pc == 100.0:
            ind_perc = ( ((perc_ll < pc) & (perc_ul == pc)) | ((perc_ll < pc) & (perc_ul < pc)) ).sum(axis = 0) - 1
        else:
            ind_perc = ( ((perc_ll < pc) & (perc_ul >= pc)) | ((perc_ll < pc) & (perc_ul < pc)) ).sum(axis = 0) - 1
        
        # reshape indeces
        ind_perc_r = ind_perc.reshape(-1)
        
        # get percentiles according to bin
        pp_l = perc_ll_r[ind_perc_r,np.arange(perc_ll_r.shape[1])].reshape(ind_perc.shape)
        pp_u = perc_ul_r[ind_perc_r,np.arange(perc_ul_r.shape[1])].reshape(ind_perc.shape)
        
        # get bin limits
        b_l = bins_ll[ind_perc]
        b_u = bins_ul[ind_perc]
        
        # interpolate between lower and upper limit of bin
        perc_out[ip, ...] = linear_interpolation(pp_l,
                                                 pp_u,
                                                 b_l,
                                                 b_u,
                                                 pc)
        
    # retrun output
    return perc_out.swapaxes(0, axis)


def score_at_percentile2(frequency, percentile, bins, time_axis = None, bin_axis = 0, split_time = 1):
    #TODO
    # check dimension match
    f_shp_tmp = list(frequency.shape)
    f_shp_tmp.pop(bin_axis)
    p_shp_tmp = list(percentile.shape)
    if time_axis != None:
        p_shp_tmp.pop(time_axis)
    print(f_shp_tmp, p_shp_tmp)
    assert f_shp_tmp == p_shp_tmp, "Spatial Dimensions, exept time or bin, of frequency and percentile respectively do not match"
    
    assert bins.ndim == 1, "bins are not 1-D"
    
    assert frequency.shape[bin_axis] == len(bins) - 1, "len(bins) not equal to bin-dimension of frequency"
    
    # get bin limits
    bins_ll = bins[:-1]
    bins_ul = bins[1:]
    #bins_center = (bins_ll + bins_ul)/2.
    
    # make sure first axis is along bins
    fqcy = frequency.view().swapaxes(bin_axis,0)
    
    # make sure first axis is along bins
    if time_axis != None:
        pctl = percentile.view().swapaxes(time_axis,0)
    
    # calc cumsum
    cum = np.cumsum(fqcy, axis = 0)
    
    # get percentage of bins (lower and upper limit)
    perc_ul = cum/cum[-1,...]*100.
    perc_ll = (cum - fqcy)/cum[-1,...]*100.
    
    # set bins with 0 only to nan
    not_all_null = ~((perc_ll==0) & (perc_ul==0))
    perc_ul = np.where(not_all_null, perc_ul, np.nan)
    perc_ll = np.where(not_all_null, perc_ll, np.nan)
    perc_ll[perc_ll==0] = -1
    
    del not_all_null
    
    
    def bin_at_percentile(bins_ll, bins_ul, perc_ll, perc_ul, pctl):
        idx2 = np.meshgrid(*[np.arange(s) for s in pctl.shape], indexing ="ij")[1:]
        
        # get last value, where pctl > perc_ll, thus argmax, reversed array and calc original index
        idx_l_tmp = perc_ll.shape[0] - np.argmax(np.flip(np.expand_dims(perc_ll,1) < np.expand_dims(pctl,0),0),axis=0) - 1
        # get first value, where pctl > perc_ll, in case of multiple occurances
        idx_l = np.argmax(np.expand_dims(perc_ll,1) == np.expand_dims(perc_ll[(np.squeeze(idx_l_tmp), *idx2,)],0),axis=0)
        # get first value, where pctl <= perc_ul
        idx_u = np.argmax(np.expand_dims(perc_ul,1) >= np.expand_dims(pctl,0),axis=0)
        
        perc_ll[perc_ll < 0] = 0
        perc_ul[perc_ul < 0] = 0
        
        score_out = linear_interpolation(perc_ll[(np.squeeze(idx_l), *idx2,)],
                                        perc_ul[(np.squeeze(idx_u), *idx2,)],bins_ll[idx_l],
                                        bins_ul[idx_u],
                                        pctl)
        
        return score_out
    
    # initialize output
    score_out = np.zeros(pctl.shape)
    
    if time_axis != None:
        if split_time == None:
            split_time = 1
        
        for ind_time in np.array_split(np.arange(pctl.shape[0]), split_time, axis = 0):
            print(ind_time)
            score_out[ind_time, ...] = bin_at_percentile(bins_ll, bins_ul, perc_ll, perc_ul, pctl[ind_time, ...])
        
        score_out = score_out.swapaxes(time_axis, 0)
        
        return score_out
    
    elif time_axis == None:
        score_out = bin_at_percentile(bins_ll, bins_ul, perc_ll, perc_ul, pctl)
    
        return score_out


def frequency_bins_from_netcdf(freq_ds, bindim = "bin"):
    '''
    Helper function to obtain frequency and bins from file written by get_frequency_multi_netcdf
    
    Parameters:
        freq_ds: file name of netCDF file or xr.Dataset
        bindim: name of bin dimension (default: "bin")
    Returns:
        frequency: array of frequencies, required by percentile_of_score and score_at_percentile
        bins: array of bins required by percentile_of_score and score_at_percentile
        pos_bindim: position of bin dimension 
    '''
    if type(freq_ds) == xr.Dataset:
        ds = freq_ds
    else:
        ds = xr.open_dataset(freq_ds)
        
    variable = [dd for dd in ds.data_vars if "bin" in ds[dd].dims][0]
    frequency = ds[variable].load().values
    pos_bindim = np.nonzero((np.array(ds[variable].dims) == bindim))[0][0]
    
    bins = np.array(list(ds["bin_ll"].values) + [ds["bin_ul"].values[-1]])
    
    ds.close()
        
    return frequency, bins, pos_bindim


def score_at_percentile_nc(freq_ds, percentile, fname_out = None, bindim = "bin"):
    # TODO Doku
    # open infile
    if type(freq_ds) == xr.Dataset:
        ds = freq_ds
    else:
        ds = xr.open_dataset(freq_ds)
        
    freq, bins, pos_bindim = frequency_bins_from_netcdf(freq_ds, bindim = bindim)
    
    # make sure percentiles are an array
    percentile = np.array([percentile]).ravel()
    
    # calculate scores percentile
    perc_out = score_at_percentile(freq, bins, percentile, axis = pos_bindim)
          
    # create outfile
    variable = [dd for dd in ds.data_vars if bindim in ds[dd].dims][0]
    
    pos_bindim = np.nonzero((np.array(ds[variable].dims) == bindim))[0][0]
    
    dim_out_names = list(ds[variable].dims)
    dim_out_names[pos_bindim] = "perc"
    
    dims_out = list(ds[variable].shape)
    dims_out[pos_bindim] = len(percentile)
    dims_out = tuple(dims_out)
    
    data_vars_out = {variable.replace("_freq", "_sco"): (dim_out_names, perc_out, ds[variable].attrs)}
    
    data_vars_out.update({vv: (ds.variables[vv].dims, ds.variables[vv].values, ds.variables[vv].attrs) 
                          for vv in ds.data_vars if bindim not in ds.variables[vv].dims})
    
    coords_out = {"perc":("perc", np.arange(len(percentile))),
                  "percentile":("perc", percentile, {"long_name":"percentile"})}
    
    coords_out.update({dd:ds[dd] for dd in ds.dims if dd != bindim})
    
    data_out = xr.Dataset(data_vars = data_vars_out,
                            coords = coords_out)

    if fname_out is not None:
        # write outfile    
        data_out.to_netcdf(fname_out)
    
    return data_out


def score_at_percentile_nc2(freq_ds, freq_var, perc_ds, perc_var, fname_out = None, bindim = "bin", time_axis="time", split_time = 1):
    # TODO Doku
    # open infile
    if type(freq_ds) != xr.Dataset:
        freq_ds = xr.open_dataset(freq_ds)
        
    if type(perc_ds) != xr.Dataset:
        perc_ds = xr.open_dataset(perc_ds)
    
    print("get frequency")
    freq, bins, pos_bindim = frequency_bins_from_netcdf(freq_ds, bindim = bindim)
    
    # make sure percentiles are an array
    print("read percentiles")
    percentile = perc_ds[perc_var].data
    
    time_pos = np.argmax(np.array(perc_ds[perc_var].dims) == time_axis)
    
    bin_pos = np.argmax(np.array(freq_ds[freq_var].dims) == bindim)
    
    # calculate scores percentile
    print("calc score")
    perc_out = score_at_percentile2(freq, percentile, bins, time_axis = time_pos, bin_axis = bin_pos, split_time = split_time)
    
    return perc_out
    
          
    ## create outfile
    #variable = [dd for dd in ds.data_vars if bindim in ds[dd].dims][0]
    
    #pos_bindim = np.nonzero((np.array(ds[variable].dims) == bindim))[0][0]
    
    #dim_out_names = list(ds[variable].dims)
    #dim_out_names[pos_bindim] = "perc"
    
    #dims_out = list(ds[variable].shape)
    #dims_out[pos_bindim] = len(percentile)
    #dims_out = tuple(dims_out)
    
    #data_vars_out = {variable.replace("_freq", "_sco"): (dim_out_names, perc_out, ds[variable].attrs)}
    
    #data_vars_out.update({vv: (ds.variables[vv].dims, ds.variables[vv].values, ds.variables[vv].attrs) 
                          #for vv in ds.data_vars if bindim not in ds.variables[vv].dims})
    
    #coords_out = {"perc":("perc", np.arange(len(percentile))),
                  #"percentile":("perc", percentile, {"long_name":"percentile"})}
    
    #coords_out.update({dd:ds[dd] for dd in ds.dims if dd != bindim})
    
    #data_out = xr.Dataset(data_vars = data_vars_out,
                            #coords = coords_out)

    #if fname_out is not None:
        ## write outfile    
        #data_out.to_netcdf(fname_out)
    
    #return data_out


def percentile_of_score_nc(freq_ds, score, fname_out = None, bindim = "bin"):
    # TODO Doku
    # open infile
    if type(freq_ds) == xr.Dataset:
        ds = freq_ds
    else:
        ds = xr.open_dataset(freq_ds)
        
    freq, bins, pos_bindim = frequency_bins_from_netcdf(freq_ds, bindim = bindim)
    
    # make sure scores are an array
    score = np.array([score]).ravel()
    
    # calculate percentile of score
    perc_out = percentile_of_score(freq, bins, score, axis = pos_bindim)
        
    # create outfile
    variable = [dd for dd in ds.data_vars if bindim in ds[dd].dims][0]
    
    pos_bindim = np.nonzero((np.array(ds[variable].dims) == bindim))[0][0]
    
    dim_out_names = list(ds[variable].dims)
    dim_out_names[pos_bindim] = "sco"
    
    dims_out = list(ds[variable].shape)
    dims_out[pos_bindim] = len(score)
    dims_out = tuple(dims_out)
    
    data_vars_out = {variable.replace("_freq", "_perc"): (dim_out_names, perc_out, ds[variable].attrs)}
    
    data_vars_out.update({vv: (ds.variables[vv].dims, ds.variables[vv].values, ds.variables[vv].attrs) 
                          for vv in ds.data_vars if bindim not in ds.variables[vv].dims})
    
    coords_out = {"sco":("sco", np.arange(len(score))),
                  "score":("sco", score, {"long_name":"score"})}
    
    coords_out.update({dd:ds[dd] for dd in ds.dims if dd != bindim})
    
    data_out = xr.Dataset(data_vars = data_vars_out,
                            coords = coords_out)

    if fname_out is not None:
        # write outfile    
        data_out.to_netcdf(fname_out)
    
    return data_out



    
        
if __name__ == "__main__":
    
    # file list of netcdf files to read in -> Spartacus precip 1981 - 1983
    flist = sorted(glob.glob("/nas4/Observations/Spartakus/RR_1961-2015/Prec_??8[1-3].nc"))
    print("Files to read:", flist)
    
    # set bins
    low_lim = 0.     # lower limit (inclusive) of bins
    high_lim = 300.  # upper limit (exclusive) of bins
    interval = 1.    # interval of bins
    
    # calculate the bin count
    ds_out = get_frequency_multi_netcdf(flist, np.arange(low_lim, high_lim, interval), "pr", 
                                        "/tmp/Bin_Count_Spartacus.nc", aggdim = "time", verbose = True)
    
    # get scores at percentiles
    sop = score_at_percentile_nc("/tmp/Bin_Count_Spartacus.nc",
                                 [50., 70., 80., 90., 95., 99., 99.9, 99.95])
    
    # plot it
    for pp in range(8):
        plt.figure()
        sop.pr_sco[pp].plot()
        plt.show()
    
    
    # get percentiles of scores
    pos = percentile_of_score_nc("/tmp/Bin_Count_Spartacus.nc",
                                 [1., 10., 50., 75., 100])
    
    # plot it
    for pp in range(5):
        plt.figure()
        pos.pr_perc[pp].plot()
        plt.show()





