#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Jan 11 10:53:05 2022

@author: mphilipp
"""

import pandas as pd
import xarray as xr
import numpy as np
import glob

# Path information
bias_corrected_path = "/path/to/bias/corrected/data/"
disaggregated_path = "/path/to/temporaly/disaggregated/temperature/"
sun_times_file = "/path/to/output/sun_times.nc"


def disaggegrate_temperature_array_like(temp_daily_max_total,temp_daily_min_total,sun_times):
    '''
     This funcion takes tmin and tmax of one year and hourly disaggregates it using the sun times created with 1_Calculate_SunTimes_WholeDomain.py.
    
     Parameters
     ----------
     temp_daily_max_total : DataSet
         Daily CORDEX tmax.
     temp_daily_min_total : DataSet
         Daily CORDEX tmin.
     sun_times : DataSet
         Sun times, calculated with 1_Calculate_SunTimes_WholeDomain.py.
    
     Returns
     -------
     temp_hourly : DataSet
         Hourly disaggregated temperature.
    
     '''

    # create hourly time as well as empty datasets with the dimensions hourly_time x lat x lon 
    date_start = pd.to_datetime(temp_daily_max_total.time.values[0]).normalize()
    date_end = pd.to_datetime(temp_daily_max_total.time.values[-1]).normalize()
    time_hourly = pd.date_range(start=date_start, end=pd.Timestamp(date_end)+pd.DateOffset(1), freq= "H")[:-1]
    
    temp_hourly = xr.DataArray(data=np.zeros((len(time_hourly),len(temp_daily_max_total.lat.values),len(temp_daily_max_total.lon.values)),dtype=np.float32),
                                 dims=["time", "lat", "lon"],
                                 coords=dict(lon=temp_daily_max_total.lon.values,lat=temp_daily_max_total.lat.values,time=time_hourly,),
                                 attrs=dict(description="Hourly Disaggregated Temperature",units="K"))
    temp_hourly = temp_hourly.rename("tas")
    
    hours_per_day = 24 # name says it all
    default_shift_hours = 2 # shift of temp max after sunnoon
    daylength_thres = 3 # threshold for polar nights
    # min / max hour during polar night assumption
    min_loc_polar = 6
    max_loc_polar = 18
    
    # get rid of 29 of feb if no leap year, otherwise use sunrise, -noon and daylength
    # then delete sun_times to save memory
    if(len(temp_daily_max_total.time.values)==365):
        sunrise_full = sun_times.sunrise.where(sun_times.sunrise.dayofyear !=60, drop = True)
        sunrise_full = sunrise_full.assign_coords({"dayofyear": np.arange(1,366)})
        sunnoon_full = sun_times.sunnoon.where(sun_times.sunnoon.dayofyear !=60, drop = True)
        sunnoon_full = sunnoon_full.assign_coords({"dayofyear": np.arange(1,366)})
        daylength_full = sun_times.daylength.where(sun_times.daylength.dayofyear !=60, drop = True)
        daylength_full = daylength_full.assign_coords({"dayofyear": np.arange(1,366)})
    else:
        sunrise_full =  sun_times.sunrise
        sunnoon_full =  sun_times.sunnoon
        daylength_full = sun_times.daylength
        
    del sun_times
    
    # Treat Polar Days, set minimum to 01:00    
    sunrise_full.values = np.where(sunnoon_full < sunrise_full,1,sunrise_full.values)
    # Treat Polar Nights, set Minimum and maximum to min_loc_polar and max_loc_polar (- default because it is added again afterwards)
    sunrise_full.values = np.where(daylength_full < daylength_thres,min_loc_polar,sunrise_full.values)
    sunnoon_full.values = np.where(daylength_full < daylength_thres,max_loc_polar-default_shift_hours,sunnoon_full.values)
       
    # loop over month and use tmax and tmin of these months - consequences are monthly continuity breaks
    for tmax_group in temp_daily_max_total.groupby("time.month"):
        temp_daily_max = tmax_group[1]
        for tmin_group in temp_daily_min_total.groupby("time.month"):
            if tmax_group[0]== tmin_group[0]:
                temp_daily_min = tmin_group[1]
                break
        
        # select month out of hourly array
        if pd.to_datetime(temp_daily_min.time[0].values).normalize().month ==12:
            temp_monthly =temp_hourly.sel(time=slice(pd.to_datetime(temp_daily_min.time[0].values).normalize(),pd.to_datetime(temp_daily_min.time[-1].values).normalize()+pd.DateOffset(1)))
        else:
            temp_monthly =temp_hourly.sel(time=slice(pd.to_datetime(temp_daily_min.time[0].values).normalize(),pd.to_datetime(temp_daily_min.time[-1].values).normalize()+pd.DateOffset(1)))[:-1]
        
        # select month out of sun variables 
        sunrise = sunrise_full.sel(dayofyear = slice(pd.to_datetime(temp_daily_max.time[0].values).timetuple().tm_yday,pd.to_datetime(temp_daily_max.time[-1].values).timetuple().tm_yday))
        sunnoon = sunnoon_full.sel(dayofyear = slice(pd.to_datetime(temp_daily_max.time[0].values).timetuple().tm_yday,pd.to_datetime(temp_daily_max.time[-1].values).timetuple().tm_yday))
        daylength = daylength_full.sel(dayofyear = slice(pd.to_datetime(temp_daily_max.time[0].values).timetuple().tm_yday,pd.to_datetime(temp_daily_max.time[-1].values).timetuple().tm_yday))
        
        ##### DISAGGREGATION ROUTINE #####
        ##################################
        
        # Save rounded times of tmin and tmax 
        min_loc = sunrise.round().astype(int)
        max_loc = sunnoon.round().astype(int) + default_shift_hours
          
        # Save tmin and tmax as current values
        min_val_cur = temp_daily_min.copy()
        max_val_cur = temp_daily_max.copy()
        
        # Shift tmin and tmax by 1 along time axis
        min_val_next = temp_daily_min.copy()
        min_val_next[:-1] = min_val_cur[1:].values
        min_val_next[-1]=  min_val_cur[-1].values
        
        max_val_next =  temp_daily_max.copy()
        max_val_next[:-1] = max_val_cur[1:].values
        max_val_next[-1]=  max_val_cur[-1].values
        
        # Shift tmin and tmax by -1 along time axis
        min_val_before = temp_daily_min.copy()
        min_val_before[1:] = min_val_cur[:-1].values
        min_val_before[0]=  min_val_cur[0].values
        
        max_val_before =  temp_daily_max.copy()
        max_val_before[1:] = max_val_cur[:-1].values
        max_val_before[0]=  max_val_cur[0].values
            
        # Create hourly versions of all the arrays
        min_loc_hourly = np.repeat(min_loc.values,24,axis =0)
        max_loc_hourly = np.repeat(max_loc.values,24,axis =0)
        min_val_cur_hourly = np.repeat(min_val_cur.values,24,axis =0)
        max_val_cur_hourly = np.repeat(max_val_cur.values,24,axis =0)
        min_val_before_hourly = np.repeat(min_val_before.values,24,axis =0)
        max_val_before_hourly = np.repeat(max_val_before.values,24,axis =0)
        min_val_next_hourly = np.repeat(min_val_next.values,24,axis =0)
        max_val_next_hourly = np.repeat(max_val_next.values,24,axis =0)
        
        # Create  array which just has 0 to 23 in it  
        hourly_array = min_loc_hourly.copy()
        for time in range(len(hourly_array[:,0,0])):
            hourly_array[time,:,:] = time %24
            
        # Locate hours before max and fill with current value 
        min_val = min_val_next_hourly.copy()
        min_val[hourly_array < max_loc_hourly] = min_val_cur_hourly[hourly_array < max_loc_hourly]
        
        # Locate hours before min and fill with values before
        max_val = max_val_cur_hourly.copy()
        max_val[hourly_array < min_loc_hourly] = max_val_before_hourly[hourly_array < min_loc_hourly]
        
        # Create calculation variables 
        delta_val = max_val - min_val
        v_trans = min_val + delta_val / 2.
        
        # Locate hours before min, between min and max and after max 
        before_min = hourly_array <= min_loc_hourly
        between_min_max = (hourly_array > min_loc_hourly) & (hourly_array < max_loc_hourly)
        after_max = hourly_array >= max_loc_hourly
        
        # Disaggregate
        temp_monthly.values    = temp_monthly.values+ before_min * (v_trans + delta_val/ 2. * np.cos(np.pi / (hours_per_day - (max_loc_hourly - min_loc_hourly)) * (hours_per_day - max_loc_hourly + hourly_array)))
        temp_monthly.values = temp_monthly.values+ between_min_max * (v_trans + delta_val / 2. * np.cos(1.25 * np.pi + 0.75 * np.pi / (max_loc_hourly - min_loc_hourly) * (hourly_array - min_loc_hourly)))
        temp_monthly.values      =  temp_monthly.values+ after_max * (v_trans+ delta_val/ 2. * np.cos(np.pi / (hours_per_day - (max_loc_hourly - min_loc_hourly)) * (hourly_array - max_loc_hourly)))
       
        # Manually Set Max and min to prevent rounding errors
        temp_monthly.values = np.where( hourly_array == min_loc_hourly, min_val_cur_hourly, temp_monthly.values)
        temp_monthly.values = np.where( hourly_array == max_loc_hourly, max_val_cur_hourly, temp_monthly.values)
        
        ########### OWN POLAR CASE ROUTINE ##############
        # # not needed because I just took the routine above and replaced the max and min sun hours tp 6 and 18h
        # polar_nights = np.repeat(daylength < daylength_thres,24,axis=0)
        # min_locs = polar_nights & (hourly_array == min_loc_polar)
        # max_locs = polar_nights & (hourly_array == max_loc_polar)
        # temp_monthly.values =  np.where(polar_nights, np.nan,  temp_monthly.values)
        # temp_monthly.values = np.where(min_locs, min_val_cur_hourly,  temp_monthly.values)
        # temp_monthly.values = np.where(max_locs, max_val_cur_hourly,  temp_monthly.values)
        # temp_monthly = temp_monthly.interpolate_na(dim="time", method='linear', limit=23)
        
        # ###### SPECIAL CASE POLAR BY MELODIST ######
        # # Was not entirly working 
        # polars = (daylength < daylength_thres).rename("polars")
        # polars_hourly = np.repeat(polars.values,24,axis =0)
        # if polars.sum()>0:
        #     # during polar night, no diurnal variation of temperature is applied
        #     # instead the daily average calculated using tmin and tmax is applied
        #     temp_monthly.values = np.where(polars_hourly, np.nan,  temp_monthly.values)
           
        #     avg_before = (min_val_before + max_val_before) / 2.
        #     avg_cur = (min_val_cur + max_val_cur) / 2.
        #     getting_warmers = polars.values &  (avg_before <= avg_cur)
        #     getting_colders = polars.values & ~(avg_before <= avg_cur)
        
        #     getting_warmers_min_loc = np.repeat(getting_warmers.values,24,axis =0) & (hourly_array == min_loc_polar)
        #     getting_warmers_max_loc = np.repeat(getting_warmers.values,24,axis =0) & (hourly_array == max_loc_polar)
        #     temp_monthly.values = np.where(getting_warmers_min_loc, min_val_cur_hourly,  temp_monthly.values)
        #     temp_monthly.values = np.where(getting_warmers_max_loc, max_val_cur_hourly,  temp_monthly.values)
           
        #     getting_colders_min_loc =  np.repeat(getting_colders.values,24,axis =0) & (hourly_array == min_loc_polar)
        #     getting_colders_max_loc =  np.repeat(getting_colders.values,24,axis =0) & (hourly_array == max_loc_polar)
        #     temp_monthly.values = np.where(getting_colders_min_loc, max_val_cur_hourly,  temp_monthly.values)
        #     temp_monthly.values = np.where(getting_colders_max_loc, min_val_cur_hourly,  temp_monthly.values)

        #     temp_polars = polars_hourly*temp_monthly
        #     # Does this work is diff different in pandas?
        #     polars_diff = polars.copy()
        #     polars_diff.values = np.where(polars,1,0)
           
        #     transition_days = polars_diff.diff(dim = "dayofyear") # -1 where transition from polar to "normal" mode, 1 where transition from normal to polar
        #     polars_diff[1:] = transition_days
        #     transition_days = polars_diff
            
        #     if (transition_days == 0).all()==False:
        #         polar_to_normal_days = transition_days == -1
                
        #         normal_to_polar_days = np.where(transition_days ==False,False,False) # To get an array of false
        #         normal_to_polar_days_index = np.argwhere(transition_days.values == 1)
        #         for  t,y,x in normal_to_polar_days_index:
        #             normal_to_polar_days[t-1,y,x]= True
                    
        #         polar_to_normal_days_index = np.argwhere(polar_to_normal_days.values)
        #         normal_to_polar_days_index = np.argwhere(normal_to_polar_days.values)
                
        #         for  t,y,x in polar_to_normal_days_index:
        #             min_loc_pol =min_loc[t,y,x].values
        #             temp_polars[t-min_loc_pol:t,y,x] = np.nan
        #             temp_polars[t,y,x] = min_val_cur[t,y,x]
                    
        #         for  t,y,x in normal_to_polar_days_index:
        #             max_loc_pol =max_loc[t,y,x].values
        #             temp_polars[t+1:t+24-max_loc_pol,y,x] = np.nan

      
        #     temp_polars = temp_polars.interpolate_na(dim="time", method='linear', limit=23)
        #     temp_monthly.values = np.where(polars_hourly,temp_polars.values, temp_monthly.values )
        
        # Copy months to hourly data
        if pd.to_datetime(temp_daily_min.time[0].values).normalize().month ==12:
            temp_hourly.sel(time=slice(pd.to_datetime(temp_daily_min.time[0].values).normalize(),pd.to_datetime(temp_daily_min.time[-1].values).normalize()+pd.DateOffset(1)))[:] = temp_monthly
        else:
            temp_hourly.sel(time=slice(pd.to_datetime(temp_daily_min.time[0].values).normalize(),pd.to_datetime(temp_daily_min.time[-1].values).normalize()+pd.DateOffset(1)))[:-1] = temp_monthly
        
    # Delete manually to be sure to save RAM    
    del temp_monthly
    del sunnoon_full
    del sunrise_full
    del daylength_full
    del sunnoon
    del sunrise
    del daylength
    
    del min_loc 
    del max_loc 
    del min_val_cur 
    del max_val_cur 
    del min_val_before 
    del max_val_before 
    del min_val_next 
    del max_val_next
   
    del min_loc_hourly 
    del max_loc_hourly 
    del min_val_cur_hourly 
    del max_val_cur_hourly 
    del min_val_before_hourly 
    del max_val_before_hourly 
    del min_val_next_hourly 
    del max_val_next_hourly 
    
    # ###### CLIMATTOLOGY PLOTS TO CHECK RESULTS ######
    # # Calculates mean error of tmax and tmin difference of daily vs hourly
    # fig, axs = plt.subplots (1,3,subplot_kw={'projection': ccrs.PlateCarree()}, figsize =  (20,10))
    # plt.tight_layout()
    # plt.subplots_adjust(left= None, bottom=None, right= 0.97, top=None, wspace= 0.1, hspace=None)
    # for i in range(len(axs)):
    #     axs[i].get_xaxis().set_visible(False)
    #     axs[i].get_yaxis().set_visible(False)
    #     axs[i].set_facecolor("silver")
    #     axs[i].add_feature(cfeature.COASTLINE.with_scale("10m"), edgecolor = [0.3,0.3,0.3])
    #     axs[i].add_feature(cfeature.BORDERS.with_scale("10m"), edgecolor = [0.3,0.3,0.3])
        
    # diff = (temp_daily_max_total - temp_hourly.resample(time="1D").max().values).mean(axis=0)
    # diff.plot(ax=axs[0], transform=ccrs.PlateCarree(),cbar_kwargs = {"fraction":0.04, "pad":0.04,"label": "Delta tmax","shrink": 0.7}
    #           ,levels = np.arange(-0.2,0.01,0.05), cmap="plasma")
    # axs[0].set_title("tmax")
    
    # diff = (temp_daily_min_total - temp_hourly.resample(time="1D").min().values).mean(axis=0)
    # diff.plot(ax=axs[1], transform=ccrs.PlateCarree(),cbar_kwargs = {"fraction":0.04, "pad":0.04,"label": "Delta tmin","shrink": 0.7}
    #           ,levels = np.arange(0.0,0.41,0.05), cmap="plasma")
    # axs[1].set_title("tmin")
    
    # diff = (temp_daily_max_total/2+temp_daily_min_total/2 - temp_hourly.resample(time="1D").mean().values).mean(axis=0)
    # diff.plot(ax=axs[2], transform=ccrs.PlateCarree(),cbar_kwargs = {"fraction":0.04, "pad":0.04,"label": "Delta tmean","shrink": 0.7}
    #           ,levels = np.arange(-0.5,-0.11,0.05), cmap="plasma")
    # axs[2].set_title("tmean")
    
    # plt.suptitle("Mean Difference Daily Values Minus Resampled Disaggregated Values\n 1951")
    # plt.savefig("/hp5/Foehnmodel/SECURES/diff_1951.png",dpi = 300, orientation='landscape')
    
    # fig, axs = plt.subplots (1,2,subplot_kw={'projection': ccrs.PlateCarree()}, figsize =  (20,10))
    # plt.tight_layout()
    # plt.subplots_adjust(left= None, bottom=None, right= 0.97, top=None, wspace= 0.1, hspace=None)
    # for i in range(len(axs)):
    #     axs[i].get_xaxis().set_visible(False)
    #     axs[i].get_yaxis().set_visible(False)
    #     axs[i].set_facecolor("silver")
    #     axs[i].add_feature(cfeature.COASTLINE.with_scale("10m"), edgecolor = [0.3,0.3,0.3])
    #     axs[i].add_feature(cfeature.BORDERS.with_scale("10m"), edgecolor = [0.3,0.3,0.3])
        
    # diff = (abs(temp_daily_max_total - temp_hourly.resample(time="1D").max().values)).max(axis=0)
    # diff.plot(ax=axs[0], transform=ccrs.PlateCarree(),cbar_kwargs = {"fraction":0.04, "pad":0.04,"label": "Delta tmax","shrink": 0.7}
    #           , cmap="plasma")
    # axs[0].set_title("tmax")
    
    # diff = (abs(temp_daily_min_total - temp_hourly.resample(time="1D").min().values)).max(axis=0)
    # diff.plot(ax=axs[1], transform=ccrs.PlateCarree(),cbar_kwargs = {"fraction":0.04, "pad":0.04,"label": "Delta tmin","shrink": 0.7}
    #           , cmap="plasma")
    # axs[1].set_title("tmin")
    
    # plt.suptitle("Max Difference Daily Values Minus Resampled Disaggregated Values\n 1951")
    # plt.savefig("/hp5/Foehnmodel/SECURES/max_diff_1951.png",dpi = 300, orientation='landscape')
    
    
    return temp_hourly

def generate_disaggregated_temperature_CORDEX_ncs():
    '''
    This function generates disaggregated CORDEX  temperature ncs of the years 1951-2100. 

    Parameters
    ----------
   
    Returns
    -------
    None.

    '''
    
    # Fill Nans with Fill Value und save
    def replace_Nans_and_save(DS,parameter, nc_metadata):
        
        _FillValue = 1e+20

        # loop over every month and Get string month
        for group in DS.groupby("time.month"):
            if group[0] < 10:
                str_month = "0"+str(group[0])
            else:
                str_month = str(group[0])
        
            # Add Metadata
            DS_month = group[1]
            DS_month = DS_month.assign_attrs({"cell_methods": "time: disaggregated"})
            DS_month = DS_month.to_dataset()
            DS_month = DS_month.assign_attrs(nc_metadata)
            # DS_month = DS_month.fillna(_FillValue)
            
            # Encode Fill Values and Save
            encoding  = {}
            for cname in DS_month.coords:
                encoding[cname] = {"_FillValue": None}
            encoding['time'] = {"_FillValue": None}
            encoding['lat']= {"_FillValue": None}
            encoding['lon']= {"_FillValue": None}
            encoding[parameter]={'_FillValue': _FillValue}
            DS_month.to_netcdf(disaggregated_path+model+"_"+str(year)+"_"+str_month+"_"+parameter+"_hourly.nc",unlimited_dims='time', encoding=encoding) 
           
    # select models and sun times 
    models = ["ICHEC-EC-EARTH_KNMI-RACMO22E_rcp45","ICHEC-EC-EARTH_KNMI-RACMO22E_rcp85"]
    sun_times = xr.open_dataset(sun_times_file)
    for model in models:  
        
        nc_metadata = {'Conventions': 'CF-1.4', 'contact': 'philipp.maier@boku.ac.at', 'experiment': model, 'realization': '1', 'driving_experiment': 'ICHEC-EC-EARTH,r12i1p1', 'driving_model_id': 'ICHEC-EC-EARTH', 'driving_model_ensemble_member': 'r12i1p1', 'model_id': 'KNMI-RACMO22E', 'rcm_version_id': 'v1','project_id': 'CORDEX', 'CORDEX_domain': 'EUR-11', 'frequency': 'hour'}
        
        Start_year = 1951
        End_year = 2100
        file_name_model =model
        
        # Loops over decades (Output files of the Bias Correction)
        for decade in range(Start_year,End_year+1,10):
            if decade == 1961 or decade == 2091:
                # Open data and the decade, exceptions for the shorter periods
                temp_max = xr.open_dataset(glob.glob(bias_corrected_path+"*tasmax*"+file_name_model+"*"+str(decade-10)+"*.nc")[0])["tasmax"]
                temp_min = xr.open_dataset(glob.glob(bias_corrected_path+"*tasmin*"+file_name_model+"*"+str(decade-10)+"*.nc")[0])["tasmin"]
            else:
                # Open data and the decade and save metadata
                temp_max = xr.open_dataset(glob.glob(bias_corrected_path+"*tasmax*"+file_name_model+"*"+str(decade)+"*.nc")[0])["tasmax"]
                temp_min = xr.open_dataset(glob.glob(bias_corrected_path+"*tasmin*"+file_name_model+"*"+str(decade)+"*.nc")[0])["tasmin"]

            for year in range(decade,decade+10,1):
                # Select year
                temp_max_yearly = temp_max.sel(time = str(year))
                temp_min_yearly = temp_min.sel(time = str(year))
                
                # Disaggregate
                temp_hourly = disaggegrate_temperature_array_like(temp_max_yearly,temp_min_yearly,sun_times)
            
                # Add Metadata again and save
                temp_hourly = temp_hourly.assign_attrs(temp_max_yearly.attrs)
                replace_Nans_and_save(temp_hourly, "tas",nc_metadata)
            
if __name__ == "__main__":
    generate_disaggregated_temperature_CORDEX_ncs()
    
