"""
Author: fiona Fix, 2023-07

this script contains functions used in the data preprocesseing for lagranto
"""

import numpy as np
import xarray as xa
import pandas as pd
import os
from scipy.signal import butter, filtfilt
from matplotlib.colors import LinearSegmentedColormap
import cartopy.crs as ccrs

#CONSTANTS
G       = 9.80665
R_earth = 6371229
gridcellarea_file='/gpfs/data/fs72140/c4031046_vsc/winter_lightning/ERA5-nAf/single-level/gridarea.nc'

def set_plotting_details(data, early_date, late_date, context, region, season, landonly, daytime_selection, length, outpath):
    """
    Set up plotting details and output directory for climatology plots.

    Parameters
    ----------
    data : xarray.Dataset
        Dataset containing the spatial and temporal data to be plotted. 
        Used to extract latitude and longitude bounds.
    early_date : str
        Start date for the data in the format 'YYYY-MM-DD'.
    late_date : str
        End date for the data in the format 'YYYY-MM-DD'.
    context : str
        Context or description of the data (e.g., "temperature", "precipitation").
        Used in naming the output directory.
    region : str
        Name of the region being analyzed (e.g., "Africa", "Europe").
        Used in naming the output directory.
    season : str
        Season being analyzed (e.g., "DJF", "MAM", "JJA", "SON").
        Used in naming the output directory.
    landonly : bool
        If True, indicates that the data is filtered to include only land areas.
        Appends "_landonly" to the output directory name.
    daytime_selection : bool
        If True, indicates that the data is filtered to include only daytime hours.
        Appends "_daytime" to the output directory name.
    length : int
        Length of the climatology period (e.g., number of years).
        Used in naming the output directory.
    outpath:str
        path where outdir should be generated

    Returns
    -------
    tuple
        A tuple containing:
        - outdir (str): Path to the output directory for saving plots.
        - map_extent (list): List defining the map extent as [west, east, south, north].
        - early_date_str (pandas.Timestamp): Start date as a pandas Timestamp object.
        - late_date_str (pandas.Timestamp): End date as a pandas Timestamp object.

    """
    print('setting plotting details')
    south_int = data.lat.min().values
    north_int = data.lat.max().values
    west_int  = data.lon.min().values
    east_int  = data.lon.max().values
    early_date_str = pd.to_datetime(early_date) #pd.to_datetime(data.time.min().values)
    late_date_str  = pd.to_datetime(late_date)   #pd.to_datetime(data.time.max().values)
    map_extent        = [data.lon.min(), data.lon.max(), data.lat.min(), data.lat.max()]
    outdir = f"{outpath}/{context}_Climatology_length{length}_"\
             f"{early_date_str.year:04}{early_date_str.month:02}{early_date_str.day:02}to"\
             f"{late_date_str.year:04}{late_date_str.month:02}{late_date_str.day:02}_"\
             f"{region}_{season}"
    if landonly:
        outdir = f"{outdir}_landonly"
    if daytime_selection:
        outdir = f"{outdir}_daytime"

    outdir = f"{outdir}/"
    if not os.path.exists(outdir):
        os.makedirs(outdir)
    print(f"outdir: {outdir}")
    return outdir, map_extent, early_date_str, late_date_str

def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):
    """
    Truncate a colormap to use only a subset of its range.
    
    Parameters:
        cmap: The original colormap (e.g., plt.get_cmap('Greys')).
        minval: The minimum value of the colormap to use (0.0 = start, 1.0 = end).
        maxval: The maximum value of the colormap to use (0.0 = start, 1.0 = end).
        n: Number of discrete colors in the truncated colormap.
    
    Returns:
        A new truncated colormap.
    """
    new_cmap = LinearSegmentedColormap.from_list(
        f'trunc({cmap.name},{minval:.2f},{maxval:.2f})',
        cmap(np.linspace(minval, maxval, n))
    )
    return new_cmap



def get_AD_data(infiles, dropvars,\
               south, north, west, east, early_date, late_date, daytime_selection, season, landonly,infiles_landmask=None,\
               morning=10, afternoon=15):
    """
    Load and preprocess atmospheric data (AD) for a specified region, time period, and conditions.

    Parameters
    ----------
    infiles : str or list of str
        Path(s) to the input NetCDF file(s) containing the atmospheric data.
    dropvars : list of str
        List of variable names to drop from the dataset.
    south : float
        Southern latitude boundary for spatial selection.
    north : float
        Northern latitude boundary for spatial selection.
    west : float
        Western longitude boundary for spatial selection.
    east : float
        Eastern longitude boundary for spatial selection.
    early_date :
        Start date for temporal selection     
    late_date : 
        End date for temporal selection
    daytime_selection : bool
        If True, filter the data to include only daytime hours (specified by `morning` and `afternoon`).
    season : str
        Season to filter the data. Options are:
        - 'DJF': December, January, February
        - 'MAM': March, April, May
        - 'JJA': June, July, August
        - 'SON': September, October, November
        - 'allyear'
    landonly : bool
        If True, filter the data to include only land areas using a land-sea mask.
    infiles_landmask : str or list of str, optional
        Path(s) to the input NetCDF file(s) containing the land-sea mask. Required if `landonly` is True.
    morning : int, optional
        Start hour for daytime selection (default is 10).
    afternoon : int, optional
        End hour for daytime selection (default is 15).

    Returns
    -------
    xarray.Dataset
        Preprocessed dataset containing the filtered atmospheric data, with additional variables:
        - `AD_colYesSmooth`: Smoothed boolean variable indicating AD occurrences.
        - `grid_cell_area`: Area of each grid cell in square kilometers.
    """
    print('getting AD data')
    print(infiles)
    data                = xa.open_mfdataset(infiles)
    data                = data.drop_vars(dropvars)
    data                = data.sel(lat=slice(south,north), lon=slice(west,east))
    data                = data.sel(time=slice(early_date,late_date))
    data['time']        = data['time'].dt.round('h')

    # if AD data is not available for the time period given in the beginning, make the availbal eperiod the valid period now
    early_date = data.time.min().values
    late_date  = data.time.max().values

    if landonly:
        print('land only selection')
        land_sea_mask = xa.open_mfdataset(infiles_landmask)
        land_sea_mask = land_sea_mask.rename({"longitude": "lon", "latitude": "lat"})
        land_sea_mask = land_sea_mask.sortby('lat')
        land_sea_mask = land_sea_mask.sel(lat=slice(south,north), lon=slice(west,east))
        land_sea_mask = land_sea_mask.sel(time=slice(early_date,late_date))
        # Apply mask
        mask = land_sea_mask.lsm > 0.5
        data = data.where(land_sea_mask.lsm > 0.5,other=False)  # get land only data
        for var in data.data_vars:
            if var != "AD_colYes":  # Skip boolean variable
                data[var] = data[var].where(land_sea_mask.lsm > 0.5)  # Keeps NaNs in other variables

    if daytime_selection:
        print('daytime selection')
        time_index   = data['time'].dt
        daytime_mask = (time_index.hour >= morning) & (time_index.hour <= afternoon)
        data         = data.sel(time=data['time'][daytime_mask])
    print(f"selecting season: {season}")
    if season=='DJF':
        data = data.sel(time=data.time.dt.month.isin([12, 1, 2]))
    elif season=='MAM':
        data = data.sel(time=data.time.dt.month.isin([3, 4, 5]))
    elif season=='JJA':
        data = data.sel(time=data.time.dt.month.isin([6, 7, 8]))
    elif season=='SON':
        data = data.sel(time=data.time.dt.month.isin([9, 10, 11]))

    data                   = AD_count_percent(data)
    lon_grid, lat_grid     = np.meshgrid(data.lon, data.lat)
    grid_cell_area         = xa.open_dataset(f"{gridcellarea_file}").rename({"longitude": "lon", "latitude": "lat"})
    grid_cell_area         = grid_cell_area.cell_area/1e6 #in km2
    data['grid_cell_area'] = grid_cell_area
    print('check')
    # smooth: fill 1h-ling interruptions within True values along time axis
    data['AD_colYesSmooth'] = smooth_short_interruptions(data.AD_colYes)
    return data



def calc_streaks(data, dim='time'):
    """
    Calculate streak-related metrics for a boolean xarray.DataArray along a specified dimension.

    Parameters
    ----------
    data : xarray.DataArray
        A boolean xarray.DataArray where streaks of `True` values are analyzed. 
        The array must have the specified dimension (`dim`) and coordinates.
    dim : str, optional
        The dimension along which streaks are calculated. Default is `'time'`.

    Returns
    -------
    num_streaks : xarray.DataArray
        The number of streaks (sequences of consecutive `True` values) along the specified dimension.
    mean_streak_length : xarray.DataArray
        The mean length of streaks along the specified dimension.
    max_streak_length : xarray.DataArray
        The maximum length of streaks along the specified dimension.
    nintieth_perc_streak_length : xarray.DataArray
        The 90th percentile of streak lengths along the specified dimension.
    all_streak_length : xarray.DataArray
        An array containing the lengths of all streaks, with non-streak values set to `NaN`.

    """
    # Insert one False value before the first timestamp to ensure the first streak is counted
    time_diff = data.time[1] - data.time[0]   # Difference between the first two timestamps
    new_time  = data.time[0] - time_diff       # Timestamp before the first one
    new_time  = np.datetime64(new_time.data)  # Explicitly cast to numpy.datetime64

    false_data = xa.DataArray(
        data=np.full((1, data.sizes['lat'], data.sizes['lon']), False),  # Shape (1, lat, lon) with False values
        dims=["time", "lat", "lon"],
        coords={"time": [new_time], "lat": data.lat, "lon": data.lon},
        name="AD_colYes"
    )

    # Combine the new DataArray with the original and sort by time
    data_with_false = xa.concat([false_data, data], dim="time").sortby("time")

    # Calculate change points (start and end of streaks)
    change_points = data.astype(int).diff(dim=dim)
    change_points = xa.concat([change_points, xa.zeros_like(change_points.isel({dim: 0}))], dim=dim)
    change_points = change_points.assign_coords({dim: data[dim]})
    # Handle the last timestamp if it is part of a streak
    if data.isel({dim: -1}).any():
        change_points[{dim: -1}] = xa.where(data.isel({dim: -1}), -1, change_points.isel({dim: -1}))

    # Calculate the number of streaks
    num_streaks = (change_points == 1).sum(dim=dim)
    max_streak_length = (
        data * (data.cumsum(dim=dim) - data.cumsum(dim=dim).where(~data).ffill(dim=dim).fillna(0))
    ).max(dim=dim)
    
    all_streak_length = (
        data * (data.cumsum(dim=dim) - data.cumsum(dim=dim).where(~data).ffill(dim=dim).fillna(0))
    ).where(change_points == -1)

    mean_streak_length = all_streak_length.mean(dim=dim)

    nintieth_perc_streak_length = all_streak_length.chunk({'time': -1}).quantile(0.9, dim=dim)

    return num_streaks, mean_streak_length, max_streak_length, nintieth_perc_streak_length, all_streak_length


def calc_lid_frequencies(data_daytime, BLH_threshold):
    """
    Calculate various metrics and frequencies related to atmospheric lid and aerosol detection (AD) data.

    Parameters
    ----------
    data_daytime : xarray.Dataset
        An xarray dataset containing atmospheric data with the following variables:
        - `lid_presentYes` : Boolean or binary variable indicating the presence of a lid.
        - `AD_colYes` : Boolean or binary variable indicating the presence of an aerosol detection (AD) column.
        - `diff_lowestAD_BLH` : Difference between the lowest AD height and the boundary layer height (BLH).
        - `counts` : Variable indicating the presence of data in a cell (used to determine valid AD cells).
    BLH_threshold: float
        float tthat determines the threshold to calculate whether a cell is below BLH
    Returns
    -------
    xarray.Dataset
        The input dataset with the following new variables added:
        - `lid_presentYesSmooth` : Smoothed version of `lid_presentYes` to handle short interruptions.
        - `Num_cols_covered` : Total number of AD columns covered, summed over longitude (`lon`) and latitude (`lat`).
        - `Num_AD_lowest_lid` : Total number of lid cells, summed over longitude (`lon`) and latitude (`lat`).
        - `Frequency_Lid` : Frequency of lid cells, averaged over the `time` dimension.
        - `Frequency_LidgivenAD` : Frequency of lid cells given the presence of AD columns
        - `Num_ADbelowBLH_thresh` : Number of AD cells where the lowest AD is below the BLH threshold, summed over `lon` and `lat`.
        - `ADbelowBLH` : Boolean flag indicating where the lowest AD is below the BLH threshold.
        - `AD_cellYes` : Boolean flag indicating the presence of valid AD cells (non-null `counts`).
        - `Frequency_ADcell` : Frequency of valid AD cells, averaged over the `time` dimension.
    """
    data_daytime['lid_presentYesSmooth']  = smooth_short_interruptions(data_daytime.lid_presentYes)
    data_daytime['Num_cols_covered']      = data_daytime['AD_colYes'].sum(dim=['lon','lat'])                       # number of AD columns
    data_daytime['Num_AD_lowest_lid' ]    = data_daytime.lid_presentYes.sum(dim=['lon','lat'])                     # number of lid cells
    data_daytime['Frequency_Lid']         = data_daytime.lid_presentYes.mean(dim='time')                           # frequency of lid cells
    data_daytime['Frequency_LidgivenAD']  = ((data_daytime.lid_presentYes & data_daytime.AD_colYes).mean(dim='time')) / data_daytime.AD_colYes.mean(dim='time') #frequency of lid cells given that there is AD
    data_daytime['Num_ADbelowBLH_thresh'] = (data_daytime.diff_lowestAD_BLH<-BLH_threshold).sum(dim=['lon','lat']) # flag where lowest AD below BLH-threshold
    data_daytime['ADbelowBLH']            = (data_daytime.diff_lowestAD_BLH<-BLH_threshold)

    data_daytime['AD_cellYes']            = data_daytime.counts.notnull()
    data_daytime['Frequency_ADcell']      = data_daytime.AD_cellYes.mean(dim='time')
    return data_daytime



def AD_count_percent(data):
    """
    Compute various metrics related to atmospheric data (AD) occurrences, including 
    the total number of trajectories, frequency of AD occurrences, and percentage 
    of columns covered.

    Parameters
    ----------
    data : xarray.Dataset
        Dataset containing atmospheric data with the following variables:
        - `AD_colcounts`: Number of AD occurrences per grid cell.
        - `AD_colYes`: Boolean variable indicating whether AD occurred (True/False) for each grid cell and time step.

    Returns
    -------
    xarray.Dataset
        The input dataset with the following additional variables:
        - `NumTrajs`: Total number of AD trajectories across all grid cells and time steps.
        - `Frequency_AD`: Mean frequency of AD occurrences across all time steps for each grid cell.
        - `colcount`: Total number of grid cells with AD occurrences across all time steps.
        - `percent_cols_covered`: Percentage of grid cells covered by AD occurrences, relative to the total number of grid cells.
    """
    data['NumTrajs']             = data.AD_colcounts.sum(dim=['lon','lat'])
    data['Frequency_AD']         = data.AD_colYes.mean(dim='time')
    colcount_domain              = data.sizes['lon'] * data.sizes['lat']
    data['colcount']             = data['AD_colYes'].sum(dim=['lon','lat'])
    data['percent_cols_covered'] = (data.colcount/colcount_domain)
    return data

def smooth_short_interruptions(data_array):
    """
    Smooth short interruptions (isolated False values) in a boolean time series by filling them with True.

    Parameters
    ----------
    data_array : xarray.DataArray
        A boolean xarray DataArray with a `time` dimension, where True indicates the presence of a condition 
        (e.g., AD occurrence) and False indicates its absence. Missing values (NaN) are treated as False.

    Returns
    -------
    xarray.DataArray
        A smoothed version of the input DataArray, where isolated False values (surrounded by True values 
        along the `time` dimension) are replaced with True.
    """
    data_array = data_array.fillna(False)
    # Shift the data along the time dimension
    shifted_forward  = data_array.shift(time=-1, fill_value=False)
    shifted_backward = data_array.shift(time=1, fill_value=False)
    # Detect isolated False values surrounded by True values
    isolated_false = ~data_array & shifted_forward & shifted_backward
    # Set these isolated False values to True
    smoothed_data = data_array | isolated_false  # this is true where original data is true and where isolated false was detected
    return smoothed_data


def group_and_count_AD(ds, variables, dz, dx):
    """
    Groups and counts data points in a dataset based on rounded spatial and vertical grid spacing.

    This function takes an xarray dataset, that contains date selected trajectory data, rounds the longitude, latitude, and height above mean sea level (hamsl)
    to specified grid spacings, groups the data by these rounded values, and computes counts and means for each group.
    The result is returned as a new xarray dataset.

    Parameters:
    -----------
    ds : xarray.Dataset
        The input dataset containing the data to be grouped and counted.
    variables : list of str
        List of variable names in the dataset to include in the grouping and aggregation.
    dz : float
        Vertical grid spacing for rounding the `hamsl` variable.
    dx : float
        Horizontal grid spacing for rounding the `lon` and `lat` variables.

    Returns:
    --------
    xarray.Dataset
        A new dataset containing the grouped and aggregated data, with counts and means for each group.
    """
    # round to grid spacing
    ds['lon_r']  = np.round(ds['lon'] / dx ) * dx
    ds['lat_r']  = np.round(ds['lat'] / dx ) * dx
    ds['hamsl_r']= np.round(ds['hamsl'] / dz) * dz

    # for the grouping use pandas
    df_AD     = ds[variables].to_dataframe()
    # groupby rounded coordinates
    AD_grouped_cells = df_AD.groupby(['lon_r','lat_r','hamsl_r'])
    # count per group
    AD_count_cells   = AD_grouped_cells.size().reset_index(name='counts')
    # means
    AD_count_cells['hamsl_r2'] = AD_grouped_cells.mean(numeric_only=True).reset_index()['hamsl_r']
    AD_mean_cells              = AD_grouped_cells.mean(numeric_only=True).reset_index()
    AD_result                  = pd.merge(AD_count_cells, AD_mean_cells, on=['lon_r', 'lat_r', 'hamsl_r'])
    AD_cells                   = xa.Dataset.from_dataframe(AD_result.set_index(['lon_r','lat_r', 'hamsl_r']))
    # drop orignal hamsl variable and rename rounded coordinates
    AD_cells                   = AD_cells.drop_vars('hamsl')
    AD_cells                   = AD_cells.rename({'lat_r':'lat','lon_r':'lon', 'hamsl_r':'hamsl'})
    return AD_cells



def get_lowest_AD_cell(hamsl_data):
    """
    Finds the lowest non-NaN cell along the 'hamsl' dimension in a given xarray DataArray.

    This function identifies the lowest (smallest index) non-NaN value along the 'hamsl' dimension
    for each combination of 'time', 'lat', and 'lon' in the input DataArray. It returns a new DataArray
    containing the corresponding values from the input DataArray.

    Parameters:
    -----------
    hamsl_data : xarray.DataArray
        An xarray DataArray with a 'hamsl' dimension, representing height above mean sea level.
        The DataArray should also have 'time', 'lat', and 'lon' dimensions.

    Returns:
    --------
    xarray.DataArray
        A new DataArray containing the lowest non-NaN values along the 'hamsl' dimension for each
        combination of 'time', 'lat', and 'lon'. The resulting DataArray has the same dimensions
        ('time', 'lat', 'lon') as the input, excluding the 'hamsl' dimension.
    """
    step1 = np.isnan(hamsl_data)                                           # identify nans
    step2 = np.nanargmin(step1, axis=hamsl_data.get_axis_num('hamsl'))     # find the index of the lowest nan along hamsl
    step3 = xa.DataArray(step2, \
                coords={'time':hamsl_data.time,\
                         'lat': hamsl_data.lat,\
                         'lon': hamsl_data.lon},\
                dims=('time','lat','lon'))                                 # create DataArray with indices
    step4 = hamsl_data.isel(hamsl=step3)                                   # select coresponding data
    return step4

def get_highest_AD_cell(hamsl_data):
    """
    Finds the highest non-NaN cell along the 'hamsl' dimension in a given xarray DataArray.

    This function identifies the highest (largest index) non-NaN value along the 'hamsl' dimension
    for each combination of 'time', 'lat', and 'lon' in the input DataArray. It returns a new DataArray
    containing the corresponding values from the input DataArray.

    Parameters:
    -----------
    hamsl_data : xarray.DataArray
        An xarray DataArray with a 'hamsl' dimension, representing height above mean sea level.
        The DataArray should also have 'time', 'lat', and 'lon' dimensions.

    Returns:
    --------
    xarray.DataArray
        A new DataArray containing the highest non-NaN values along the 'hamsl' dimension for each
        combination of 'time', 'lat', and 'lon'. The resulting DataArray has the same dimensions
        ('time', 'lat', 'lon') as the input, excluding the 'hamsl' dimension.
    """
    step0 = np.flip(hamsl_data, hamsl_data.get_axis_num('hamsl'))                # flip to start looking from top
    step1 = np.isnan(step0)                                                      # find nans
    step2 = np.nanargmin(step1, hamsl_data.get_axis_num('hamsl'))                # find index of first non-nan along hamsl
    step22 = hamsl_data.shape[hamsl_data.get_axis_num('hamsl')] - 1 - step2      # get indices back in original order
    step3 = xa.DataArray(step22, \
            coords={'time':hamsl_data.time,\
                    'lat': hamsl_data.lat, \
                    'lon': hamsl_data.lon},\
            dims=('time','lat','lon'))                                           # convert indices to DataArray
    step4 = hamsl_data.isel(hamsl=step3)                                         # get corresponding data
    return step4

def calculate_num_streaks(arr, axis):
    """ 
    Calculate the number of continuous non-NaN streaks along a specified axis in a NumPy array.

    This function identifies and counts the number of continuous blocks (or streaks) of non-NaN values
    along the specified axis of the input array. A streak is defined as a sequence of consecutive
    non-NaN values.

    Parameters:
    -----------
    arr : numpy.ndarray
        The input array in which to count the streaks of non-NaN values.
    axis : int
        The axis along which to calculate the streaks.

    Returns:
    --------
    numpy.ndarray
        An array containing the count of non-NaN streaks along the specified axis. The shape of the
        returned array will be the same as the input array, except for the specified axis, which will
        be reduced."""
    not_nan       = ~np.isnan(arr)                                # find nans
    change_points = np.diff(not_nan.astype(int), axis=axis)       # Use np.diff to find where the streak changes occur
    streak_starts = (change_points == 1).sum(axis=axis)           # count how many streak starts there are
    return streak_starts






def geopot_to_geoh(geopot, g=G):
    """
    function to calculate geopotential height from geopotential
    Parameters
    ---------
    geopot: array
          geopotential
    g     : float
          earths acceleration
    Returns
    --------
    geoh: array
          geopotential height in m 
    """
    geoh = geopot/g
    return geoh

def geoh_to_hamsl(geoh, R=R_earth):
    """
    function to calculate  height above mean sea level from geopotential height
    Parameters
    ---------
    geoh: array
          geopotential height
    R     : float
          earths radius
    Returns
    --------
    z: array
        height above mean sea level
    """
    z = (geoh*R)/(R-geoh)
    return z

def geopot_to_hamsl(geopot, g=G, R=R_earth):
    """
    function to calculate  height above mean sea level from geopotential
    Parameters
    ---------
    geopot: array
          geopotential
    g     : float
          earths acceleration
    R     : float
          earths radius
    Returns
    --------
    z: array
         height above mean sea level in m 
    """
    z = (geopot * R) / (g * R - geopot)
    return z

def hag_to_hamsl(hag, topo):
    """
    function to calculate hieght above mean sea level fromheight above ground
    Parameters
    ---------
    hag: array
         height above groun din m
    topo: array
         topography above mean sea level in m
    Returns
    ------
    hamsl: array
         height above mean sea level

    """
    hamsl = hag + topo
    return hamsl


def lowpass_filter(data, cutoff, fs, order=4):
    """
    Apply a low-pass Butterworth filter to a 1D time series.

    Parameters:
    ----------
    data : array-like
        The input time series data to be filtered. Must be a 1D array.
    cutoff : float
        The cutoff frequency of the low-pass filter. Frequencies 
        above this value will be attenuated.
    fs : float
        The sampling frequency of the input signal. This is used 
        to calculate the Nyquist frequency.
    order : int, optional
        The order of the Butterworth filter. Higher orders result in a steeper 
        roll-off around the cutoff frequency. Default is 4.

    Returns:
    -------
    y : ndarray
        The filtered time series, with the same shape as the input `data`. The filtered time series.
    """
    nyquist = 0.5 * fs                # Nyquist frequency
    normal_cutoff = cutoff / nyquist  # Normalize cutoff frequency
    b, a = butter(order, normal_cutoff, btype='low', analog=False)
    y = filtfilt(b, a, data)
    return y


def get_African_Polygon():
    import cartopy.feature as cfeature
    import cartopy.io.shapereader as shpreader

    # Load the Natural Earth coastline data
    shapefile_path = shpreader.natural_earth(resolution='110m', category='physical', name='coastline')
    reader = shpreader.Reader(shapefile_path)

    # Define the bounding box for northern Africa
    min_lat, max_lat = 20, 37
    min_lon, max_lon = -17, 51

    # Extract and filter the coastline coordinates
    northern_africa_coords = []
    for record in reader.records():
        geometry = record.geometry
        if geometry.geom_type == 'LineString':
            coords = list(geometry.coords)
            # Filter coordinates within the bounding box
            filtered_coords = [coord for coord in coords if min_lon <= coord[0] <= max_lon and min_lat <= coord[1] <= max_lat]
            northern_africa_coords.extend(filtered_coords)
        elif geometry.geom_type == 'MultiLineString':
            for line in geometry:
                coords = list(line.coords)
                # Filter coordinates within the bounding box
                filtered_coords = [coord for coord in coords if min_lon <= coord[0] <= max_lon and min_lat <= coord[1] <= max_lat]
                northern_africa_coords.extend(filtered_coords)
    lons, lats = zip(*northern_africa_coords[23:110])
    lons2, lats2 = zip(*northern_africa_coords[180:-1])

    coordinates = northern_africa_coords[23:110]+northern_africa_coords[180:-1]
    return coordinates

def type_trajcluster(delta_hamsl, delta_TH, delta_Q, delta_Q_rel):
    #### CONDITIONS ####
    cond_blue    = (delta_hamsl > 1000) & (delta_TH > 0) &     (delta_Q < 0)
    cond_yellow  =                        (delta_TH < 0) & (delta_Q_rel<-.1) #(delta_hamsl > 0)

    cond_red     =                     (delta_TH< 0)   &        (delta_Q_rel>.1) # (delta_hamsl<500)
    cond_green   =                     (delta_TH<-1.5) & (np.abs(delta_Q_rel)<.1)
    cond_cyan    =            (np.abs(delta_TH)<1.5)   & (np.abs(delta_Q_rel)<.1)

    conditions = [cond_blue, cond_red, cond_cyan, cond_yellow, cond_green]
    if sum(conditions) > 1:
        print("Warning: Overlapping conditions detected!")
    else: 
        print("conditions not overlapping")
    if cond_blue:
        colour='blue'
    elif cond_yellow:
        colour='gold'
    elif cond_red:
        colour='red'
    elif cond_green:
        colour='limegreen'
    elif cond_cyan:
        colour='cyan'
    else:
        colour='k'
    return colour

def plot_trajcluster(ax_hamsl, ax_TH, ax_Q, axmap, 
                     dat_mean, dat_N37, dat_Africa, delta_hamsl, delta_TH, delta_THE, delta_Q, delta_Q_rel,
                     colour, label, vline=None, step=24, put_text=True):
    ax_hamsl.plot(dat_N37.timestamp, dat_N37.hamsl,
                  c=colour, ls='-')
    ax_TH.plot(  dat_N37.timestamp, (dat_N37.TH[:] - dat_mean.TH[0]),
                  c=colour, ls=':')
    ax_Q.plot(   dat_N37.timestamp, (dat_N37.Q[:] - dat_mean.Q[0]), 
                  c=colour, ls='-.')

    ax_hamsl.plot(dat_Africa.timestamp, dat_Africa.hamsl,
                  c=colour, ls='', marker='.', alpha=.3, markersize=.5)
    ax_TH.plot( dat_Africa.timestamp, (dat_Africa.TH[:] - dat_mean.TH[0]), 
                  c=colour, ls='', marker='.',   alpha=.3, markersize=.5)
    ax_Q.plot(  dat_Africa.timestamp, (dat_Africa.Q[:] - dat_mean.Q[0]), 
                  c=colour, ls='', marker='.',   alpha=.3, markersize=.5)

    ax_TH.axhline(0,lw=.5, color='grey')
    if put_text:
        ax_hamsl.text(2,7000,
                    f"$\Delta$h a.m.s.l. = {(delta_hamsl/1e3):.2f} km \n$\Delta \\theta$    = {delta_TH:.2f} K \n$\Delta q$    = {delta_Q:.2f} gkg$^{-1}$\n$\Delta q_r$   = {delta_Q_rel:.2f}\n$\Delta \\theta_E$   = {delta_THE:.2f}K",
                    fontsize='small')
    ax_hamsl.text(110,10000,
                f"{label}", fontsize=6)
    if vline!=None:
        ax_hamsl.axvline(vline, color='grey', lw=.5, ls=':')
    
    axmap.plot(dat_N37.lon, dat_N37.lat,
               ls='-',marker='',lw=1,\
               transform=ccrs.PlateCarree(), color=colour)
    axmap.plot(dat_Africa.lon, dat_Africa.lat,
               ls='',marker='.',lw=1, alpha=.3, markersize=.5,
               transform=ccrs.PlateCarree(), color=colour)
    axmap.scatter(dat_mean.lon[::step], dat_mean.lat[::step],
                  c=np.arange(dat_mean.lon[::step].shape[0]) ,
                  cmap='Greys', edgecolor=colour,\
                  transform=ccrs.PlateCarree(),
                 s=10)
    axmap.text(dat_mean.lon[-1]+0.5, dat_mean.lat[-1]+0.5,
               f"{label}", color=colour,
               transform=ccrs.PlateCarree())