"""
Created on Tue Feb  1 12:43:06 2022

@author: mphilipp
"""

import pandas as pd
import xarray as xr
import numpy as np
import glob
import time

ERA5L_path = "/path/to/ERA5-Land/data/"
rsds_fractions_path = "/path/to/radiation/fractions/"

def calc_rsds_fractions_lat_sliced(test = False):
        '''
        This function calculates the fraction of every hour of every day of the year to the average of the particular day from ERA5L files. 
        Afterwards, a 7 day rolling mean of is applied. This file is then stored as nc file. Done by lat slicing.

        Parameters
        ----------
     
        test : bool, optional
            Check if you want only 1997 and 1998, for debugging reasons.
    
        Returns
        -------
        None.

        '''
        
        import glob
        start_time=time.time()
        # If debugging mode, just use 1997 and 1998, otherwise use all available ERA5L files (Starting from 1981)
        if test:
            rsds_full = xr.open_mfdataset([ERA5L_path+"Rsds/Era5l_eu_rsds_hourly_1981.nc",ERA5L_path+"Rsds/Era5l_eu_rsds_hourly_2020.nc"])["rsds"]
        else:
            rsds_full = xr.open_mfdataset(sorted(glob.glob(ERA5L_path+"Rsds/Era5l_eu_rsds_hourly_*.nc")), parallel=True, compat='override', coords='minimal')["rsds"]
           
        print("done opening")
        # Loop over all lats
        for latnr in range(len(rsds_full.lat)):
            print(latnr) 
            rsds = rsds_full[:,latnr:latnr+1,:].load() #Only load one lat into the RAM
           
            # Build mean max and min for every day and save 
            rsds_extrema = rsds.groupby("time.dayofyear").mean() 
            rsds_extrema = rsds_extrema.rename("avg_rsds").to_dataset().assign(max_rsds=rsds.groupby("time.dayofyear").max())
            rsds_extrema = rsds_extrema.assign(min_rsds=rsds.groupby("time.dayofyear").min())
            if test == False:
                rsds_extrema.to_netcdf(rsds_fractions_path+"sliced/rsds_extrema_lat"+str(latnr).zfill(4)+".nc", unlimited_dims="lat") #####
            del rsds_extrema
            print("done building extrema")
            
            # Create a 4D array, which has the dimension daysodfyear x hour x y x x
            dates = pd.date_range(rsds.time[0].values,rsds.time[-1].values,freq="D").normalize()
            hours = range(24) 
            # days = rsds.time[:len(rsds.time.values)//24].values
            # ind = pd.MultiIndex.from_product((days,hours),names=('date','hour'))
            ind = pd.MultiIndex.from_product((dates,hours),names=('date','hour'))[1:]
            frac_4D = rsds.rename("frac_rsds").to_dataset().assign(time=ind).unstack('time')
            
            # Set the glob in this hour in relation with the daily average and build the day of year mean
            mean_values = rsds.resample(time="1D").mean()
            frac_4D = frac_4D["frac_rsds"].transpose("hour","date","lat","lon")/mean_values.values #####
            frac_4D = frac_4D.rename("frac_rsds")
            frac_4D = frac_4D.transpose("date","hour","lat","lon")
            frac_4D = frac_4D.groupby("date.dayofyear").mean()
            
            # get replacement values: 7 day rolling mean of daily max/avg glob 
            rolling_window_size = 7
            replace_values= frac_4D.roll(dayofyear=2*rolling_window_size).rolling(dayofyear = rolling_window_size, center =True).mean().roll(dayofyear=-2*rolling_window_size).values
            frac_4D = frac_4D.rolling(dayofyear = rolling_window_size, center =True).mean()
            frac_4D.values = np.where(np.isnan(frac_4D.values),replace_values, frac_4D.values)
            
            # Rename and save
            frac_4D = frac_4D.rename("frac_rsds").to_dataset()
            if test == False:
                frac_4D.to_netcdf(rsds_fractions_path+"sliced/rsds_fractions_lat"+str(latnr).zfill(4)+".nc", unlimited_dims="lat")
            print("done building fractions")
            
        # concat files along dimension "lat"
        rsds_extrema_full = xr.open_mfdataset(sorted(glob.glob(rsds_fractions_path+"sliced/rsds_extrema_lat*.nc")))
        rsds_extrema_full.to_netcdf(rsds_fractions_path+"rsds_extrema_full_domain.nc")
        rsds_fractions_full = xr.open_mfdataset(sorted(glob.glob(rsds_fractions_path+"sliced/rsds_fractions_lat*.nc")))
        rsds_fractions_full.to_netcdf(rsds_fractions_path+"rsds_fractions_full_domain.nc")
        print("required:" + str(time.time()-start_time))

if __name__ == "__main__":     
     calc_rsds_fractions_lat_sliced(test = False)
