import xarray as xa
import pandas as pd
import numpy as np
import datetime as dt
from optparse import OptionParser
import sys
sys.path.append("/gpfs/data/fs72140/fiona_fix/ADs_longterm")
import functions as func

# get input from arguments
parser = OptionParser()
parser.add_option("-Y", dest = "YEAR", type = "str")
parser.add_option("-M", dest = "MONTH", type = "str")
parser.add_option("-D", dest = "DAY", type = "str")
parser.add_option("-T", dest = "TIME", type = "str")
parser.add_option("-L", dest = "LENGTH", type = "str")
parser.add_option("-B", dest = "BLH_THRESHOLD", type = "str")
parser.add_option("-O", dest = "OUTDIR", type = "str")
(options,args) = parser.parse_args()
YEAR       = int(options.YEAR)
MONTH      = int(options.MONTH)
DAY        = int(options.DAY)
TIME       = int(options.TIME)
LENGTH     = int(options.LENGTH)
BLH_threshold = float(options.BLH_THRESHOLD)
outdir     = options.OUTDIR

date      = dt.datetime(YEAR, MONTH, DAY, TIME, 0, tzinfo = dt.timezone.utc)
file_prop = outdir+f"AD_select_{YEAR:04}{MONTH:02}{DAY:02}T{TIME:02}00_length{LENGTH}_tmp.nc"

# SPECIFY according to domain and needs 
# define desired grid to aggregate to
dx = 0.25                              
dz = 500
lon_grid = np.arange(-30, 60+dx, dx)
lat_grid = np.arange(15, 73+dx, dx)
hamsl_grid = np.arange(0,16000+dz,dz)


# SPECIFY location of BLH and geopotential ERA5 data
infile_BLH       = f"/gpfs/data/fs72140/c4031046_vsc/winter_lightning/ERA5-nAf/single-level/boundary_layer_height/ERA5_sfc_{YEAR}_boundary_layer_height.nc"
infile_sfcgeopot = f"/gpfs/data/fs72140/c4031046_vsc/winter_lightning/ERA5-nAf/geopotential-surface/ERA5_sfc_geopotential.nc"

#----------------------
########################################################################################################################
#---------------
# open dataset
fill_value = 9.96920997e+36
ds_AD = xa.open_dataset(file_prop, decode_times=False)
ds_AD = ds_AD.where(ds_AD != fill_value, other=np.nan)
ds_AD = ds_AD.where(ds_AD != -999, other=np.nan)

# grid data and kep names variables
variables = ['lon_r','lat_r','hamsl_r','hamsl', 'TH_diff','TH_E_diff','Q_diff','time_diff', 'hamsl_diff', 'TH_diffN37', 'TH_E_diffN37', 'Q_diffN37', 'time_diffN37', 'hamsl_diffN37']
AD_cells  = func.group_and_count_AD(ds_AD, variables, dz=dz, dx=dx)

# Create an empty dataset with the new domain, filled with NaNs
empty_gridded_ds = xa.Dataset(
    {'counts': (['lon', 'lat', 'hamsl'], np.full((len(lon_grid), len(lat_grid), len(hamsl_grid)), np.nan)),\
    'hamsl_r2': (['lon', 'lat', 'hamsl'], np.full((len(lon_grid), len(lat_grid), len(hamsl_grid)), np.nan))},
    coords={'lon': lon_grid, 'lat': lat_grid, 'hamsl': hamsl_grid}
)
# Merge the original dataset with the extended dataset
# this is to make sure we have data for the entire grid, as it mihgt happen that a cell does not occur when using groupby
AD_cells = empty_gridded_ds.merge(AD_cells)
# add time dimension, for later merging
time     = dt.datetime(YEAR, MONTH, DAY, TIME, 0)
AD_cells = AD_cells.expand_dims(time=[time])
AD_cells = AD_cells.transpose("time","hamsl","lat","lon") # important for cdo


#-----------------------------
# calculate some more features
print("calculating columns")
AD_cells['AD_colcounts'] = AD_cells.counts.sum(dim='hamsl')
# Ensure that AD_colcounts has the same coordinates as counts
AD_cells['AD_colcounts'] = AD_cells['AD_colcounts'].assign_coords(
    lon=AD_cells['counts'].lon,
    lat=AD_cells['counts'].lat,
    time=AD_cells['counts'].time  # Ensure time is also included if applicable
)


AD_cells['Num_ADcellspercol'] = (~np.isnan(AD_cells.counts)).where((~np.isnan(AD_cells.counts))).sum(dim='hamsl') # number of AD cells in each column
AD_cells['AD_colYes']         = AD_cells['AD_colcounts']!=0                                                       # boolean, columns with at least one AD cell

# lowest, highest AD cell and number of continous AD layers in each column 
AD_cells['lowest_AD_cell']  = func.get_lowest_AD_cell(AD_cells.hamsl_r2)
AD_cells['highest_AD_cell'] = func.get_highest_AD_cell(AD_cells.hamsl_r2)
AD_cells['Num_layers']      = xa.DataArray(func.calculate_num_streaks(AD_cells.counts,\
                                      axis=AD_cells.counts.get_axis_num('hamsl')), \
                                dims=['time','lat','lon'], \
                                coords={'lon': AD_cells.lon, 'lat': AD_cells.lat, 'time':AD_cells.time})

#%%======================================================================
# identify earliest andlatest date
early_date       = AD_cells.time.min()
late_date        = AD_cells.time.max()
YEAR             = int(np.datetime_as_string(AD_cells.time.min().data, unit='Y'))

print("getting BLH data")
data_BLH = xa.open_dataset(infile_BLH)
data_BLH = data_BLH.rename({"longitude": "lon", "latitude": "lat"})
if YEAR==2024:
    data_BLH = data_BLH.reset_coords("expver", drop=True)
data_BLH = data_BLH.sortby('lat')
# slice data to same size as AD data
data_BLH = data_BLH.sel(time=slice(early_date,late_date),\
                       lon=AD_cells.lon, lat=AD_cells.lat)

data_sfc_geopot      = xa.open_dataset(infile_sfcgeopot).rename({"longitude": "lon", "latitude": "lat"})
data_sfc_geopot      = data_sfc_geopot.sortby('lat')
surf_height          = func.geopot_to_hamsl(data_sfc_geopot.z)
surf_height          = surf_height.expand_dims(dim={'time':data_BLH.sizes['time']}, axis=0)
data_BLH['topo']     = (['time','lat','lon'],surf_height.values)
BLH_amsl             = func.hag_to_hamsl(data_BLH['blh'], data_BLH['topo'])
data_BLH['BLH_a']    = (['time','lat','lon'], BLH_amsl.values)
data_BLH.BLH_a.attrs = dict(long_name="boundary_layer_height_above_mean_sealevel", units="m")
#%%======================================================================
print("calculating diff lowest AD and BLH, flag lid")
AD_cells['diff_lowestAD_BLH'] = AD_cells.lowest_AD_cell - data_BLH.BLH_a
AD_cells['lid_presentYes']    = ((AD_cells.diff_lowestAD_BLH>-BLH_threshold) & \
                             (AD_cells.diff_lowestAD_BLH<BLH_threshold))     # flag where lowest AD below BLH

#%%======================================================================
print("saving file")
encoding = {'hamsl_r2':{'dtype':'int32','_FillValue':-999},
            'counts':{'dtype':'int32','_FillValue':-999},
            'AD_colcounts':{'dtype':'int32','_FillValue':-999},
            'Num_ADcellspercol':{'dtype':'int32', '_FillValue':-999},
            'lowest_AD_cell':{'dtype':'int32','_FillValue':-999},
            'highest_AD_cell':{'dtype':'int32','_FillValue':-999},
            'Num_layers':{'dtype':'int32', '_FillValue':-999},
            'diff_lowestAD_BLH':{'dtype':'float32','_FillValue':-999}
           }        # make sure where possible variables are ints to save memory

AD_cells.to_netcdf(f"{outdir}AD_gridded_{YEAR:04}{MONTH:02}{DAY:02}T{TIME:02}00_length{LENGTH}tmp.nc",encoding=encoding)

