import xarray as xa
import numpy as np
import datetime as dt

import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import seaborn as sns
from matplotlib.colors import BoundaryNorm, LogNorm, Normalize, ListedColormap

import os
import sys
import functions as func
import glob
import dask
import dask.array as da


xa.set_options(display_expand_attrs=False)

#%%--------------------------------
# SPECIFY
LENGTH           = 24        # length of trajectories for file name
landonly         = False     # if land mask should be applied
region           = 'N37'     # for which region analysis should be done
daytime_selection= False     # if only daytimehours should be used
early_date       = dt.datetime(2022,5,1,0)  # period for which analsysis should be done
late_date        = dt.datetime(2024,4,30,23)
season           = 'allyear' # season for which analysis should be done

morning       = 10 # >= this time for daytime selection
afternoon     = 15 # <= this time for daytime selection


indir            = f"/data_dir"
infiles          = f"{indir}/AD_gridded_all_??????_length{LENGTH}.nc" 

indir_landmask = f"/indir_land-sea_mask/"

outdir  = indir
###################################################################
lowest_year = int(f"{early_date.year:04}")
highest_year= int(f"{late_date.year:04}")
year_range = range(lowest_year, highest_year+1) 
#----------------------------------------------------------------------

dx          = 0.25 # horizontal gridspacing, I could also get that form lon and lat data

#----------------------------------------------------------------------
proj              = ccrs.Mercator()    # what kind of map projection to use
dpi               = 300                              # has to be 300 for WCD
fig_width   = 6                      #in inch (min 8cm=3.2in) dinA4: <8in
context     = 'paper'
scale_fonts = 0.7
sns.set_context(context, font_scale=scale_fonts)
#---------------------------------------------------------------------------i
# box to select data from
if region=='GB':
    box = [49,59,-10,2]
elif region=='SC':
    box = [57,65,5,20]
elif region=='RU':
    box = [48,58,34,46]
elif region=='MT':
    box = [36,44,10,25]
elif region=='IB':
    box = [36,44,-10,3]
elif region=='CE':
    box = [45,55,4,16]
elif region=='N37':
    box=[37,None,None,None]
else:
    print("region not implemented. Interrupting script.")
    sys.exit()  # This will terminate the script

south, north, west, east = box




patterns = [f"{indir_landmask}ERA5_sfc_{y}_land-sea_mask.nc" for y in year_range]
infiles_landmask = []
for pattern in patterns:
    print(f"adding land mask files: {pattern}")
    infiles_landmask.extend(glob.glob(pattern))

#%%------------------------------------------------------

#%%-----------------

data = func.get_AD_data(infiles=infiles, dropvars=[],\
               south=south, north=north, west=west, east=east,\
               early_date=early_date, late_date=late_date,\
               daytime_selection=daytime_selection, season=season, landonly=landonly, infiles_landmask=infiles_landmask)

data_daytime = data.sel(time=data["time"].dt.hour.isin(range(morning, afternoon+1)))

#%%-------------------

dat_reduced = data[['time_diff', 'TH_diff', 'Q_diff', 'TH_E_diff']].sel(hamsl=[500,1500,3000,4500,9000]).mean(dim='time')
dat_reduced.to_netcdf(f"{outdir}/AD_AverageChangesAltitude.nc")
