import xarray as xa
import numpy as np
import datetime as dt

import glob


early_date       = dt.datetime.strptime('20220426_00', "%Y%m%d_%H")
late_date        = dt.datetime.strptime('20240430_23', "%Y%m%d_%H")

indir     = '/indata_dir/'  # SPECIFY input data dir 

# get list of all traj files needed
directories        = glob.glob(f"{indir}Lagranto_*")
# Extract dates from directory names and convert to datetime objects
directory_dates    = [dt.datetime.strptime(directory.split('/')[-1].split('_',1)[-1], "%Y%m%d_%H").replace(tzinfo=None) for directory in directories]
directory_dates_np = np.array(directory_dates)
# Use NumPy to perform filtering based on the target date
directories_np          = np.array(directories)
directories_before_date = directories_np[(directory_dates_np <= late_date) & (directory_dates_np >= early_date)]
directories             = directories_before_date.tolist()


print(len(directories))

traj_files  = []
for directory in directories:
    traj_files.extend(glob.glob(directory + '/All_trajtrace_????????_??.4'))   #SPECIFY filename, if changed

traj_files = sorted(traj_files)
num_times  = len(traj_files)
wrongs=[]
trajs = []
max_wrongs_per=[]
print('looking at each file now')
for file in traj_files:
    print(file)
    dirname = file.split('.')[0]
    print(dirname)
    traj = xa.open_dataset(file, engine='netcdf4').isel(dimz_lon=0, dimy_lon=0)
    max_wrong_onetraj  = ((traj.TH_E==-999)&(traj.TH!=-999)).sum(dim='time').values
    num_wrong          =  (((traj.TH_E==-999)&(traj.TH!=-999)).sum(dim='time')!=0).sum()
    wrongs.append(num_wrong)
    trajs.append(traj.sizes['dimx_lon'])
    max_wrongs_per.append(max_wrong_onetraj)

    # where TH_E has a missing value but TH oes not: write TH to TH_E
    # this makes sense, because missing values in TH_E come from negative q when TH is fine
    # which is a numeric inacuracy at very low q
    condition        = ((traj.TH_E==-999)&(traj.TH!=-999))
    traj_new         = traj.copy()
    traj_new['TH_E'] = traj_new.TH_E.where(~condition, other=traj_new.TH)

    num_wrong_new = (((traj_new.TH_E==-999)&(traj_new.TH!=-999)).sum(dim='time')!=0).sum()
    if num_wrong_new !=0:
        print('!!! filtering did not work correctly !!!')
    #traj_new.to_netcdf(f"{dirname}_corr.4")
    traj_new.to_netcdf(f"{dirname}_corr_new.4")

max_wrongs_per = np.concatenate(max_wrongs_per)
max_wrongs_per = max_wrongs_per[max_wrongs_per != 0]

print('============================================================')
print(f"done looking through all {num_times} files.")
print(f" the number of trajectories that have -999 in TH_E but not in TH is: {sum(wrongs)}")
print(f" the number of all trajectroies is {sum(trajs)}")
print(f" therefore, {sum(wrongs)/sum(trajs) * 100}% have this problem")
print(f" the maximum number of instances within one trajectory is {np.max((max_wrongs_per))}")
print(f" the mean number of instances within one trajectory that has at least one is {np.mean((max_wrongs_per))}")
print(f" the media number of instances within one trajectory that has at least one is {np.median(np.array(max_wrongs_per))}")                                                                               
