#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Feb  1 12:43:06 2022

@author: mphilipp
"""

import pandas as pd
import xarray as xr
import numpy as np
import glob
import time

ERA5_ERA5L_merged_path = "/path/to/ERA5/and/ERA5-Land/merged/windspeed/"
wspd_fractions_path =  "/path/to/windspeed/fractions/" 

def calc_wspd_fractions_final(test = False):
        '''
        This function calculates the fraction of every hour of every day of the year to the average of the particular day from ERA5L files. 
        Afterwards, a 7 day rolling mean of is applied. This file is then stored as nc file. TAKES A LOT OF RAM  Because it loads whole Europe
        in RAM.

        Parameters
        ----------
     
        test : bool, optional
            Check if you want only 1997 and 1998, for debugging reasons.
    
        Returns
        -------
        None.

        '''
        # If debugging mode, just use 1997 and 1998, otherwise use all available ERA5L files (Starting from 1981)        
        if test:
            wspd_full = xr.open_mfdataset([ERA5_ERA5L_merged_path+"ERA5_Land_1997_Wspeed_hourly.nc",ERA5_ERA5L_merged_path+"ERA5_Land_1998_Wspeed_hourly.nc"])["ws"]
        else:
            wspd_full = xr.open_mfdataset(sorted(glob.glob(ERA5_ERA5L_merged_path+"ERA5_Land_*_Wspeed_hourly.nc")), parallel=True, compat='override', coords='minimal')["ws"]
            
        # Select austria and calculate pythagoras for total wspd
        wspd = wspd_full.load()
        # Build mean max and min for every day and save 
        wspd_extrema = wspd.groupby("time.dayofyear").mean()
        wspd_extrema = wspd_extrema.rename("avg_wspd").to_dataset().assign(max_wspd=wspd.groupby("time.dayofyear").max())
        wspd_extrema = wspd_extrema.assign(min_wspd=wspd.groupby("time.dayofyear").min())
        if test == False:
            wspd_extrema.to_netcdf(wspd_fractions_path+"wspd_extrema_full_domain.nc")
        del wspd_extrema
        print("done building extrema")
        
        # Create a 4D array, which has the dimension daysodfyear x hour x y x x
        dates = pd.date_range(wspd.time[0].values,wspd.time[-1].values,freq="D")
        hours = range(24) 
        days = wspd.time[:len(wspd.time.values)//24].values
        ind = pd.MultiIndex.from_product((days,hours),names=('date','hour'))
        frac_4D = wspd.rename("frac").to_dataset().assign(time=ind).unstack('time')
        frac_4D = frac_4D.assign_coords(date=("date",dates))
        
        # Set the wspd in this hour in relation with the daily average and build the day of year mean
        mean_values = wspd.resample(time="1D").mean()
        frac_4D = frac_4D["frac"].transpose("hour","date","lat","lon")/mean_values.values
        frac_4D = frac_4D.rename("frac")
        frac_4D = frac_4D.transpose("date","hour","lat","lon")
        frac_4D = frac_4D.groupby("date.dayofyear").mean()
        
        # get replacement values: 7 day rolling mean of daily max/avg wspd 
        rolling_window_size = 7
        replace_values= frac_4D.roll(dayofyear=2*rolling_window_size).rolling(dayofyear = rolling_window_size, center =True).mean().roll(dayofyear=-2*rolling_window_size).values
        frac_4D = frac_4D.rolling(dayofyear = rolling_window_size, center =True).mean()
        frac_4D.values = np.where(np.isnan(frac_4D.values),replace_values, frac_4D.values)
        
        # Rename and save
        frac_4D = frac_4D.rename("frac").to_dataset()
        if test == False:
            frac_4D.to_netcdf(wspd_fractions_path+"wspd_fractions_full_domain.nc")
        print("done building fractions")

def calc_wspd_fractions_lat_sliced(test = False):
        '''
        This function calculates the fraction of every hour of every day of the year to the average of the particular day from ERA5L files. 
        Afterwards, a 7 day rolling mean of is applied. This file is then stored as nc file. Slices along Lats to save RAM.

        Parameters
        ----------
     
        test : bool, optional
            Check if you want only 1997 and 1998, for debugging reasons.
    
        Returns
        -------
        None.

        '''
        
        start_time=time.time()
        # If debugging mode, just use 1997 and 1998, otherwise use all available ERA5L files (Starting from 1981)
        if test:
            wspd_full = xr.open_mfdataset([ERA5_ERA5L_merged_path+"ERA5_Land_1997_Wspeed_hourly.nc",ERA5_ERA5L_merged_path+"ERA5_Land_1998_Wspeed_hourly.nc"])["ws"]
        else:
            wspd_full = xr.open_mfdataset(sorted(glob.glob(ERA5_ERA5L_merged_path+"ERA5_Land_*_Wspeed_hourly.nc")), parallel=True, compat='override', coords='minimal')["ws"]
            
        print("done opening")
        # Loop over Kat
        for latnr in range(len(wspd_full.lat)):
            print(latnr)
            wspd = wspd_full[:,latnr:latnr+1,:].load() # Load only single lat in RAM
           
            # Build mean max and min for every day and save 
            wspd_extrema = wspd.groupby("time.dayofyear").mean() 
            wspd_extrema = wspd_extrema.rename("avg_wspd").to_dataset().assign(max_wspd=wspd.groupby("time.dayofyear").max())
            wspd_extrema = wspd_extrema.assign(min_wspd=wspd.groupby("time.dayofyear").min())
            if test == False:
                wspd_extrema.to_netcdf(wspd_fractions_path+"sliced/wspd_extrema_lat"+str(latnr).zfill(4)+".nc", unlimited_dims="lat") #####
            del wspd_extrema
            print("done building extrema")
            
            # Create a 4D array, which has the dimension daysodfyear x hour x y x x
            dates = pd.date_range(wspd.time[0].values,wspd.time[-1].values,freq="D")
            hours = range(24) 
            days = wspd.time[:len(wspd.time.values)//24].values
            ind = pd.MultiIndex.from_product((days,hours),names=('date','hour'))
            frac_4D = wspd.rename("frac").to_dataset().assign(time=ind).unstack('time')
            frac_4D = frac_4D.assign_coords(date=("date",dates))
            
            # Set the wspd in this hour in relation with the daily average and build the day of year mean
            mean_values = wspd.resample(time="1D").mean()
            frac_4D = frac_4D["frac"].transpose("hour","date","lat","lon")/mean_values.values #####
            frac_4D = frac_4D.rename("frac")
            frac_4D = frac_4D.transpose("date","hour","lat","lon")
            frac_4D = frac_4D.groupby("date.dayofyear").mean()
            
            # get replacement values: 7 day rolling mean of daily max/avg wspd 
            rolling_window_size = 7
            replace_values= frac_4D.roll(dayofyear=2*rolling_window_size).rolling(dayofyear = rolling_window_size, center =True).mean().roll(dayofyear=-2*rolling_window_size).values
            frac_4D = frac_4D.rolling(dayofyear = rolling_window_size, center =True).mean()
            frac_4D.values = np.where(np.isnan(frac_4D.values),replace_values, frac_4D.values)
            
            # Rename and save
            frac_4D = frac_4D.rename("frac").to_dataset()
            if test == False:
                frac_4D.to_netcdf(wspd_fractions_path+"sliced/wspd_fractions_lat"+str(latnr).zfill(4)+".nc", unlimited_dims="lat")
            print("done building fractions")
            
        # concat files along dimension "lat"
        wspd_extrema_full = xr.open_mfdataset(sorted(glob.glob(wspd_fractions_path+"sliced/wspd_extrema_lat*.nc")))
        wspd_extrema_full.to_netcdf(wspd_fractions_path+"sliced/wspd_extrema_full_domain.nc")
        wspd_fractions_full = xr.open_mfdataset(sorted(glob.glob(wspd_fractions_path+"sliced/wspd_fractions_lat*.nc")))
        wspd_fractions_full.to_netcdf(wspd_fractions_path+"wspd_fractions_full_domain.nc")
        print("required:" + str(time.time()-start_time))

def calc_wspd_fractions_dask(test = False):
        '''
        This function calculates the fraction of every hour of every day of the year to the average of the particular day from ERA5L files. 
        Afterwards, a 7 day rolling mean of is applied. This file is then stored as nc file. Tries to use dask efficiently, does not work that good.

        Parameters
        ----------
     
        test : bool, optional
            Check if you want only 1997 and 1998, for debugging reasons.
    
        Returns
        -------
        None.

        '''
        
        start_time=time.time()
        # If debugging mode, just use 1997 and 1998, otherwise use all available ERA5L files (Starting from 1981)
        if test:
            wspd_full = xr.open_mfdataset([ERA5_ERA5L_merged_path+"ERA5_Land_1997_Wspeed_hourly.nc",ERA5_ERA5L_merged_path+"ERA5_Land_1998_Wspeed_hourly.nc"])["ws"]
        else:
            wspd_full = xr.open_mfdataset(sorted(glob.glob(ERA5_ERA5L_merged_path+"ERA5_Land_*_Wspeed_hourly.nc")), parallel=True, compat='override', coords='minimal')["ws"]
            
        wspd = wspd_full
        
        # Build mean max and min for every day and save 
        wspd_extrema = wspd.groupby("time.dayofyear").mean().compute() # Call compute (Dask)
        wspd_extrema = wspd_extrema.rename("avg_wspd").to_dataset().assign(max_wspd=wspd.groupby("time.dayofyear").max().compute())
        wspd_extrema = wspd_extrema.assign(min_wspd=wspd.groupby("time.dayofyear").min().compute())
        if test == False:
            wspd_extrema.to_netcdf(wspd_fractions_path+"wspd_extrema_full_domain_dask.nc")
        del wspd_extrema
        print("done building extrema")
        
        # Create a 4D array, which has the dimension daysodfyear x hour x y x x
        dates = pd.date_range(wspd.time[0].values,wspd.time[-1].values,freq="D")
        hours = range(24) 
        days = wspd.time[:len(wspd.time.values)//24].values
        ind = pd.MultiIndex.from_product((days,hours),names=('date','hour'))
        frac_4D = wspd.rename("frac").to_dataset().assign(time=ind).unstack('time')
        frac_4D = frac_4D.assign_coords(date=("date",dates))
        
        # Set the wspd in this hour in relation with the daily average and build the day of year mean
        mean_values = wspd.resample(time="1D").mean().compute() # Compute for Dask
        frac_4D = frac_4D["frac"].transpose("hour","date","lat","lon")/mean_values.values
        frac_4D = frac_4D.rename("frac")
        frac_4D = frac_4D.transpose("date","hour","lat","lon")
        frac_4D = frac_4D.groupby("date.dayofyear").mean()
        
        # get replacement values: 7 day rolling mean of daily max/avg wspd 
        rolling_window_size = 7
        replace_values= frac_4D.roll(dayofyear=2*rolling_window_size).rolling(dayofyear = rolling_window_size, center =True).mean().roll(dayofyear=-2*rolling_window_size).values
        frac_4D = frac_4D.rolling(dayofyear = rolling_window_size, center =True).mean()
        frac_4D.values = np.where(np.isnan(frac_4D.values),replace_values, frac_4D.values)
        
        # Rename and save
        frac_4D = frac_4D.rename("frac").to_dataset()
        if test == False:
            frac_4D.to_netcdf(wspd_fractions_path+"wspd_fractions_full_domain_dask.nc")
        print("done building fractions")
        print("required:" + str(time.time()-start_time))
        
if __name__ == "__main__":     
     # calc_wspd_fractions_final(test = False)
     calc_wspd_fractions_lat_sliced(test = False)
     # calc_wspd_fractions_dask(test = False)
