### plot clustered trajectories per day

import os
import numpy as np
import xarray as xa
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns
import cartopy.crs as ccrs
import datetime as dt
import glob 

import time
time1 = time.time()
 
def logabs(data):
    return np.log(np.abs(data))

# SPECIFY
outdir      = f"outdir"
indir       = "data_dir"


file_pattern = f"{indir}Clustering_random42_nclusters4_nVariables5_startday????????_logcwc_Cluster_*_trajectories_relative.nc"
###################################################################################
# plotting specifics 
proj              = ccrs.Mercator()    # what kind of map projection to use
map_extent_large  = [-20, 50, 20, 70]  # defines extent of drawn map
map_extent        = [-20, 50, 20, 60]  # defines extent of drawn map
dpi = 300                              # has to be 300 for WCD

fig_width   = 6.5                      #in inch (min 8cm=3.2in)
context     = 'paper'
scale_fonts = 0.8
sns.set_context(context, font_scale=scale_fonts)

clN  = [2,1,4,3]
step = 12


#==================================================
# Get a list of all files matching the pattern
all_files = glob.glob(file_pattern)

file_groups = {}
for file in all_files:
    # Extract the date from the filename (adjust based on your filename structure)
    # For example, if the date is "20230101" in the filename:
    filename = file.split("/")[-1]  # Get the filename without the path
    date = filename.split("_")[4]
    date = date.split("startday")[1]# Adjust the index based on your filename structure
    file_groups.setdefault(date, []).append(file)

# Loop through each date and process the 4 files
for date, infiles in file_groups.items():
    if len(infiles) != 4:
        print(f"Skipping date {date} because it does not have exactly 4 files.")
        continue

    # Process the 4 files for this date
    print(f"Processing files for date {date}:\n {infiles}")

    color_list = []
    for file in infiles:
        filename = file.split('/')[-1]
        name     = filename.split('logcwc_')[0]+'logcwc_'
        color_list.append(filename.split('_')[-3])
        startday = filename.split('_')[4]
        if filename.split('_')[-5]=='logcwc':
            logcwc = True
        else:
            logcwc = False


    nClusters = len(color_list)
    print(name)
    print(filename)
    print(outdir)
    print(color_list)
    print(startday)
    print(logcwc)
    
    #=====================
    # read in traj data
    time1 = time.time()
    dat_closests = []
    dat_means    = []
    dat_medians  = []
    number_of_trajs_clusters = []

    for n, file in enumerate(infiles):
        print(f"Cluster {color_list[n]}")
        dat_cluster_in = xa.open_dataset(file)
        # dat_cluster_in = dat_cluster_in.drop(['geoh','DZ','DZs','topo','BLH_m','BLH_a','BLH_s','PS','p']) # drop not needed
        dat_cluster_in = dat_cluster_in.astype(dtype='float64')
        print("calc mean")
        dat_mean        = dat_cluster_in.mean(dim='dimx_lon')
        dat_means.append(dat_mean)
        del(dat_mean)
        dat_counts      = dat_cluster_in.count(dim='dimx_lon')
        number_of_trajs = dat_cluster_in.dims['dimx_lon']

        # dat_clusters.append(dat_cluster)
        number_of_trajs_clusters.append(number_of_trajs)

    time2 = time.time()
    print(f"took {(time2-time1)/60:6.2f}min to read in and average data")
    #=====================
    # PLOTTING
    
    
    fig4 = plt.figure(figsize=(fig_width,1.1*fig_width), dpi=dpi)
    spec = fig4.add_gridspec(4,2)
    ax00 = fig4.add_subplot(spec[0,0])
    ax10 = fig4.add_subplot(spec[1,0])
    ax20 = fig4.add_subplot(spec[2,0])
    ax30 = fig4.add_subplot(spec[3,0])
    ax21 = fig4.add_subplot(spec[2,1])
    ax31 = fig4.add_subplot(spec[3,1])
    axmap = fig4.add_subplot(spec[0:2,1],projection=proj)


    ax10.sharex(ax00)
    ax20.sharex(ax00)
    ax30.sharex(ax00)
    ax21.sharex(ax00)
    ax31.sharex(ax00)


    sns.despine(offset=5)
    print(f"clusters based on {sum(number_of_trajs_clusters)/1e3:4.2f} thousand trajectories")

    for n in range(nClusters):
        # dat_means[n]['tttotal'] = (dat_means[n].ttswr + dat_means[n].ttlwr) * 60 * 60
        dat_means[n]['q_t']     = (dat_means[n].Q/1000 + dat_means[n].ciwc + dat_means[n].clwc + dat_means[n].crwc + dat_means[n].cswc) * 1000 

        # PLOT MEAN TIMESERIES IN ONE
        ax00.plot(dat_means[n].hamsl.where(dat_means[n].lat>37),ls='-',marker='', color=color_list[n], label=f"{dat_means[n].hamsl[0].values/1e3:6.2f}",lw=1)
        ax10.plot(dat_means[n].TH.where(dat_means[n].lat>37) - dat_means[n].TH[0],ls='-',marker='', color=color_list[n],  label=f"{dat_means[n].TH[0].values:6.1f}", lw=1)
        # ax20.plot(dat_means[n]['T']-dat_means[n]['T'][0],ls='-',marker='', color=color_list[n], label=f"{dat_means[n]['T'][0].values:6.1f}",lw=1)
        ax30.plot(dat_means[n].Q.where(dat_means[n].lat>37) - dat_means[n].Q[0],ls='-',marker='', color=color_list[n], label=f"{dat_means[n].Q[0].values:6.2f}",lw=1)

        ax00.plot(dat_means[n].hamsl.where(dat_means[n].lat<=37),ls='--',marker='', color=color_list[n], label=f"{dat_means[n].hamsl[0].values/1e3:6.2f}",lw=1)
        ax10.plot(dat_means[n].TH.where(dat_means[n].lat<=37) - dat_means[n].TH[0],ls='--',marker='', color=color_list[n],  label=f"{dat_means[n].TH[0].values:6.1f}", lw=1)
        # ax20.plot(dat_means[n]['T']-dat_means[n]['T'][0],ls='-',marker='', color=color_list[n], label=f"{dat_means[n]['T'][0].values:6.1f}",lw=1)
        ax30.plot(dat_means[n].Q.where(dat_means[n].lat<=37) - dat_means[n].Q[0],ls='--',marker='', color=color_list[n], label=f"{dat_means[n].Q[0].values:6.2f}",lw=1)

        # combine liquid and ice panels 
        pl_ciwc = ax21.plot( dat_means[n].ciwc,ls='-',marker='', color=color_list[n], lw=1,\
                           label=f"{dat_means[n].ciwc[0].values/1e3:6.2f}")
        pl_cswc = ax21.plot( dat_means[n].cswc,ls='--',marker='', color=color_list[n], lw=1,\
                           label=f"{dat_means[n].cswc[0].values/1e3:6.2f}")

        pl_clwc = ax31.plot( dat_means[n].clwc,ls='-',marker='', color=color_list[n], lw=1,\
                           label=f"{dat_means[n].clwc[0].values/1e3:6.2f}")
        pl_crwc = ax31.plot( dat_means[n].crwc,ls='--',marker='', color=color_list[n], lw=1,\
                           label=f"{dat_means[n].crwc[0].values/1e3:6.2f}")

    ax10.axhline(y=0, color='gray',lw=0.5) 
    ax20.axhline(y=0, color='gray',lw=0.5) 
    ax30.axhline(y=0, color='gray',lw=0.5) 

    ax30.set_xlabel(f"hours since start")
    ax31.set_xlabel(f"hours since start")
    ax30.set_xticks(range(0,dat_means[n].dims['timestamp'],step))
    ax31.set_xticks(range(0,dat_means[n].dims['timestamp'],step))
    ax30.set_xlim(0,dat_means[n].dims['timestamp'])
    ax31.set_xlim(0,dat_means[n].dims['timestamp'])

    ax00.set_ylabel('hamsl / km')
    ax00.set_ylim(0,10000)
    ax00.set_yticks(range(0,10000,2000))
    ax00.set_yticklabels(np.arange(0,10,2))
    ax10.set_ylabel('$\Delta\\theta$ / K')
    ax20.set_ylabel('$\Delta T$ / K')
    ax20.set_ylim(-50,5)
    ax30.set_ylabel('$\Delta q$ / gkg$^{-1}$')
    ax30.set_ylim(-4,4)
    ax21.set_ylabel('cwc / kgkg$^{-1}$')
    ax31.set_ylabel('cwc / kgkg$^{-1}$')

    dashed = plt.Line2D([0], [0], color='k', linestyle='--',marker='',  lw=0.5, label='explained by radiation')
    solid = plt.Line2D([0], [0], color='k', linestyle='-',marker='', label='$\Delta \\theta$')
    legend1 = ax10.legend(handles=[solid, dashed],ncols=2, frameon=False, prop={'size': 6}, loc='upper center')

    dashed = plt.Line2D([0], [0], color='k', linestyle='--',marker='',  lw=0.5, label='$q_t$')
    solid = plt.Line2D([0], [0], color='k', linestyle='-',marker='', label='q')
    legend2 = ax30.legend(handles=[solid, dashed],ncols=2, frameon=False, prop={'size': 6}, loc='upper center')

    ax00.legend(labelcolor=color_list, handlelength=0, frameon=False, ncols=2, loc='upper left')
    ax10.legend(labelcolor=color_list, handlelength=0, frameon=False, ncols=2, loc='lower left')
    ax10.add_artist(legend1)
    ax20.legend(labelcolor=color_list, handlelength=0, frameon=False, ncols=2, loc='lower left')
    ax30.legend(labelcolor=color_list, handlelength=0, frameon=False, ncols=2, loc='lower left')
    ax21.legend(labelcolor=color_list, handlelength=0, frameon=False, ncols=2, loc='lower left')
    ax31.legend(labelcolor=color_list, handlelength=0, frameon=False, ncols=2, loc='lower left')
    # ax30.add_artist(legend2)



    dashed = plt.Line2D([0], [0], color='k', linestyle='--',marker='', label='cswc')
    solid = plt.Line2D([0], [0], color='k', linestyle='-',marker='', label='ciwc')
    ax21.legend(handles=[solid, dashed],ncols=2, frameon=False, prop={'size': 6}, loc='upper right')

    dashed = plt.Line2D([0], [0], color='k', linestyle='--',marker='', label='crwc')
    solid = plt.Line2D([0], [0], color='k', linestyle='-',marker='', label='clwc')
    ax31.legend(handles=[solid, dashed],ncols=2, frameon=False, prop={'size': 6}, loc='upper right')

    # map -----------------------
    step=12
    for n in range(nClusters):
        print(f"C{clN[n]}, mean over the {number_of_trajs_clusters[n]/1e3:4.2} thousand trajs in this cluster")
        axmap.plot(dat_means[n].lon, dat_means[n].lat,ls='-',marker='',lw=1,\
                    transform=ccrs.PlateCarree(), color=color_list[n], label=f"C{clN[n]}, mean over cluster")
        axmap.scatter(dat_means[n].lon[::step], dat_means[n].lat[::step],c=np.arange(dat_means[n].lon[::step].shape[0]) ,cmap='Greys', edgecolor=color_list[n],\
                transform=ccrs.PlateCarree())
    axmap.plot([-20,60],[37,37], ls='--', color='k',lw=1.5,  transform=ccrs.PlateCarree())
    gl = axmap.gridlines(draw_labels=True, lw=1, color='gray', alpha=0.5,\
                   linestyle='--')
    gl.right_labels=False
    gl.top_labels=False
    axmap.set_extent(map_extent)
    axmap.coastlines()
    axmap.spines['left'].set_visible(False)
    axmap.spines['bottom'].set_visible(False)
    # axmap.legend(loc='lower right'
    #get handles and labels
    handles, labels = axmap.get_legend_handles_labels()
    #specify order of items in legend
    order = np.array(clN)-1
    #add legend to plot
    axmap.legend([handles[idx] for idx in order],[labels[idx] for idx in order], loc='lower right')
    #-------------------------------

    if logcwc:
        ax21.set_yscale('log')
        ax31.set_yscale('log')
        ax21.set_ylim(1e-9,1e-3)
        ax31.set_ylim(1e-9,1e-3)    

    else:
        ax21.set_title('cluster closest scaled by 10')

    ax00.text(0,ax00.get_ylim()[1],'(a)', va='top')
    ax10.text(0,ax10.get_ylim()[1],'(b)', va='top')
    ax20.text(0,ax20.get_ylim()[1],'(c)', va='top')
    ax30.text(0,ax30.get_ylim()[1],'(d)', va='top')
    ax21.text(0,ax21.get_ylim()[1],'(e)', va='top')
    ax31.text(0,ax31.get_ylim()[1],'(f)', va='top')
    axmap.text(axmap.get_xlim()[1]*0.9,axmap.get_ylim()[1],'(g)', va='top')


    print(f"clusters based on {sum(number_of_trajs_clusters)} trajectories \n"+\
                  f"{startday}" + \
                  f"\n all equally long, \n ")
    fig4.tight_layout()
    fig4.savefig( f"{outdir}Fig_ClusterResults_{name}{context}.png")  #_ClusterResults_{name}{context}
    # plt.show()
    plt.close(fig4)
    
