#!/usr/bin/env python
# coding: utf-8

# # Calculate Kenn-Numbers for trajectories to cluster them by later

import netCDF4
import os
import glob
import datetime as dt
import numpy as np
import pandas as pd
import xarray as xa
import dask


import time
from optparse import OptionParser

def logabs(data):
    return np.log(np.abs(data))

# get input from arguments
parser = OptionParser()
parser.add_option("-Y", dest = "YEAR", type = "str")
parser.add_option("-M", dest = "MONTH", type = "str")
parser.add_option("-S", dest = "STARTTIME", type = "str")
parser.add_option("-W", dest = "TIMEWINDOW", type = "str")
parser.add_option("-N", dest = "NORTHOF", type = "str")
parser.add_option("-I", dest = "INDIR", type = "str")
parser.add_option("-O", dest = "OUTDIR", type = "str")

(options,args) = parser.parse_args()


timewindow= int(options.TIMEWINDOW) #6
YEAR      = int(options.YEAR)
MONTH     = int(options.MONTH)
STARTDAY  = int(options.STARTTIME)
north_of  = int(options.NORTHOF)
indir     = options.INDIR
outdir    = options.OUTDIR

myvars =  ['corr_THdiffcrwc','corr_THdiffcswc','corr_THdiffciwc','corr_THdiffclwc','corr_THQ','diff_TH','diff_Q','diff_hamsl_max', 'diff_hamsl' ,'diff_hamsl_subs', 'diff_lon','diff_lat','diff_cwc',\
          'corr_THdiffcrwclog','corr_THdiffcswclog','corr_THdiffciwclog','corr_THdiffclwclog','diff_cwclog']



print("i managed to assign the options:")
print(f"timewindow {timewindow}")
print(f"north_of {north_of}")
print(f"YEAR {YEAR}")
print(f"MONTH {MONTH}")
print(f"STARTDAY {STARTDAY}")




casestudy = 'Case_'+f"{YEAR:04}"+'_'+f"{MONTH:02}"


outfile   = outdir + f"AD_KennNumbers_timewindow{timewindow}_{YEAR:04}{MONTH:02}{STARTDAY:02}.nc"
if os.path.isfile(outfile): 
    os.remove(outfile)


# ### get list of files to read in

directories        = glob.glob(indir+f"Lagranto_{YEAR:04}{MONTH:02}{STARTDAY:02}_*")
print(directories)
traj_files  = []
for directory in directories:
    print(directory)
    traj_files.extend(glob.glob(directory + '/All_trajtrace_????????_??_corr_new.4'))
        
traj_files = sorted(traj_files)
print(f"calculate KennNumbers for trajectories starting on  {YEAR:04}{MONTH:02}{STARTDAY:02}")
print(traj_files)

# ### find out how many trajs there will be in total

print("--------------------------------------------------------------")
print("opening every file once to get number of trajectories and to add the traj number to file")
ntraj_total = 0
nmin = 0
nmax = 0
for file in traj_files:
    print(file)
    nc_in       = netCDF4.Dataset(file, mode='r+', format='NETCDF4_CALSSIC')
    ntraj       = nc_in.dimensions["dimx_lon"].size
    nmax        = nmin + ntraj
    
    ntraj_total = ntraj_total+ntraj
    nc_in.close()
    nmin        = nmax
print(f"number of trajectories in total: {ntraj_total}")


# ### create empty nc file 

print("--------------------------------------------------------------")
print(f"Creating NetCDF (v4) file: {outfile}")
fillvalue=np.nan

nc_out = netCDF4.Dataset(outfile, mode='a', format='NETCDF4_CLASSIC')

dim_traj = nc_out.createDimension("traj", ntraj_total)
for var in myvars:
    tmp      = nc_out.createVariable(var,np.float32,("traj",),fill_value=fillvalue)
init         = nc_out.createVariable("init", np.float32, ("traj",),fill_value=fillvalue)
init.units   = "seconds since 1970-01-01 00:00:00"
traj_number  = nc_out.createVariable("traj_number", np.float64, ("traj",), fill_value=fillvalue)
nc_out.close()
print("done creating outfile, closing it")


# ### readin in data


print("--------------------------------------------------------------")
print(f"opening each file again to extract data")
print(f"opening {outfile} again")
nc_out = netCDF4.Dataset(outfile, mode="a", format="NETCDF4_CLASSIC")
# counter for index in outfile
nmin = 0
nmax = 0
for file in traj_files:
    print(f"opening {file}")
    
    nc_in = xa.open_dataset(file)
    
    ntraj = nc_in.dims['dimx_lon']
    nmax  = nmin + ntraj # set max counter
    
    
    # find startdate
    startdate = dt.datetime(np.int32(nc_in.BASEDATE.values[0,0]), np.int32(nc_in.BASEDATE.values[0,1]), np.int32(nc_in.BASEDATE.values[0,2]),\
                np.int32(nc_in.BASEDATE.values[0,3]),np.int32(nc_in.BASEDATE.values[0,4]),tzinfo=dt.timezone.utc)
    nc_in = nc_in.sel(dimx_BASEDATE=0)        # remove BASEDATE dimension and select vars

    startdate_timestamp = startdate.timestamp()    
    print(f"replace mising values with nan")
    nc_in = nc_in.where(nc_in != -999, np.nan)    # replace missing_values with nan
    print(f"masking those that never pass {north_of} lat")
    mask = nc_in.lat.max(dim='time') < north_of   # find which trajs never go north of north_of
    nc_in = nc_in.where(~mask, drop=False)        # False: set them to nan but dont drop them

    # calculate difference in TH and add to ds
    TH_diff = nc_in.TH.values[timewindow:,:]-nc_in.TH.values[:-timewindow,:]
    nc_in['TH_diff'] = nc_in.TH * np.nan
    nc_in['TH_diff'][timewindow//2:-timewindow//2,:] = TH_diff
    
    # here I could add a scaling to the length or something
    # calculate KennNumbers
    corr_THdiffciwc = xa.corr(nc_in.TH_diff, nc_in.ciwc, dim="time")
    corr_THdiffclwc = xa.corr(nc_in.TH_diff, nc_in.clwc, dim="time")
    corr_THdiffcswc = xa.corr(nc_in.TH_diff, nc_in.cswc, dim="time")
    corr_THdiffcrwc = xa.corr(nc_in.TH_diff, nc_in.crwc, dim="time")
    
    corr_THdiffciwclog = xa.corr(nc_in.TH_diff, logabs(nc_in.ciwc), dim="time")
    corr_THdiffclwclog = xa.corr(nc_in.TH_diff, logabs(nc_in.clwc), dim="time")
    corr_THdiffcswclog = xa.corr(nc_in.TH_diff, logabs(nc_in.cswc), dim="time")
    corr_THdiffcrwclog = xa.corr(nc_in.TH_diff, logabs(nc_in.crwc), dim="time")
    
    corr_THQ        = xa.corr(nc_in.TH, nc_in.Q, dim="time")

    diff_TH         = nc_in.TH[-1,:] - nc_in.TH[0,:]
    diff_Q          = nc_in.Q[-1,:] - nc_in.Q[0,:]
    diff_hamsl      = nc_in.hamsl[-1,:] - nc_in.hamsl[0,:]
    diff_hamsl_subs = nc_in.hamsl.max(dim='time',skipna=True) - nc_in.hamsl[-1,:]
    diff_lat        = nc_in.lat[-1,:] - nc_in.lat[0,:]
    diff_lon        = nc_in.lon[-1,:] - nc_in.lon[0,:]
    diff_cwc        = (nc_in.ciwc.values + nc_in.clwc.values + nc_in.crwc.values + nc_in.cswc.values)[-1,:] - \
                       (nc_in.ciwc.values + nc_in.clwc.values + nc_in.crwc.values + nc_in.cswc.values)[0,:]
    
    diff_cwc_log     = (logabs(nc_in.ciwc.values) + logabs(nc_in.clwc.values) + logabs(nc_in.crwc.values) + logabs(nc_in.cswc.values))[-1,:] - \
                       (logabs(nc_in.ciwc.values) + logabs(nc_in.clwc.values) + logabs(nc_in.crwc.values) + logabs(nc_in.cswc.values))[0,:]
    if (nc_in.hamsl.values[timewindow:,:]-nc_in.hamsl.values[:-timewindow,:]).size==0:
        diff_hamsl_max = np.ones(ntraj) * np.nan
    else:
        diff_hamsl_max = np.nanmax(nc_in.hamsl.values[timewindow:,:]-nc_in.hamsl.values[:-timewindow,:], axis=0)

    # write them to outfile
    nc_out['corr_THdiffciwc'][nmin:nmax] = corr_THdiffciwc.data
    nc_out['corr_THdiffclwc'][nmin:nmax] = corr_THdiffclwc.data
    nc_out['corr_THdiffcswc'][nmin:nmax] = corr_THdiffcswc.data
    nc_out['corr_THdiffcrwc'][nmin:nmax] = corr_THdiffcrwc.data
    nc_out['corr_THdiffciwclog'][nmin:nmax] = corr_THdiffciwclog.data
    nc_out['corr_THdiffclwclog'][nmin:nmax] = corr_THdiffclwclog.data
    nc_out['corr_THdiffcswclog'][nmin:nmax] = corr_THdiffcswclog.data
    nc_out['corr_THdiffcrwclog'][nmin:nmax] = corr_THdiffcrwclog.data
    nc_out['corr_THQ'][nmin:nmax]        = corr_THQ.data
    nc_out['diff_TH'][nmin:nmax]         = diff_TH.data
    nc_out['diff_Q'][nmin:nmax]          = diff_Q.data
    nc_out['diff_hamsl'][nmin:nmax]      = diff_hamsl.data
    nc_out['diff_hamsl_max'][nmin:nmax]  = diff_hamsl_max.data
    nc_out['diff_hamsl_subs'][nmin:nmax] = diff_hamsl_subs.data
    nc_out['diff_lon'][nmin:nmax]        = diff_lon.data
    nc_out['diff_lat'][nmin:nmax]        = diff_lat.data
    nc_out['diff_cwc'][nmin:nmax]        = diff_cwc.data
    nc_out['diff_cwclog'][nmin:nmax]     = diff_cwc_log.data
    
    nc_out['traj_number'][nmin:nmax]     = nc_in.traj_number.data
    nc_out['init'][nmin:nmax]            = np.array([startdate_timestamp]*ntraj)
    
    
    nmin = nmax # new min counter is old max counter

nc_out.close()


