# -*- coding: utf-8 -*-
"""
Created on Fri Jul 28 09:31:05 2023

@author: Matteo Meli
"""

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
sns.set_style("ticks")
sns.set_context("paper")
plt.yticks(fontname = "Times New Roman")
plt.xticks(fontname = "Times New Roman")
plt.rcParams["font.family"] = "Times New Roman"
from netCDF4 import Dataset
from mpl_toolkits.basemap import Basemap
import xarray as xr
import matplotlib.colors as colors
from matplotlib.colors import LinearSegmentedColormap
GMT_panoply = ['#00009E','#0000D5', '#0000FF', '#0048FF', '#0090FF','#00C4FF', '#00EBFF', '#47FEFF',
               '#FFFFFF',
               '#FFFF00', '#FFDF00', '#FFD300', '#FFAF00','#FF9300', '#FF7200', '#FF3700', '#D80E04']
panoply = LinearSegmentedColormap.from_list("panoply",GMT_panoply)
GMT_white = ['#ffffff','#ffffff']
white = LinearSegmentedColormap.from_list("white",GMT_white)
from matplotlib.ticker import FuncFormatter
from PyEMD import EEMD
eemd = EEMD()
from scipy.stats import mstats
from matplotlib.colors import ListedColormap
from matplotlib.lines import Line2D
GMT_salinity = ["deepskyblue","cyan","white","white", "white", "orange","orangered"]
gmt_salinity = LinearSegmentedColormap.from_list("gmt_salinity", GMT_salinity)
GMT_divergent = ["darkblue", "Blue", "cyan", "white", "orange","red", "darkred"]
divergent = LinearSegmentedColormap.from_list("divergent", GMT_divergent)
fig, ax = plt.subplots(figsize=(6, 2), subplot_kw=dict(xticks=[], yticks=[], frame_on=False))
cbar = plt.colorbar(plt.cm.ScalarMappable(cmap=divergent), ax=ax, orientation="horizontal")
plt.show()






# UPLOAD OF CMEMS L4 PRODUCT AND LAST UPDATE DATA, APPLICATION OF SPATIAL MASK FROM C3S PRODUCT
cmems = xr.open_dataset("path/cmems.nc")
c3s = xr.open_dataset("path/C3S.nc")
cmems = cmems.sel(latitude=slice(None, None), longitude=slice(None, None))['sla']
c3s = c3s.sel(longitude=slice(None, None))['sla']
cmems_interp = cmems.interp(latitude=c3s.latitude, longitude=c3s.longitude)
mask = c3s.isel(time=0).notnull()
sea = cmems_interp.where(mask)
# MASKING OUT MARMARA
lat_min, lat_max = 40.1875, 41.1875
lon_min, lon_max = 26.9375, 30.1875
mask = ((sea.latitude < lat_min) | (sea.latitude > lat_max) | 
        (sea.longitude < lon_min) | (sea.longitude > lon_max))
sea = sea.where(mask)
# APPLY THE TOPEX/POSEIDON CORRECTION
tp = pd.read_csv('path/tp_correction.txt',\
                 delim_whitespace=True,header=None)[:]
tp = tp[0]
tp = tp.to_numpy()
sea = sea + np.repeat(tp, sea.shape[1] * sea.shape[2]).reshape(sea.shape)
# APPLY THE GEOCENTRIC GIA CORRECTION FROM ICE-6G MODEL
gia = Dataset("path/absolute_gia_downscaled.nc")
gia.variables.keys()
lat_gia=gia.variables['lat'][:]
lon_gia=gia.variables['lon'][:]
absolute_gia=gia.variables['z'][:,:]
absolute_gia=absolute_gia/1000
daily_trend = absolute_gia / 365.25
days_passed = np.arange(sea.shape[0])[:, None, None]
cumulative_trend_daily = days_passed * daily_trend
sea = sea - cumulative_trend_daily
# CONVERTING TIME SERIES TO DIFFERENT TEMPORAL RESOLUTIONS
sea_m = sea.resample(time='M').mean('time')
sea_y = sea.resample(time='Y').mean('time')
datesea   = pd.date_range('1/1/1993','12/31/2022', freq='D')
datesea_m = pd.date_range('1/1/1993','12/31/2022', freq='M')
datesea_y = pd.date_range('1/1/1993','12/31/2022', freq='Y')
# DESEASONING OF THE MONTHLY MEAN DATASET
seasonal_cycle = sea_m.groupby('time.month').mean('time')
sea_m_deseas = sea_m.groupby('time.month') - seasonal_cycle
# PLOTTING AN EXAMPLE
plt.plot(datesea, sea[:,100,100])
plt.plot(datesea_m, sea_m[:,100,100])
plt.plot(datesea_m, sea_m_deseas[:,100,100])
plt.plot(datesea_y, sea_y[:,100,100])
# COUNTING THE AMOUNT OF AVAILABLE DATA
sea_y_flat = sea_y[0,:,:]
non_nan_count = np.count_nonzero(~np.isnan(sea_y_flat.values))
print(f"Non-nan values: {non_nan_count}") # 16404

# SAVING DATASETS
#sea_y.to_netcdf('sea_y.nc')
#sea_m.to_netcdf('sea_m.nc')
#sea_m_deseas.to_netcdf('sea_m_deseas.nc')


# ---------- EEMD
#------------------------------------------------------------------------------
sea_y = xr.open_dataset("path/sea_y.nc")
eemd = EEMD(trials=1) #  - - - - - - - - -  SET TO 10000
lat_dim = sea_y.latitude.size
lon_dim = sea_y.longitude.size
time_dim = sea_y.time.size
max_imfs = 4
IMFs_array = xr.DataArray(np.zeros((max_imfs, time_dim, lat_dim, lon_dim)), 
                          coords=[range(max_imfs), sea_y.time, sea_y.latitude, sea_y.longitude], 
                          dims=["imf", "time", "latitude", "longitude"])

for i in range(lat_dim):
    for j in range(lon_dim):
        sea_series = sea_y['sla'].isel(latitude=i, longitude=j).values
        E_IMFs = eemd.eemd(sea_series)
        num_imfs = min(E_IMFs.shape[0], max_imfs)
        for k in range(num_imfs):
            imf = E_IMFs[k]
            if len(imf) < time_dim:
                imf = np.concatenate([imf, np.zeros(time_dim - len(imf))])
            IMFs_array[k, :, i, j] = imf

# Save the array
#IMFs_array.to_netcdf("IMFs_array.nc")
#------------------------------------------------------------------------------
#------------------------------------------------------------------------------


# MODE2
imfs_array = Dataset("path/IMFs_array.nc")
imfs_array.variables.keys()
lat2=imfs_array.variables['latitude'][:]
lon2=imfs_array.variables['longitude'][:]
imf_data = imfs_array.variables['__xarray_dataarray_variable__'][:]
mode2 = imf_data[1, :, :, :]

min_val = np.inf
max_val = -np.inf

fig, axs = plt.subplots(5, 6, figsize=(15, 8.5), dpi=100)
axs = axs.ravel()
for i in range(30):
    ax = axs[i]
    mp = Basemap(projection='merc', 
                 llcrnrlon=-7, llcrnrlat=30, urcrnrlat=46.5, urcrnrlon=37, 
                 resolution='l', ax=ax)
    lons, lats = np.meshgrid(lon2, lat2)
    x, y = mp(lons, lats)
    masked_data = np.where(mask, mode2[i,:,:], np.nan)
    min_val = min(np.nanmin(masked_data), min_val)
    max_val = max(np.nanmax(masked_data), max_val)
    divnorm = colors.TwoSlopeNorm(vmin=-0.07,vcenter=0,vmax=0.07)
    c_scheme = mp.pcolor(x, y, masked_data, cmap=divergent, norm=divnorm)
    mp.fillcontinents(color='black')
    ax.set_title(f'{1993+i}') 
    ax.title.set_size(18)
for j in range(i+1, len(axs)):
    fig.delaxes(axs[j])
cbar_ax = fig.add_axes([0.35, 0.0005, 0.3, 0.025])
cbar = fig.colorbar(c_scheme, cax=cbar_ax, orientation='horizontal',extend='both')
cbar.formatter = FuncFormatter(lambda x, pos: f'{x*100:.0f}')
cbar.update_ticks()
cbar.ax.tick_params(labelsize=16) 
cbar.set_label('SLA (cm)', size=16, labelpad=-2)
plt.tight_layout()
fig.suptitle('second IMF', fontsize=18, y=1.02)
plt.show()


#MODE 1
imfs_array = Dataset("path/IMFs_array.nc")
imfs_array.variables.keys()
lat2=imfs_array.variables['latitude'][:]
lon2=imfs_array.variables['longitude'][:]
imf_data = imfs_array.variables['__xarray_dataarray_variable__'][:]
mode1 = imf_data[0, :, :, :]
min_val = np.inf
max_val = -np.inf

fig, axs = plt.subplots(5, 6, figsize=(15, 8.5), dpi=100)
axs = axs.ravel()
for i in range(30):
    ax = axs[i]
    mp = Basemap(projection='merc', 
                 llcrnrlon=-7, llcrnrlat=30, urcrnrlat=46.5, urcrnrlon=37, 
                 resolution='l', ax=ax)
    lons, lats = np.meshgrid(lon2, lat2)
    x, y = mp(lons, lats)
    masked_data = np.where(mask, mode1[i,:,:], np.nan)
    min_val = min(np.nanmin(masked_data), min_val)
    max_val = max(np.nanmax(masked_data), max_val)
    divnorm = colors.TwoSlopeNorm(vmin=-0.07,vcenter=0,vmax=0.07)
    c_scheme = mp.pcolor(x, y, masked_data, cmap=divergent, norm=divnorm)
    mp.fillcontinents(color='black')
    ax.set_title(f'{1993+i}') 
    ax.title.set_size(18)
for j in range(i+1, len(axs)):
    fig.delaxes(axs[j])
cbar_ax = fig.add_axes([0.35, 0.0005, 0.3, 0.025])
cbar = fig.colorbar(c_scheme, cax=cbar_ax, orientation='horizontal',extend='both')
cbar.formatter = FuncFormatter(lambda x, pos: f'{x*100:.0f}')
cbar.update_ticks()
cbar.ax.tick_params(labelsize=16) 
cbar.set_label('cm', size=16, labelpad=-2)
plt.tight_layout()
fig.suptitle('first IMF', fontsize=18, y=1.02)
plt.show()


# - - - REMOVING IMF 2 FROM LINEARLY DETRENDED TOTAL SEA LEVEL
masked_data_all = np.empty((30, len(lat2), len(lon2)))
min_val = np.inf
max_val = -np.inf
for i in range(30):
    lons, lats = np.meshgrid(lon2, lat2)
    masked_data = np.where(mask, mode2[i,:,:], np.nan)
    masked_data_all[i,:,:] = masked_data
    min_val = min(np.nanmin(masked_data), min_val)
    max_val = max(np.nanmax(masked_data), max_val)

masked_data_all_xr = xr.DataArray(masked_data_all, 
                                  coords=[sea_y.time, sea_y.latitude, sea_y.longitude], 
                                  dims=["time", "latitude", "longitude"])

def detrend_pointwise_sens_slope(data):
    t_len, y_len, x_len = data.shape
    detrended_data = np.empty_like(data)
    time = np.arange(t_len)
    
    for j in range(y_len):
        for i in range(x_len):
            y = data[:, j, i]
            if not np.isnan(y).all(): 
                slopes = mstats.theilslopes(y, time, 0.9)
                sen_slope = slopes[0]
                trend = sen_slope * time + slopes[1]
                detrended_data[:, j, i] = y - trend
            else:
                detrended_data[:, j, i] = np.nan
    return detrended_data
sea_y_detrended = detrend_pointwise_sens_slope(sea_y.values)

fig, axs = plt.subplots(5, 6, figsize=(15, 8.5), dpi=100)
axs = axs.ravel()
for i in range(30):
    ax = axs[i]
    mp = Basemap(projection='merc', 
                 llcrnrlon=-7, llcrnrlat=30, urcrnrlat=46.5, urcrnrlon=37, 
                 resolution='l', ax=ax)
    lons, lats = np.meshgrid(sea_y.coords['longitude'], sea_y.coords['latitude'])
    x, y = mp(lons, lats)
    min_val = min(np.nanmin(masked_data), min_val) 
    max_val = max(np.nanmax(masked_data), max_val) 
    divnorm = colors.TwoSlopeNorm(vmin=-0.07,vcenter=0,vmax=0.07)
    c_scheme = mp.pcolor(x, y, sea_y_detrended[i,:,:], cmap=divergent, norm=divnorm)
    mp.fillcontinents(color='black')
    ax.set_title(f'{1993+i}') 
    ax.title.set_size(18)
for j in range(i+1, len(axs)):
    fig.delaxes(axs[j])
cbar_ax = fig.add_axes([0.35, 0.0005, 0.3, 0.025])
cbar = fig.colorbar(c_scheme, cax=cbar_ax, orientation='horizontal',extend='both')
cbar.formatter = FuncFormatter(lambda x, pos: f'{x*100:.0f}')
cbar.update_ticks()
cbar.ax.tick_params(labelsize=16) 
cbar.set_label('cm', size=16, labelpad=-2)
plt.tight_layout()
fig.suptitle('Linearly detrended geocentric sea level', fontsize=18, y=1.02)
plt.show()

result = sea_y_detrended - masked_data_all
result_xr = xr.DataArray(result, 
                         coords=[sea_y.time, sea_y.latitude, sea_y.longitude], 
                         dims=["time", "latitude", "longitude"])

fig, axs = plt.subplots(5, 6, figsize=(15, 8.5), dpi=100)
axs = axs.ravel()
for i in range(30):
    ax = axs[i]
    mp = Basemap(projection='merc', 
                 llcrnrlon=-7, llcrnrlat=30, urcrnrlat=46.5, urcrnrlon=37, 
                 resolution='l', ax=ax)
    lons, lats = np.meshgrid(result_xr.coords['longitude'], result_xr.coords['latitude'])
    x, y = mp(lons, lats)
    min_val = min(np.nanmin(masked_data), min_val)
    max_val = max(np.nanmax(masked_data), max_val)
    divnorm = colors.TwoSlopeNorm(vmin=-0.07,vcenter=0,vmax=0.07)
    c_scheme = mp.pcolor(x, y, result[i,:,:], cmap=divergent, norm=divnorm)
    mp.fillcontinents(color='black')
    ax.set_title(f'{1993+i}') 
    ax.title.set_size(18)
for j in range(i+1, len(axs)):
    fig.delaxes(axs[j])
cbar_ax = fig.add_axes([0.35, 0.0005, 0.3, 0.025])
cbar = fig.colorbar(c_scheme, cax=cbar_ax, orientation='horizontal',extend='both')
cbar.formatter = FuncFormatter(lambda x, pos: f'{x*100:.0f}')
cbar.update_ticks()
cbar.ax.tick_params(labelsize=16) 
cbar.set_label('cm', size=16, labelpad=-2)
plt.tight_layout() 
fig.suptitle('Linearly detrended geocentric sea level - second IMF', fontsize=18, y=1.02)
plt.show()


# - - - - - - - - - - -  ANALYSIS OF SECOND IMFS PERIODICITY
dominant_frequencies = np.zeros((128,344))
for i in range(128):
    for j in range(344):
        time_series = mode2[:,i,j]
        if np.all(np.isnan(time_series)):
            dominant_frequencies[i,j] = np.nan
            continue
        time_series = time_series[~np.isnan(time_series)]
        N = len(time_series)
        frequencies = np.fft.fftfreq(N, 1)
        fft_values = np.fft.fft(time_series)
        amp = 2.0/N * np.abs(fft_values[0:N//2])
        dominant_frequencies[i,j] = frequencies[np.argmax(amp)]

dominant_frequencies = np.where(dominant_frequencies == 0, np.nan, dominant_frequencies)
dominant_periodicities = 1.0 / dominant_frequencies
mean_value = np.nanmean(dominant_periodicities)
min_value = np.nanmin(dominant_periodicities)
max_value = np.nanmax(dominant_periodicities)
print(f"Mean: {mean_value}")
print(f"Min: {min_value}")
print(f"Max: {max_value}")

flat_array = dominant_periodicities.flatten()
flat_array = flat_array[~np.isnan(flat_array)]
bins = np.arange(4, 31, 0.1) 
hist, bin_edges = np.histogram(flat_array, bins=bins)
percentages = (hist / flat_array.size) * 100
for i in range(len(bin_edges) - 1):
    print(f'Periodicity between {bin_edges[i]:.1f} and {bin_edges[i+1]:.1f} years: {percentages[i]:.1f}%')

percentages = [0.3, 6.1, 19.2, 57, 17.1, 0.3]
periodicities = [5.1, 6.1, 7.6, 10.1, 15.1, 30.1]
weights = np.array(percentages) / 100
mean = np.average(periodicities, weights=weights)
std_dev = np.sqrt(np.average((periodicities - mean)**2, weights=weights))
print(f"Mean: {mean:.1f}, Standard Deviation: {std_dev:.1f}")


# - - - - - - - - - - -  ANALYSIS OF FIRST IMFS PERIODICITY
dominant_frequencies = np.zeros((128,344))
for i in range(128):
    for j in range(344):
        time_series = mode1[:,i,j]
        if np.all(np.isnan(time_series)):
            dominant_frequencies[i,j] = np.nan
            continue
        time_series = time_series[~np.isnan(time_series)]
        N = len(time_series)
        if N == 0:
            dominant_frequencies[i,j] = np.nan
            continue
        frequencies = np.fft.fftfreq(N, 1)
        fft_values = np.fft.fft(time_series)
        amp = 2.0/N * np.abs(fft_values[0:N//2])
        dominant_frequencies[i,j] = frequencies[np.argmax(amp)]

dominant_frequencies = np.where(dominant_frequencies == 0, np.nan, dominant_frequencies)
dominant_periodicities = 1.0 / dominant_frequencies
mean_value = np.nanmean(dominant_periodicities)
min_value = np.nanmin(dominant_periodicities)
max_value = np.nanmax(dominant_periodicities)
print(f"Mean: {mean_value}")
print(f"Min: {min_value}")
print(f"Max: {max_value}")

flat_array = dominant_periodicities.flatten()
flat_array = flat_array[~np.isnan(flat_array)]
bins = np.arange(2, 11, 0.1) 
hist, bin_edges = np.histogram(flat_array, bins=bins)
percentages = (hist / flat_array.size) * 100
for i in range(len(bin_edges) - 1):
    print(f'Periodicity between {bin_edges[i]:.1f} and {bin_edges[i+1]:.1f} years: {percentages[i]:.1f}%')

periodicities = [2.1, 2.3, 2.4, 2.7, 2.9, 3.3, 3.7, 4.2, 4.9, 5.9, 7.4, 9.9]
percentages = [3.6, 2.2, 5.2, 21.2, 5.2, 13.9, 9.7, 23.2, 3.8, 10.9, 1, 0.2]
weights = np.array(percentages) / 100
mean = np.average(periodicities, weights=weights)
std_dev = np.sqrt(np.average((periodicities - mean)**2, weights=weights))
print(f"Mean: {mean:.1f}, Standard Deviation: {std_dev:.1f}")


#------------------------------------------------------------------------------
# SUB-BASIN AVERAGED TIME SERIES OF MODE-II
mask_adria = xr.zeros_like(masked_data_all_xr.isel(time=0), dtype=bool)
mask_ion = xr.zeros_like(masked_data_all_xr.isel(time=0), dtype=bool)
mask_lev = xr.zeros_like(masked_data_all_xr.isel(time=0), dtype=bool)
mask_aeg = xr.zeros_like(masked_data_all_xr.isel(time=0), dtype=bool)
mask_west = xr.zeros_like(masked_data_all_xr.isel(time=0), dtype=bool)
mask_tyr = xr.zeros_like(masked_data_all_xr.isel(time=0), dtype=bool)
mask_cre = xr.zeros_like(masked_data_all_xr.isel(time=0), dtype=bool)
mask_scm = xr.zeros_like(masked_data_all_xr.isel(time=0), dtype=bool)

mask_adria[98:128,143:198] = True
mask_adria[86:99,164:208] = True
mask_adria[82:87,192:208] = True

mask_ion[32:56,169:210] = True
mask_ion[55:81,178:210] = True
mask_ion[80:84,178:191] = True
mask_ion[55:64,169:179] = True

mask_lev[4:58,274:] = True

mask_aeg[50:90,228:262] = True
mask_aeg[42:50,229:273] = True
mask_aeg[50:80,260:273] = True

mask_west[40:105,:125] = True
mask_west[105:118,70:125] = True

mask_tyr[54:116,126:143] = True
mask_tyr[62:99,142:164] = True
mask_tyr[64:90,164:178] = True

mask_cre[7:42,229:273] = True

mask_scm[:54,127:169] = True
mask_scm[:32,168:229] = True
mask_scm[31:75,210:229] = True
mask_scm[54:62,143:164] = True

points = [(13.758472,45.647361),(12.333333,45.433333),(12.282731,44.49205),(13.6283,45.0833),(14.533333,45.3),
          (16.4417,43.5067),(18.0633,42.6583),(15.235,44.1233),(19.083333,42.083333),(26.847994,37.129675),(26.141189,38.371514),
          (25.878272,40.844139),(22.934933,40.632542),(23.589458,38.460889),(23.626714,37.937328),(24.945808,37.439969),
          (27.416667,37.033333),(26.716667,38.433333),(20.756628,38.959078),(20.712108,38.834544),(14.532955,35.820063),(21.319681,37.644822),
          (22.115839,37.023678),(29.916667,31.216667),(30.7,36.883333),(34.86316,32.47044),(32.3,31.25),(13.333333,38.133333),(14.2575,40.8429),
          (11.816667,42.05),(9.166667,39.2),(8.9,44.4),(8.016667,43.866667),(5.91472,43.112898),(5.35386,43.278801),(3.69911,43.397598),
          (3.206448,42.053832),(2.1657,41.34177),(-0.31128,39.44203),(-0.481229,38.33892),(-4.41546,36.7127),(-5.3649,36.148256),
          (-5.434973,36.12077),(-5.6026,36.12077),(-5.31589,35.8924),(2.638912,39.552381)]

combined_data = xr.full_like(masked_data_all_xr.isel(time=0), fill_value=-1, dtype=int)
for idx, mask in enumerate([mask_adria, mask_ion, mask_lev, mask_aeg, mask_west, mask_tyr, mask_cre, mask_scm]):
    combined_data = xr.where(mask, idx, combined_data)
custom_colors = ['#648FFF', '#FE6100', '#D42121', '#785EF0', '#FFB000','#FFB000','white','#FE6100']
cmap = ListedColormap(['white'] + custom_colors)
def plot_regions_on_map(data, cmap, lat_bounds, lon_bounds):
    fig = plt.figure(figsize=(8, 7), dpi=100)
    extended_lon_bounds = [lon_bounds[0] - 2, lon_bounds[1] + 1]
    extended_lat_bounds = [lat_bounds[0] - 1, lat_bounds[1] + 1]
    mp = Basemap(projection='merc',
                 llcrnrlon=extended_lon_bounds[0],
                 llcrnrlat=extended_lat_bounds[0],
                 urcrnrlat=extended_lat_bounds[1],
                 urcrnrlon=extended_lon_bounds[1],
                 resolution='l')    
    mp.drawparallels(np.arange(31,47,5), labels=[True,False,False,False], fontsize=12,dashes=[8,12], linewidth=0.01)
    mp.drawmeridians(np.arange(-5,36,5), labels=[False,True,False,True], fontsize=12,dashes=[8,12], linewidth=0.01)
    lons, lats = np.meshgrid(masked_data_all_xr.longitude.values, masked_data_all_xr.latitude.values)
    xx, yy = mp(lons, lats)
    mp.pcolormesh(xx, yy, data.values, cmap=cmap, alpha=0.6)
    mp.contour(xx, yy, data.values, levels=np.arange(-0.5, len(custom_colors)+0.5, 1), colors='whitesmoke', linewidths=1)
    mp.drawlsmask(ocean_color='white',lakes=False)
    mp.fillcontinents(color='whitesmoke', lake_color='white')
    mp.drawcoastlines(linewidth=0.3)
    coords_and_labels = [
        (5.6, 39.6, '1', 'black'),
        (11.5, 39.2, '2', 'black'),
        (15, 42.3, '5', 'black'),
        (17.5, 36.5, '4', 'black'),
        (12.5, 34, '3', 'black'),
        (24.8, 36, '6', 'black'),
        (25, 33, '7', 'black'),
        (30.8, 33.4, '8', 'black')
    ]
    for lon, lat, label, color in coords_and_labels:
        x, y = mp(lon, lat)
        plt.text(x, y, label, fontsize=18, style='normal', color=color, rotation=0, weight="bold")
    legend_labels = [
        "1. Western Basin",
        "2. Tyrrhenian",
        "3. Southern Central",
        "4. Ionian",
        "5. Adriatic",
        "6. Aegean",
        "7. Southern Crete",
        "8. Levantine"
    ]
    handles = [Line2D([0], [0], marker=None, color='w', label=legend_label, markersize=0) for legend_label in legend_labels]
    legend=plt.legend(handles=handles, loc='upper right',bbox_to_anchor=(1.01, 1.02), fontsize='medium', ncol=2,handlelength=0,framealpha=1, prop={'weight':'bold'})
    legend.get_frame().set_linewidth(1) 
    legend.get_frame().set_edgecolor('black')
    for lon, lat in points:
        x, y = mp(lon, lat)
        mp.scatter(x, y, s=30, color='black', marker='o')
    plt.show()

plot_regions_on_map(combined_data, cmap,
                   lat_bounds=[masked_data_all_xr.latitude.min().values, masked_data_all_xr.latitude.max().values],
                   lon_bounds=[masked_data_all_xr.longitude.min().values, masked_data_all_xr.longitude.max().values]),


# SAVING MODE 2 SUB-BASIN AVERAGE
# - - - - - - - ------------------------------------------------
# ------------------ - - - - - - - - - - -
def extract_regional_time_series(data, masks, labels):
    regional_stats = {}
    for mask, label in zip(masks, labels):
        masked_data = data.where(mask)
        mean_data = masked_data.mean(dim=['latitude', 'longitude'], skipna=True)
        std_data = masked_data.std(dim=['latitude', 'longitude'], skipna=True) 
        regional_stats[label] = xr.Dataset({'mean': mean_data, 'std': std_data})
    return regional_stats

regional_data = extract_regional_time_series(masked_data_all_xr, 
                                             [mask_adria, mask_ion, mask_lev, mask_aeg, mask_west, mask_tyr, mask_cre,mask_scm], 
                                             ["Adriatic","Ionian","Levantine","Aegean", "Western Mediterranean", "Tyrrhenian Sea", "Crete Sea", "South Central Mediterranean"])

def save_series(dataset, region, name_variable, name_file):
    mean_data = dataset[region]['mean'].values
    std_data = dataset[region]['std'].values
    combined_data = np.vstack((mean_data, std_data)).T
    np.savetxt(name_file, combined_data, delimiter=' ', fmt='%1.4f')

for region in regional_data.keys():
    save_series(regional_data, region, region, f'{region.lower().replace(" ", "_")}_basin_2.txt')


# SAVING MODE 1 SUB-BASIN AVERAGE
# - - - - - - - ------------------------------------------------
# ------------------ - - - - - - - - - - -
masked_data_all = np.empty((30, len(lat2), len(lon2)))
min_val = np.inf
max_val = -np.inf
for i in range(30):
    lons, lats = np.meshgrid(lon2, lat2)
    masked_data = np.where(mask, mode1[i,:,:], np.nan)
    masked_data_all[i,:,:] = masked_data
    min_val = min(np.nanmin(masked_data), min_val)
    max_val = max(np.nanmax(masked_data), max_val)

masked_data_all_xr_1 = xr.DataArray(masked_data_all, 
                                  coords=[sea_y.time, sea_y.latitude, sea_y.longitude], 
                                  dims=["time", "latitude", "longitude"])

mask_adria = xr.zeros_like(masked_data_all_xr_1.isel(time=0), dtype=bool)
mask_ion = xr.zeros_like(masked_data_all_xr_1.isel(time=0), dtype=bool)
mask_lev = xr.zeros_like(masked_data_all_xr_1.isel(time=0), dtype=bool)
mask_aeg = xr.zeros_like(masked_data_all_xr_1.isel(time=0), dtype=bool)
mask_west = xr.zeros_like(masked_data_all_xr_1.isel(time=0), dtype=bool)
mask_tyr = xr.zeros_like(masked_data_all_xr_1.isel(time=0), dtype=bool)
mask_cre = xr.zeros_like(masked_data_all_xr_1.isel(time=0), dtype=bool)
mask_scm = xr.zeros_like(masked_data_all_xr_1.isel(time=0), dtype=bool)

mask_adria[98:128,143:198] = True
mask_adria[86:99,164:208] = True
mask_adria[82:87,192:208] = True

mask_ion[32:56,169:210] = True
mask_ion[55:81,178:210] = True
mask_ion[80:84,178:191] = True
mask_ion[55:64,169:179] = True

mask_lev[4:58,274:] = True

mask_aeg[50:90,228:262] = True
mask_aeg[42:50,229:273] = True
mask_aeg[50:80,260:273] = True

mask_west[40:105,:125] = True
mask_west[105:118,70:125] = True

mask_tyr[54:116,126:143] = True
mask_tyr[62:99,142:164] = True
mask_tyr[64:90,164:178] = True

mask_cre[7:42,229:273] = True

mask_scm[:54,127:169] = True
mask_scm[:32,168:229] = True
mask_scm[31:75,210:229] = True
mask_scm[54:62,143:164] = True

def extract_regional_time_series(data, masks, labels):
    regional_stats = {}
    for mask, label in zip(masks, labels):
        masked_data = data.where(mask)
        mean_data = masked_data.mean(dim=['latitude', 'longitude'], skipna=True)
        std_data = masked_data.std(dim=['latitude', 'longitude'], skipna=True)
        regional_stats[label] = xr.Dataset({'mean': mean_data, 'std': std_data})
    return regional_stats

regional_data = extract_regional_time_series(masked_data_all_xr, 
                                             [mask_adria, mask_ion, mask_lev, mask_aeg, mask_west, mask_tyr, mask_cre,mask_scm], 
                                             ["Adriatic","Ionian","Levantine","Aegean", "Western Mediterranean", "Tyrrhenian Sea", "Crete Sea", "South Central Mediterranean"])

def save_series(dataset, region, name_variable, name_file):
    mean_data = dataset[region]['mean'].values
    std_data = dataset[region]['std'].values
    combined_data = np.vstack((mean_data, std_data)).T  
    np.savetxt(name_file, combined_data, delimiter=' ', fmt='%1.4f')

for region in regional_data.keys():
    save_series(regional_data, region, region, f'{region.lower().replace(" ", "_")}_basin_1.txt')


# -----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
# VORTICITY ---------------------------------------------------------------
# -----------------------------------------------------------------------------
vorticity = xr.open_dataset("path/vorticity_ionian.nc")
ugos = vorticity['ugos']
vgos = vorticity['vgos']
vgos_by_x = vgos.differentiate("longitude")
ugos_by_y = ugos.differentiate("latitude")
zeta = vgos_by_x - ugos_by_y
zeta_mean = zeta.mean(dim=['latitude', 'longitude'])
z_monthly = zeta_mean.resample(time='1M').mean('time')
z_annual = zeta_mean.resample(time='1Y').mean('time')
z_monthly = z_monthly.to_series()
z_annual = z_annual.to_series()

z_monthly.to_csv("z_monthly.txt", header=False, sep="\t")
z_annual.to_csv("z_annual.txt", header=False, sep="\t")

# VORTICITY OF IONIAN SEA
vorticity_y=pd.read_csv('path/z_annual.txt',\
                 delim_whitespace=True,header=None)[:]
vorticity_y.index = datesea_y
vorticity_y = vorticity_y[1]*10 # 1/s
vorticity_y.index = vorticity_y.index - pd.DateOffset(months=5)

vorticity_m=pd.read_csv('path/z_monthly.txt',\
                 delim_whitespace=True,header=None)[:]
vorticity_m.index = datesea_m
vorticity_m = vorticity_m[1]*10 # 1/s (values x10^-6)
datesat = pd.date_range('1/1/1993','12/31/2022', freq='Y')
datesat_m = pd.date_range('1/1/1993','12/31/2022', freq='M')
years = [year for year in range(1993, 2023)]

def read_basin_data(file_path, date_index, column_names=None):
    if column_names is None:
        column_names = ['mean', 'std']
    basin_data = pd.read_csv(file_path, delim_whitespace=True, header=None, names=column_names)
    basin_data.index = date_index
    basin_data.index = basin_data.index - pd.DateOffset(months=11)
    return basin_data
file_base_path = 'path/'
adria_basin = read_basin_data(file_base_path + 'adriatic_basin_2.txt', datesat)
ion_basin = read_basin_data(file_base_path + 'ionian_basin_2.txt', datesat)
lev_basin = read_basin_data(file_base_path + 'levantine_basin_2.txt', datesat)
egeo_basin = read_basin_data(file_base_path + 'aegean_basin_2.txt', datesat)
tyr_basin = read_basin_data(file_base_path + 'tyrrhenian_sea_basin_2.txt', datesat)
west_basin = read_basin_data(file_base_path + 'western_mediterranean_basin_2.txt', datesat)
scm_basin = read_basin_data(file_base_path + 'south_central_mediterranean_basin_2.txt', datesat)
western_med = (west_basin['mean'] + tyr_basin['mean']) / 2
ion_basin_combined = (ion_basin['mean'] + scm_basin['mean']) / 2
med = (western_med + ion_basin_combined + adria_basin['mean'] + lev_basin['mean'] + egeo_basin['mean']) / 5
med = med*100
western_med_std = (west_basin['std'] + tyr_basin['std']) / 2
ion_basin_combined_std = (ion_basin['std'] + scm_basin['std']) / 2
med_std = (western_med_std + ion_basin_combined_std + adria_basin['std'] + lev_basin['std'] + egeo_basin['std']) / 5
med_std = med_std*100

def color_background(ax, vorticity_m, years, alpha=0.3):
    for year in years:
        color = "#56B4E9" if vorticity_m[str(year)].mean() > 0 else "#E69F00"
        ax.axvspan(np.datetime64(f'{year}-01-01'), np.datetime64(f'{year+1}-01-01'), color=color, alpha=alpha, zorder=0, edgecolor='none', linewidth=0)
red_lines_years = [1997, 2006, 2010, 2016, 2020]

average_vorticity = [vorticity_m[str(year)].mean() for year in years]
sign_changes = np.where(np.diff(np.sign(average_vorticity)))[0]
change_years = [years[i + 1] for i in sign_changes]
change_dates = [np.datetime64(f'{year}-01-01') for year in change_years]

plt.figure(figsize=(6, 3), dpi=1000)
ax = plt.gca()
color_background(ax, vorticity_m, years)
for date in change_dates:
    ax.axvspan(date - pd.Timedelta("73 days"), date + pd.Timedelta("73 days"), color='white', zorder=1)
for year in red_lines_years:
    ax.axvline(pd.Timestamp(f'{year}-01-01'), color='red', linestyle='--', lw=1.5, zorder=3)
ax.plot(vorticity_m.index, vorticity_m, color='black', zorder=2, linewidth=1.5)
ax2 = ax.twinx()
ax2.plot(med.index, med, color='red', linewidth=1.5)
ax2.set_ylabel('IMF2 sea level (cm)', color='red', rotation=-90, labelpad=10)
ax2.tick_params(axis='y', colors='red')
ax2.fill_between(med.index, med-med_std, med+med_std, color='red', alpha=0.15, linewidth=0, edgecolor=None)
ax.axhline(0, color='black', linestyle='--', lw=1.5, zorder=2)
ax.set_xlabel("Time")
ax.set_ylabel("Vorticity (1/s x 10⁻⁶)")
ax.set_xticks([pd.Timestamp(f'{year}-01-01') for year in years])  
ax.set_xticklabels(years, rotation=45) 
ax.set_xlim(pd.Timestamp('1993-01-01'), pd.Timestamp('2022-12-31'))
plt.tight_layout()
plt.show()


# SALINITY ------------------------------------------------------------------
# ----------------------------------------------------------------------------
# -----------------------------------------------------------------------------
salinity = xr.open_dataset("path/med-cmcc-sal-rean-m_1698750481739.nc")
sal = salinity['so'].sel(time=slice('1987-01-16','2020-12-16'))
sal_y = sal.resample(time='Y').mean('time')
sal_y = sal_y.isel(depth=0)
lon_s = salinity.variables['lon'][:]
lat_s = salinity.variables['lat'][:]
print(np.nanmin(sal_y), np.nanmean(sal_y), np.nanmax(sal_y))

fig, axs = plt.subplots(5, 7, figsize=(12, 8.5), dpi=100)
axs = axs.ravel()
for i in range(34):
    ax = axs[i]
    mp = Basemap(projection='merc', 
                 llcrnrlon=12, llcrnrlat=40, urcrnrlat=46, urcrnrlon=20, 
                 resolution='l', ax=ax)
    lons, lats = np.meshgrid(lon_s, lat_s)
    x, y = mp(lons, lats)
    divnorm = colors.TwoSlopeNorm(vmin=38.06, vcenter=38.26, vmax=38.46)
    c_scheme = mp.pcolor(x, y, sal_y[i, :, :], cmap=gmt_salinity, norm=divnorm)
    mp.fillcontinents(color='black')
    ax.set_title(f'{1987+i}') 
    ax.title.set_size(18)
for j in range(i+1, len(axs)):
    fig.delaxes(axs[j])
cbar_ax = fig.add_axes([0.35, -0.05, 0.3, 0.025]) 
cbar = fig.colorbar(c_scheme, cax=cbar_ax, orientation='horizontal', extend='both')
vmin = 38.06
vcenter = 38.26
vmax = 38.46
cbar.set_ticks([vmin, vcenter, vmax])
cbar.set_ticklabels([f'{vmin:.2f}', f'{vcenter:.2f}', f'{vmax:.2f}'])  # Format to two decimal places
cbar.ax.tick_params(labelsize=18)
cbar.set_label('psu', size=18, labelpad=-2)
plt.tight_layout()
fig.suptitle('Salinity of the Adriatic Sea', fontsize=20, y=1.02)
plt.show()


