#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Jan 11 10:53:05 2022

@author: mphilipp
"""

import pandas as pd
import xarray as xr
import numpy as np
import glob

# Path information
bias_corrected_path =  "/path/to/bias/corrected/data/"
disaggregated_path = "/path/to/temporaly/disaggregated/windspeed/"
wspd_fractions_path =  "/path/to/windspeed/fractions/" 

def apply_rolling_window(wspd_hourly, iy,ix):
    '''
    Applies rolling window during day changes to ensure smooth transitions.

    Parameters
    ----------
    wspd_hourly : DataSet
        Hourly Windspeed.
    iy : int
        Index of lat.
    ix : int
        Index of lon.

    Returns
    -------
    iy : int
        Index of lat.
    ix : int
        Index of lon.
    wspd_hourly : DataSet
        Hourly Radiation with applied rolling window.

    '''
    
    wspd_hourly_rolling =np.array(pd.DataFrame(np.array(wspd_hourly).flatten()).rolling(5, center=True).mean()).flatten()
    wspd_hourly = np.where(np.isnan(wspd_hourly_rolling),np.array(wspd_hourly).flatten(),wspd_hourly_rolling)
    return iy,ix,wspd_hourly
    
def disaggegrate_windspeed(wspd_daily,frac):
    '''
    This functions generates hourly windspeed by multiplying the daily windspeed with the fractions calculated with 
    1_Calculate_WspdFractions_WholeDomain.py
 
    Parameters
    ----------
    wspd_daily : DataSet
        Daily Windspeed.
    frac : DataSet.
        Fractions with which to multiply.
 
    Yields
    ------
    wspd_daily : DataSet
        Hourly Windspeed.
 
    '''
  
    # create hourly time as well as empty datasets with the dimensions hourly_time x lat x lon and some metadata
    date_start = pd.to_datetime(wspd_daily.time.values[0]).normalize()
    date_end = pd.to_datetime(wspd_daily.time.values[-1]).normalize()
    time_hourly = pd.date_range(start=date_start, end=pd.Timestamp(date_end)+pd.DateOffset(1), freq= "H")[:-1]
    
    wspd_hourly = xr.DataArray(data=np.zeros((len(time_hourly),len(wspd_daily.lat),len(wspd_daily.lon)),dtype=np.float32),
                             dims=["time", "lat", "lon"],
                             coords=dict(lon=wspd_daily.lon,lat=wspd_daily.lat,time=time_hourly,),
                             attrs=dict(description="Hourly Disaggregated Windspeed",units="ms-1"))
    wspd_hourly = wspd_hourly.rename("sfcWind")
    
    wspd_hourly.values = np.repeat(wspd_daily.values,24,axis =0)

    # Applies rolling mean to ensure smooth transitions
    for iy, ix in  np.ndindex(wspd_hourly[0,:,:].shape):
        y,x,wspd_hourly[:,iy,ix]= apply_rolling_window(wspd_hourly[:,iy,ix],iy,ix)
    
    # Multiply with fractions
    wspd_hourly = wspd_hourly*frac.values
    
    return wspd_hourly
    
def generate_disaggregated__wspd_CORDEX_ncs():
    '''
    This function generates disaggregated CORDEX  wspd ncs of the years 1951-2100. 

    Parameters
    ----------
   
    Returns
    -------
    None.

    '''
    
    # Fill Nans with Fill Value und save
    def replace_Nans_and_save(DS,parameter, nc_metadata):
        _FillValue = 1e+20
        
        # loop over every month and Gets month as string
        for group in DS.groupby("time.month"):
            if group[0] < 10:
                str_month = "0"+str(group[0])
            else:
                str_month = str(group[0])
        
            # Gets month as string
            DS_month = group[1]
            DS_month = DS_month.assign_attrs({"cell_methods": "time: disaggregated"})
            DS_month = DS_month.to_dataset()
            DS_month = DS_month.assign_attrs(nc_metadata)
            # DS_month = DS_month.fillna(_FillValue)
            
            # Encodes Fill Value and Saves
            encoding  = {}
            for cname in DS_month.coords:
                encoding[cname] = {"_FillValue": None}
            encoding['time'] = {"_FillValue": None}
            encoding['lat']= {"_FillValue": None}
            encoding['lon']= {"_FillValue": None}
            encoding[parameter]={'_FillValue': _FillValue}
            DS_month.to_netcdf(disaggregated_path+model+"_"+str(year)+"_"+str_month+"_"+parameter+"_hourly.nc",unlimited_dims='time', encoding=encoding) 
           
        
    import glob
    models = ["ICHEC-EC-EARTH_KNMI-RACMO22E_rcp45","ICHEC-EC-EARTH_KNMI-RACMO22E_rcp85"]
    
    # Open fractions nc file
    frac = xr.open_dataset(wspd_fractions_path+"wspd_fractions_full_domain.nc")

    try:
        frac=frac['__xarray_dataarray_variable__'].rename('frac')
    except:
        pass
    frac=frac.frac
    
    # Creates two fractions, one for normal and one for leap year
    frac_leap_year = frac.stack(time=("dayofyear","hour")).transpose("time","lat","lon")
    frac_normal_year = frac.where(frac.dayofyear!=60,drop = True).stack(time=("dayofyear","hour")).transpose("time","lat","lon")
      
    # Loops over models
    for model in models:  
        
        nc_metadata = {'Conventions': 'CF-1.4', 'contact': 'philipp.maier@boku.ac.at', 'experiment': model, 'realization': '1', 'driving_experiment': 'ICHEC-EC-EARTH,r12i1p1', 'driving_model_id': 'ICHEC-EC-EARTH', 'driving_model_ensemble_member': 'r12i1p1', 'model_id': 'KNMI-RACMO22E', 'rcm_version_id': 'v1','project_id': 'CORDEX', 'CORDEX_domain': 'EUR-11', 'frequency': 'hour'}
        
        Start_year = 1951
        End_year = 2100
        file_name_model =model
        
        # Loops over decades (Output files of the Bias Correction)
        for decade in range(Start_year,End_year+1,10):
            if decade == 1961 or decade == 2091:
                # Open data and the decade and save metadata
                wspd = xr.open_dataset(glob.glob(bias_corrected_path+"*sfcWind*"+file_name_model+"*"+str(decade-10)+"*.nc")[0])["sfcWind"]
            
            else:
                # Open data and the decade and save metadata
                wspd = xr.open_dataset(glob.glob(bias_corrected_path+"*sfcWind*"+file_name_model+"*"+str(decade)+"*.nc")[0])["sfcWind"]
                
            for year in range(decade,decade+10,1):
                # Select year
                wspd_yearly = wspd.sel(time = str(year))
                
                # disaggregate
                if len(wspd_yearly.time.values) ==366:
                  wspd_hourly = disaggegrate_windspeed(wspd_yearly, frac_leap_year)
                else:
                    wspd_hourly = disaggegrate_windspeed(wspd_yearly, frac_normal_year)
               
                # Add Metadata and save           
                wspd_hourly = wspd_hourly.assign_attrs(wspd.attrs)
                replace_Nans_and_save(wspd_hourly, "sfcWind",nc_metadata)   
             
if __name__ == "__main__":
    generate_disaggregated__wspd_CORDEX_ncs()
    
