#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Jan 11 10:53:05 2022

@author: mphilipp
"""

import pandas as pd
import xarray as xr
import numpy as np
import glob

# Path information
bias_corrected_path = "/path/to/bias/corrected/data/"
disaggregated_path = "/path/to/temporaly/disaggregated/radiaton/"
rsds_fractions_path = "/path/to/radiation/fractions/"

def apply_rolling_window(rsds_hourly, iy,ix):
    '''
    Applies rolling window during day changes to ensure smooth transitions during polar Days.

    Parameters
    ----------
    rsds_hourly : DataSet
        Hourly Radiation.
    iy : int
        Index of lat.
    ix : int
        Index of lon.

    Returns
    -------
    iy : int
        Index of lat.
    ix : int
        Index of lon.
    rsds_hourly : DataSet
        Hourly Radiation with applied rolling window.

    '''
    
    rsds_hourly_rolling =np.array(pd.DataFrame(np.array(rsds_hourly).flatten()).rolling(5, center=True).mean()).flatten()
    rsds_hourly = np.where(np.isnan(rsds_hourly_rolling),np.array(rsds_hourly).flatten(),rsds_hourly_rolling)
    return iy,ix,rsds_hourly
    
def disaggegrate_radiation(rsds_daily,frac):
    '''
    This functions generates hourly radiation by multiplying the daily radiation with the fractions calculated with 
    1_Calculate_GlobFractions_WholeDomain.py

    Parameters
    ----------
    rsds_daily : DataSet
        Daily Radiation.
    frac : DataSet.
        Fractions with which to multiply.

    Yields
    ------
    rsds_hourly : DataSet
        Hourly Radiation.

    '''
    
    # create hourly time as well as empty datasets with the dimensions hourly_time x lat x lon and some metadata
    date_start = pd.to_datetime(rsds_daily.time.values[0]).normalize()
    date_end = pd.to_datetime(rsds_daily.time.values[-1]).normalize()
    time_hourly = pd.date_range(start=date_start, end=pd.Timestamp(date_end)+pd.DateOffset(1), freq= "H")[:-1] # radiation has one additional entry
    
    rsds_hourly = xr.DataArray(data=np.zeros((len(time_hourly),len(rsds_daily.lat),len(rsds_daily.lon)),dtype=np.float32),
                             dims=["time", "lat", "lon"],
                             coords=dict(lon=rsds_daily.lon,lat=rsds_daily.lat,time=time_hourly,),
                             attrs=dict(description="Hourly Disaggregated Gloabl Radiation",units="Wm-2"))
    rsds_hourly = rsds_hourly.rename("rsds")
    
    # Repeats daily values 24 times
    rsds_hourly.values = np.repeat(rsds_daily.values,24,axis =0)

    # Applies rolling mean to ensure smooth transitions during polar days
    for iy, ix in  np.ndindex(rsds_hourly[0,:,:].shape):
        y,x,rsds_hourly[:,iy,ix]= apply_rolling_window(rsds_hourly[:,iy,ix],iy,ix)
   
    # rsds_hourly.values = rsds_hourly.values*frac.values # Did yield internal bug, had to loop
    
    # Multiply with fractions
    for i in  range(len(rsds_hourly[:,0,0])):
        rsds_hourly[i,:,:] = rsds_hourly[i,:,:]*frac[i,:,:]
    
    # Ensure, that no negative values occur
    rsds_hourly.values = np.where(rsds_hourly<0,0,rsds_hourly)
    return rsds_hourly
    
def generate_disaggregated_rsds_CORDEX_ncs():
    '''
    This function generates disaggregated CORDEX  radiation ncs of the years 1951-2100. 

    Parameters
    ----------
   
    Returns
    -------
    None.

    '''
    
    # Fill Nans with Fill Value und save
    def replace_Nans_and_save(DS,parameter, nc_metadata):
        _FillValue = 1e+20
        
        # loop over every month and Gets month as string
        for group in DS.groupby("time.month"):
            if group[0] < 10:
                str_month = "0"+str(group[0])
            else:
                str_month = str(group[0])
        
            # Adds metadata
            DS_month = group[1]
            DS_month = DS_month.assign_attrs({"cell_methods": "time: disaggregated"})
            DS_month = DS_month.to_dataset()
            DS_month = DS_month.assign_attrs(nc_metadata)
            # DS_month = DS_month.fillna(_FillValue)
            
            # Encodes Fill Value and Saves
            encoding  = {}
            for cname in DS_month.coords:
                encoding[cname] = {"_FillValue": None}
            encoding['time'] = {"_FillValue": None}
            encoding['lat']= {"_FillValue": None}
            encoding['lon']= {"_FillValue": None}
            encoding[parameter]={'_FillValue': _FillValue}
            DS_month.to_netcdf(disaggregated_path+model+"_"+str(year)+"_"+str_month+"_"+parameter+"_hourly.nc",unlimited_dims='time', encoding=encoding) 
           
        
    import glob
    models = ["ICHEC-EC-EARTH_KNMI-RACMO22E_rcp45","ICHEC-EC-EARTH_KNMI-RACMO22E_rcp85"]
    
    # Open fractions nc file
    frac = xr.open_dataset(rsds_fractions_path+"rsds_fractions_full_domain.nc")

    try:
        frac=frac['__xarray_dataarray_variable__'].rename('frac_rsds')
    except:
        pass
    frac=frac.frac_rsds
    
    # Creates two fractions, one for normal and one for leap year
    frac_leap_year = frac.stack(time=("dayofyear","hour")).transpose("time","lat","lon")
    frac_normal_year = frac.where(frac.dayofyear!=60,drop = True).stack(time=("dayofyear","hour")).transpose("time","lat","lon")
      
    # Loops over models
    for model in models:  
        
        nc_metadata = {'Conventions': 'CF-1.4', 'contact': 'philipp.maier@boku.ac.at', 'experiment': model, 'realization': '1', 'driving_experiment': 'ICHEC-EC-EARTH,r12i1p1', 'driving_model_id': 'ICHEC-EC-EARTH', 'driving_model_ensemble_member': 'r12i1p1', 'model_id': 'KNMI-RACMO22E', 'rcm_version_id': 'v1','project_id': 'CORDEX', 'CORDEX_domain': 'EUR-11', 'frequency': 'hour'}
        
        Start_year = 1951
        End_year = 2100
        file_name_model =model
        
        # Loops over decades (Output files of the Bias Correction)
        for decade in range(Start_year,End_year+1,10):
            if decade == 1961 or decade == 2091:
                rsds = xr.open_dataset(glob.glob(bias_corrected_path+"*rsds*"+file_name_model+"*"+str(decade-10)+"*.nc")[0])["rsds"]
            else:
                rsds = xr.open_dataset(glob.glob(bias_corrected_path+"*rsds*"+file_name_model+"*"+str(decade)+"*.nc")[0])["rsds"]
                
            for year in range(decade,decade+10,1):
                # Select year
                rsds_yearly = rsds.sel(time = str(year))
                
                # disaggregate
                if len(rsds_yearly.time.values) ==366:
                    rsds_hourly = disaggegrate_radiation(rsds_yearly, frac_leap_year)
                else:
                    rsds_hourly = disaggegrate_radiation(rsds_yearly, frac_normal_year)
               
                # Assign attributes and save
                rsds_hourly = rsds_hourly.assign_attrs(rsds.attrs)
                replace_Nans_and_save(rsds_hourly, "rsds",nc_metadata) 
            
if __name__ == "__main__":
    generate_disaggregated_rsds_CORDEX_ncs()
    
