## This script opens every file of the LAgranto output once to see how many trajs there are
## then it selects the selected date and saves all trajs for that date to a new file
# author: Fiona Fix


##################################################

from optparse import OptionParser
import netCDF4
import os
import glob
import datetime as dt
import numpy as np


# get input from arguments
parser = OptionParser()
parser.add_option("-Y", dest = "YEAR", type = "str")
parser.add_option("-M", dest = "MONTH", type = "str")
parser.add_option("-D", dest = "DAY", type = "str")
parser.add_option("-T", dest = "TIME", type = "str")
parser.add_option("-L", dest = "LENGTH", type="str")
parser.add_option("-F", dest = "FILENAMES", type="str")
parser.add_option("-I", dest = "INDIR", type="str")
parser.add_option("-O", dest = "OUTDIR", type="str")
(options,args) = parser.parse_args()

YEAR  = int(options.YEAR)
MONTH = int(options.MONTH)
DAY   = int(options.DAY)
TIME  = int(options.TIME)
LENGTH= int(options.LENGTH)
FILENAMES=options.FILENAMES
outdir   =options.OUTDIR
indir    =options.INDIR


date          = dt.datetime(YEAR, MONTH, DAY, TIME, 0, tzinfo = dt.timezone.utc)
date_earliest = date - dt.timedelta(hours=LENGTH)
print(f"look for trajs between {date_earliest} and {date}")
############################

outfile   = outdir + 'AD_select_'+f"{YEAR:04}"+f"{MONTH:02}"+f"{DAY:02}"+'T'+f"{TIME:02}00_length{LENGTH}_tmp.nc"
if os.path.isfile(outfile): os.remove(outfile)
directories        = glob.glob(f"{indir}Lagranto_*")
# Extract dates from directory names and convert to datetime objects
directory_dates    = [dt.datetime.strptime(directory.split('/')[-1].split('_',1)[-1], "%Y%m%d_%H").replace(tzinfo=dt.timezone.utc) for directory in directories]
directory_dates_np = np.array(directory_dates)
# Use NumPy to perform filtering based on the target date
directories_np          = np.array(directories)
directories_before_date = directories_np[(directory_dates_np < date) & (directory_dates_np >= date_earliest)]
directories             = directories_before_date.tolist()

traj_files  = []
for directory in directories:
    traj_files.extend(glob.glob(directory + '/'+ FILENAMES))

traj_files = sorted(traj_files)
print(directories)
print(traj_files)
print('AD identification for: '+str(date))
#%%==================================================================================================================

# find out how many trajs there will be in total
print("opening every file once to get number of trajectories")
ntraj_total = 0
for file in traj_files:
    nc_in       = netCDF4.Dataset(file, mode='r', format='NETCDF4_CALSSIC')
    ntraj       = nc_in.dimensions["dimx_lon"].size
    ntraj_total = ntraj_total+ntraj
    nc_in.close()
print(f"number of trajectories in total: {ntraj_total}")

myvars=['time', 'lon', 'lat', 'p', 'hamsl','TH_diff', 'TH_E_diff', 'Q_diff', 'time_diff', 'hamsl_diff', 'TH_diffN37', 'TH_E_diffN37', 'Q_diffN37', 'time_diffN37', 'hamsl_diffN37']

#%%==================================================================================================================
# create empty nc file 
print("--------------------------------------------------------------")
print(f"Creating NetCDF (v4) file: {outfile}")

nc_out = netCDF4.Dataset(outfile, mode='a', format='NETCDF4_CLASSIC')
dim_traj = nc_out.createDimension("traj", ntraj_total)
for var in myvars:
    if var=='traj_number':
        tmp      = nc_out.createVariable(var,np.int32,("traj",))
    else:
        tmp      = nc_out.createVariable(var,np.float32,("traj",))

init = nc_out.createVariable("init", np.float32, ("traj",))
init.units = "seconds since 1970-01-01 00:00:00"
nc_out.close()


#%%==================================================================================================================
print("--------------------------------------------------------------")
print(f"opening each file again to extract data")
nc_out = netCDF4.Dataset(outfile, mode="a", format="NETCDF4_CLASSIC")
nmin = 0
nmax = 0
wrong_files= []
for file in traj_files:
    nc_in    = netCDF4.Dataset(file, mode='r', format='NETCDF4_CALSSIC')
    ntraj    = nc_in.dimensions["dimx_lon"].size
    nmax     = nmin + ntraj # set max counter

    basedate  = nc_in.variables["BASEDATE"][0,:]
    print(basedate)
    startdate = dt.datetime(np.int32(basedate[0]), np.int32(basedate[1]), np.int32(basedate[2]),\
                np.int32(basedate[3]),np.int32(basedate[4]),tzinfo=dt.timezone.utc)
    print(f"startdate: {startdate}")
    timediff  = date-startdate
    time_idx  = timediff.total_seconds()/(60*60)
    startdate_timestamp = startdate.timestamp()

    latitude                = nc_in.variables['lat'][:]
    condition_met           = latitude > 37
    first_time_index_N37    = np.argmax(condition_met, axis=0)
    no_condition_met        = ~np.any(condition_met, axis=0)   # Check if the condition is never met for each trajectory
    first_time_index_N37[no_condition_met] = -1                # zero is default, we don't want that

    for var in myvars:
        if (var!='time') & (var!='traj_number') & ('diff' not in var):
            tmp = nc_in.variables[var][time_idx,:]
            nc_out[var][nmin:nmax] = tmp
            if var=='lat':
                if (tmp.sum()==0):
                    wrong_files.append(file)
                    print(f"something wrong with file: {file}")
                    # Raise an exception to signal the error
                    raise ValueError("zeros in lat of input file!!!!")
        elif 'diff' in var:
            if 'N37' not in var:
                var_orig = var.rpartition('_')[0]
                if var=='time_diff':
                    tmp = nc_in.variables[var_orig][time_idx] - nc_in.variables[var_orig][0]
                else:
                    tmp = nc_in.variables[var_orig][time_idx,:] - nc_in.variables[var_orig][0,:]
            elif 'diff_N37' in var:
                var_orig = var.rpartition('_')[0]
                if var=='time_diff_N37':
                    tmp = nc_in.variables[var_orig][time_idx] - nc_in.variables[var_orig][first_time_index_N37]
                    mask = first_time_index_N37 > time_idx
                    tmp[mask] = np.nan
                    mask = first_time_index_N37==-1
                    tmp[mask] = np.nan
                else:
                    tmp = nc_in.variables[var_orig][time_idx,:] - nc_in.variables[var_orig][first_time_index_N37,:]

            nc_out[var][nmin:nmax] = tmp    # difference between now and initialisation

        elif var=='traj_number':
            tmp = nc_in.variables[var][:]
            # Find the non-unique values
            unique_values, counts = np.unique(tmp, return_counts=True)
            non_unique_values = unique_values[counts > 1]
            if len(non_unique_values!=0):
                print("NON UNIQUE TRAJ NUMBERS:", non_unique_values)
                break
            else:
                print("NO DUPLICATES :) ")

            nc_out[var][nmin:nmax] = tmp

    nc_out['init'][nmin:nmax]     = np.array([startdate_timestamp]*ntraj)
    nc_in.close()
    nmin = nmax # new min counter is old max counter

nc_out.close()
print('---------------------------------------------------------')
print(wrong_files)

