#!/usr/bin/env python
# coding: utf-8

# # Cluster trajectories


print("I am in the python file")
from optparse import OptionParser
import glob
import datetime as dt
import numpy as np
import pandas as pd
import xarray as xa

import dask

import matplotlib.pyplot as plt
import seaborn as sns
import cartopy.crs as ccrs

from sklearn.cluster import KMeans
from pca import pca
from functools import partial
import time
time1 = time.time()
 

def _preprocess(ds, align):
    #ds = ds.sel(dimy_lon=0,dimz_lon=0)                                # remove pseudo dimensions
    # find startdate and add time as index instead timestamp
    startdate       = dt.datetime(np.int32(ds.BASEDATE.values[0,0]), np.int32(ds.BASEDATE.values[0,1]),\
                                  np.int32(ds.BASEDATE.values[0,2]), np.int32(ds.BASEDATE.values[0,3]),\
                                  np.int32(ds.BASEDATE.values[0,4]),tzinfo=dt.timezone.utc)
    time_deltas     = pd.to_timedelta(ds['time'], unit='H')
    ds['startdate'] = startdate.timestamp()                          # make new variable with startdate

    ds = ds.assign_coords(timestamp=("time",ds['time'].data ))
    ds['time'] = pd.to_datetime(startdate+time_deltas)                # use actual time as time dimension
    ds = ds.sel(dimx_BASEDATE=0)                                      # remove pseudo dimension
    #ds = ds.sel(time=slice(None,date))                                # select only traj until desired date
    ds = ds.where(ds != -999, np.nan)                                 # replace missing values with nan
    if align=='relative':
        ds = ds.swap_dims({'time': 'timestamp'})
    return ds


parser = OptionParser()
# customise input date
parser.add_option("-Y", dest = "YEAR", type = "str")
parser.add_option("-M", dest = "MONTH", type = "str")
parser.add_option("-S", dest = "STARTTIME", type = "str")
parser.add_option("-W", dest = "TIMEWINDOW", type = "str")
parser.add_option("-N", dest = "NCLUSTERS", type = "str")
parser.add_option("-L", dest = "LOGCWC", type="str")
parser.add_option("-A", dest = "ALIGN", type="str")
parser.add_option("-I", dest = "INDIR", type = "str")
parser.add_option("-O", dest = "OUTDIR", type = "str")


(options,args) = parser.parse_args()



timewindow= int(options.TIMEWINDOW) #6
Nclusters = int(options.NCLUSTERS) #4
YEAR      = int(options.YEAR)
MONTH     = int(options.MONTH)
STARTDAY  = int(options.STARTTIME)
logcwc    = bool(options.LOGCWC)
align     = options.ALIGN
indir     = options.INDIR
outdir    = options.OUTDIR

# SPECIFY
if logcwc:
    myvars = ['diff_lon','diff_lat','diff_hamsl','diff_Q','diff_TH']
    #myvars =   ['corr_THdiffcpwclog','corr_THdiffccwclog','diff_TH','diff_hamsl_max', 'diff_hamsl' ,'diff_hamsl_subs','corr_THQ','diff_Q', 'diff_lon','diff_lat', 'diff_cwclog']
else:
    myvars = ['diff_lon','diff_lat','diff_hamsl','diff_Q','diff_TH']
    #myvars =   ['corr_THdiffcpwc','corr_THdiffccwc','diff_TH','diff_hamsl_max', 'diff_hamsl' ,'diff_hamsl_subs','corr_THQ','diff_Q', 'diff_lon','diff_lat', 'diff_cwc']



print("i managed to assign the options:")
print(f"timewindow {timewindow}")
print(f"Nclusters {Nclusters}")
print(f"YEAR {YEAR}")
print(f"MONTH {MONTH}")
print(f"STARTDAY {STARTDAY}")
print(f"myvars: {myvars}")

random    = 42
saveplot  = True




if logcwc:
    name      = f"{outdir}Clustering_random{random}_nclusters{Nclusters}_nVariables{len(myvars)}_startday{YEAR:04}{MONTH:02}{STARTDAY:02}_logcwc"
else:
    name      = f"{outdir}Clustering_random{random}_nclusters{Nclusters}_nVariables{len(myvars)}_startday{YEAR:04}{MONTH:02}{STARTDAY:02}"


infile    = outdir + f"AD_KennNumbers_timewindow{timewindow}_{YEAR:04}{MONTH:02}{STARTDAY:02}.nc"

print(f"infile: {infile}")

# ### plotting specifics

dpi     = 150
proj              = ccrs.Mercator()    # what kind of map projection to use
map_extent        = [-20, 50, 20, 60]  # defines extent of drawn map

fig_width   = 6.5 #in inch
context     = 'paper'
scale_fonts = 0.8
sns.set_context(context, font_scale=scale_fonts)


# ### list of files - you have to make sure these are the same that were used to create the infile!!!

directories        = glob.glob(indir+f"Lagranto_{YEAR:04}{MONTH:02}{STARTDAY:02}_*")
print(directories)
traj_files  = []
for directory in directories:
    print(directory)
    traj_files.extend(glob.glob(directory + '/All_trajtrace_????????_??_corr_new.4'))

traj_files = sorted(traj_files)

print(f"using trajectories staring on {YEAR:04}{MONTH:02}{STARTDAY:02}")
print(traj_files)

# ## read in KennNumbers file and cluster



ds = xa.open_dataset(infile)
if logcwc:
    ds['corr_THdiffcpwclog'] = ds.corr_THdiffcrwclog + ds.corr_THdiffcswclog
    ds['corr_THdiffccwclog'] = ds.corr_THdiffclwclog + ds.corr_THdiffciwclog
else:
    ds['corr_THdiffcpwc'] = ds.corr_THdiffcrwc + ds.corr_THdiffcswc
    ds['corr_THdiffccwc'] = ds.corr_THdiffclwc + ds.corr_THdiffciwc
total_num_trajs = ds.dims['traj']



startdatestring = f"starts {pd.to_datetime(ds.init.min().values).floor('H').strftime('%Y-%m-%d %H:%M')} to {pd.to_datetime(ds.init.max().values).floor('H').strftime('%Y-%m-%d %H:%M')}"

df_in = ds.to_dataframe()
df_in = df_in.drop('init',axis=1)  # drop the column with startdates, because they pose problems when normalising
df_in = df_in.dropna(how='any')    # drop trajs that have nan anywhere (they either were filtered or left domain by date
df_in.reset_index(inplace=True)
df_in.set_index(["traj", "traj_number"], inplace=True)
df = df_in[myvars].copy()

df_orig = df.copy()
df_mean = df.mean()
df_std  = df.std()
df =(df-df.mean())/df.std()  # normalise

with open(f"{name}_mean_std.txt", 'w') as file:
    # so that I have written down what I normalised by
    file.write("Mean:\n")
    file.write(df_mean.to_string() + '\n\n')
    file.write("Standard Deviation:\n")
    file.write(df_std.to_string() + '\n')

# ### cluster
print("clustering now")

tick_labels = df.columns
km          = KMeans(n_clusters=Nclusters, init='random', random_state=random, n_init=100, max_iter=300).fit(df)

euclid_dist = km.transform(df)
km_params   = km.get_params()
km_labels   = km.predict(df)
num_per_cluster = np.bincount(km_labels)
cluster_centers = km.cluster_centers_

print(f"number of iterations: {km.n_iter_}")
# print(f"cluster_centers     : \n {cluster_centers}")
print(f"total (within) sum of squares (ss): {km.inertia_}")
print(f"numbe rof elements in the clusters: {num_per_cluster}")

df.reset_index(inplace=True)
df.insert(loc=0, column="Labels", value=km.labels_)
df.insert(loc=0, column='ind', value = np.arange(df.shape[0]))
# df['ind'] = np.arange(df.shape[0])
df.set_index(["traj", "traj_number","Labels",'ind'], inplace=True)       # Set "dimx_lon" (traj number ) and "Labels" as index

df_orig.reset_index(inplace=True)
df_orig.insert(loc=0, column='Labels', value=km.labels_)
df_orig.set_index(["traj", "traj_number","Labels"], inplace=True) 


# ## plot cluster properties
cluster_means = df.groupby('Labels').mean()
# cluster_medians = df.groupby('Labels').median()


color_list    = ['cyan']*Nclusters
cluster_names = ['C3']*Nclusters
maxhamsldiff      = cluster_means['diff_hamsl'].argmax()
minhamsldiff      = cluster_means['diff_hamsl'].argmin()
color_list[maxhamsldiff]    = 'blue'
cluster_names[maxhamsldiff] ='C1'
color_list[minhamsldiff]    = 'red'                   # always make the cluster with the lowest P1 red
cluster_names[minhamsldiff] ='C2'

orderdifflon = cluster_means["diff_lon"].argsort() # order them by P2
for i in range(Nclusters):
    if orderdifflon[i] not in [maxhamsldiff,minhamsldiff]:      # if the cluster with the lowest P2 is not already colored, make it darkgreen
        color_list[orderdifflon[i]] = 'orange' # otherwise look for the next lowest
        cluster_names[orderdifflon[i]] = 'C4'
        break
    else:
        print(f"{i},{orderdifflon[i]}")



# ## read in trajectory data and save cluster wise


dask.config.set(**{'array.slicing.split_large_chunks': False})
time1     = time.time()
prep_func = partial(_preprocess,  align=align)
dat_all   = xa.open_mfdataset(traj_files, combine='nested', concat_dim='dimx_lon', parallel=True, chunks={'dimx_lon':1e8,'time':121}, preprocess=prep_func)
dat_all   = dat_all.set_index(dimx_lon='traj_number')
time2     = time.time()
print(f"took {time2-time1}s to read in traj data")


for n in range(Nclusters):
    time3                      = time.time()
    cluster_indices            = df.loc[df.index.get_level_values('Labels') == n].index.get_level_values('traj_number') 
    distance_indices           = df.loc[df.index.get_level_values('Labels') == n].index.get_level_values('ind').astype(int)
    
    print('selecting cluster')
    dat_cluster                     = dat_all.sel(dimx_lon=cluster_indices) 
    dat_cluster['dist_to_centroid'] = ('dimx_lon',euclid_dist[distance_indices,n])

    outfile_nc = f"{name}_Cluster_{color_list[n]}_trajectories_{align}.nc"
    
    print(f"writing to {outfile_nc}")
    dat_cluster_write               = dat_cluster.to_netcdf(outfile_nc, compute=True)
    time4                           = time.time()

    print(f" writing cluster {n} to file took {(time4-time3)/60:6.2f}min")



