import xarray as xa
import numpy as np
import datetime as dt

import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import seaborn as sns
from matplotlib.colors import BoundaryNorm, LogNorm, Normalize, ListedColormap

import os
import sys
import functions as func
import glob
import dask
import dask.array as da


xa.set_options(display_expand_attrs=False)

#################
# INPUT  SPECIFY
################
LENGTH           = 120
indir            = f"/data_dir/"
infiles          = f"{indir}/AD_gridded_all_??????_length{LENGTH}.nc" #change name if needed
outpath          = 'outpath' 

landonly         = False
region           = 'N37'
daytime_selection= False
early_date       = dt.datetime(2022,5,1,0)
late_date        = dt.datetime(2024,4,30,23)
date             = dt.datetime(2023,9,4,12)
season           = 'allyear'


######################################################
morning       = 10 # >= this time for daytime selection
afternoon     = 15 # <= this time for daytime selection

lowest_year = int(f"{early_date.year:04}")
highest_year= int(f"{late_date.year:04}")
#----------------------------------------------------------------------

dx          = 0.25 # horizontal gridspacing, I could also get that form lon and lat data
dz          = 500


#----------------------------------------------------------------------
proj              = ccrs.Mercator()    # what kind of map projection to use
dpi               = 300                              # has to be 300 for WCD
fig_width   = 6                      #in inch (min 8cm=3.2in) dinA4: <8in
context     = 'paper'
scale_fonts = 0.7
sns.set_context(context, font_scale=scale_fonts)
#---------------------------------------------------------------------------i
# box to select data from
if region=='GB':
    box = [49,59,-10,2]
elif region=='SC':
    box = [57,65,5,20]
elif region=='RU':
    box = [48,58,34,46]
elif region=='MT':
    box = [36,44,10,25]
elif region=='IB':
    box = [36,44,-10,3]
elif region=='CE':
    box = [45,55,4,16]
elif region=='N37':
    box=[37,None,None,None]
else:
    print("region not implemented. Interrupting script.")
    sys.exit()  # This will terminate the script

south, north, west, east = box


year_range = range(lowest_year, highest_year+1)  # This will give you 2022, 2023, 2024


#SPECIFY
patterns = [f"/gpfs/data/fs72140/c4031046_vsc/winter_lightning/ERA5-nAf/single-level/land-sea_mask/ERA5_sfc_{y}_land-sea_mask.nc" for y in year_range]
infiles_landmask = []
for pattern in patterns:
    print(f"adding land mask files: {pattern}")
    infiles_landmask.extend(glob.glob(pattern))
patterns = [f"/gpfs/data/fs72140/c4031046_vsc/winter_lightning/ERA5-nAf/pressure-level/geopotential/ERA5_pl_{y}-??_geopotential.nc" for y in year_range]
infiles_geopot = []
for pattern in patterns:
    print(f"adding geopotential files: {pattern}")
    infiles_geopot.extend(glob.glob(pattern))
    
    
    
################################
# get AD data 
################################

data = func.get_AD_data(infiles=infiles, dropvars=[],\
               south=south, north=north, west=west, east=east,\
               early_date=early_date, late_date=late_date,\
               daytime_selection=daytime_selection, season=season, landonly=landonly, infiles_landmask=infiles_landmask)

# plotting details

outdir , map_extent, early_date_str, late_date_str = func.set_plotting_details(data, early_date, late_date, context, region, season, landonly, daytime_selection, LENGTH, outpath)

colors = ['grey', 'cornflowerblue', 'green', 'orange', 'saddlebrown']
labels = ['allyear', 'DJF', 'MAM', 'JJA', 'SON']

##########################################

TH_min = float(data.TH_diff.min().compute())
TH_max = float(data.TH_diff.max().compute())
TH_E_min = float(data.TH_E_diff.min().compute())
TH_E_max = float(data.TH_E_diff.max().compute())
Q_min = float(data.Q_diff.min().compute())
Q_max = float(data.Q_diff.max().compute())
age_min = float(data.time_diff.min().compute())
age_max = float(data.time_diff.max().compute())

T_bins = np.arange(-25,27,2)
T_bin_centers = 0.5 * (T_bins[:-1] + T_bins[1:])
t_bins = np.arange(0,121, 6)
t_bin_centers = 0.5 * (t_bins[:-1] + t_bins[1:])
q_bins = np.arange(-7, 7.5, 0.5)
q_bin_centers = 0.5 * (q_bins[:-1] + q_bins[1:])

hist_THs  = []
hist_THEs = []
hist_Qs   = []
hist_ages = []
hamsl_list = []

for h, hamsl in enumerate(data.hamsl.values):
        # print(f"h: {hamsl} m")
        with dask.config.set(**{'array.slicing.split_large_chunks': False}):
            histTH, _  = da.histogram(data.TH_diff.isel(hamsl=h).data.ravel(), bins=T_bins)#, range=(TH_min,TH_max))
            histTHE, _ = da.histogram(data.TH_E_diff.isel(hamsl=h).data.ravel(), bins=T_bins)#, ramge=(TH_E_min, TH_E_max))
            histQ, _   = da.histogram(data.Q_diff.isel(hamsl=h).data.ravel(), bins=q_bins)#, range(Q_min, Q_max))
            histage, _ = da.histogram(data.time_diff.isel(hamsl=h).data.ravel(), bins=t_bins)#, range=(age_min, age_max))
            
            hist_THs.append(histTH)
            hist_THEs.append(histTHE)
            hist_Qs.append(histQ)
            hist_ages.append(histage)
            hamsl_list.append(hamsl)
            
hist_THs  = da.stack(hist_THs, axis=0)
hist_THEs = da.stack(hist_THEs, axis=0)
hist_Qs   = da.stack(hist_Qs, axis=0)
hist_ages = da.stack(hist_ages, axis=0)


hist_THs  ,hist_THEs ,hist_Qs,hist_ages = dask.compute(hist_THs  ,hist_THEs ,hist_Qs,hist_ages)

# Save to disk
np.save(f"{outdir}hist_THs.npy", hist_THs)
np.save(f"{outdir}hist_THEs.npy", hist_THEs)
np.save(f"{outdir}hist_Qs.npy", hist_Qs)
np.save(f"{outdir}hist_ages.npy", hist_ages)

