This script:
- reads two paired CDD scenario datasets (CDD_1 and CDD_2)
- averages corresponding outputs
- merges with city coordinates and typology information
- generates publication-ready maps and boxplots
- saves the final figure as PNG and PDF


from pathlib import Path
import warnings

import geopandas as gpd
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import ListedColormap
import matplotlib.patches as mpatches
from matplotlib.patches import FancyArrow
from matplotlib.ticker import FuncFormatter
from pyproj import Geod


# ============================================================
# USER INSTRUCTIONS
# ============================================================
# Please edit ONLY the following sections as required:
#
# 1. INPUT FILE NAMES
#    Update the names of your input Excel and shapefile datasets.
#
# 2. OUTPUT FILE NAMES
#    Modify the desired names for exported PNG and PDF figures.
#
# 3. FIGURE SIZE / DPI
#    Change figure width, height, and resolution if needed.
#
# 4. BUBBLE SIZE / MARKER SETTINGS
#    Adjust the minimum and maximum bubble sizes for map plots.
#
# 5. TYPOLOGY / CATEGORY LABELS
#    If your dataset uses different class names, update them here.
#
# 6. PALETTE / STYLE SETTINGS
#    Update only if required for your own visualization style.
#
# Recommended folder structure:
# project/
# ├── India_Typology_CDD_Analysis.py
# ├── data/
# │   ├── CDD_1.xlsx
# │   ├── CDD_1_log.xlsx
# │   ├── CDD_2.xlsx
# │   ├── CDD_2_log.xlsx
# │   ├── city_coordinates.xlsx
# │   ├── city_typology.xlsx
# │   └── STATE_BOUNDARY.shp
# └── outputs/
# ============================================================


# ============================================================

BASE_DIR = Path(__file__).resolve().parent
DATA_DIR = BASE_DIR / "data"
OUTPUT_DIR = BASE_DIR / "outputs"

# ----------------------------
# INPUT FILES
# ----------------------------
SHAPEFILE_PATH = DATA_DIR / "STATE_BOUNDARY.shp"
CITY_DETAILS_PATH = DATA_DIR / "city_coordinates.xlsx"
UPDATED_TYPOLOGY_PATH = DATA_DIR / "city_typology.xlsx"

CDD_1_SIMPLE_PATH = DATA_DIR / "CDD_1.xlsx"
CDD_1_LOG_PATH = DATA_DIR / "CDD_1_log.xlsx"

CDD_2_SIMPLE_PATH = DATA_DIR / "CDD_2.xlsx"
CDD_2_LOG_PATH = DATA_DIR / "CDD_2_log.xlsx"

# ----------------------------
# OUTPUT FILES
# ----------------------------
OUTPUT_PNG_NAME = "FIG_CDD_avg_CDD1_CDD2.png"
OUTPUT_PDF_NAME = "FIG_CDD_avg_CDD1_CDD2.pdf"

# ----------------------------
# FIGURE SIZE / DPI
# ----------------------------
FIG_WIDTH_MM = 180
FIG_HEIGHT_MM = 210
FIG_DPI = 600

# ----------------------------
# BUBBLE SIZE
# ----------------------------
MIN_BUBBLE_SIZE = 15
MAX_BUBBLE_SIZE = 50

# ----------------------------
# TYPOLOGY LABELS / STYLE MAP
# Update labels if your dataset uses different class names
# Update palette/settings only if needed
# ----------------------------
DEFAULT_TYPOLOGY_COLORS = {
    "TYPE-I": "#d95f02",
    "TYPE-II": "#1b9e77",
    "TYPE-III": "#7570b3",
    "TYPE-IV": "#e7298a"
}

UPDATED_TYPOLOGY_PALETTE = [
    "#9D4EDD",
    "#00B4D8",
    "#3A86FF",
    "#FF006E",
    "#FF9F1C",
    "#00D4A6",
    "#8E7DBE",
    "#5F0F40",
    "#2A9D8F",
    "#BC4749"
]

# ----------------------------
# OPTIONAL COLUMN NAME SETTINGS
# Change only if your file headers differ
# ----------------------------
CITY_COL = "city"
LAT_COL = "latitude"
LON_COL = "longitude"
TYPOLOGY_COL = "typology"
AVG_CDD_COL = "average_cdd"
TREND_SLOPE_COL = "trend_slope"
REL_INCREASE_COL = "relative_increase_per_year_(%)"


# ============================================================
# STYLE AND GLOBAL SETTINGS
# ============================================================

warnings.filterwarnings("ignore", category=UserWarning)

MM_TO_IN = 0.0393701
FIG_W = FIG_WIDTH_MM * MM_TO_IN
FIG_H = FIG_HEIGHT_MM * MM_TO_IN

plt.rcParams["font.family"] = "Arial"
sns.set(style="white")

plt.rcParams["xtick.direction"] = "in"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["xtick.major.size"] = 1.5
plt.rcParams["ytick.major.size"] = 1.5
plt.rcParams["xtick.major.width"] = 0.2
plt.rcParams["ytick.major.width"] = 0.2
plt.rcParams["xtick.minor.size"] = 0
plt.rcParams["ytick.minor.size"] = 0

AXIS_LABEL_FONTSIZE = 6
TICK_LABEL_FONTSIZE = 6
LEGEND_FONTSIZE = 5
LEGEND_TITLE_FONTSIZE = 6
COLORBAR_LABEL_FONTSIZE = 7
COLORBAR_TICK_FONTSIZE = 6
SUBPLOT_LETTER_FONTSIZE = 14
BOXPLOT_TICK_FONTSIZE = 5
N_INSIDE_FONTSIZE = 4.8

OUTER_WSPACE = 0.14
OUTER_HSPACE = 0.04
INNER_WSPACE = 0.5
INNER_HSPACE = 0.28

MAP_BOX_WIDTH_RATIO = [2.55, 0.95]
PANEL_LEGEND_HEIGHT_RATIO = [1.0, 0.18]

COLORBAR_GAP = 0.004
COLORBAR_WIDTH = 0.014

BOX_SCATTER_SIZE = 10
BOX_SCATTER_ALPHA = 0.85
BOX_JITTER = 0.10

STATE_BOUNDARY_COLOR = "#c7c7c7"
STATE_BOUNDARY_WIDTH = 0.25
INDIA_OUTLINE_COLOR = "black"
INDIA_OUTLINE_WIDTH = 1.0
AXES_FACECOLOR = "white"

CMAP_AVG = ListedColormap(["#fff5e1", "#f9cdaa", "#f08a67", "#d83b3b", "#990000", "#660000"])
CMAP_TREND = ListedColormap(["#ffe6e6", "#ffb3b3", "#ff8080", "#ff4d4d", "#e60000", "#b30000", "#800000"])
CMAP_REL = ListedColormap([
    "#4d1c00",
    "#8c2d04",
    "#cc4c02",
    "#ec7014",
    "#fe9929",
    "#fec44f",
    "#fff7bc"
])

GEOD = Geod(ellps="WGS84")


# ============================================================
# HELPER FUNCTIONS
# ============================================================

def standardize_columns(df: pd.DataFrame) -> pd.DataFrame:
    """Standardize column names for reliable matching."""
    df = df.copy()
    df.columns = (
        df.columns.astype(str)
        .str.strip()
        .str.replace(" ", "_", regex=False)
        .str.lower()
    )
    return df


def clean_city_names(df: pd.DataFrame, city_col: str = CITY_COL) -> pd.DataFrame:
    """Clean city names for safe merging."""
    df = df.copy()
    if city_col in df.columns:
        df[city_col] = (
            df[city_col]
            .astype(str)
            .str.strip()
            .str.lower()
        )
    return df


def check_required_columns(df: pd.DataFrame, required_cols: list, df_name: str) -> None:
    """Raise clear error if required columns are missing."""
    missing = [col for col in required_cols if col not in df.columns]
    if missing:
        raise KeyError(
            f"Missing required column(s) in '{df_name}': {missing}\n"
            f"Available columns: {list(df.columns)}"
        )


def validate_files() -> None:
    """Validate existence of all required input files."""
    required_files = [
        SHAPEFILE_PATH,
        CITY_DETAILS_PATH,
        UPDATED_TYPOLOGY_PATH,
        CDD_1_SIMPLE_PATH,
        CDD_1_LOG_PATH,
        CDD_2_SIMPLE_PATH,
        CDD_2_LOG_PATH,
    ]
    missing = [str(path) for path in required_files if not path.exists()]
    if missing:
        raise FileNotFoundError(
            "The following required file(s) are missing:\n" + "\n".join(missing)
        )


def lon_formatter(x, pos):
    return f"{abs(x):.0f}°W" if x < 0 else f"{abs(x):.0f}°E"


def lat_formatter(x, pos):
    return f"{abs(x):.0f}°S" if x < 0 else f"{abs(x):.0f}°N"


def normalize_sizes(values, min_size=MIN_BUBBLE_SIZE, max_size=MAX_BUBBLE_SIZE):
    vals = np.array(values, dtype=float)

    if vals.size == 0:
        return np.array([])
    if np.all(np.isnan(vals)):
        return np.full(vals.shape, (min_size + max_size) / 2.0)

    vmin, vmax = np.nanmin(vals), np.nanmax(vals)
    if vmax - vmin == 0:
        return np.full(vals.shape, (min_size + max_size) / 2.0)

    return min_size + (vals - vmin) / (vmax - vmin) * (max_size - min_size)


def add_colorbar(fig, ax, mappable, label):
    pos = ax.get_position()
    cax = fig.add_axes([pos.x1 + COLORBAR_GAP, pos.y0, COLORBAR_WIDTH, pos.height])
    cbar = fig.colorbar(mappable, cax=cax)
    cbar.set_label(label, fontsize=COLORBAR_LABEL_FONTSIZE)
    cbar.ax.tick_params(labelsize=COLORBAR_TICK_FONTSIZE)
    cbar.outline.set_linewidth(0.5)


def nice_round_value(x, mode="nearest"):
    if not np.isfinite(x):
        return x

    ax = abs(x)
    if ax == 0:
        return 0

    if ax < 1:
        step = 0.1
    elif ax < 10:
        step = 0.5
    elif ax < 100:
        step = 1
    elif ax < 1000:
        step = 5
    else:
        step = 10

    if mode == "floor":
        return np.floor(x / step) * step
    if mode == "ceil":
        return np.ceil(x / step) * step
    return np.round(x / step) * step


def add_geodesic_scale_bar(ax, total_length_km=500, n_segments=2,
                           location=(0.66, 0.08), bar_height_frac=0.012,
                           fontsize=5.5, line_width=0.55):
    x0, x1 = ax.get_xlim()
    y0, y1 = ax.get_ylim()
    center_lat = (y0 + y1) / 2
    center_lon = (x0 + x1) / 2

    _, _, m_per_deg = GEOD.inv(center_lon, center_lat, center_lon + 1, center_lat)
    km_per_deg = m_per_deg / 1000
    if km_per_deg == 0:
        km_per_deg = 111.32 * np.cos(np.deg2rad(center_lat))

    total_deg = total_length_km / km_per_deg
    seg_deg = total_deg / n_segments

    ax_frac_x, ax_frac_y = location
    left_x = x0 + ax_frac_x * (x1 - x0)
    bottom_y = y0 + ax_frac_y * (y1 - y0)
    bar_height_deg = bar_height_frac * (y1 - y0)

    colors = ["black", "white"]
    for i in range(n_segments):
        rect_x = left_x + i * seg_deg
        rect = mpatches.Rectangle(
            (rect_x, bottom_y), seg_deg, bar_height_deg,
            facecolor=colors[i % 2], edgecolor="black",
            linewidth=line_width, transform=ax.transData, zorder=10
        )
        ax.add_patch(rect)

        cum_km = (i + 1) * (total_length_km / n_segments)
        ax.text(
            rect_x + seg_deg, bottom_y - bar_height_deg * 0.55,
            f"{int(cum_km)}",
            fontsize=fontsize, ha="center", va="top",
            transform=ax.transData
        )

    ax.text(
        left_x, bottom_y - bar_height_deg * 0.55, "0",
        fontsize=fontsize, ha="center", va="top",
        transform=ax.transData
    )

    ax.text(
        left_x + total_deg / 2, bottom_y + bar_height_deg * 1.45,
        "Distance (km)",
        fontsize=fontsize, ha="center", va="bottom",
        transform=ax.transData
    )


def add_north_arrow(ax, location=(0.94, 0.88), fontsize=7):
    ax.annotate(
        "N", xy=location, xycoords="axes fraction",
        fontsize=fontsize, fontweight="bold",
        ha="center", va="center"
    )
    ax.add_patch(
        FancyArrow(
            location[0], location[1] + 0.04,
            0, 0.05, width=0.01,
            transform=ax.transAxes, color="black"
        )
    )


def draw_india_base(ax, gdf, india_outline):
    ax.set_facecolor(AXES_FACECOLOR)

    gdf.plot(
        ax=ax,
        facecolor="none",
        edgecolor=STATE_BOUNDARY_COLOR,
        linewidth=STATE_BOUNDARY_WIDTH,
        zorder=0.5
    )

    india_outline.boundary.plot(
        ax=ax,
        color=INDIA_OUTLINE_COLOR,
        linewidth=INDIA_OUTLINE_WIDTH,
        zorder=1.0
    )


def add_horizontal_boxplot(ax, data, y_col, typology_col, color_map):
    typologies = sorted(data[typology_col].dropna().unique())

    if len(typologies) == 0:
        ax.text(0.5, 0.5, "No data", ha="center", va="center", fontsize=5)
        ax.set_xticks([])
        ax.set_yticks([])
        for spine in ax.spines.values():
            spine.set_linewidth(0.6)
        return

    box_data = [data.loc[data[typology_col] == t, y_col].dropna().values for t in typologies]
    positions = np.arange(1, len(typologies) + 1)

    all_vals = (
        np.concatenate([v for v in box_data if len(v) > 0])
        if any(len(v) > 0 for v in box_data)
        else np.array([])
    )

    if len(all_vals) == 0:
        ax.text(0.5, 0.5, "No data", ha="center", va="center", fontsize=5)
        ax.set_xticks([])
        ax.set_yticks([])
        for spine in ax.spines.values():
            spine.set_linewidth(0.6)
        return

    global_min = np.nanmin(all_vals)
    global_max = np.nanmax(all_vals)
    global_median = np.nanmedian(all_vals)

    ax.set_axis_off()
    left_ax = ax.inset_axes([0.00, 0.00, 0.74, 1.00])
    right_ax = ax.inset_axes([0.80, 0.00, 0.20, 1.00], sharey=left_ax)

    whisker_highs = []
    for vals in box_data:
        if len(vals) == 0:
            continue
        q1, q3 = np.percentile(vals, [25, 75])
        iqr = q3 - q1
        upper_whisker_limit = q3 + 1.5 * iqr
        whisker_high = np.max(vals[vals <= upper_whisker_limit]) if np.any(vals <= upper_whisker_limit) else np.max(vals)
        whisker_highs.append(whisker_high)

    detail_max = np.max(whisker_highs) if whisker_highs else global_median

    if detail_max >= global_max:
        detail_max = global_min + 0.72 * (global_max - global_min)

    left_pad = 0.06 * (detail_max - global_min if detail_max > global_min else 1.0)
    right_pad = 0.08 * (global_max - detail_max if global_max > detail_max else 1.0)

    left_xlim = (global_min - left_pad, detail_max + left_pad)

    right_start = detail_max + 0.35 * (global_max - detail_max)
    if right_start >= global_max:
        right_start = global_max - 0.25 * max((global_max - global_min), 1.0)

    right_xlim = (right_start, global_max + right_pad)

    rng = np.random.default_rng(42)

    for draw_ax in [left_ax, right_ax]:
        bp = draw_ax.boxplot(
            box_data,
            positions=positions,
            vert=False,
            widths=0.55,
            patch_artist=True,
            showfliers=False,
            whis=1.5,
            medianprops={"color": "black", "linewidth": 0.7},
            boxprops={"linewidth": 0.7},
            whiskerprops={"linewidth": 0.7},
            capprops={"linewidth": 0.7}
        )

        for i, patch in enumerate(bp["boxes"]):
            patch.set_facecolor(color_map.get(typologies[i], "gray"))
            patch.set_edgecolor("black")
            patch.set_linewidth(0.6)

        for i, t in enumerate(typologies):
            vals = box_data[i]
            if len(vals) == 0:
                continue

            y_jitter = positions[i] + rng.uniform(-BOX_JITTER, BOX_JITTER, size=len(vals))
            draw_ax.scatter(
                vals, y_jitter,
                s=BOX_SCATTER_SIZE,
                color=color_map.get(t, "gray"),
                edgecolors="white",
                linewidths=0.25,
                alpha=BOX_SCATTER_ALPHA,
                zorder=3
            )

        draw_ax.set_yticks([])
        draw_ax.set_yticklabels([])
        draw_ax.tick_params(axis="y", left=False, right=False, labelleft=False)
        draw_ax.tick_params(axis="x", labelsize=BOXPLOT_TICK_FONTSIZE, width=0.4, length=2)

        for spine in draw_ax.spines.values():
            spine.set_linewidth(0.6)

    left_ax.set_xlim(*left_xlim)
    right_ax.set_xlim(*right_xlim)

    left_ax.spines["right"].set_visible(False)
    right_ax.spines["left"].set_visible(False)
    right_ax.tick_params(axis="y", left=False, labelleft=False)

    tick_min = nice_round_value(global_min, mode="floor")
    tick_mid = nice_round_value(global_median, mode="nearest")
    tick_max = nice_round_value(global_max, mode="ceil")

    left_range = left_xlim[1] - left_xlim[0]
    min_sep = 0.18 * left_range

    if abs(tick_mid - tick_min) < min_sep:
        tick_mid = tick_min + min_sep
    if abs(tick_mid - left_xlim[1]) < 0.12 * left_range:
        tick_mid = left_xlim[1] - 0.12 * left_range

    tick_mid = np.clip(
        tick_mid,
        left_xlim[0] + 0.12 * left_range,
        left_xlim[1] - 0.12 * left_range
    )

    left_ticks = [tick_min, tick_mid]
    right_ticks = [tick_max]

    left_ax.set_xticks(left_ticks)
    left_ax.set_xticklabels(
        [f"{t:.0f}" if abs(t) >= 10 else f"{t:.1f}" for t in left_ticks],
        fontsize=BOXPLOT_TICK_FONTSIZE
    )

    right_ax.set_xticks(right_ticks)
    right_ax.set_xticklabels(
        [f"{t:.0f}" if abs(t) >= 10 else f"{t:.1f}" for t in right_ticks],
        fontsize=BOXPLOT_TICK_FONTSIZE
    )

    d = 0.012
    kwargs = dict(transform=left_ax.transAxes, color="k", clip_on=False, linewidth=0.6)
    left_ax.plot((1 - d, 1 + d), (-d, +d), **kwargs)
    left_ax.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)

    kwargs.update(transform=right_ax.transAxes)
    right_ax.plot((-d, +d), (-d, +d), **kwargs)
    right_ax.plot((-d, +d), (1 - d, 1 + d), **kwargs)

    for i, vals in enumerate(box_data):
        if len(vals) == 0:
            continue

        _, med, _ = np.percentile(vals, [25, 50, 75])

        x_text = np.clip(
            med,
            left_xlim[0] + 0.10 * (left_xlim[1] - left_xlim[0]),
            left_xlim[1] - 0.10 * (left_xlim[1] - left_xlim[0])
        )

        left_ax.text(
            x_text, positions[i],
            f"n={len(vals)}",
            ha="center", va="center",
            fontsize=N_INSIDE_FONTSIZE,
            color="black",
            zorder=4,
            bbox=dict(facecolor="white", edgecolor="none", alpha=0.65, pad=0.15)
        )


def add_fullwidth_typology_legend(ax_leg, color_map, title="Typology"):
    ax_leg.set_xlim(0, 1)
    ax_leg.set_ylim(0, 1)
    ax_leg.axis("off")

    labels = sorted(color_map.keys())
    handles = [
        mpatches.Patch(
            facecolor=color_map[k],
            edgecolor="black",
            linewidth=0.45,
            label=k
        )
        for k in labels
    ]

    frame = mpatches.FancyBboxPatch(
        (0.01, 0.12), 0.98, 0.76,
        boxstyle="round,pad=0.012,rounding_size=0.008",
        linewidth=0.55,
        edgecolor="black",
        facecolor="white",
        transform=ax_leg.transAxes,
        zorder=0
    )
    ax_leg.add_patch(frame)

    if handles:
        ax_leg.legend(
            handles=handles,
            loc="center",
            bbox_to_anchor=(0.5, 0.50),
            ncol=max(1, len(handles)),
            fontsize=LEGEND_FONTSIZE,
            title=title,
            title_fontsize=LEGEND_TITLE_FONTSIZE,
            frameon=False,
            handlelength=1.2,
            handletextpad=0.35,
            columnspacing=0.9,
            borderaxespad=0.0
        )


def draw_map_panel(ax_map, fig, data, y_col, edge_series, cmap, cbar_label, map_letter, gdf, india_outline):
    for spine in ax_map.spines.values():
        spine.set_linewidth(1.0)

    draw_india_base(ax_map, gdf, india_outline)

    vals = data[y_col].values
    sizes = normalize_sizes(vals)

    sc = ax_map.scatter(
        data[LON_COL], data[LAT_COL],
        s=sizes, c=vals, cmap=cmap,
        edgecolor=edge_series, linewidths=1.0, alpha=0.92,
        zorder=4
    )

    ax_map.set_xlim((68, 97.5))
    ax_map.set_ylim((6.5, 38.5))

    add_geodesic_scale_bar(ax_map, location=(0.56, 0.08))
    add_north_arrow(ax_map, location=(0.94, 0.84))
    add_colorbar(fig, ax_map, sc, cbar_label)

    ax_map.text(
        -0.20, 1.02, map_letter,
        transform=ax_map.transAxes,
        fontsize=SUBPLOT_LETTER_FONTSIZE,
        fontweight="bold",
        va="top"
    )

    ax_map.xaxis.set_major_formatter(FuncFormatter(lon_formatter))
    ax_map.yaxis.set_major_formatter(FuncFormatter(lat_formatter))
    ax_map.tick_params(labelsize=TICK_LABEL_FONTSIZE)
    ax_map.set_xlabel("Longitude", fontsize=AXIS_LABEL_FONTSIZE)
    ax_map.set_ylabel("Latitude", fontsize=AXIS_LABEL_FONTSIZE)


# ============================================================
# DATA LOADING AND PROCESSING
# ============================================================

def load_and_prepare_data():
    """Load, validate, merge, and prepare final plotting dataframe."""
    gdf = gpd.read_file(SHAPEFILE_PATH)
    city_details = clean_city_names(standardize_columns(pd.read_excel(CITY_DETAILS_PATH)))
    updated_typo_df = clean_city_names(standardize_columns(pd.read_excel(UPDATED_TYPOLOGY_PATH)))

    cdd1_simple = clean_city_names(standardize_columns(pd.read_excel(CDD_1_SIMPLE_PATH)))
    cdd1_log = clean_city_names(standardize_columns(pd.read_excel(CDD_1_LOG_PATH)))
    cdd2_simple = clean_city_names(standardize_columns(pd.read_excel(CDD_2_SIMPLE_PATH)))
    cdd2_log = clean_city_names(standardize_columns(pd.read_excel(CDD_2_LOG_PATH)))

    check_required_columns(city_details, [CITY_COL, LAT_COL, LON_COL], "city_details")
    check_required_columns(updated_typo_df, [CITY_COL, TYPOLOGY_COL], "updated_typology")
    check_required_columns(cdd1_simple, [CITY_COL, TYPOLOGY_COL, AVG_CDD_COL, TREND_SLOPE_COL], "CDD_1 simple")
    check_required_columns(cdd2_simple, [CITY_COL, TYPOLOGY_COL, AVG_CDD_COL, TREND_SLOPE_COL], "CDD_2 simple")
    check_required_columns(cdd1_log, [CITY_COL, REL_INCREASE_COL], "CDD_1 log")
    check_required_columns(cdd2_log, [CITY_COL, REL_INCREASE_COL], "CDD_2 log")

    updated_typo_df[TYPOLOGY_COL] = (
        updated_typo_df[TYPOLOGY_COL]
        .astype(str)
        .str.replace(" ", "", regex=False)
        .str.upper()
    )

    if gdf.crs is not None and gdf.crs.to_string().lower() != "epsg:4326":
        try:
            gdf = gdf.to_crs(epsg=4326)
        except Exception:
            pass

    simple1_use = cdd1_simple[[CITY_COL, TYPOLOGY_COL, AVG_CDD_COL, TREND_SLOPE_COL]].copy()
    simple1_use = simple1_use.rename(columns={
        TYPOLOGY_COL: "typology_1",
        AVG_CDD_COL: "average_cdd_1",
        TREND_SLOPE_COL: "trend_slope_1"
    })

    simple2_use = cdd2_simple[[CITY_COL, TYPOLOGY_COL, AVG_CDD_COL, TREND_SLOPE_COL]].copy()
    simple2_use = simple2_use.rename(columns={
        TYPOLOGY_COL: "typology_2",
        AVG_CDD_COL: "average_cdd_2",
        TREND_SLOPE_COL: "trend_slope_2"
    })

    log1_use = cdd1_log[[CITY_COL, REL_INCREASE_COL]].copy()
    log1_use = log1_use.rename(columns={REL_INCREASE_COL: "relative_increase_per_year_(%)_1"})

    log2_use = cdd2_log[[CITY_COL, REL_INCREASE_COL]].copy()
    log2_use = log2_use.rename(columns={REL_INCREASE_COL: "relative_increase_per_year_(%)_2"})

    avg_df = (
        simple1_use
        .merge(simple2_use, on=CITY_COL, how="outer")
        .merge(log1_use, on=CITY_COL, how="outer")
        .merge(log2_use, on=CITY_COL, how="outer")
    )

    avg_df[TYPOLOGY_COL] = avg_df["typology_1"].combine_first(avg_df["typology_2"])
    avg_df[AVG_CDD_COL] = avg_df[["average_cdd_1", "average_cdd_2"]].mean(axis=1)
    avg_df[TREND_SLOPE_COL] = avg_df[["trend_slope_1", "trend_slope_2"]].mean(axis=1)
    avg_df[REL_INCREASE_COL] = avg_df[
        ["relative_increase_per_year_(%)_1", "relative_increase_per_year_(%)_2"]
    ].mean(axis=1)

    merged_df = (
        city_details[[CITY_COL, LAT_COL, LON_COL]]
        .merge(
            avg_df[[CITY_COL, TYPOLOGY_COL, AVG_CDD_COL, TREND_SLOPE_COL, REL_INCREASE_COL]],
            on=CITY_COL,
            how="left"
        )
        .merge(
            updated_typo_df[[CITY_COL, TYPOLOGY_COL]],
            on=CITY_COL,
            how="left",
            suffixes=("", "_updated")
        )
    )

    merged_df["typology_orig"] = merged_df[TYPOLOGY_COL]
    merged_df[TYPOLOGY_COL] = merged_df["typology_updated"].fillna(merged_df[TYPOLOGY_COL])

    typology_updated_values = sorted(merged_df[TYPOLOGY_COL].dropna().unique())
    color_mapping_updated = dict(
        zip(
            typology_updated_values,
            UPDATED_TYPOLOGY_PALETTE[:len(typology_updated_values)]
        )
    )

    india_outline = gdf.dissolve()

    return gdf, india_outline, merged_df, color_mapping_updated


# ============================================================
# FIGURE GENERATION
# ============================================================

def create_figure(gdf, india_outline, merged_df, color_mapping_updated):
    """Create the full multi-panel figure."""
    fig = plt.figure(figsize=(FIG_W, FIG_H), dpi=FIG_DPI)

    outer = fig.add_gridspec(
        3, 2,
        wspace=OUTER_WSPACE,
        hspace=OUTER_HSPACE
    )

    left_panels = [
        (0, CMAP_AVG, AVG_CDD_COL, "a", "Average CDD (°C-year)"),
        (1, CMAP_TREND, TREND_SLOPE_COL, "b", "Absolute rise in CDD (°C-year/year)"),
        (2, CMAP_REL, REL_INCREASE_COL, "c", "Relative rise in CDD (%/year)")
    ]

    right_panels = [
        (0, CMAP_AVG, AVG_CDD_COL, "d", "Average CDD (°C-year)"),
        (1, CMAP_TREND, TREND_SLOPE_COL, "e", "Absolute rise in CDD (°C-year/year)"),
        (2, CMAP_REL, REL_INCREASE_COL, "f", "Relative rise in CDD (%/year)")
    ]

    for row, cmap, y_col, map_letter, cbar_label in left_panels:
        sub = outer[row, 0].subgridspec(
            2, 2,
            height_ratios=PANEL_LEGEND_HEIGHT_RATIO,
            width_ratios=MAP_BOX_WIDTH_RATIO,
            hspace=INNER_HSPACE,
            wspace=INNER_WSPACE
        )

        ax_map = fig.add_subplot(sub[0, 0])
        ax_box = fig.add_subplot(sub[0, 1])
        ax_leg = fig.add_subplot(sub[1, :])

        edges = merged_df["typology_orig"].map(DEFAULT_TYPOLOGY_COLORS).fillna("gray")
        draw_map_panel(ax_map, fig, merged_df, y_col, edges, cmap, cbar_label, map_letter, gdf, india_outline)
        add_horizontal_boxplot(ax_box, merged_df, y_col, "typology_orig", DEFAULT_TYPOLOGY_COLORS)
        add_fullwidth_typology_legend(ax_leg, DEFAULT_TYPOLOGY_COLORS, title="Typology")

    for row, cmap, y_col, map_letter, cbar_label in right_panels:
        sub = outer[row, 1].subgridspec(
            2, 2,
            height_ratios=PANEL_LEGEND_HEIGHT_RATIO,
            width_ratios=MAP_BOX_WIDTH_RATIO,
            hspace=INNER_HSPACE,
            wspace=INNER_WSPACE
        )

        ax_map = fig.add_subplot(sub[0, 0])
        ax_box = fig.add_subplot(sub[0, 1])
        ax_leg = fig.add_subplot(sub[1, :])

        edges = merged_df[TYPOLOGY_COL].map(color_mapping_updated).fillna("gray")
        draw_map_panel(ax_map, fig, merged_df, y_col, edges, cmap, cbar_label, map_letter, gdf, india_outline)
        add_horizontal_boxplot(ax_box, merged_df, y_col, TYPOLOGY_COL, color_mapping_updated)
        add_fullwidth_typology_legend(ax_leg, color_mapping_updated, title="Typology")

    return fig


# ============================================================
# MAIN
# ============================================================

def main():
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
    validate_files()

    gdf, india_outline, merged_df, color_mapping_updated = load_and_prepare_data()
    fig = create_figure(gdf, india_outline, merged_df, color_mapping_updated)

    boxplot_note = (
        "Box plots show the median (centre line), interquartile range "
        "(box bounds: 25th–75th percentiles), and whiskers extending to the most "
        "extreme values within 1.5×IQR from the quartiles. All individual observations, "
        "including minima and maxima, are overlaid as scatter points. The sample size (n) "
        "denotes the total number of observations."
    )
    print(boxplot_note)

    out_png = OUTPUT_DIR / OUTPUT_PNG_NAME
    out_pdf = OUTPUT_DIR / OUTPUT_PDF_NAME

    fig.savefig(out_png, dpi=FIG_DPI, bbox_inches="tight", facecolor="white")
    fig.savefig(out_pdf, bbox_inches="tight", facecolor="white")

    print(f"Saved PNG: {out_png}")
    print(f"Saved PDF: {out_pdf}")

    plt.show()


if __name__ == "__main__":
    main()