from pathlib import Path
import warnings

import geopandas as gpd
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import ListedColormap
import matplotlib.patches as mpatches
from matplotlib.ticker import FuncFormatter, FixedLocator
from pyproj import Geod
from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec


# ============================================================
# USER INSTRUCTIONS
# ============================================================
# Please edit ONLY the following sections as required:
#
# 1. INPUT FILE NAMES
#    Update the names of your shapefile and Excel datasets.
#
# 2. OUTPUT FILE NAMES
#    Modify the desired names for exported PNG and PDF figures.
#
# 3. FIGURE SIZE / DPI
#    Change figure width, height, and resolution if needed.
#
# 4. MAP / BOXPLOT LAYOUT SETTINGS
#    Adjust map extent, subplot spacing, and width ratios if needed.
#
# 5. COLUMN NAME SETTINGS
#    Change only if your file headers differ.
#
# 6. TYPOLOGY / CATEGORY LABELS
#    Update only if your dataset uses different group labels.
#
# 7. PALETTE / STYLE SETTINGS
#    Update only if needed for your own visualization style.
#
# Recommended folder structure:
# project/
# ├── Global_CDD_Map_Boxplot_Analysis.py
# ├── data/
# │   ├── world_boundary.shp
# │   ├── cdd_data.xlsx
# │   ├── city_coordinates.xlsx
# │   ├── simple_regression.xlsx
# │   └── log_regression.xlsx
# └── outputs/
# ============================================================


# ============================================================
# USER CUSTOMIZATION SECTION
# Edit ONLY this section for your own data / figure settings
# ============================================================

BASE_DIR = Path(__file__).resolve().parent
DATA_DIR = BASE_DIR / "data"
OUTPUT_DIR = BASE_DIR / "outputs"

# ----------------------------
# INPUT FILES
# ----------------------------
SHAPEFILE_PATH = DATA_DIR / "world_boundary.shp"
CDD_DATA_PATH = DATA_DIR / "cdd_data.xlsx"
CITY_DETAILS_PATH = DATA_DIR / "city_coordinates.xlsx"
SIMPLE_SLOPE_PATH = DATA_DIR / "simple_regression.xlsx"
LOG_SLOPE_PATH = DATA_DIR / "log_regression.xlsx"

# ----------------------------
# OUTPUT FILES
# ----------------------------
OUTPUT_PNG_NAME = "FIG_global_CDD.png"
OUTPUT_PDF_NAME = "FIG_global_CDD.pdf"

# ----------------------------
# FIGURE SIZE / DPI
# ----------------------------
FIG_WIDTH_MM = 180
FIG_HEIGHT_MM = 210
FIG_DPI = 600
OUTPUT_PDF_DPI = 450

# ----------------------------
# FONT SIZES
# ----------------------------
TITLE_FONTSIZE = 7
LABEL_FONTSIZE = 7
TICK_FONTSIZE = 7
COLORBAR_LABEL_FONTSIZE = 7
COLORBAR_TICK_FONTSIZE = 7
BOX_LABEL_FONTSIZE = 6.5
BOX_TICK_FONTSIZE = 6
ANNOTATION_FONTSIZE = 14
MAP_TICK_FONTSIZE = 7
N_FONTSIZE = 5.8
MANUAL_NAME_FONTSIZE = 6.1

# ----------------------------
# LAYOUT
# ----------------------------
MAIN_HSPACE = 0.20
MAIN_WSPACE = 0.30
BROKEN_WSPACE = 0.20

MAP_WIDTH_RATIO = 2.4
BOX_WIDTH_RATIO = 0.4

# ----------------------------
# PANEL OUTLINE CONTROL
# ----------------------------
SHOW_PANEL_OUTLINE = True
PANEL_OUTLINE_LW = 0.8
PANEL_OUTLINE_COLOR = "black"

MAP_PANEL_BOX_X0 = 0.00
MAP_PANEL_BOX_Y0 = 0.00
MAP_PANEL_BOX_W = 1.00
MAP_PANEL_BOX_H = 1.00

BOX_PANEL_BOX_X0 = 0.00
BOX_PANEL_BOX_Y0 = 0.00
BOX_PANEL_BOX_W = 1.00
BOX_PANEL_BOX_H = 1.00

# ----------------------------
# MAP EXTENT CONTROL
# ----------------------------
MAP_XMIN = -180
MAP_XMAX = 180
MAP_YMIN = -60
MAP_YMAX = 90

MAP_XTICKS = [-150, -100, -50, 0, 50, 100, 150]
MAP_YTICKS = [-50, -25, 0, 25, 50, 75]

# ----------------------------
# BOXPLOT CONTROL
# ----------------------------
BOX_WIDTH = 0.50

STRIP_SIZE = 3.0
STRIP_ALPHA = 0.95
STRIP_JITTER = 0.10
STRIP_EDGE_COLOR = "none"
STRIP_LINEWIDTH = 0.0

USE_SAME_COLOR_AS_PANEL_FOR_POINTS = True

SHOW_BROKEN_AXIS = True
BREAK_TRIGGER_RATIO = 2.8
LEFT_SEGMENT_QUANTILE = 0.90
RIGHT_SEGMENT_PADDING = 0.08
LEFT_SEGMENT_PADDING = 0.06
BREAK_MARK_SIZE = 0.014
BREAK_MARK_LW = 0.6

BROKEN_LEFT_RATIO = 0.76
BROKEN_RIGHT_RATIO = 0.24

# ----------------------------
# LABEL BOX CONTROL
# ----------------------------
LABEL_BOX_X = 0.50
LABEL_BOX_FACE_COLOR = "white"
LABEL_BOX_ALPHA = 0.55
LABEL_BOX_EDGE_COLOR = "none"
LABEL_BOX_PAD = 0.16
LABEL_TEXT_HA = "center"
LABEL_TEXT_VA = "center"
LABEL_Y_FINE_SHIFT = 0.00

# ----------------------------
# BUBBLE SIZE
# ----------------------------
MIN_BUBBLE_SIZE = 15
MAX_BUBBLE_SIZE = 50

# ----------------------------
# STYLE SETTINGS
# Update palette/settings only if needed
# ----------------------------
CMAP_AVG = ListedColormap(["#fff5e1", "#f9cdaa", "#f08a67", "#d83b3b", "#990000", "#660000"])
CMAP_TREND = ListedColormap(["#ffe6e6", "#ffb3b3", "#ff8080", "#ff4d4d", "#e60000", "#b30000", "#800000"])
CMAP_REL = ListedColormap([
    "#4d1c00",
    "#8c2d04",
    "#cc4c02",
    "#ec7014",
    "#fe9929",
    "#fec44f",
    "#fff7bc"
])

BOX_COLOR_AVG = "#9ecae1"
BOX_COLOR_TREND = "#f4a582"
BOX_COLOR_REL = "#d8b365"

# ----------------------------
# OPTIONAL COLUMN NAME SETTINGS
# Change only if your file headers differ
# ----------------------------
CITY_COL = "city"
LAT_COL = "latitude"
LON_COL = "longitude"
TYPOLOGY_COL = "typology"
CONTINENT_COL = "continent"
AVG_CDD_COL = "average_cdd"
ABS_TREND_COL = "actual_increase_per_year"
REL_TREND_COL = "relative_increase_per_year_(%)"


# ============================================================
# GLOBAL SETTINGS
# ============================================================

warnings.filterwarnings("ignore", category=UserWarning)

MM_TO_IN = 0.0393701
FIG_W = FIG_WIDTH_MM * MM_TO_IN
FIG_H = FIG_HEIGHT_MM * MM_TO_IN

plt.rcParams["font.family"] = "Arial"
sns.set(style="white")

GEOD = Geod(ellps="WGS84")


# ============================================================
# HELPER FUNCTIONS
# ============================================================

def standardize_columns(df):
    df = df.copy()
    df.columns = df.columns.astype(str).str.strip().str.replace(" ", "_", regex=False).str.lower()
    return df


def clean_text_column(df, col):
    df = df.copy()
    if col in df.columns:
        df[col] = df[col].astype(str).str.strip()
    return df


def clean_city_names(df, city_col=CITY_COL):
    df = df.copy()
    if city_col in df.columns:
        df[city_col] = df[city_col].astype(str).str.strip().str.lower()
    return df


def validate_files():
    required_files = [
        SHAPEFILE_PATH,
        CDD_DATA_PATH,
        CITY_DETAILS_PATH,
        SIMPLE_SLOPE_PATH,
        LOG_SLOPE_PATH,
    ]
    missing = [str(p) for p in required_files if not p.exists()]
    if missing:
        raise FileNotFoundError("Missing required files:\n" + "\n".join(missing))


def check_required_columns(df, required_cols, df_name):
    missing = [c for c in required_cols if c not in df.columns]
    if missing:
        raise KeyError(
            f"Missing required column(s) in '{df_name}': {missing}\n"
            f"Available columns: {list(df.columns)}"
        )


def normalize_sizes(values, min_size=MIN_BUBBLE_SIZE, max_size=MAX_BUBBLE_SIZE):
    vals = np.array(values, dtype=float)
    finite = np.isfinite(vals)

    if not np.any(finite):
        return np.full(len(vals), (min_size + max_size) / 2)

    min_val = np.nanmin(vals[finite])
    max_val = np.nanmax(vals[finite])
    range_val = max_val - min_val

    if range_val == 0:
        out = np.full(len(vals), (min_size + max_size) / 2)
        out[~finite] = min_size
        return out

    normalized = (vals - min_val) / range_val
    out = min_size + normalized * (max_size - min_size)
    out[~finite] = min_size
    return out


def lon_formatter(x, pos):
    if x < 0:
        return f"{abs(int(x))}°W"
    if x > 0:
        return f"{int(x)}°E"
    return "0°"


def lat_formatter(y, pos):
    if y < 0:
        return f"{abs(int(y))}°S"
    if y > 0:
        return f"{int(y)}°N"
    return "0°"


def add_colorbar(fig, ax, mappable, label, pad=0.006, width=0.012):
    pos = ax.get_position()
    cax = fig.add_axes([pos.x1 + pad, pos.y0, width, pos.height])
    cbar = fig.colorbar(mappable, cax=cax)
    cbar.set_label(label, fontsize=COLORBAR_LABEL_FONTSIZE)
    cbar.ax.tick_params(labelsize=COLORBAR_TICK_FONTSIZE)
    for spine in cax.spines.values():
        spine.set_linewidth(0.3)
    return cbar


def set_map_extent(ax, xmin=MAP_XMIN, xmax=MAP_XMAX, ymin=MAP_YMIN, ymax=MAP_YMAX):
    ax.set_xlim([xmin, xmax])
    ax.set_ylim([ymin, ymax])

    ax.xaxis.set_major_locator(FixedLocator(MAP_XTICKS))
    ax.yaxis.set_major_locator(FixedLocator(MAP_YTICKS))
    ax.xaxis.set_major_formatter(FuncFormatter(lon_formatter))
    ax.yaxis.set_major_formatter(FuncFormatter(lat_formatter))

    ax.tick_params(axis="x", labelsize=MAP_TICK_FONTSIZE, pad=1)
    ax.tick_params(axis="y", labelsize=MAP_TICK_FONTSIZE, pad=1)

    ax.set_xlabel("Longitude", fontsize=LABEL_FONTSIZE)
    ax.set_ylabel("Latitude", fontsize=LABEL_FONTSIZE)
    ax.set_aspect("auto")


def add_panel_outline(ax, x0, y0, w, h, lw=PANEL_OUTLINE_LW, color=PANEL_OUTLINE_COLOR):
    if not SHOW_PANEL_OUTLINE:
        return
    rect = mpatches.Rectangle(
        (x0, y0), w, h,
        transform=ax.transAxes,
        fill=False,
        edgecolor=color,
        linewidth=lw,
        zorder=100,
        clip_on=False
    )
    ax.add_patch(rect)


def add_map_outline(ax):
    add_panel_outline(ax, MAP_PANEL_BOX_X0, MAP_PANEL_BOX_Y0, MAP_PANEL_BOX_W, MAP_PANEL_BOX_H)


def draw_box_outline_sides(
    ax, show_left=True, show_right=True, show_top=True, show_bottom=True,
    x0=BOX_PANEL_BOX_X0, y0=BOX_PANEL_BOX_Y0, w=BOX_PANEL_BOX_W, h=BOX_PANEL_BOX_H,
    lw=PANEL_OUTLINE_LW, color=PANEL_OUTLINE_COLOR
):
    if not SHOW_PANEL_OUTLINE:
        return

    x1 = x0 + w
    y1 = y0 + h

    if show_bottom:
        ax.plot([x0, x1], [y0, y0], transform=ax.transAxes, color=color, lw=lw, clip_on=False, zorder=100)
    if show_top:
        ax.plot([x0, x1], [y1, y1], transform=ax.transAxes, color=color, lw=lw, clip_on=False, zorder=100)
    if show_left:
        ax.plot([x0, x0], [y0, y1], transform=ax.transAxes, color=color, lw=lw, clip_on=False, zorder=100)
    if show_right:
        ax.plot([x1, x1], [y0, y1], transform=ax.transAxes, color=color, lw=lw, clip_on=False, zorder=100)


def add_geodesic_scale_bar(ax, total_length_km=5000, n_segments=2,
                           location=(0.60, 0.10), bar_height_deg=0.35,
                           fontsize=6, line_width=0.6):
    x0, x1 = ax.get_xlim()
    y0, y1 = ax.get_ylim()
    center_lat = (y0 + y1) / 2
    center_lon = (x0 + x1) / 2

    try:
        _, _, m_per_deg = GEOD.inv(center_lon, center_lat, center_lon + 1, center_lat)
        km_per_deg = m_per_deg / 1000
    except Exception:
        km_per_deg = 111.32 * np.cos(np.deg2rad(center_lat))

    total_deg = total_length_km / km_per_deg
    seg_deg = total_deg / n_segments

    ax_x, ax_y = location
    left_x = x0 + ax_x * (x1 - x0)
    bottom_y = y0 + ax_y * (y1 - y0)

    colors = ["black", "white"]
    for i in range(n_segments):
        rect = mpatches.Rectangle(
            (left_x + i * seg_deg, bottom_y),
            seg_deg, bar_height_deg,
            facecolor=colors[i % 2],
            edgecolor="black",
            linewidth=line_width,
            transform=ax.transData,
            zorder=10
        )
        ax.add_patch(rect)

        ax.text(
            left_x + (i + 1) * seg_deg,
            bottom_y - bar_height_deg * 0.6,
            f"{(i + 1) * (total_length_km // n_segments)}",
            fontsize=6,
            ha="center", va="top",
            transform=ax.transData
        )

    ax.text(
        left_x,
        bottom_y - bar_height_deg * 0.6,
        "0",
        fontsize=6,
        ha="center", va="top",
        transform=ax.transData
    )

    ax.text(
        left_x + total_deg / 2,
        bottom_y + bar_height_deg * 1.25,
        "Distance\n(km)",
        fontsize=6,
        ha="center", va="bottom",
        transform=ax.transData,
        bbox=dict(facecolor="white", alpha=0.0, edgecolor="none")
    )


def need_axis_break(vals, ratio=BREAK_TRIGGER_RATIO):
    vals = np.asarray(vals, dtype=float)
    vals = vals[np.isfinite(vals)]
    if len(vals) < 5:
        return False
    q3 = np.percentile(vals, 75)
    vmax = np.max(vals)
    if q3 == 0:
        return False
    return vmax > ratio * q3


def get_break_limits(vals):
    vals = np.asarray(vals, dtype=float)
    vals = vals[np.isfinite(vals)]
    if len(vals) == 0:
        return (0, 1), (1, 2)

    left_end = np.percentile(vals, LEFT_SEGMENT_QUANTILE * 100)
    vmax = np.max(vals)

    left_min = 0.0

    left_pad_right = (left_end - left_min) * LEFT_SEGMENT_PADDING if left_end > left_min else 0.1
    right_pad = (vmax - left_end) * RIGHT_SEGMENT_PADDING if vmax > left_end else 0.1

    left_xlim = (0.0, left_end + left_pad_right * 0.3)
    right_xlim = (max(left_end, vmax - (vmax - left_end) * 0.25), vmax + right_pad)

    if left_xlim[1] <= left_xlim[0]:
        left_xlim = (0.0, max(1.0, vmax * 0.8))

    if right_xlim[0] < 0:
        right_xlim = (0.0, right_xlim[1])

    if right_xlim[0] >= right_xlim[1]:
        right_xlim = (max(0.0, vmax - 0.1), vmax + 0.1)

    return left_xlim, right_xlim


def draw_break_marks(ax, where="right", size=BREAK_MARK_SIZE, lw=BREAK_MARK_LW):
    kwargs = dict(transform=ax.transAxes, color="black", clip_on=False, linewidth=lw)
    if where == "right":
        ax.plot((1 - size, 1 + size), (-size, +size), **kwargs)
        ax.plot((1 - size, 1 + size), (1 - size, 1 + size), **kwargs)
    elif where == "left":
        ax.plot((-size, +size), (-size, +size), **kwargs)
        ax.plot((-size, +size), (1 - size, 1 + size), **kwargs)


def fix_broken_axis_ticks(ax_left, ax_right):
    left_ticks = ax_left.get_xticks()
    if len(left_ticks) > 1:
        left_ticks = [t for t in left_ticks if t >= 0]
        if len(left_ticks) > 1:
            ax_left.set_xticks(left_ticks[:-1])
        else:
            ax_left.set_xticks(left_ticks)

    _, x1 = ax_right.get_xlim()
    ax_right.set_xticks([x1])


def add_internal_names_and_n(ax, order, counts, x_text=LABEL_BOX_X):
    n_rows = len(order)
    ax.set_yticks([])
    ax.tick_params(axis="y", left=False, labelleft=False)

    for i, name in enumerate(order):
        y_frac = 1 - ((i + 0.5) / n_rows) + LABEL_Y_FINE_SHIFT
        label_text = f"{name}\n n={counts.get(name, 0)}"

        ax.text(
            x_text, y_frac, label_text,
            transform=ax.transAxes,
            ha=LABEL_TEXT_HA,
            va=LABEL_TEXT_VA,
            fontsize=MANUAL_NAME_FONTSIZE,
            zorder=8,
            bbox=dict(
                facecolor=LABEL_BOX_FACE_COLOR,
                edgecolor=LABEL_BOX_EDGE_COLOR,
                alpha=LABEL_BOX_ALPHA,
                boxstyle=f"round,pad={LABEL_BOX_PAD}"
            )
        )


def style_box_axis(ax, show_left_labels=False):
    ax.set_xlabel("")
    ax.set_ylabel("")
    ax.tick_params(axis="x", labelsize=BOX_TICK_FONTSIZE, length=2, pad=1)
    ax.tick_params(axis="y", length=0)
    ax.grid(axis="x", linestyle="--", linewidth=0.3, alpha=0.35)
    ax.grid(axis="y", visible=False)

    for side in ["top", "right", "left", "bottom"]:
        ax.spines[side].set_linewidth(0.35)
        ax.spines[side].set_visible(True)

    if not show_left_labels:
        ax.tick_params(axis="y", left=False, labelleft=False)

    ax.set_aspect("auto")


def overlay_stripplot(ax, plot_df, value_col, order, point_color):
    sns.stripplot(
        data=plot_df,
        x=value_col,
        y=CONTINENT_COL,
        order=order,
        ax=ax,
        color=point_color,
        size=STRIP_SIZE,
        alpha=STRIP_ALPHA,
        jitter=STRIP_JITTER,
        orient="h",
        edgecolor=STRIP_EDGE_COLOR,
        linewidth=STRIP_LINEWIDTH,
        zorder=4
    )


def add_row_boxplot_broken(fig, subspec, data, value_col, order=None, base_color="skyblue"):
    plot_df = data[[CONTINENT_COL, value_col]].copy()
    plot_df = plot_df.replace([np.inf, -np.inf], np.nan).dropna()

    if order is None:
        order = sorted(plot_df[CONTINENT_COL].dropna().unique().tolist())

    counts = plot_df.groupby(CONTINENT_COL)[value_col].count().to_dict()
    all_vals = plot_df[value_col].values

    point_color = base_color if USE_SAME_COLOR_AS_PANEL_FOR_POINTS else "#4d4d4d"
    use_break = SHOW_BROKEN_AXIS and need_axis_break(all_vals)

    if use_break:
        inner = GridSpecFromSubplotSpec(
            1, 2,
            subplot_spec=subspec,
            width_ratios=[BROKEN_LEFT_RATIO, BROKEN_RIGHT_RATIO],
            wspace=BROKEN_WSPACE,
            hspace=0.0
        )

        ax_left = fig.add_subplot(inner[0, 0])
        ax_right = fig.add_subplot(inner[0, 1], sharey=ax_left)

        ax_left.set_aspect("auto")
        ax_right.set_aspect("auto")

        left_xlim, right_xlim = get_break_limits(all_vals)

        for ax in [ax_left, ax_right]:
            sns.boxplot(
                data=plot_df,
                x=value_col,
                y=CONTINENT_COL,
                order=order,
                ax=ax,
                color=base_color,
                width=BOX_WIDTH,
                fliersize=0,
                linewidth=0.6
            )
            overlay_stripplot(ax, plot_df, value_col, order, point_color)

        ax_left.set_xlim(*left_xlim)
        ax_right.set_xlim(*right_xlim)

        style_box_axis(ax_left, show_left_labels=False)
        style_box_axis(ax_right, show_left_labels=False)

        fix_broken_axis_ticks(ax_left, ax_right)

        ax_left.spines["right"].set_visible(False)
        ax_right.spines["left"].set_visible(False)

        draw_break_marks(ax_left, where="right")
        draw_break_marks(ax_right, where="left")

        ax_right.tick_params(axis="y", left=False, labelleft=False)

        add_internal_names_and_n(ax_left, order, counts, x_text=LABEL_BOX_X)

        draw_box_outline_sides(ax_left, show_left=True, show_right=False, show_top=True, show_bottom=True)
        draw_box_outline_sides(ax_right, show_left=False, show_right=True, show_top=True, show_bottom=True)

        return ax_left, ax_right

    else:
        ax = fig.add_subplot(subspec)
        ax.set_aspect("auto")

        sns.boxplot(
            data=plot_df,
            x=value_col,
            y=CONTINENT_COL,
            order=order,
            ax=ax,
            color=base_color,
            width=BOX_WIDTH,
            fliersize=0,
            linewidth=0.6
        )

        overlay_stripplot(ax, plot_df, value_col, order, point_color)

        ax.set_xlim(left=0)

        style_box_axis(ax, show_left_labels=False)
        add_internal_names_and_n(ax, order, counts, x_text=LABEL_BOX_X)

        draw_box_outline_sides(ax, show_left=True, show_right=True, show_top=True, show_bottom=True)

        return ax, None


# ============================================================
# DATA LOADING AND PREPARATION
# ============================================================

def load_and_prepare_data():
    print("Loading shapefile...")
    gdf = gpd.read_file(SHAPEFILE_PATH)

    print("Loading CDD data...")
    cdd_data = standardize_columns(pd.read_excel(CDD_DATA_PATH))

    print("Loading city details...")
    city_details = standardize_columns(pd.read_excel(CITY_DETAILS_PATH))

    print("Loading regression results...")
    simple_slope = standardize_columns(pd.read_excel(SIMPLE_SLOPE_PATH))
    log_slope = standardize_columns(pd.read_excel(LOG_SLOPE_PATH))

    city_details = clean_city_names(city_details)
    cdd_data = clean_city_names(cdd_data)
    simple_slope = clean_city_names(simple_slope)
    log_slope = clean_city_names(log_slope)

    check_required_columns(city_details, [CITY_COL, LAT_COL, LON_COL], "city_details")
    check_required_columns(simple_slope, [CITY_COL], "simple_regression")
    check_required_columns(log_slope, [CITY_COL, REL_TREND_COL], "log_regression")

    if TYPOLOGY_COL not in city_details.columns:
        city_details[TYPOLOGY_COL] = "Unknown"

    if CONTINENT_COL not in city_details.columns:
        if CONTINENT_COL in simple_slope.columns:
            city_details = city_details.merge(simple_slope[[CITY_COL, CONTINENT_COL]], on=CITY_COL, how="left")
        elif CONTINENT_COL in log_slope.columns:
            city_details = city_details.merge(log_slope[[CITY_COL, CONTINENT_COL]], on=CITY_COL, how="left")
        else:
            city_details[CONTINENT_COL] = "Unknown"

    merged_df = city_details[[CITY_COL, LAT_COL, LON_COL, TYPOLOGY_COL, CONTINENT_COL]].copy()
    merged_df = merged_df.merge(simple_slope, on=CITY_COL, how="left", suffixes=("", "_simple"))
    merged_df = merged_df.merge(log_slope, on=CITY_COL, how="left", suffixes=("", "_log"))

    exceed_cols = [c for c in cdd_data.columns if c.startswith("days_exceeding_22_")]
    if exceed_cols:
        cdd_data["avg_exceedance"] = cdd_data[exceed_cols].mean(axis=1)
    else:
        cdd_data["avg_exceedance"] = np.nan

    if CITY_COL in cdd_data.columns:
        merged_df = merged_df.merge(cdd_data[[CITY_COL, "avg_exceedance"]], on=CITY_COL, how="left")

    if AVG_CDD_COL not in merged_df.columns:
        merged_df[AVG_CDD_COL] = merged_df.get("avg_exceedance", np.nan)

    if REL_TREND_COL not in merged_df.columns:
        raise KeyError(f"Column '{REL_TREND_COL}' not found in merged dataframe.")

    if ABS_TREND_COL not in merged_df.columns:
        raise KeyError(f"Column '{ABS_TREND_COL}' not found in merged dataframe.")

    continent_order = sorted(merged_df[CONTINENT_COL].dropna().unique().tolist())

    return gdf, merged_df, continent_order


# ============================================================
# FIGURE GENERATION
# ============================================================

def create_figure(gdf, merged_df, continent_order):
    fig = plt.figure(figsize=(FIG_W, FIG_H), dpi=FIG_DPI)

    gs = GridSpec(
        nrows=3, ncols=2, figure=fig,
        width_ratios=[MAP_WIDTH_RATIO, BOX_WIDTH_RATIO],
        height_ratios=[1, 1, 1],
        hspace=MAIN_HSPACE,
        wspace=MAIN_WSPACE
    )

    ax_map1 = fig.add_subplot(gs[0, 0])
    ax_map2 = fig.add_subplot(gs[1, 0])
    ax_map3 = fig.add_subplot(gs[2, 0])

    ax_map1.set_aspect("auto")
    ax_map2.set_aspect("auto")
    ax_map3.set_aspect("auto")

    # Panel a
    gdf.plot(ax=ax_map1, color="lightgrey", edgecolor="white", linewidth=0.3)
    bubble_sizes_1 = normalize_sizes(merged_df[AVG_CDD_COL])
    sc1 = ax_map1.scatter(
        merged_df[LON_COL], merged_df[LAT_COL],
        c=merged_df[AVG_CDD_COL], s=bubble_sizes_1,
        cmap=CMAP_AVG, edgecolor="grey", linewidth=0.2, alpha=0.85
    )
    set_map_extent(ax_map1)
    add_colorbar(fig, ax_map1, sc1, "Average CDD (°C-year)")
    add_geodesic_scale_bar(ax_map1)
    ax_map1.text(-0.05, 1.00, "a", transform=ax_map1.transAxes, fontsize=ANNOTATION_FONTSIZE, fontweight="bold")
    add_map_outline(ax_map1)
    add_row_boxplot_broken(fig, gs[0, 1], merged_df, AVG_CDD_COL, order=continent_order, base_color=BOX_COLOR_AVG)

    # Panel b
    gdf.plot(ax=ax_map2, color="lightgrey", edgecolor="white", linewidth=0.3)
    bubble_sizes_2 = normalize_sizes(merged_df[ABS_TREND_COL])
    sc2 = ax_map2.scatter(
        merged_df[LON_COL], merged_df[LAT_COL],
        c=merged_df[ABS_TREND_COL], s=bubble_sizes_2,
        cmap=CMAP_TREND, edgecolor="grey", linewidth=0.2, alpha=0.85
    )
    set_map_extent(ax_map2)
    add_colorbar(fig, ax_map2, sc2, "Absolute rise in CDD (°C-year/year)")
    add_geodesic_scale_bar(ax_map2)
    ax_map2.text(-0.05, 1.00, "b", transform=ax_map2.transAxes, fontsize=ANNOTATION_FONTSIZE, fontweight="bold")
    add_map_outline(ax_map2)
    add_row_boxplot_broken(fig, gs[1, 1], merged_df, ABS_TREND_COL, order=continent_order, base_color=BOX_COLOR_TREND)

    # Panel c
    gdf.plot(ax=ax_map3, color="lightgrey", edgecolor="white", linewidth=0.3)
    bubble_sizes_3 = normalize_sizes(merged_df[REL_TREND_COL])
    sc3 = ax_map3.scatter(
        merged_df[LON_COL], merged_df[LAT_COL],
        c=merged_df[REL_TREND_COL], s=bubble_sizes_3,
        cmap=CMAP_REL, edgecolor="grey", linewidth=0.2, alpha=0.85
    )
    set_map_extent(ax_map3)
    add_colorbar(fig, ax_map3, sc3, "Relative rise in CDD (%/year)")
    add_geodesic_scale_bar(ax_map3)
    ax_map3.text(-0.05, 1.00, "c", transform=ax_map3.transAxes, fontsize=ANNOTATION_FONTSIZE, fontweight="bold")
    add_map_outline(ax_map3)
    add_row_boxplot_broken(fig, gs[2, 1], merged_df, REL_TREND_COL, order=continent_order, base_color=BOX_COLOR_REL)

    return fig


# ============================================================
# MAIN
# ============================================================

def main():
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    validate_files()

    gdf, merged_df, continent_order = load_and_prepare_data()
    fig = create_figure(gdf, merged_df, continent_order)

    out_png = OUTPUT_DIR / OUTPUT_PNG_NAME
    out_pdf = OUTPUT_DIR / OUTPUT_PDF_NAME

    fig.savefig(out_png, dpi=FIG_DPI, bbox_inches="tight")
    fig.savefig(out_pdf, dpi=OUTPUT_PDF_DPI, bbox_inches="tight", format="pdf")

    plt.show()

    print("Saved PNG:", out_png)
    print("Saved PDF:", out_pdf)


if __name__ == "__main__":
    main()