#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Jul 31 2024

@author: Fiona Fix

This script processes the ERA BLH to create a smoothed version of it, interpolating between the daily maxima

"""
import numpy as np
import pandas as pd
import xarray as xa
from optparse import OptionParser

# get input and outout files from arguments
parser = OptionParser()
parser.add_option("--INFILE", dest = "INFILE", type = "str")
parser.add_option("--OUTFILE", dest = "OUTFILE", type = "str")
(options,args) = parser.parse_args()
filename = options.INFILE

# open infile
ds       = xa.open_dataset(filename)

# Group by day and find the daily maximum 
daily_max = ds.blh.resample(time='1D').max(dim='time')

# Initialize a new dataset filled with NaNs
blh_max = xa.full_like(ds.blh, np.nan)

# Function to find the exact times of daily maxima
def find_daily_max_times(ds, daily_max, blh_max):
    """
    Finds the time of the daily maximum value of the 'blh' variable in the dataset
    and updates the provided `blh_max` object with the maximum values for each day.

    Args:
        ds (xarray.Dataset):
            The dataset containing time-series data, including the 'blh' variable.
        daily_max (xarray.Dataset):
            A dataset containing daily max values 
        blh_max (xarray.DataArray):
            An xarray DataArray where the daily maximum 'blh' values will be stored.
            It should have a 'time' dimension matching the days in `daily_max`.

    Returns:
        xarray.DataArray:
            The updated `blh_max` DataArray with the maximum 'blh' values for each day
            at the correct times
    """
    max_times = []
    for day in daily_max['time'].values:                                                  # loop through all days
        day_data = ds.sel(time=str(pd.Timestamp(day).date()))                             # select data on that day
        if day_data['blh'].notnull().any():                                               # 
            max_time                        = day_data['blh'].idxmax(dim='time')          # find the time of the max
            max_value                       = day_data['blh'].max(dim='time')             # get the ac
            blh_max.loc[{'time': max_time}] = max_value                                   # fill the max value in the right position in the new array
    return blh_max


blh_max = find_daily_max_times(ds, daily_max, blh_max)

# Interpolate the NaN values along the time dimension
blh_s = blh_max.interpolate_na(dim='time', method='linear')
# add coordinate info from old blh data
ds_new = xa.Dataset(
    data_vars= dict(blh_interp=(['time','latitude','longitude'], blh_s.data)),
    coords   = dict(time=ds.time, longitude=ds.longitude, latitude=ds.latitude))

#%%
# Save to netcdf
ofilename = options.OUTFILE
ds_new.to_netcdf(ofilename,format='netcdf4')

