{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51e29998-d7c1-43da-b084-51c58a35eae3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import numpy as np\n",
    "import shap\n",
    "import xgboost as xgb\n",
    "\n",
    "from statsmodels.stats.outliers_influence import variance_inflation_factor\n",
    "from sklearn.cross_decomposition import CCA\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.model_selection import train_test_split\n",
    "from collections import Counter\n",
    "\n",
    "# ============================================================\n",
    "# Setup\n",
    "# ============================================================\n",
    "excel_file = r'filepath'\n",
    "plot_dir = 'feature_importance_plots'\n",
    "os.makedirs(plot_dir, exist_ok=True)\n",
    "\n",
    "cca_scores = []\n",
    "vif_scores = []\n",
    "shap_scores = []\n",
    "vote_features_all = []\n",
    "\n",
    "excel_data = pd.ExcelFile(excel_file)\n",
    "\n",
    "# ============================================================\n",
    "# LaTeX label formatting\n",
    "# ============================================================\n",
    "def format_label(label):\n",
    "    parts = label.split('_')\n",
    "    formatted = parts[0]\n",
    "    if len(parts) > 1:\n",
    "        formatted += ''.join([f'$_{{{p}}}$' for p in parts[1:]])\n",
    "    formatted = formatted.replace('PM25', 'PM$_{2.5}$')\n",
    "    return r'\\textbf{' + formatted + '}'\n",
    "\n",
    "# ============================================================\n",
    "# Feature analysis loop\n",
    "# ============================================================\n",
    "for sheet_name in excel_data.sheet_names:\n",
    "    print(f\"Processing: {sheet_name}\")\n",
    "\n",
    "    df = pd.read_excel(excel_file, sheet_name=sheet_name)\n",
    "    df['Date'] = pd.to_datetime(df['Date'])\n",
    "    df.set_index('Date', inplace=True)\n",
    "    df.dropna(inplace=True)\n",
    "\n",
    "    # Lag features\n",
    "    for i in range(1, 6):\n",
    "        df[f'PM25_prev_{i}'] = df['PM25'].shift(i)\n",
    "        df[f'P_prev_{i}'] = df['P'].shift(i)\n",
    "        df[f'Tmin_prev_{i}'] = df['Tmin'].shift(i)\n",
    "        df[f'Tmax_prev_{i}'] = df['Tmax'].shift(i)\n",
    "        df[f'VI_prev_{i}'] = df['VI'].shift(i)\n",
    "        df[f'BLH_prev_{i}'] = df['BLH'].shift(i)\n",
    "        df[f'Wind_prev_{i}'] = df['Wind'].shift(i)\n",
    "\n",
    "    df.dropna(inplace=True)\n",
    "    df = df.apply(pd.to_numeric, errors='coerce').dropna()\n",
    "\n",
    "    y = df['PM25']\n",
    "    X = df.drop(columns=['PM25'])\n",
    "\n",
    "    scaler = StandardScaler()\n",
    "    X_scaled = scaler.fit_transform(X)\n",
    "\n",
    "    # ---------------- CCA ----------------\n",
    "    cca = CCA(n_components=1)\n",
    "    cca.fit(X_scaled, y.values.reshape(-1, 1))\n",
    "    cca_loadings = pd.Series(\n",
    "        np.abs(cca.x_weights_[:, 0]), index=X.columns\n",
    "    )\n",
    "    cca_scores.append(cca_loadings)\n",
    "\n",
    "    # ---------------- VIF ----------------\n",
    "    vif = pd.Series(\n",
    "        [variance_inflation_factor(X_scaled, i)\n",
    "         for i in range(X_scaled.shape[1])],\n",
    "        index=X.columns\n",
    "    )\n",
    "    vif_scores.append(vif)\n",
    "\n",
    "    # ---------------- SHAP ----------------\n",
    "    X_train, X_test, y_train, y_test = train_test_split(\n",
    "        X, y, test_size=0.2, random_state=42\n",
    "    )\n",
    "\n",
    "    model = xgb.XGBRegressor(\n",
    "        n_estimators=100,\n",
    "        learning_rate=0.05,\n",
    "        random_state=42\n",
    "    )\n",
    "    model.fit(X_train, y_train)\n",
    "\n",
    "    explainer = shap.Explainer(model)\n",
    "    shap_values = explainer(X_test)\n",
    "    shap_importance = pd.Series(\n",
    "        np.abs(shap_values.values).mean(axis=0),\n",
    "        index=X.columns\n",
    "    )\n",
    "    shap_scores.append(shap_importance)\n",
    "\n",
    "    # ---------------- Voting ----------------\n",
    "    top_cca = cca_loadings.nlargest(10).index.tolist()\n",
    "    top_vif = vif.nsmallest(10).index.tolist()\n",
    "    top_shap = shap_importance.nlargest(10).index.tolist()\n",
    "\n",
    "    combined = top_cca + top_vif + top_shap\n",
    "    vote = Counter(combined)\n",
    "\n",
    "    top_vote = [f for f, _ in vote.most_common(10)]\n",
    "    vote_features_all.extend(top_vote)\n",
    "\n",
    "# ============================================================\n",
    "# Aggregate scores\n",
    "# ============================================================\n",
    "cca_avg = pd.concat(cca_scores, axis=1).mean(axis=1).sort_values(ascending=False)\n",
    "vif_avg = pd.concat(vif_scores, axis=1).mean(axis=1).sort_values()\n",
    "shap_avg = pd.concat(shap_scores, axis=1).mean(axis=1).sort_values(ascending=False)\n",
    "vote_counter = Counter(vote_features_all)\n",
    "\n",
    "# ============================================================\n",
    "# Global plotting style (BOLD + LARGE)\n",
    "# ============================================================\n",
    "plt.rcParams.update({\n",
    "    'text.usetex': True,\n",
    "    'font.size': 24,\n",
    "    'axes.labelweight': 'bold',\n",
    "    'axes.titleweight': 'bold',\n",
    "    'axes.titlesize': 24,\n",
    "    'axes.labelsize': 24,\n",
    "    'xtick.labelsize': 22,\n",
    "    'ytick.labelsize': 22,\n",
    "    'xtick.major.width': 2,\n",
    "    'ytick.major.width': 2\n",
    "})\n",
    "\n",
    "# ============================================================\n",
    "# Generic plotting function\n",
    "# ============================================================\n",
    "def plot_feature_importance(all_scores, top10_index, title, xlabel,\n",
    "                            filename, color='skyblue', highlight='crimson'):\n",
    "\n",
    "    plt.figure(figsize=(20, 20))\n",
    "\n",
    "    features = all_scores.index\n",
    "    values = all_scores.values\n",
    "    colors = [highlight if f in top10_index else color for f in features]\n",
    "    y_labels = [format_label(f) for f in features]\n",
    "\n",
    "    ax = sns.barplot(x=values, y=y_labels, palette=colors)\n",
    "\n",
    "    ax.set_title(r'\\textbf{' + title + '}', fontsize=26)\n",
    "    ax.set_xlabel(r'\\textbf{' + xlabel + '}', fontsize=26)\n",
    "    ax.set_ylabel(r'\\textbf{Features}', fontsize=26)\n",
    "\n",
    "    for label in ax.get_xticklabels():\n",
    "        label.set_fontweight('bold')\n",
    "    for label in ax.get_yticklabels():\n",
    "        label.set_fontweight('bold')\n",
    "\n",
    "    ax.spines['top'].set_visible(False)\n",
    "    ax.spines['right'].set_visible(False)\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(os.path.join(plot_dir, filename), dpi=300)\n",
    "    plt.show()\n",
    "\n",
    "# ============================================================\n",
    "# Plot 1: CCA\n",
    "# ============================================================\n",
    "plot_feature_importance(\n",
    "    cca_avg,\n",
    "    cca_avg.nlargest(10).index,\n",
    "    'Top Features by Canonical Correlation Analysis (CCA)',\n",
    "    'Canonical Loading',\n",
    "    'CCA_Top_All_Features.png'\n",
    ")\n",
    "\n",
    "# ============================================================\n",
    "# Plot 2: VIF\n",
    "# ============================================================\n",
    "plot_feature_importance(\n",
    "    vif_avg,\n",
    "    vif_avg.nsmallest(10).index,\n",
    "    'Top Features by Variance Inflation Factor (VIF)',\n",
    "    'VIF Score',\n",
    "    'VIF_Top_All_Features.png',\n",
    "    color='lightgray',\n",
    "    highlight='royalblue'\n",
    ")\n",
    "\n",
    "# ============================================================\n",
    "# Plot 3: SHAP\n",
    "# ============================================================\n",
    "plot_feature_importance(\n",
    "    shap_avg,\n",
    "    shap_avg.nlargest(10).index,\n",
    "    'Top Features by SHAP',\n",
    "    'Mean Absolute SHAP Value',\n",
    "    'SHAP_Top_All_Features.png',\n",
    "    color='lightgreen',\n",
    "    highlight='darkgreen'\n",
    ")\n",
    "\n",
    "# ============================================================\n",
    "# Plot 4: Voting\n",
    "# ============================================================\n",
    "vote_series = pd.Series(vote_counter).sort_values(ascending=False)\n",
    "\n",
    "features = vote_series.index\n",
    "votes = vote_series.values\n",
    "top10 = vote_series.head(10).index\n",
    "\n",
    "colors = ['darkorange' if f in top10 else 'lightgray' for f in features]\n",
    "y_labels = [format_label(f) for f in features]\n",
    "\n",
    "plt.figure(figsize=(20, 20))\n",
    "ax = sns.barplot(x=votes, y=y_labels, palette=colors)\n",
    "\n",
    "ax.set_title(r'\\textbf{Top Features by Voting (CCA + VIF + SHAP)}', fontsize=22)\n",
    "ax.set_xlabel(r'\\textbf{Vote Count}', fontsize=30)\n",
    "ax.set_ylabel(r'\\textbf{Features}', fontsize=30)\n",
    "\n",
    "for label in ax.get_xticklabels():\n",
    "    label.set_fontweight('bold')\n",
    "for label in ax.get_yticklabels():\n",
    "    label.set_fontweight('bold')\n",
    "\n",
    "ax.spines['top'].set_visible(False)\n",
    "ax.spines['right'].set_visible(False)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(os.path.join(plot_dir, 'Voting_Top_All_Features.png'), dpi=300)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f66511ee-5314-42c2-8088-3307856a0e91",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import numpy as np\n",
    "import pywt\n",
    "import pickle\n",
    "import warnings\n",
    "from scipy import stats\n",
    "from scipy.stats import gaussian_kde\n",
    "from datetime import datetime\n",
    "\n",
    "# Deep Learning Libraries\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras.models import Model\n",
    "from tensorflow.keras.layers import Input, Conv1D, LSTM, GRU, Dense, Dropout, BatchNormalization, MaxPooling1D, Flatten\n",
    "from tensorflow.keras.optimizers import Adam\n",
    "from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau\n",
    "from tensorflow.keras.regularizers import l1_l2\n",
    "\n",
    "# ML Libraries\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error\n",
    "\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "# Fix matplotlib LaTeX issues and set global font properties\n",
    "plt.rcParams['text.usetex'] = False\n",
    "plt.rcParams['font.family'] = 'sans-serif'\n",
    "plt.rcParams['mathtext.default'] = 'regular'\n",
    "plt.rcParams['axes.labelweight'] = 'bold'\n",
    "plt.rcParams['axes.titleweight'] = 'bold'\n",
    "plt.rcParams['font.weight'] = 'bold'\n",
    "\n",
    "# ---- Setup ----\n",
    "excel_file = r'file path'\n",
    "station_data_file = r'Efile path'  # UPDATE THIS PATH\n",
    "base_plot_dir = 'city_model_plots'\n",
    "os.makedirs(base_plot_dir, exist_ok=True)\n",
    "\n",
    "# Global settings\n",
    "RANDOM_STATE = 42\n",
    "np.random.seed(RANDOM_STATE)\n",
    "tf.random.set_seed(RANDOM_STATE)\n",
    "\n",
    "# ---- Model Building Functions ----\n",
    "def build_cnn_model(input_shape):\n",
    "    \"\"\"Build CNN model with fixed hyperparameters\"\"\"\n",
    "    inputs = Input(shape=input_shape)\n",
    "    \n",
    "    x = Conv1D(filters=64, kernel_size=3, activation='relu', padding='same',\n",
    "               kernel_regularizer=l1_l2(l1=1e-5, l2=1e-4))(inputs)\n",
    "    x = BatchNormalization()(x)\n",
    "    x = MaxPooling1D(pool_size=2, padding='same')(x)\n",
    "    x = Dropout(0.3)(x)\n",
    "    \n",
    "    x = Conv1D(filters=128, kernel_size=3, activation='relu', padding='same',\n",
    "               kernel_regularizer=l1_l2(l1=1e-5, l2=1e-4))(x)\n",
    "    x = BatchNormalization()(x)\n",
    "    x = MaxPooling1D(pool_size=2, padding='same')(x)\n",
    "    x = Dropout(0.3)(x)\n",
    "    \n",
    "    x = Conv1D(filters=256, kernel_size=3, activation='relu', padding='same',\n",
    "               kernel_regularizer=l1_l2(l1=1e-5, l2=1e-4))(x)\n",
    "    x = BatchNormalization()(x)\n",
    "    x = Dropout(0.3)(x)\n",
    "    \n",
    "    x = Flatten()(x)\n",
    "    x = Dense(256, activation='relu', kernel_regularizer=l1_l2(l1=1e-5, l2=1e-4))(x)\n",
    "    x = Dropout(0.3)(x)\n",
    "    x = Dense(128, activation='relu', kernel_regularizer=l1_l2(l1=1e-5, l2=1e-4))(x)\n",
    "    x = Dropout(0.3)(x)\n",
    "    outputs = Dense(1)(x)\n",
    "    \n",
    "    model = Model(inputs, outputs)\n",
    "    model.compile(optimizer=Adam(learning_rate=0.001), loss='mse', metrics=['mae'])\n",
    "    return model\n",
    "\n",
    "def build_lstm_model(input_shape):\n",
    "    \"\"\"Build LSTM model with fixed hyperparameters\"\"\"\n",
    "    inputs = Input(shape=input_shape)\n",
    "    \n",
    "    x = LSTM(100, return_sequences=True, kernel_regularizer=l1_l2(l1=1e-5, l2=1e-4))(inputs)\n",
    "    x = BatchNormalization()(x)\n",
    "    x = Dropout(0.3)(x)\n",
    "    \n",
    "    x = LSTM(100, return_sequences=False, kernel_regularizer=l1_l2(l1=1e-5, l2=1e-4))(x)\n",
    "    x = BatchNormalization()(x)\n",
    "    x = Dropout(0.3)(x)\n",
    "    \n",
    "    x = Dense(128, activation='relu', kernel_regularizer=l1_l2(l1=1e-5, l2=1e-4))(x)\n",
    "    x = Dropout(0.3)(x)\n",
    "    outputs = Dense(1)(x)\n",
    "    \n",
    "    model = Model(inputs, outputs)\n",
    "    model.compile(optimizer=Adam(learning_rate=0.001), loss='mse', metrics=['mae'])\n",
    "    return model\n",
    "\n",
    "def build_gru_model(input_shape):\n",
    "    \"\"\"Build GRU model with fixed hyperparameters\"\"\"\n",
    "    inputs = Input(shape=input_shape)\n",
    "    \n",
    "    x = GRU(100, return_sequences=True, kernel_regularizer=l1_l2(l1=1e-5, l2=1e-4))(inputs)\n",
    "    x = BatchNormalization()(x)\n",
    "    x = Dropout(0.3)(x)\n",
    "    \n",
    "    x = GRU(100, return_sequences=False, kernel_regularizer=l1_l2(l1=1e-5, l2=1e-4))(x)\n",
    "    x = BatchNormalization()(x)\n",
    "    x = Dropout(0.3)(x)\n",
    "    \n",
    "    x = Dense(128, activation='relu', kernel_regularizer=l1_l2(l1=1e-5, l2=1e-4))(x)\n",
    "    x = Dropout(0.3)(x)\n",
    "    outputs = Dense(1)(x)\n",
    "    \n",
    "    model = Model(inputs, outputs)\n",
    "    model.compile(optimizer=Adam(learning_rate=0.001), loss='mse', metrics=['mae'])\n",
    "    return model\n",
    "\n",
    "# ---- Ensemble Class ----\n",
    "class SimpleEnsemble:\n",
    "    def __init__(self):\n",
    "        self.weights = {}\n",
    "        self.models = {}\n",
    "    \n",
    "    def fit(self, models, X_val, y_val):\n",
    "        \"\"\"Fit ensemble weights based on validation performance\"\"\"\n",
    "        self.models = models\n",
    "        performances = {}\n",
    "        \n",
    "        for name, model in models.items():\n",
    "            pred = model.predict(X_val, verbose=0).flatten()\n",
    "            r2 = r2_score(y_val, pred)\n",
    "            mse = mean_squared_error(y_val, pred)\n",
    "            nse = 1 - (np.sum((y_val - pred) ** 2) / np.sum((y_val - np.mean(y_val)) ** 2))\n",
    "            \n",
    "            performance_score = r2 * max(nse, 0) / (mse + 1e-8)\n",
    "            performances[name] = performance_score\n",
    "        \n",
    "        total_performance = sum(performances.values())\n",
    "        if total_performance > 0:\n",
    "            self.weights = {name: perf/total_performance for name, perf in performances.items()}\n",
    "        else:\n",
    "            self.weights = {name: 1.0/len(models) for name in models.keys()}\n",
    "        \n",
    "        print(\"Ensemble weights:\")\n",
    "        for name, weight in self.weights.items():\n",
    "            print(f\"{name}: {weight:.4f}\")\n",
    "    \n",
    "    def predict(self, X):\n",
    "        \"\"\"Make ensemble predictions\"\"\"\n",
    "        predictions = np.zeros(X.shape[0])\n",
    "        \n",
    "        for name, model in self.models.items():\n",
    "            pred = model.predict(X, verbose=0).flatten()\n",
    "            predictions += self.weights[name] * pred\n",
    "        \n",
    "        return predictions\n",
    "\n",
    "# ---- Bias Correction Functions ----\n",
    "def load_station_data(station_file, city_name):\n",
    "    \"\"\"Load station observation data for bias correction\"\"\"\n",
    "    try:\n",
    "        station_df = pd.read_excel(station_file)\n",
    "        station_df['Date'] = pd.to_datetime(station_df['Date'], format='%m/%d/%Y')\n",
    "        station_df.set_index('Date', inplace=True)\n",
    "        \n",
    "        if city_name in station_df.columns:\n",
    "            print(f\"Station data loaded for {city_name}\")\n",
    "            print(f\"Date range: {station_df.index.min()} to {station_df.index.max()}\")\n",
    "            print(f\"Total observations: {len(station_df[city_name].dropna())}\")\n",
    "            return station_df[city_name].dropna()\n",
    "        else:\n",
    "            print(f\"Warning: {city_name} not found in station data\")\n",
    "            return None\n",
    "    except Exception as e:\n",
    "        print(f\"Error loading station data: {str(e)}\")\n",
    "        return None\n",
    "\n",
    "def calculate_bias_correction(predictions, observations, dates_pred, dates_obs):\n",
    "    \"\"\"Calculate bias correction factor using overlapping period\"\"\"\n",
    "    # Find overlapping dates\n",
    "    common_dates = dates_pred.intersection(dates_obs)\n",
    "    \n",
    "    if len(common_dates) == 0:\n",
    "        print(\"Warning: No overlapping dates found between predictions and observations\")\n",
    "        return 0.0, None\n",
    "    \n",
    "    print(f\"Overlapping period: {common_dates.min()} to {common_dates.max()}\")\n",
    "    print(f\"Number of overlapping days: {len(common_dates)}\")\n",
    "    \n",
    "    # Get predictions and observations for common dates\n",
    "    pred_indices = [i for i, date in enumerate(dates_pred) if date in common_dates]\n",
    "    pred_overlap = predictions[pred_indices]\n",
    "    obs_overlap = observations.loc[common_dates].values\n",
    "    \n",
    "    # Calculate bias (predicted - observed)\n",
    "    bias = pred_overlap - obs_overlap\n",
    "    mean_bias = np.mean(bias)\n",
    "    median_bias = np.median(bias)\n",
    "    std_bias = np.std(bias)\n",
    "    \n",
    "    # Calculate relative bias\n",
    "    relative_bias = 100 * mean_bias / np.mean(obs_overlap)\n",
    "    \n",
    "    bias_stats = {\n",
    "        'mean_bias': mean_bias,\n",
    "        'median_bias': median_bias,\n",
    "        'std_bias': std_bias,\n",
    "        'relative_bias': relative_bias,\n",
    "        'n_samples': len(bias),\n",
    "        'date_range': f\"{common_dates.min()} to {common_dates.max()}\"\n",
    "    }\n",
    "    \n",
    "    print(f\"Bias Statistics:\")\n",
    "    print(f\"  Mean Bias: {mean_bias:.3f} μg/m³\")\n",
    "    print(f\"  Median Bias: {median_bias:.3f} μg/m³\")\n",
    "    print(f\"  Std Bias: {std_bias:.3f} μg/m³\")\n",
    "    print(f\"  Relative Bias: {relative_bias:.2f}%\")\n",
    "    \n",
    "    return mean_bias, bias_stats\n",
    "\n",
    "def apply_bias_correction(predictions, bias_correction):\n",
    "    \"\"\"Apply bias correction to predictions\"\"\"\n",
    "    corrected_predictions = predictions - bias_correction\n",
    "    corrected_predictions = np.maximum(corrected_predictions, 0)\n",
    "    return corrected_predictions\n",
    "\n",
    "# ---- Uncertainty Quantification ----\n",
    "def calculate_prediction_intervals(predictions, residuals, confidence=0.95):\n",
    "    \"\"\"Calculate prediction intervals using residual statistics\"\"\"\n",
    "    if len(residuals) == 0:\n",
    "        return np.array([]), np.array([])\n",
    "    alpha = 1 - confidence\n",
    "    residual_std = np.std(residuals)\n",
    "    z_score = stats.norm.ppf(1 - alpha/2)\n",
    "    \n",
    "    lower_bound = predictions - z_score * residual_std\n",
    "    upper_bound = predictions + z_score * residual_std\n",
    "    \n",
    "    return lower_bound, upper_bound\n",
    "\n",
    "# ---- Visualization Functions ----\n",
    "def plot_scatter_comparison(y_true, y_pred, city_name, model_name, save_path):\n",
    "    \"\"\"Create comprehensive scatter visualization\"\"\"\n",
    "    fig, axes = plt.subplots(2, 2, figsize=(16, 12))\n",
    "    fig.suptitle(f'PM$_{{2.5}}$ Prediction Analysis - {city_name} ({model_name})',\n",
    "                 fontsize=16, fontweight='bold')\n",
    "    \n",
    "    for ax in axes.flat:\n",
    "        ax.grid(False)\n",
    "    \n",
    "    axes[0, 0].scatter(y_true, y_pred, alpha=0.6, s=30, c='blue')\n",
    "    axes[0, 0].plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()],\n",
    "                    'r--', lw=2, alpha=0.8, label='1:1 Line')\n",
    "    axes[0, 0].set_xlabel('Actual PM$_{{2.5}}$ (μg/m³)', fontweight='bold')\n",
    "    axes[0, 0].set_ylabel('Predicted PM$_{{2.5}}$ (μg/m³)', fontweight='bold')\n",
    "    axes[0, 0].set_title('(a) Scatter Plot', fontweight='bold')\n",
    "    axes[0, 0].legend()\n",
    "    axes[0, 0].tick_params(axis='both', which='major', labelsize=12, width=2)\n",
    "    \n",
    "    hb = axes[0, 1].hexbin(y_true, y_pred, gridsize=25, cmap='Blues', mincnt=1)\n",
    "    axes[0, 1].plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()],\n",
    "                    'r--', lw=2, alpha=0.8, label='1:1 Line')\n",
    "    axes[0, 1].set_xlabel('Actual PM$_{{2.5}}$ (μg/m³)', fontweight='bold')\n",
    "    axes[0, 1].set_ylabel('Predicted PM$_{{2.5}}$ (μg/m³)', fontweight='bold')\n",
    "    axes[0, 1].set_title('(b) Density Plot', fontweight='bold')\n",
    "    axes[0, 1].legend()\n",
    "    axes[0, 1].tick_params(axis='both', which='major', labelsize=12, width=2)\n",
    "    plt.colorbar(hb, ax=axes[0, 1], label='Count')\n",
    "    \n",
    "    residuals = y_true - y_pred\n",
    "    axes[1, 0].scatter(y_pred, residuals, alpha=0.6, s=30, c='red')\n",
    "    axes[1, 0].axhline(y=0, color='black', linestyle='--', alpha=0.8)\n",
    "    axes[1, 0].set_xlabel('Predicted PM$_{{2.5}}$ (μg/m³)', fontweight='bold')\n",
    "    axes[1, 0].set_ylabel('Residuals (μg/m³)', fontweight='bold')\n",
    "    axes[1, 0].set_title('(c) Residual Analysis', fontweight='bold')\n",
    "    axes[1, 0].tick_params(axis='both', which='major', labelsize=12, width=2)\n",
    "    \n",
    "    axes[1, 1].axis('off')\n",
    "    \n",
    "    r2 = r2_score(y_true, y_pred)\n",
    "    rmse = np.sqrt(mean_squared_error(y_true, y_pred))\n",
    "    mae = mean_absolute_error(y_true, y_pred)\n",
    "    nse = 1 - (np.sum((y_true - y_pred) ** 2) / np.sum((y_true - np.mean(y_true)) ** 2))\n",
    "    pbias = 100 * (np.sum(y_true - y_pred) / np.sum(y_true))\n",
    "    r = np.corrcoef(y_true, y_pred)[0, 1] if len(np.unique(y_pred)) > 1 else 0\n",
    "    \n",
    "    metrics_text = f\"\"\"\n",
    "    Performance Metrics:\n",
    "    \n",
    "    R² = {r2:.3f}\n",
    "    RMSE = {rmse:.2f} μg/m³\n",
    "    MAE = {mae:.2f} μg/m³\n",
    "    NSE = {nse:.3f}\n",
    "    PBIAS = {pbias:.2f}%\n",
    "    r = {r:.3f}\n",
    "    \"\"\"\n",
    "    \n",
    "    axes[1, 1].text(0.1, 0.9, metrics_text, transform=axes[1, 1].transAxes,\n",
    "                    verticalalignment='top', fontsize=14, fontweight='bold',\n",
    "                    bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.savefig(save_path, dpi=300, bbox_inches='tight')\n",
    "    plt.close(fig)\n",
    "\n",
    "def plot_bias_correction_comparison(y_true, y_pred_original, y_pred_corrected, \n",
    "                                   city_name, save_path):\n",
    "    \"\"\"Plot comparison of original vs bias-corrected predictions\"\"\"\n",
    "    fig, axes = plt.subplots(2, 2, figsize=(16, 12))\n",
    "    fig.suptitle(f'Bias Correction Impact - {city_name}', fontsize=16, fontweight='bold')\n",
    "    \n",
    "    for ax in axes.flat:\n",
    "        ax.grid(False)\n",
    "    \n",
    "    axes[0, 0].scatter(y_true, y_pred_original, alpha=0.6, s=30, c='red', label='Original')\n",
    "    axes[0, 0].plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], \n",
    "                    'k--', lw=2, alpha=0.8, label='1:1 Line')\n",
    "    axes[0, 0].set_xlabel('Actual PM$_{{2.5}}$ (μg/m³)', fontweight='bold')\n",
    "    axes[0, 0].set_ylabel('Predicted PM$_{{2.5}}$ (μg/m³)', fontweight='bold')\n",
    "    axes[0, 0].set_title('(a) Original Predictions', fontweight='bold')\n",
    "    axes[0, 0].legend()\n",
    "    \n",
    "    axes[0, 1].scatter(y_true, y_pred_corrected, alpha=0.6, s=30, c='green', label='Corrected')\n",
    "    axes[0, 1].plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], \n",
    "                    'k--', lw=2, alpha=0.8, label='1:1 Line')\n",
    "    axes[0, 1].set_xlabel('Actual PM$_{{2.5}}$ (μg/m³)', fontweight='bold')\n",
    "    axes[0, 1].set_ylabel('Predicted PM$_{{2.5}}$ (μg/m³)', fontweight='bold')\n",
    "    axes[0, 1].set_title('(b) Bias-Corrected Predictions', fontweight='bold')\n",
    "    axes[0, 1].legend()\n",
    "    \n",
    "    time_index = range(len(y_true))\n",
    "    axes[1, 0].plot(time_index, y_true, 'b-', label='Actual', linewidth=2, alpha=0.8)\n",
    "    axes[1, 0].plot(time_index, y_pred_original, 'r--', label='Original', linewidth=1.5, alpha=0.7)\n",
    "    axes[1, 0].plot(time_index, y_pred_corrected, 'g--', label='Corrected', linewidth=1.5, alpha=0.7)\n",
    "    axes[1, 0].set_xlabel('Time Index', fontweight='bold')\n",
    "    axes[1, 0].set_ylabel('PM$_{{2.5}}$ (μg/m³)', fontweight='bold')\n",
    "    axes[1, 0].set_title('(c) Time Series Comparison', fontweight='bold')\n",
    "    axes[1, 0].legend()\n",
    "    \n",
    "    axes[1, 1].axis('off')\n",
    "    \n",
    "    r2_orig = r2_score(y_true, y_pred_original)\n",
    "    rmse_orig = np.sqrt(mean_squared_error(y_true, y_pred_original))\n",
    "    mae_orig = mean_absolute_error(y_true, y_pred_original)\n",
    "    \n",
    "    r2_corr = r2_score(y_true, y_pred_corrected)\n",
    "    rmse_corr = np.sqrt(mean_squared_error(y_true, y_pred_corrected))\n",
    "    mae_corr = mean_absolute_error(y_true, y_pred_corrected)\n",
    "    \n",
    "    metrics_text = f\"\"\"\n",
    "    Metrics Comparison:\n",
    "    \n",
    "    Original:\n",
    "      R² = {r2_orig:.3f}\n",
    "      RMSE = {rmse_orig:.2f} μg/m³\n",
    "      MAE = {mae_orig:.2f} μg/m³\n",
    "    \n",
    "    Bias-Corrected:\n",
    "      R² = {r2_corr:.3f}\n",
    "      RMSE = {rmse_corr:.2f} μg/m³\n",
    "      MAE = {mae_corr:.2f} μg/m³\n",
    "    \n",
    "    Improvement:\n",
    "      ΔR² = {r2_corr - r2_orig:+.3f}\n",
    "      ΔRMSE = {rmse_corr - rmse_orig:+.2f} μg/m³\n",
    "      ΔMAE = {mae_corr - mae_orig:+.2f} μg/m³\n",
    "    \"\"\"\n",
    "    \n",
    "    axes[1, 1].text(0.1, 0.9, metrics_text, transform=axes[1, 1].transAxes,\n",
    "                    verticalalignment='top', fontsize=12, fontweight='bold',\n",
    "                    bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.savefig(save_path, dpi=300, bbox_inches='tight')\n",
    "    plt.close(fig)\n",
    "\n",
    "def plot_complete_time_series_with_uncertainty(dates, y_true, y_pred_all, train_end_idx, val_end_idx,\n",
    "                                             y_train_pred, y_val_pred, y_test_pred,\n",
    "                                             y_train_true, y_val_true, y_test_true,\n",
    "                                             city_name, save_path):\n",
    "    \"\"\"Plot complete time series with train/val/test splits and uncertainty bands\"\"\"\n",
    "    fig, ax = plt.subplots(figsize=(20, 10))\n",
    "    ax.grid(False)\n",
    "    \n",
    "    ax.plot(dates, y_true, label='Actual PM$_{2.5}$', color='blue', linewidth=2, alpha=0.8)\n",
    "    \n",
    "    ax.plot(dates[:train_end_idx], y_train_pred,\n",
    "            label='Training Predictions', color='green', linewidth=2, alpha=0.7)\n",
    "    ax.plot(dates[train_end_idx:val_end_idx], y_val_pred,\n",
    "            label='Validation Predictions', color='orange', linewidth=2, alpha=0.7)\n",
    "    ax.plot(dates[val_end_idx:], y_test_pred,\n",
    "            label='Test Predictions', color='red', linewidth=2, alpha=0.7)\n",
    "    \n",
    "    residuals_train = y_train_true - y_train_pred\n",
    "    residuals_val = y_val_true - y_val_pred\n",
    "    residuals_test = y_test_true - y_test_pred\n",
    "    \n",
    "    lower_train, upper_train = calculate_prediction_intervals(y_train_pred, residuals_train)\n",
    "    lower_val, upper_val = calculate_prediction_intervals(y_val_pred, residuals_val)\n",
    "    lower_test, upper_test = calculate_prediction_intervals(y_test_pred, residuals_test)\n",
    "    \n",
    "    if len(dates[:train_end_idx]) > 0 and len(lower_train) > 0:\n",
    "        ax.fill_between(dates[:train_end_idx], lower_train, upper_train,\n",
    "                        alpha=0.2, color='lightgreen', label='95% PI (Training)')\n",
    "    if len(dates[train_end_idx:val_end_idx]) > 0 and len(lower_val) > 0:\n",
    "        ax.fill_between(dates[train_end_idx:val_end_idx], lower_val, upper_val,\n",
    "                        alpha=0.2, color='moccasin', label='95% PI (Validation)')\n",
    "    if len(dates[val_end_idx:]) > 0 and len(lower_test) > 0:\n",
    "        ax.fill_between(dates[val_end_idx:], lower_test, upper_test,\n",
    "                        alpha=0.2, color='lightcoral', label='95% PI (Test)')\n",
    "    \n",
    "    ax.axvline(x=dates[train_end_idx], color='black', linestyle='--', alpha=0.5, linewidth=1)\n",
    "    ax.axvline(x=dates[val_end_idx], color='black', linestyle='--', alpha=0.5, linewidth=1)\n",
    "    \n",
    "    if train_end_idx > 0:\n",
    "        ax.text(dates[train_end_idx//2], ax.get_ylim()[1]*0.9, 'Training',\n",
    "                ha='center', va='top', fontsize=14, fontweight='bold',\n",
    "                bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.7))\n",
    "    if val_end_idx - train_end_idx > 0:\n",
    "        ax.text(dates[train_end_idx + (val_end_idx - train_end_idx)//2], ax.get_ylim()[1]*0.9, 'Validation',\n",
    "                ha='center', va='top', fontsize=14, fontweight='bold',\n",
    "                bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.7))\n",
    "    if len(dates) - val_end_idx > 0:\n",
    "        ax.text(dates[val_end_idx + (len(dates) - val_end_idx)//2], ax.get_ylim()[1]*0.9, 'Test',\n",
    "                ha='center', va='top', fontsize=14, fontweight='bold',\n",
    "                bbox=dict(boxstyle='round', facecolor='lightcoral', alpha=0.7))\n",
    "    \n",
    "    ax.set_xlabel('Year', fontweight='bold', fontsize=14)\n",
    "    ax.set_ylabel('PM$_{2.5}$ (μg/m³)', fontweight='bold', fontsize=14)\n",
    "    ax.set_title(f'Complete PM$_{{2.5}}$ Time Series Prediction - {city_name}',\n",
    "                fontweight='bold', fontsize=16)\n",
    "    ax.legend(fontsize=12, loc='upper left')\n",
    "    ax.tick_params(axis='both', which='major', labelsize=12, width=2)\n",
    "    ax.set_xlim(pd.to_datetime('1980-01-01'), pd.to_datetime('2023-12-31'))\n",
    "    ax.tick_params(axis='x', rotation=45)\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.savefig(save_path, dpi=300, bbox_inches='tight')\n",
    "    plt.close(fig)\n",
    "\n",
    "def plot_model_comparison(results, city_name, save_path):\n",
    "    \"\"\"Plot model comparison\"\"\"\n",
    "    fig, axes = plt.subplots(2, 2, figsize=(15, 10))\n",
    "    fig.suptitle(f'Model Performance Comparison - {city_name}', fontsize=16, fontweight='bold')\n",
    "    \n",
    "    for ax in axes.flat:\n",
    "        ax.grid(False)\n",
    "    \n",
    "    model_names = list(results.keys())\n",
    "    metrics = ['R²', 'RMSE', 'NSE', 'PBIAS']\n",
    "    colors = ['skyblue', 'lightcoral', 'lightgreen', 'orange']\n",
    "    \n",
    "    for i, metric in enumerate(metrics):\n",
    "        ax = axes[i//2, i%2]\n",
    "        values = [results[name][metric] for name in model_names]\n",
    "        \n",
    "        bars = ax.bar(range(len(model_names)), values,\n",
    "                     color=[colors[i] if 'Corrected' not in name else 'purple' for name in model_names],\n",
    "                     alpha=0.7, edgecolor='black', linewidth=2)\n",
    "        \n",
    "        ax.set_xlabel('Models', fontweight='bold', fontsize=12)\n",
    "        ax.set_ylabel(f'{metric}', fontweight='bold', fontsize=12)\n",
    "        ax.set_title(f'{metric} Comparison', fontweight='bold', fontsize=14)\n",
    "        ax.set_xticks(range(len(model_names)))\n",
    "        ax.set_xticklabels(model_names, rotation=45, ha='right', fontweight='bold')\n",
    "        ax.tick_params(axis='both', which='major', labelsize=11, width=2)\n",
    "        \n",
    "        for bar, value in zip(bars, values):\n",
    "            height = bar.get_height()\n",
    "            ax.text(bar.get_x() + bar.get_width()/2., height,\n",
    "                   f'{value:.3f}', ha='center', va='bottom', fontweight='bold', fontsize=10)\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.savefig(save_path, dpi=300, bbox_inches='tight')\n",
    "    plt.close(fig)\n",
    "\n",
    "# ---- Main Processing Function ----\n",
    "def process_city_data(city_name, sheet_name):\n",
    "    \"\"\"Process data for a given city with bias correction\"\"\"\n",
    "    print(f\"\\nProcessing {city_name} City Data\")\n",
    "    print(\"=\" * 50)\n",
    "    \n",
    "    try:\n",
    "        df = pd.read_excel(excel_file, sheet_name=sheet_name)\n",
    "    except Exception as e:\n",
    "        print(f\"Error loading data for {city_name}: {e}\")\n",
    "        return None\n",
    "    \n",
    "    if 'Date' not in df.columns:\n",
    "        print(f\"Error: 'Date' column not found. Skipping.\")\n",
    "        return None\n",
    "    \n",
    "    df['Date'] = pd.to_datetime(df['Date'])\n",
    "    df.set_index('Date', inplace=True)\n",
    "    df = df.loc[(df.index >= '1980-01-01') & (df.index <= '2023-12-31')]\n",
    "    df.dropna(inplace=True)\n",
    "    \n",
    "    if df.empty:\n",
    "        print(f\"Warning: Empty dataframe for {city_name}. Skipping.\")\n",
    "        return None\n",
    "    \n",
    "    print(f\"Data shape: {df.shape}\")\n",
    "    print(f\"Date range: {df.index.min()} to {df.index.max()}\")\n",
    "    \n",
    "    for i in range(1, 6):\n",
    "        for col in ['PM25', 'P', 'Tmin', 'Tmax', 'VI', 'BLH', 'Wind']:\n",
    "            if col in df.columns:\n",
    "                df[f'{col}_prev_{i}'] = df[col].shift(i)\n",
    "    \n",
    "    df.dropna(inplace=True)\n",
    "    df = df.apply(pd.to_numeric, errors='coerce').dropna()\n",
    "    \n",
    "    if df.empty:\n",
    "        print(f\"Warning: Empty after feature engineering. Skipping.\")\n",
    "        return None\n",
    "    \n",
    "    print(f\"After feature engineering: {df.shape}\")\n",
    "    final_dates = df.index.copy()\n",
    "    \n",
    "    if 'PM25' not in df.columns:\n",
    "        print(f\"Error: PM25 column missing. Skipping.\")\n",
    "        return None\n",
    "    \n",
    "    try:\n",
    "        wavelet = 'db4'\n",
    "        data_len = len(df['PM25'])\n",
    "        if data_len == 0:\n",
    "            df['PM25_wavelet'] = df['PM25']\n",
    "        else:\n",
    "            max_level = pywt.dwt_max_level(data_len, wavelet)\n",
    "            level = min(3, max_level)\n",
    "            coeffs = pywt.wavedec(df['PM25'], wavelet, level=level)\n",
    "            for i in range(1, len(coeffs)):\n",
    "                coeffs[i] = np.zeros_like(coeffs[i])\n",
    "            pm25_wavelet = pywt.waverec(coeffs, wavelet)\n",
    "            df['PM25_wavelet'] = pm25_wavelet[:len(df)]\n",
    "    except ValueError as e:\n",
    "        print(f\"Wavelet error: {e}. Using original.\")\n",
    "        df['PM25_wavelet'] = df['PM25']\n",
    "    \n",
    "    feature_cols = ['PM25_wavelet'] + [f'PM25_prev_{i}' for i in range(1, 6)]\n",
    "    feature_cols = [col for col in feature_cols if col in df.columns]\n",
    "    \n",
    "    X = df[feature_cols].values\n",
    "    y = df['PM25'].values\n",
    "    \n",
    "    if X.shape[0] < 10:\n",
    "        print(f\"Insufficient data. Skipping.\")\n",
    "        return None\n",
    "    \n",
    "    total_samples = X.shape[0]\n",
    "    test_size = int(total_samples * 0.1)\n",
    "    val_size = int(total_samples * 0.2)\n",
    "    train_size = total_samples - test_size - val_size\n",
    "    \n",
    "    if train_size <= 0 or val_size <= 0 or test_size <= 0:\n",
    "        print(f\"Cannot create valid splits. Skipping.\")\n",
    "        return None\n",
    "    \n",
    "    X_train = X[:train_size]\n",
    "    y_train = y[:train_size]\n",
    "    X_val = X[train_size:train_size + val_size]\n",
    "    y_val = y[train_size:train_size + val_size]\n",
    "    X_test = X[train_size + val_size:]\n",
    "    y_test = y[train_size + val_size:]\n",
    "    \n",
    "    train_end_idx = len(X_train)\n",
    "    val_end_idx = len(X_train) + len(X_val)\n",
    "    \n",
    "    print(f\"Splits - Train: {X_train.shape[0]}, Val: {X_val.shape[0]}, Test: {X_test.shape[0]}\")\n",
    "    \n",
    "    input_shape = (X_train.shape[1], 1)\n",
    "    X_train_reshaped = X_train.reshape(X_train.shape[0], X_train.shape[1], 1)\n",
    "    X_val_reshaped = X_val.reshape(X_val.shape[0], X_val.shape[1], 1)\n",
    "    X_test_reshaped = X_test.reshape(X_test.shape[0], X_test.shape[1], 1)\n",
    "    \n",
    "    city_plot_dir = os.path.join(base_plot_dir, city_name.replace(\" \", \"_\"))\n",
    "    os.makedirs(city_plot_dir, exist_ok=True)\n",
    "    \n",
    "    models = {}\n",
    "    histories = {}\n",
    "    \n",
    "    early_stopping = EarlyStopping(monitor='val_loss', patience=20, restore_best_weights=True)\n",
    "    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=10, min_lr=1e-7)\n",
    "    \n",
    "    print(\"Training CNN...\")\n",
    "    cnn_model = build_cnn_model(input_shape)\n",
    "    cnn_path = os.path.join(city_plot_dir, f'best_cnn_{city_name.lower().replace(\" \", \"_\")}.keras')\n",
    "    cnn_checkpoint = ModelCheckpoint(cnn_path, monitor='val_loss', save_best_only=True)\n",
    "    cnn_history = cnn_model.fit(X_train_reshaped, y_train,\n",
    "                               validation_data=(X_val_reshaped, y_val),\n",
    "                               epochs=50, batch_size=32, verbose=0,\n",
    "                               callbacks=[early_stopping, reduce_lr, cnn_checkpoint])\n",
    "    cnn_model = tf.keras.models.load_model(cnn_path)\n",
    "    models['CNN'] = cnn_model\n",
    "    histories['CNN'] = cnn_history\n",
    "    \n",
    "    print(\"Training LSTM...\")\n",
    "    lstm_model = build_lstm_model(input_shape)\n",
    "    lstm_path = os.path.join(city_plot_dir, f'best_lstm_{city_name.lower().replace(\" \", \"_\")}.keras')\n",
    "    lstm_checkpoint = ModelCheckpoint(lstm_path, monitor='val_loss', save_best_only=True)\n",
    "    lstm_history = lstm_model.fit(X_train_reshaped, y_train,\n",
    "                                 validation_data=(X_val_reshaped, y_val),\n",
    "                                 epochs=50, batch_size=32, verbose=0,\n",
    "                                 callbacks=[early_stopping, reduce_lr, lstm_checkpoint])\n",
    "    lstm_model = tf.keras.models.load_model(lstm_path)\n",
    "    models['LSTM'] = lstm_model\n",
    "    histories['LSTM'] = lstm_history\n",
    "    \n",
    "    print(\"Training GRU...\")\n",
    "    gru_model = build_gru_model(input_shape)\n",
    "    gru_path = os.path.join(city_plot_dir, f'best_gru_{city_name.lower().replace(\" \", \"_\")}.keras')\n",
    "    gru_checkpoint = ModelCheckpoint(gru_path, monitor='val_loss', save_best_only=True)\n",
    "    gru_history = gru_model.fit(X_train_reshaped, y_train,\n",
    "                               validation_data=(X_val_reshaped, y_val),\n",
    "                               epochs=50, batch_size=32, verbose=0,\n",
    "                               callbacks=[early_stopping, reduce_lr, gru_checkpoint])\n",
    "    gru_model = tf.keras.models.load_model(gru_path)\n",
    "    models['GRU'] = gru_model\n",
    "    histories['GRU'] = gru_history\n",
    "    \n",
    "    print(\"Creating ensemble...\")\n",
    "    ensemble = SimpleEnsemble()\n",
    "    ensemble.fit(models, X_val_reshaped, y_val)\n",
    "    \n",
    "    X_all_reshaped = X.reshape(X.shape[0], X.shape[1], 1)\n",
    "    ensemble_pred_all = ensemble.predict(X_all_reshaped)\n",
    "    \n",
    "    print(\"\\n\" + \"=\" * 60)\n",
    "    print(\"BIAS CORRECTION\")\n",
    "    print(\"=\" * 60)\n",
    "    \n",
    "    station_obs = load_station_data(station_data_file, city_name)\n",
    "    \n",
    "    if station_obs is not None:\n",
    "        bias_correction, bias_stats = calculate_bias_correction(\n",
    "            ensemble_pred_all, station_obs, final_dates, station_obs.index\n",
    "        )\n",
    "        ensemble_pred_all_corrected = apply_bias_correction(ensemble_pred_all, bias_correction)\n",
    "        ensemble_pred_train_corrected = ensemble_pred_all_corrected[:train_end_idx]\n",
    "        ensemble_pred_val_corrected = ensemble_pred_all_corrected[train_end_idx:val_end_idx]\n",
    "        ensemble_pred_test_corrected = ensemble_pred_all_corrected[val_end_idx:]\n",
    "    else:\n",
    "        print(\"Warning: No bias correction applied\")\n",
    "        bias_correction = 0.0\n",
    "        bias_stats = None\n",
    "        ensemble_pred_all_corrected = ensemble_pred_all.copy()\n",
    "        ensemble_pred_train_corrected = ensemble_pred_all[:train_end_idx]\n",
    "        ensemble_pred_val_corrected = ensemble_pred_all[train_end_idx:val_end_idx]\n",
    "        ensemble_pred_test_corrected = ensemble_pred_all[val_end_idx:]\n",
    "    \n",
    "    results = {}\n",
    "    \n",
    "    for phase, (X_data, y_data) in [\n",
    "        ('Train', (X_train_reshaped, y_train)),\n",
    "        ('Validation', (X_val_reshaped, y_val)),\n",
    "        ('Test', (X_test_reshaped, y_test))\n",
    "    ]:\n",
    "        phase_results = {}\n",
    "        \n",
    "        for name, model in models.items():\n",
    "            pred = model.predict(X_data, verbose=0).flatten()\n",
    "            \n",
    "            mse = mean_squared_error(y_data, pred)\n",
    "            rmse = np.sqrt(mse)\n",
    "            mae = mean_absolute_error(y_data, pred)\n",
    "            r2 = r2_score(y_data, pred)\n",
    "            r = np.corrcoef(y_data, pred)[0, 1] if len(np.unique(pred)) > 1 else 0\n",
    "            nse = 1 - (np.sum((y_data - pred) ** 2) / np.sum((y_data - np.mean(y_data)) ** 2))\n",
    "            pbias = 100 * (np.sum(y_data - pred) / np.sum(y_data))\n",
    "            \n",
    "            phase_results[name] = {\n",
    "                'predictions': pred,\n",
    "                'MSE': mse,\n",
    "                'RMSE': rmse,\n",
    "                'MAE': mae,\n",
    "                'R²': r2,\n",
    "                'r': r,\n",
    "                'NSE': nse,\n",
    "                'PBIAS': pbias\n",
    "            }\n",
    "        \n",
    "        ensemble_pred = ensemble.predict(X_data)\n",
    "        \n",
    "        mse = mean_squared_error(y_data, ensemble_pred)\n",
    "        rmse = np.sqrt(mse)\n",
    "        mae = mean_absolute_error(y_data, ensemble_pred)\n",
    "        r2 = r2_score(y_data, ensemble_pred)\n",
    "        r = np.corrcoef(y_data, ensemble_pred)[0, 1] if len(np.unique(ensemble_pred)) > 1 else 0\n",
    "        nse = 1 - (np.sum((y_data - ensemble_pred) ** 2) / np.sum((y_data - np.mean(y_data)) ** 2))\n",
    "        pbias = 100 * (np.sum(y_data - ensemble_pred) / np.sum(y_data))\n",
    "        \n",
    "        phase_results['Ensemble'] = {\n",
    "            'predictions': ensemble_pred,\n",
    "            'MSE': mse,\n",
    "            'RMSE': rmse,\n",
    "            'MAE': mae,\n",
    "            'R²': r2,\n",
    "            'r': r,\n",
    "            'NSE': nse,\n",
    "            'PBIAS': pbias\n",
    "        }\n",
    "        \n",
    "        if phase == 'Train':\n",
    "            ensemble_pred_corrected = ensemble_pred_train_corrected\n",
    "        elif phase == 'Validation':\n",
    "            ensemble_pred_corrected = ensemble_pred_val_corrected\n",
    "        else:\n",
    "            ensemble_pred_corrected = ensemble_pred_test_corrected\n",
    "        \n",
    "        mse_corr = mean_squared_error(y_data, ensemble_pred_corrected)\n",
    "        rmse_corr = np.sqrt(mse_corr)\n",
    "        mae_corr = mean_absolute_error(y_data, ensemble_pred_corrected)\n",
    "        r2_corr = r2_score(y_data, ensemble_pred_corrected)\n",
    "        r_corr = np.corrcoef(y_data, ensemble_pred_corrected)[0, 1]\n",
    "        nse_corr = 1 - (np.sum((y_data - ensemble_pred_corrected) ** 2) / np.sum((y_data - np.mean(y_data)) ** 2))\n",
    "        pbias_corr = 100 * (np.sum(y_data - ensemble_pred_corrected) / np.sum(y_data))\n",
    "        \n",
    "        phase_results['Ensemble_BiasCorrect'] = {\n",
    "            'predictions': ensemble_pred_corrected,\n",
    "            'MSE': mse_corr,\n",
    "            'RMSE': rmse_corr,\n",
    "            'MAE': mae_corr,\n",
    "            'R²': r2_corr,\n",
    "            'r': r_corr,\n",
    "            'NSE': nse_corr,\n",
    "            'PBIAS': pbias_corr\n",
    "        }\n",
    "        \n",
    "        results[phase] = phase_results\n",
    "    \n",
    "    print(\"\\n\" + \"=\"*60)\n",
    "    print(f\"RESULTS SUMMARY - {city_name.upper()}\")\n",
    "    print(\"=\"*60)\n",
    "    \n",
    "    for phase in ['Train', 'Validation', 'Test']:\n",
    "        print(f\"\\n{phase}:\")\n",
    "        print(\"-\" * 40)\n",
    "        for model_name in ['CNN', 'LSTM', 'GRU', 'Ensemble', 'Ensemble_BiasCorrect']:\n",
    "            metrics = results[phase][model_name]\n",
    "            print(f\"{model_name:20} - R²: {metrics['R²']:.3f}, RMSE: {metrics['RMSE']:.2f}, NSE: {metrics['NSE']:.3f}\")\n",
    "    \n",
    "    print(f\"\\nCreating visualizations...\")\n",
    "    \n",
    "    test_ensemble_pred = results['Test']['Ensemble']['predictions']\n",
    "    test_ensemble_pred_corrected = results['Test']['Ensemble_BiasCorrect']['predictions']\n",
    "    \n",
    "    plot_scatter_comparison(y_test, test_ensemble_pred, city_name, 'Ensemble (Original)',\n",
    "                          os.path.join(city_plot_dir, f'{city_name.lower().replace(\" \", \"_\")}_ensemble_scatter_original.png'))\n",
    "    \n",
    "    plot_scatter_comparison(y_test, test_ensemble_pred_corrected, city_name, 'Ensemble (Bias-Corrected)',\n",
    "                          os.path.join(city_plot_dir, f'{city_name.lower().replace(\" \", \"_\")}_ensemble_scatter_corrected.png'))\n",
    "    \n",
    "    if station_obs is not None:\n",
    "        test_dates = final_dates[val_end_idx:]\n",
    "        common_test_dates = test_dates.intersection(station_obs.index)\n",
    "        \n",
    "        if len(common_test_dates) > 0:\n",
    "            test_indices = [i for i, date in enumerate(test_dates) if date in common_test_dates]\n",
    "            y_test_overlap = y_test[test_indices]\n",
    "            pred_original_overlap = test_ensemble_pred[test_indices]\n",
    "            pred_corrected_overlap = test_ensemble_pred_corrected[test_indices]\n",
    "            \n",
    "            plot_bias_correction_comparison(y_test_overlap, pred_original_overlap, \n",
    "                                          pred_corrected_overlap, city_name,\n",
    "                                          os.path.join(city_plot_dir, f'{city_name.lower().replace(\" \", \"_\")}_bias_correction_comparison.png'))\n",
    "    \n",
    "    plot_model_comparison(results['Test'], city_name,\n",
    "                         os.path.join(city_plot_dir, f'{city_name.lower().replace(\" \", \"_\")}_model_comparison.png'))\n",
    "    \n",
    "    plot_complete_time_series_with_uncertainty(final_dates, y, ensemble_pred_all_corrected,\n",
    "                                                 train_end_idx, val_end_idx,\n",
    "                                                 results['Train']['Ensemble_BiasCorrect']['predictions'],\n",
    "                                                 results['Validation']['Ensemble_BiasCorrect']['predictions'],\n",
    "                                                 results['Test']['Ensemble_BiasCorrect']['predictions'],\n",
    "                                                 y_train, y_val, y_test,\n",
    "                                                 f'{city_name} (Bias-Corrected)',\n",
    "                                                 os.path.join(city_plot_dir, f'{city_name.lower().replace(\" \", \"_\")}_complete_timeseries_corrected.png'))\n",
    "    \n",
    "    fig, axes = plt.subplots(1, 3, figsize=(18, 6))\n",
    "    fig.suptitle(f'Learning Curves - {city_name}', fontsize=16, fontweight='bold')\n",
    "    \n",
    "    for i, (name, history) in enumerate(histories.items()):\n",
    "        ax = axes[i]\n",
    "        ax.grid(False)\n",
    "        ax.plot(history.history['loss'], label='Training Loss', color='blue', linewidth=2)\n",
    "        ax.plot(history.history['val_loss'], label='Validation Loss', color='red', linewidth=2)\n",
    "        ax.set_xlabel('Epochs', fontweight='bold')\n",
    "        ax.set_ylabel('Loss (MSE)', fontweight='bold')\n",
    "        ax.set_title(f'{name} Learning Curve', fontweight='bold')\n",
    "        ax.legend()\n",
    "        ax.set_yscale('log')\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.savefig(os.path.join(city_plot_dir, f'{city_name.lower().replace(\" \", \"_\")}_learning_curves.png'), dpi=300, bbox_inches='tight')\n",
    "    plt.close(fig)\n",
    "    \n",
    "    print(f\"Saving results to Excel...\")\n",
    "    \n",
    "    results_data = []\n",
    "    for phase in ['Train', 'Validation', 'Test']:\n",
    "        for model_name in ['CNN', 'LSTM', 'GRU', 'Ensemble', 'Ensemble_BiasCorrect']:\n",
    "            metrics = results[phase][model_name]\n",
    "            results_data.append({\n",
    "                'Phase': phase,\n",
    "                'Model': model_name,\n",
    "                'MSE': metrics['MSE'],\n",
    "                'RMSE': metrics['RMSE'],\n",
    "                'MAE': metrics['MAE'],\n",
    "                'R²': metrics['R²'],\n",
    "                'r': metrics['r'],\n",
    "                'NSE': metrics['NSE'],\n",
    "                'PBIAS': metrics['PBIAS']\n",
    "            })\n",
    "    \n",
    "    results_df = pd.DataFrame(results_data)\n",
    "    \n",
    "    test_residuals_corrected = y_test - test_ensemble_pred_corrected\n",
    "    lower_95, upper_95 = calculate_prediction_intervals(test_ensemble_pred_corrected, test_residuals_corrected)\n",
    "    \n",
    "    predictions_df = pd.DataFrame({\n",
    "        'Date': final_dates[val_end_idx:],\n",
    "        'Actual': y_test,\n",
    "        'CNN': results['Test']['CNN']['predictions'],\n",
    "        'LSTM': results['Test']['LSTM']['predictions'],\n",
    "        'GRU': results['Test']['GRU']['predictions'],\n",
    "        'Ensemble_Original': results['Test']['Ensemble']['predictions'],\n",
    "        'Ensemble_BiasCorrect': results['Test']['Ensemble_BiasCorrect']['predictions'],\n",
    "        'Lower_95_Corrected': lower_95,\n",
    "        'Upper_95_Corrected': upper_95\n",
    "    })\n",
    "    \n",
    "    complete_predictions_df = pd.DataFrame({\n",
    "        'Date': final_dates,\n",
    "        'Actual': y,\n",
    "        'Ensemble_Original': ensemble_pred_all,\n",
    "        'Ensemble_BiasCorrect': ensemble_pred_all_corrected\n",
    "    })\n",
    "    \n",
    "    if bias_stats is not None:\n",
    "        bias_df = pd.DataFrame([{\n",
    "            'Mean_Bias': bias_stats['mean_bias'],\n",
    "            'Median_Bias': bias_stats['median_bias'],\n",
    "            'Std_Bias': bias_stats['std_bias'],\n",
    "            'Relative_Bias_%': bias_stats['relative_bias'],\n",
    "            'N_Samples': bias_stats['n_samples'],\n",
    "            'Date_Range': bias_stats['date_range'],\n",
    "            'Correction_Applied': bias_correction\n",
    "        }])\n",
    "    else:\n",
    "        bias_df = pd.DataFrame([{\n",
    "            'Mean_Bias': 'N/A',\n",
    "            'Median_Bias': 'N/A',\n",
    "            'Std_Bias': 'N/A',\n",
    "            'Relative_Bias_%': 'N/A',\n",
    "            'N_Samples': 'N/A',\n",
    "            'Date_Range': 'N/A',\n",
    "            'Correction_Applied': 0.0\n",
    "        }])\n",
    "    \n",
    "    excel_output_path = os.path.join(city_plot_dir, f'{city_name.lower().replace(\" \", \"_\")}_pm25_results_with_bias_correction.xlsx')\n",
    "    with pd.ExcelWriter(excel_output_path, engine='openpyxl') as writer:\n",
    "        results_df.to_excel(writer, sheet_name='Model_Metrics', index=False)\n",
    "        predictions_df.to_excel(writer, sheet_name='Test_Predictions', index=False)\n",
    "        complete_predictions_df.to_excel(writer, sheet_name='Complete_Predictions', index=False)\n",
    "        bias_df.to_excel(writer, sheet_name='Bias_Correction_Stats', index=False)\n",
    "        \n",
    "        weights_df = pd.DataFrame([\n",
    "            {'Model': name, 'Weight': weight}\n",
    "            for name, weight in ensemble.weights.items()\n",
    "        ])\n",
    "        weights_df.to_excel(writer, sheet_name='Ensemble_Weights', index=False)\n",
    "    \n",
    "    ensemble_file_name = f'{city_name.lower().replace(\" \", \"_\")}_ensemble.pkl'\n",
    "    with open(os.path.join(city_plot_dir, ensemble_file_name), 'wb') as f:\n",
    "        pickle.dump(ensemble, f)\n",
    "    \n",
    "    with open(os.path.join(city_plot_dir, f'{city_name.lower().replace(\" \", \"_\")}_bias_correction.pkl'), 'wb') as f:\n",
    "        pickle.dump({'bias_correction': bias_correction, 'bias_stats': bias_stats}, f)\n",
    "    \n",
    "    print(f\"\\nResults saved:\")\n",
    "    print(f\"- Excel: {excel_output_path}\")\n",
    "    print(f\"- Plots: {city_plot_dir}/\")\n",
    "    print(f\"- Models: {city_plot_dir}/\")\n",
    "    \n",
    "    if bias_stats is not None:\n",
    "        print(f\"\\nBias Correction Applied:\")\n",
    "        print(f\"- Mean Bias: {bias_correction:.3f} μg/m³\")\n",
    "        print(f\"- Relative Bias: {bias_stats['relative_bias']:.2f}%\")\n",
    "    \n",
    "    return results\n",
    "\n",
    "# ---- Main Execution ----\n",
    "if __name__ == \"__main__\":\n",
    "    try:\n",
    "        xls = pd.ExcelFile(excel_file)\n",
    "        city_sheets = xls.sheet_names\n",
    "        \n",
    "        print(f\"Found {len(city_sheets)} cities: {', '.join(city_sheets)}\")\n",
    "        print(\"\\n\" + \"=\"*80)\n",
    "        \n",
    "        for sheet_name in city_sheets:\n",
    "            city_name = sheet_name.replace(\"_\", \" \")\n",
    "            \n",
    "            process_city_data(city_name, sheet_name)\n",
    "            print(\"\\n\" + \"=\"*80)\n",
    "            print(f\"COMPLETED: {city_name.upper()}\")\n",
    "            print(\"=\"*80 + \"\\n\")\n",
    "        \n",
    "        print(\"\\n\" + \"=\"*80)\n",
    "        print(\"ALL CITIES PROCESSED SUCCESSFULLY!\")\n",
    "        print(\"=\"*80)\n",
    "        \n",
    "    except Exception as e:\n",
    "        print(f\"Error: {str(e)}\")\n",
    "        import traceback\n",
    "        traceback.print_exc()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35241a66-9c48-4df4-af36-bab3b73ab459",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import numpy as np\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "import netCDF4 as nc\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "# Define the cities with their coordinates and approximate areas\n",
    "cities = {\n",
    "    'Indore': {'coords': (22.719568, 75.857727), 'area_km2': 525},\n",
    "    'Bhopal': {'coords': (23.259933, 77.412613), 'area_km2': 463},\n",
    "    'Jabalpur': {'coords': (23.185884, 79.97438), 'area_km2': 374},\n",
    "    'Gwalior': {'coords': (26.218287, 78.182832), 'area_km2': 289},\n",
    "    'Ujjain': {'coords': (23.1793, 75.784912), 'area_km2': 152},\n",
    "    'Rewa': {'coords': (24.530727, 81.29911), 'area_km2': 100},\n",
    "    'Sagar': {'coords': (23.83403, 78.746567), 'area_km2': 78},\n",
    "    'Ratlam': {'coords': (23.334169, 75.037363), 'area_km2': 75},\n",
    "    'Chhindwara': {'coords': (22.057163, 78.938202), 'area_km2': 65},\n",
    "    'Dewas': {'coords': (22.962267, 76.050797), 'area_km2': 50}\n",
    "}\n",
    "\n",
    "# File paths\n",
    "pop_path = r\"D:\\MP_Aerosol\\Cities\\Pop\\Masked_Data\\Clipped_gpw_v4_population_count_rev11_2pt5_min.nc\"\n",
    "pm25_file_path = r'D:\\MP_Aerosol\\Cities\\city_pm25_data_1980_2023.xlsx'\n",
    "\n",
    "def calculate_grid_count_from_area(area_km2, resolution_deg=2.5/60):\n",
    "    \"\"\"\n",
    "    Calculate the number of grid cells needed to cover a given area.\n",
    "    \"\"\"\n",
    "    km_per_degree = 111.0\n",
    "    cell_size_km = resolution_deg * km_per_degree\n",
    "    cell_area_km2 = cell_size_km ** 2\n",
    "    \n",
    "    num_cells = area_km2 / cell_area_km2\n",
    "    grid_size = int(np.sqrt(num_cells))\n",
    "    \n",
    "    # Ensure minimum grid size of 3x3 and make it odd for better centering\n",
    "    grid_size = max(3, grid_size)\n",
    "    if grid_size % 2 == 0:\n",
    "        grid_size += 1\n",
    "    \n",
    "    grid_count = grid_size ** 2\n",
    "    \n",
    "    return grid_count, grid_size\n",
    "\n",
    "def extract_city_data_from_netcdf(lat, lon, data_grid, lat_grid, lon_grid, grid_count=400):\n",
    "    \"\"\"\n",
    "    Extract data for a city location using a square grid pattern.\n",
    "    \"\"\"\n",
    "    grid_size = int(np.sqrt(grid_count))\n",
    "    if grid_size ** 2 != grid_count:\n",
    "        grid_size = int(np.round(np.sqrt(grid_count)))\n",
    "        grid_count = grid_size ** 2\n",
    "    \n",
    "    half_size = grid_size // 2\n",
    "\n",
    "    # Find the closest grid point to the city coordinates\n",
    "    distances = np.sqrt((lat_grid - lat)**2 + (lon_grid - lon)**2)\n",
    "    center_idx = np.unravel_index(np.argmin(distances), distances.shape)\n",
    "    center_row, center_col = center_idx\n",
    "\n",
    "    # Calculate the bounds for the grid window\n",
    "    row_start = max(0, center_row - half_size)\n",
    "    row_end = min(data_grid.shape[0], center_row + half_size + 1)\n",
    "    col_start = max(0, center_col - half_size)\n",
    "    col_end = min(data_grid.shape[1], center_col + half_size + 1)\n",
    "\n",
    "    # Extract the grid data\n",
    "    grid_data = data_grid[row_start:row_end, col_start:col_end]\n",
    "    total_value = np.nansum(grid_data)\n",
    "\n",
    "    return total_value\n",
    "\n",
    "print(\"=\"*80)\n",
    "print(\"LOADING POPULATION DATA FROM NETCDF\")\n",
    "print(\"=\"*80)\n",
    "\n",
    "# Load population data\n",
    "ds_pop = nc.Dataset(pop_path)\n",
    "pop_var = list(ds_pop.variables.keys())[-1]\n",
    "pop_data_latest = ds_pop.variables[pop_var][-1]  # Get the latest year (2020)\n",
    "lat_pop = ds_pop.variables['latitude'][:]\n",
    "lon_pop = ds_pop.variables['longitude'][:]\n",
    "lon_grid_pop, lat_grid_pop = np.meshgrid(lon_pop, lat_pop)\n",
    "ds_pop.close()\n",
    "\n",
    "print(f\"✅ Population data loaded successfully\")\n",
    "print(f\"    Population data shape: {pop_data_latest.shape}\")\n",
    "print(f\"    Lat range: {lat_pop.min():.3f} to {lat_pop.max():.3f}\")\n",
    "print(f\"    Lon range: {lon_pop.min():.3f} to {lon_pop.max():.3f}\")\n",
    "\n",
    "# Calculate grid resolution\n",
    "lat_resolution = np.abs(np.mean(np.diff(lat_pop)))\n",
    "lon_resolution = np.abs(np.mean(np.diff(lon_pop)))\n",
    "print(f\"    Grid resolution: {lat_resolution:.4f}° (≈{lat_resolution*60:.2f} arcminutes)\")\n",
    "\n",
    "# Calculate grid sizes and extract population for each city\n",
    "print(f\"\\n{'='*80}\")\n",
    "print(\"EXTRACTING POPULATION DATA FOR EACH CITY\")\n",
    "print(\"=\"*80)\n",
    "\n",
    "city_population = {}\n",
    "city_grid_info = {}\n",
    "\n",
    "for city, info in cities.items():\n",
    "    lat, lon = info['coords']\n",
    "    area_km2 = info['area_km2']\n",
    "    \n",
    "    # Calculate optimal grid\n",
    "    grid_count, grid_size = calculate_grid_count_from_area(area_km2, resolution_deg=lat_resolution)\n",
    "    city_grid_info[city] = {'grid_count': grid_count, 'grid_size': grid_size}\n",
    "    \n",
    "    # Extract population\n",
    "    population = extract_city_data_from_netcdf(lat, lon, pop_data_latest, lat_grid_pop, lon_grid_pop, grid_count)\n",
    "    city_population[city] = population\n",
    "    \n",
    "    actual_area_covered = grid_count * (lat_resolution * 111) ** 2\n",
    "    print(f\"{city:12s}: Area={area_km2:3.0f} km², Grid={grid_size}x{grid_size} ({grid_count:3d} cells), \"\n",
    "          f\"Pop={population:,.0f}, Covers≈{actual_area_covered:.0f} km²\")\n",
    "\n",
    "print(f\"✅ Population extraction completed for {len(city_population)} cities\")\n",
    "\n",
    "# Load PM2.5 data from Excel\n",
    "print(f\"\\n{'='*80}\")\n",
    "print(\"LOADING PM2.5 DATA FROM EXCEL\")\n",
    "print(\"=\"*80)\n",
    "\n",
    "xls = pd.ExcelFile(pm25_file_path)\n",
    "dfs = pd.read_excel(xls, sheet_name=None)\n",
    "\n",
    "print(f\"Available sheets: {xls.sheet_names}\")\n",
    "\n",
    "# Extract latest year data for each city\n",
    "latest_data = {}\n",
    "\n",
    "# Check if sheets are named by cities\n",
    "if any(city in xls.sheet_names for city in cities.keys()):\n",
    "    print(\"Processing city-based sheets...\")\n",
    "    for sheet_name, df in dfs.items():\n",
    "        if sheet_name in cities.keys():\n",
    "            # Get the most recent year's data\n",
    "            df['Year'] = pd.to_datetime(df['Year'], format='%Y')\n",
    "            latest_row = df.loc[df['Year'].idxmax()]\n",
    "            latest_data[sheet_name] = {\n",
    "                'Annual_Mean_PM25': latest_row['Annual_Mean_PM25'],\n",
    "                'Annual_Max_PM25': latest_row['Annual_Max_PM25'],\n",
    "                'PM2.5D': latest_row['PM2.5D'],\n",
    "                'Population': city_population[sheet_name],\n",
    "                'Area_km2': cities[sheet_name]['area_km2'],\n",
    "                'Grid_Size': city_grid_info[sheet_name]['grid_size']\n",
    "            }\n",
    "            print(f\"  {sheet_name}: Latest year = {latest_row['Year'].year}, \"\n",
    "                  f\"Mean PM2.5 = {latest_row['Annual_Mean_PM25']:.2f} μg/m³\")\n",
    "else:\n",
    "    # If sheets are named by years, extract data for all cities from the latest year\n",
    "    print(\"Processing year-based sheets...\")\n",
    "    latest_year = max([s for s in xls.sheet_names if s.isdigit()], key=lambda x: int(x))\n",
    "    print(f\"Using latest year: {latest_year}\")\n",
    "    latest_df = dfs[latest_year]\n",
    "    \n",
    "    # Assuming the dataframe has a 'City' column\n",
    "    for _, row in latest_df.iterrows():\n",
    "        city_name = row['City']  # Adjust column name as needed\n",
    "        if city_name in cities.keys():\n",
    "            latest_data[city_name] = {\n",
    "                'Annual_Mean_PM25': row['Annual_Mean_PM25'],\n",
    "                'Annual_Max_PM25': row['Annual_Max_PM25'],\n",
    "                'PM2.5D': row['PM2.5D'],\n",
    "                'Population': city_population[city_name],\n",
    "                'Area_km2': cities[city_name]['area_km2'],\n",
    "                'Grid_Size': city_grid_info[city_name]['grid_size']\n",
    "            }\n",
    "\n",
    "print(f\"✅ PM2.5 data loaded for {len(latest_data)} cities\")\n",
    "\n",
    "# Convert to DataFrame\n",
    "print(f\"\\n{'='*80}\")\n",
    "print(\"COMPUTING COMPOSITE HAZARD INDICES\")\n",
    "print(\"=\"*80)\n",
    "\n",
    "analysis_df = pd.DataFrame.from_dict(latest_data, orient='index')\n",
    "analysis_df.reset_index(inplace=True)\n",
    "analysis_df.rename(columns={'index': 'Cities'}, inplace=True)\n",
    "\n",
    "# Method 1: Simple Average Composite Hazard Index\n",
    "analysis_df['Composite_Hazard_Simple'] = (\n",
    "    analysis_df['Annual_Mean_PM25'] + \n",
    "    analysis_df['Annual_Max_PM25'] + \n",
    "    analysis_df['PM2.5D']\n",
    ") / 3\n",
    "\n",
    "# Method 2: Weighted Average (giving more weight to Annual_Mean as it represents long-term exposure)\n",
    "weights = {'Annual_Mean_PM25': 0.5, 'Annual_Max_PM25': 0.3, 'PM2.5D': 0.2}\n",
    "analysis_df['Composite_Hazard_Weighted'] = (\n",
    "    analysis_df['Annual_Mean_PM25'] * weights['Annual_Mean_PM25'] +\n",
    "    analysis_df['Annual_Max_PM25'] * weights['Annual_Max_PM25'] +\n",
    "    analysis_df['PM2.5D'] * weights['PM2.5D']\n",
    ")\n",
    "\n",
    "# Method 3: Standardized Composite Index\n",
    "scaler = StandardScaler()\n",
    "hazard_cols = ['Annual_Mean_PM25', 'Annual_Max_PM25', 'PM2.5D']\n",
    "standardized_hazards = scaler.fit_transform(analysis_df[hazard_cols])\n",
    "analysis_df['Composite_Hazard_Standardized'] = np.mean(standardized_hazards, axis=1)\n",
    "\n",
    "# Choose which composite hazard to use\n",
    "composite_hazard_column = 'Composite_Hazard_Weighted'\n",
    "\n",
    "print(f\"Using: {composite_hazard_column}\")\n",
    "print(f\"Weights: Annual_Mean={weights['Annual_Mean_PM25']}, \"\n",
    "      f\"Annual_Max={weights['Annual_Max_PM25']}, PM2.5D={weights['PM2.5D']}\")\n",
    "\n",
    "# Calculate Risk Rankings\n",
    "# Without Population: Risk = Hazard only (V=1, E=1)\n",
    "analysis_df['Risk_Without_Population'] = analysis_df[composite_hazard_column]\n",
    "analysis_df['Rank_Without_Population'] = analysis_df['Risk_Without_Population'].rank(ascending=False).astype(int)\n",
    "\n",
    "# With Population: Risk = Hazard × Population (V=1, E=Population)\n",
    "analysis_df['Risk_With_Population'] = analysis_df[composite_hazard_column] * analysis_df['Population']\n",
    "analysis_df['Rank_With_Population'] = analysis_df['Risk_With_Population'].rank(ascending=False).astype(int)\n",
    "\n",
    "print(f\"✅ Composite hazard indices computed\")\n",
    "\n",
    "# Set enhanced plotting style with ALL BOLD FONTS\n",
    "print(f\"\\n{'='*80}\")\n",
    "print(\"GENERATING VISUALIZATIONS\")\n",
    "print(\"=\"*80)\n",
    "\n",
    "sns.set_style(\"whitegrid\")\n",
    "plt.rcParams.update({\n",
    "    'font.size': 24,\n",
    "    'font.weight': 'bold',\n",
    "    'axes.labelsize': 24,\n",
    "    'axes.labelweight': 'bold',\n",
    "    'axes.titlesize': 24,\n",
    "    'axes.titleweight': 'bold',\n",
    "    'xtick.labelsize': 24,\n",
    "    'ytick.labelsize': 24,\n",
    "    'legend.fontsize': 18,\n",
    "    'legend.title_fontsize': 18,\n",
    "    'figure.titlesize': 20,\n",
    "    'figure.titleweight': 'bold'\n",
    "})\n",
    "\n",
    "# 1. Main comparison plot\n",
    "fig, ax = plt.subplots(figsize=(16, 10))\n",
    "bar_width = 0.35\n",
    "x = range(len(analysis_df[\"Cities\"]))\n",
    "\n",
    "bars1 = ax.bar(x, analysis_df['Rank_Without_Population'], width=bar_width,\n",
    "               label=\"Without Population (Risk = Composite Hazard)\", \n",
    "               color=\"royalblue\", edgecolor=\"black\", linewidth=2)\n",
    "bars2 = ax.bar([i + bar_width for i in x], analysis_df['Rank_With_Population'], width=bar_width,\n",
    "               label=\"With Population (Risk = Composite Hazard × Population)\", \n",
    "               color=\"tomato\", edgecolor=\"black\", linewidth=2)\n",
    "\n",
    "# Add value labels on bars\n",
    "for i, (bar1, bar2) in enumerate(zip(bars1, bars2)):\n",
    "    ax.text(bar1.get_x() + bar1.get_width()/2, bar1.get_height() + 0.1,\n",
    "            str(int(analysis_df['Rank_Without_Population'].iloc[i])), \n",
    "            ha='center', va='bottom', fontweight='bold', fontsize=13)\n",
    "    ax.text(bar2.get_x() + bar2.get_width()/2, bar2.get_height() + 0.1,\n",
    "            str(int(analysis_df['Rank_With_Population'].iloc[i])), \n",
    "            ha='center', va='bottom', fontweight='bold', fontsize=13)\n",
    "\n",
    "ax.set_xticks([i + bar_width / 2 for i in x])\n",
    "ax.set_xticklabels(analysis_df[\"Cities\"], rotation=45, fontsize=24, fontweight=\"bold\")\n",
    "ax.set_xlabel(\"Cities\", fontsize=20, fontweight=\"bold\")\n",
    "ax.set_ylabel(\"Risk Ranking (Lower = Higher Risk)\", fontsize=24, fontweight=\"bold\")\n",
    "#ax.set_title(\"Composite Risk Ranking of Cities\\n(Combining Annual Mean PM$_{2.5}$, Annual Max PM$_{2.5}$, and PM$_{2.5}$D)\", \n",
    "             #fontsize=18, fontweight=\"bold\", pad=20)\n",
    "ax.invert_yaxis()  # Invert to show rank 1 at the top\n",
    "\n",
    "# Make tick labels bold\n",
    "for label in ax.get_xticklabels():\n",
    "    label.set_fontweight('bold')\n",
    "for label in ax.get_yticklabels():\n",
    "    label.set_fontweight('bold')\n",
    "\n",
    "legend = ax.legend(fontsize=14, loc='best', prop={'weight': 'bold'})\n",
    "ax.grid(False)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "# 2. Scatter plot: Population vs Composite Hazard\n",
    "fig, ax = plt.subplots(figsize=(14, 10))\n",
    "scatter = ax.scatter(analysis_df['Population'], analysis_df[composite_hazard_column], \n",
    "                     c=analysis_df['Rank_With_Population'], cmap='viridis_r', \n",
    "                     s=400, edgecolor='black', linewidth=2, alpha=0.8)\n",
    "\n",
    "# Add city labels\n",
    "for i, txt in enumerate(analysis_df['Cities']):\n",
    "    ax.annotate(txt, (analysis_df['Population'].iloc[i], analysis_df[composite_hazard_column].iloc[i]), \n",
    "                fontsize=12, fontweight='bold', xytext=(10, 10), \n",
    "                textcoords='offset points', ha='left')\n",
    "\n",
    "ax.set_xlabel(\"Population\", fontsize=16, fontweight=\"bold\")\n",
    "ax.set_ylabel(\"Composite Hazard Index\", fontsize=16, fontweight=\"bold\")\n",
    "ax.set_title(\"Population vs Composite Hazard Index\\n(Color represents Risk Rank with Population)\", \n",
    "             fontsize=18, fontweight=\"bold\", pad=20)\n",
    "\n",
    "# Make tick labels bold\n",
    "for label in ax.get_xticklabels():\n",
    "    label.set_fontweight('bold')\n",
    "for label in ax.get_yticklabels():\n",
    "    label.set_fontweight('bold')\n",
    "\n",
    "# Add colorbar\n",
    "cbar = plt.colorbar(scatter, ax=ax)\n",
    "cbar.set_label('Risk Rank (with Population)\\n(1 = Highest Risk)', fontsize=14, fontweight='bold')\n",
    "cbar.ax.invert_yaxis()  # Invert colorbar so rank 1 is at top\n",
    "for label in cbar.ax.get_yticklabels():\n",
    "    label.set_fontweight('bold')\n",
    "    label.set_fontsize(12)\n",
    "\n",
    "# Format x-axis\n",
    "ax.ticklabel_format(style='scientific', axis='x', scilimits=(0,0))\n",
    "ax.xaxis.get_offset_text().set_fontweight('bold')\n",
    "ax.xaxis.get_offset_text().set_fontsize(12)\n",
    "\n",
    "ax.grid(False)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "# 3. Heatmap of hazard components\n",
    "fig, ax = plt.subplots(figsize=(14, 10))\n",
    "heatmap_data = analysis_df[['Cities', 'Annual_Mean_PM25', 'Annual_Max_PM25', 'PM2.5D']].set_index('Cities')\n",
    "sns.heatmap(heatmap_data, annot=True, fmt='.1f', cmap='YlOrRd', \n",
    "            cbar_kws={'label': 'PM$_{2.5}$ (μg/m³)'}, linewidths=2, linecolor='black',\n",
    "            annot_kws={'weight': 'bold', 'fontsize': 12}, ax=ax)\n",
    "ax.set_title('PM$_{2.5}$ Hazard Components by City', fontweight='bold', fontsize=18, pad=20)\n",
    "ax.set_xlabel('Hazard Metrics', fontweight='bold', fontsize=16)\n",
    "ax.set_ylabel('Cities', fontweight='bold', fontsize=16)\n",
    "\n",
    "# Make tick labels bold\n",
    "for label in ax.get_xticklabels():\n",
    "    label.set_fontweight('bold')\n",
    "for label in ax.get_yticklabels():\n",
    "    label.set_fontweight('bold')\n",
    "\n",
    "# Make colorbar label and ticks bold\n",
    "cbar = ax.collections[0].colorbar\n",
    "cbar.ax.yaxis.label.set_weight('bold')\n",
    "cbar.ax.yaxis.label.set_size(14)\n",
    "for label in cbar.ax.get_yticklabels():\n",
    "    label.set_fontweight('bold')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "# 4. Composite hazard comparison across methods\n",
    "fig, ax = plt.subplots(figsize=(14, 10))\n",
    "methods_df = analysis_df[['Cities', 'Composite_Hazard_Simple', 'Composite_Hazard_Weighted', 'Composite_Hazard_Standardized']]\n",
    "methods_df_sorted = methods_df.sort_values('Composite_Hazard_Weighted', ascending=False)\n",
    "\n",
    "x_pos = np.arange(len(methods_df_sorted))\n",
    "width = 0.25\n",
    "\n",
    "ax.bar(x_pos - width, methods_df_sorted['Composite_Hazard_Simple'], width, \n",
    "       label='Simple Average', color='skyblue', edgecolor='black', linewidth=2)\n",
    "ax.bar(x_pos, methods_df_sorted['Composite_Hazard_Weighted'], width, \n",
    "       label='Weighted Average', color='salmon', edgecolor='black', linewidth=2)\n",
    "ax.bar(x_pos + width, methods_df_sorted['Composite_Hazard_Standardized'], width, \n",
    "       label='Standardized', color='lightgreen', edgecolor='black', linewidth=2)\n",
    "\n",
    "ax.set_xlabel('Cities', fontweight='bold', fontsize=16)\n",
    "ax.set_ylabel('Composite Hazard Index', fontweight='bold', fontsize=16)\n",
    "ax.set_title('Comparison of Composite Hazard Methods', fontweight='bold', fontsize=18, pad=20)\n",
    "ax.set_xticks(x_pos)\n",
    "ax.set_xticklabels(methods_df_sorted['Cities'], rotation=45, fontweight='bold')\n",
    "\n",
    "# Make tick labels bold\n",
    "for label in ax.get_xticklabels():\n",
    "    label.set_fontweight('bold')\n",
    "for label in ax.get_yticklabels():\n",
    "    label.set_fontweight('bold')\n",
    "\n",
    "legend = ax.legend(fontsize=14, prop={'weight': 'bold'})\n",
    "ax.grid(False)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "# Print comprehensive summary\n",
    "print(\"\\n\" + \"=\"*100)\n",
    "print(\"COMPOSITE RISK ANALYSIS SUMMARY\")\n",
    "print(\"=\"*100)\n",
    "print(f\"Composite Hazard Method Used: {composite_hazard_column}\")\n",
    "if composite_hazard_column == 'Composite_Hazard_Weighted':\n",
    "    print(f\"Weights: Annual_Mean_PM2.5={weights['Annual_Mean_PM25']}, \"\n",
    "          f\"Annual_Max_PM2.5={weights['Annual_Max_PM25']}, PM2.5D={weights['PM2.5D']}\")\n",
    "print(\"=\"*100)\n",
    "\n",
    "# Individual hazard components\n",
    "print(\"\\nINDIVIDUAL HAZARD COMPONENTS:\")\n",
    "print(\"-\"*100)\n",
    "hazard_display = analysis_df[['Cities', 'Annual_Mean_PM25', 'Annual_Max_PM25', 'PM2.5D', \n",
    "                               'Population', 'Area_km2', 'Grid_Size']].copy()\n",
    "hazard_display['Population'] = hazard_display['Population'].apply(lambda x: f\"{x:,.0f}\")\n",
    "hazard_display = hazard_display.round(2)\n",
    "print(hazard_display.to_string(index=False))\n",
    "\n",
    "print(\"\\nCOMPOSITE HAZARD VALUES:\")\n",
    "print(\"-\"*80)\n",
    "composite_display = analysis_df[['Cities', 'Composite_Hazard_Simple', \n",
    "                                  'Composite_Hazard_Weighted', 'Composite_Hazard_Standardized']].copy()\n",
    "for col in ['Composite_Hazard_Simple', 'Composite_Hazard_Weighted', 'Composite_Hazard_Standardized']:\n",
    "    composite_display[col] = composite_display[col].round(3)\n",
    "print(composite_display.to_string(index=False))\n",
    "\n",
    "print(\"\\nRISK RANKINGS COMPARISON:\")\n",
    "print(\"-\"*80)\n",
    "ranking_display = analysis_df[['Cities', 'Rank_Without_Population', 'Rank_With_Population']].copy()\n",
    "ranking_display['Rank_Change'] = ranking_display['Rank_Without_Population'] - ranking_display['Rank_With_Population']\n",
    "ranking_display = ranking_display.sort_values('Rank_With_Population')\n",
    "print(ranking_display.to_string(index=False))\n",
    "\n",
    "print(\"\\nKEY INSIGHTS:\")\n",
    "print(\"-\"*80)\n",
    "highest_risk_no_pop = analysis_df.loc[analysis_df['Rank_Without_Population']==1, 'Cities'].iloc[0]\n",
    "highest_risk_with_pop = analysis_df.loc[analysis_df['Rank_With_Population']==1, 'Cities'].iloc[0]\n",
    "\n",
    "print(f\"• Highest risk without population: {highest_risk_no_pop}\")\n",
    "print(f\"• Highest risk with population: {highest_risk_with_pop}\")\n",
    "\n",
    "# Cities with biggest rank changes\n",
    "biggest_increase = ranking_display.loc[ranking_display['Rank_Change'].idxmin()]\n",
    "biggest_decrease = ranking_display.loc[ranking_display['Rank_Change'].idxmax()]\n",
    "\n",
    "print(f\"• Biggest rank improvement when considering population: {biggest_increase['Cities']} \"\n",
    "      f\"(moved {abs(int(biggest_increase['Rank_Change']))} positions up)\")\n",
    "print(f\"• Biggest rank decline when considering population: {biggest_decrease['Cities']} \"\n",
    "      f\"(moved {int(biggest_decrease['Rank_Change'])} positions down)\")\n",
    "\n",
    "# Risk values for context\n",
    "print(f\"\\nRISK VALUES (using {composite_hazard_column}):\")\n",
    "print(\"-\"*100)\n",
    "risk_display = analysis_df[['Cities', 'Risk_Without_Population', 'Risk_With_Population']].copy()\n",
    "risk_display['Risk_Without_Population'] = risk_display['Risk_Without_Population'].round(3)\n",
    "risk_display['Risk_With_Population'] = risk_display['Risk_With_Population'].apply(lambda x: f\"{x:,.0f}\")\n",
    "risk_display = risk_display.sort_values('Cities')\n",
    "print(risk_display.to_string(index=False))\n",
    "\n",
    "# Export results\n",
    "analysis_df.to_csv('composite_risk_analysis_results.csv', index=False)\n",
    "ranking_display.to_csv('risk_ranking_comparison.csv', index=False)\n",
    "\n",
    "print(f\"\\n💾 Results exported to:\")\n",
    "print(f\"    - composite_risk_analysis_results.csv\")\n",
    "print(f\"    - risk_ranking_comparison.csv\")\n",
    "\n",
    "print(\"\\n🎉 Composite risk analysis completed!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a0b1463-20ff-4202-a5a4-6f9a437985ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import netCDF4 as nc\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from scipy.interpolate import griddata\n",
    "from scipy import stats\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "# Set professional style for research paper\n",
    "plt.rcParams.update({\n",
    "    'font.size': 24,\n",
    "    'font.family': 'serif',\n",
    "    'font.serif': ['Aptos'],\n",
    "    'axes.labelsize': 18,\n",
    "    'axes.titlesize': 18,\n",
    "    'axes.linewidth': 1.2,\n",
    "    'legend.fontsize': 16,\n",
    "    'xtick.labelsize': 12,\n",
    "    'ytick.labelsize': 12,\n",
    "    'xtick.direction': 'out',\n",
    "    'ytick.direction': 'out',\n",
    "    'xtick.major.size': 6,\n",
    "    'ytick.major.size': 6,\n",
    "    'xtick.major.width': 1.2,\n",
    "    'ytick.major.width': 1.2,\n",
    "    'axes.grid': False,\n",
    "    'figure.figsize': (10, 8),\n",
    "    'figure.dpi': 300,\n",
    "    'savefig.dpi': 300,\n",
    "    'savefig.bbox': 'tight',\n",
    "    'savefig.pad_inches': 0.1\n",
    "})\n",
    "\n",
    "# Define the cities with their coordinates and approximate areas\n",
    "cities = {\n",
    "    'Indore': {'coords': (22.719568, 75.857727), 'area_km2': 525},\n",
    "    'Bhopal': {'coords': (23.259933, 77.412613), 'area_km2': 463},\n",
    "    'Jabalpur': {'coords': (23.185884, 79.97438), 'area_km2': 374},\n",
    "    'Gwalior': {'coords': (26.218287, 78.182832), 'area_km2': 289},\n",
    "    'Ujjain': {'coords': (23.1793, 75.784912), 'area_km2': 152},\n",
    "    'Rewa': {'coords': (24.530727, 81.29911), 'area_km2': 100},\n",
    "    'Sagar': {'coords': (23.83403, 78.746567), 'area_km2': 78},\n",
    "    'Ratlam': {'coords': (23.334169, 75.037363), 'area_km2': 75},\n",
    "    'Chhindwara': {'coords': (22.057163, 78.938202), 'area_km2': 65},\n",
    "    'Dewas': {'coords': (22.962267, 76.050797), 'area_km2': 50}\n",
    "}\n",
    "\n",
    "# Health risk assessment constants\n",
    "ET = 24  # Exposure time (hours/day)\n",
    "ED = 24  # Exposure duration (years)\n",
    "RfC = 5.0  # Reference concentration\n",
    "AT = ED * 365 * ET  # Averaging time\n",
    "beta = 0.069  # Concentration-response coefficient\n",
    "C_standard = 5  # Standard WHO concentration threshold\n",
    "EF = 350  # Exposure frequency (days/year)\n",
    "\n",
    "# Mortality rates (per 100,000 population)\n",
    "mortality_rates = {\n",
    "    \"LC\": 14.21e-5,    # Lung Cancer\n",
    "    \"CEV\": 0.65e-5,    # Cerebrovascular Disease\n",
    "    \"COPD\": 16.58e-5,  # COPD\n",
    "    \"IHD\": 11.64e-5    # Ischemic Heart Disease\n",
    "}\n",
    "\n",
    "# Years for analysis\n",
    "years = [2000, 2005, 2010, 2015, 2020]\n",
    "\n",
    "# File paths (update these paths according to your system)\n",
    "pop_path = r\"D:\\MP_Aerosol\\Cities\\Pop\\Masked_Data\\Clipped_gpw_v4_population_count_rev11_2pt5_min.nc\"\n",
    "pm25_excel_path = r'D:\\MP_Aerosol\\Cities\\city_daily_pm25_1980_2023_Edited.xlsx'\n",
    "\n",
    "def calculate_grid_count_from_area(area_km2, resolution_deg=2.5/60):\n",
    "    \"\"\"\n",
    "    Calculate the number of grid cells needed to cover a given area.\n",
    "    \n",
    "    Parameters:\n",
    "    - area_km2: Area of the city in square kilometers\n",
    "    - resolution_deg: Resolution of the grid in degrees (default: 2.5 arcminutes = 2.5/60 degrees)\n",
    "    \n",
    "    Returns:\n",
    "    - grid_count: Number of grid cells (as a perfect square)\n",
    "    \"\"\"\n",
    "    # Convert resolution from degrees to kilometers (approximate at equator)\n",
    "    # 1 degree ≈ 111 km at equator\n",
    "    # This is an approximation and varies with latitude\n",
    "    km_per_degree = 111.0\n",
    "    cell_size_km = resolution_deg * km_per_degree\n",
    "    cell_area_km2 = cell_size_km ** 2\n",
    "    \n",
    "    # Calculate number of cells needed\n",
    "    num_cells = area_km2 / cell_area_km2\n",
    "    \n",
    "    # Round to nearest perfect square\n",
    "    grid_size = int(np.sqrt(num_cells))\n",
    "    \n",
    "    # Ensure minimum grid size of 3x3 and make it odd for better centering\n",
    "    grid_size = max(3, grid_size)\n",
    "    if grid_size % 2 == 0:\n",
    "        grid_size += 1\n",
    "    \n",
    "    grid_count = grid_size ** 2\n",
    "    \n",
    "    return grid_count, grid_size\n",
    "\n",
    "def extract_city_data_from_netcdf(lat, lon, data_grid, lat_grid, lon_grid, grid_count=400):\n",
    "    \"\"\"\n",
    "    Extract data for a city location using a square grid pattern that covers approximately `grid_count` cells.\n",
    "    For 400 grids, it uses a 20x20 neighborhood centered around the closest grid point.\n",
    "    \"\"\"\n",
    "    # Calculate grid size (assuming square grid neighborhood)\n",
    "    grid_size = int(np.sqrt(grid_count))\n",
    "    if grid_size ** 2 != grid_count:\n",
    "        # If not perfect square, round to nearest perfect square\n",
    "        grid_size = int(np.round(np.sqrt(grid_count)))\n",
    "        grid_count = grid_size ** 2\n",
    "        print(f\"    Adjusted grid_count to {grid_count} ({grid_size}x{grid_size})\")\n",
    "    \n",
    "    half_size = grid_size // 2  # Half-width of the neighborhood\n",
    "\n",
    "    # Find the closest grid point to the city coordinates\n",
    "    distances = np.sqrt((lat_grid - lat)**2 + (lon_grid - lon)**2)\n",
    "    center_idx = np.unravel_index(np.argmin(distances), distances.shape)\n",
    "    center_row, center_col = center_idx\n",
    "\n",
    "    # Calculate the bounds for the grid window\n",
    "    row_start = max(0, center_row - half_size)\n",
    "    row_end = min(data_grid.shape[0], center_row + half_size + 1)\n",
    "    col_start = max(0, center_col - half_size)\n",
    "    col_end = min(data_grid.shape[1], center_col + half_size + 1)\n",
    "\n",
    "    # Extract the grid data\n",
    "    grid_data = data_grid[row_start:row_end, col_start:col_end]\n",
    "    total_value = np.nansum(grid_data)\n",
    "\n",
    "    print(f\"    Grid extracted: {grid_data.shape}, Center: ({center_row}, {center_col}), Total: {total_value:,.0f}\")\n",
    "    print(f\"    Grid bounds: rows {row_start}:{row_end}, cols {col_start}:{col_end}\")\n",
    "\n",
    "    expected_size = (grid_size, grid_size)\n",
    "    if grid_data.shape != expected_size:\n",
    "        print(\"    ⚠ Warning: Extracted grid size is smaller than expected due to boundary limits.\")\n",
    "\n",
    "    return total_value\n",
    "\n",
    "def load_pm25_excel_data(excel_path, target_years):\n",
    "    \"\"\"Load PM2.5 data from Excel file with robust column detection\"\"\"\n",
    "    print(f\"🔄 Loading PM2.5 data from Excel: {excel_path}\")\n",
    "    \n",
    "    try:\n",
    "        xls = pd.ExcelFile(excel_path)\n",
    "        pm25_data = {}\n",
    "        \n",
    "        print(f\"    Available sheets: {xls.sheet_names}\")\n",
    "        \n",
    "        for city in cities.keys():\n",
    "            if city in xls.sheet_names:\n",
    "                print(f\"    Processing sheet: {city}\")\n",
    "                \n",
    "                # Read the sheet\n",
    "                df = pd.read_excel(excel_path, sheet_name=city)\n",
    "                \n",
    "                # Print column names for debugging\n",
    "                print(f\"      Columns found: {list(df.columns)}\")\n",
    "                print(f\"      Data shape: {df.shape}\")\n",
    "                print(f\"      First few rows:\")\n",
    "                print(df.head(3))\n",
    "                \n",
    "                # Try to identify date and PM2.5 columns\n",
    "                date_col = None\n",
    "                pm25_col = None\n",
    "                \n",
    "                # Look for date column (case insensitive)\n",
    "                for col in df.columns:\n",
    "                    if col.lower() in ['date', 'dates', 'time', 'datetime']:\n",
    "                        date_col = col\n",
    "                        break\n",
    "                \n",
    "                # Look for PM2.5 column (case insensitive)\n",
    "                for col in df.columns:\n",
    "                    if any(term in col.lower() for term in ['pm25', 'pm2.5', 'pm_25', 'daily_pm25', 'pm']):\n",
    "                        pm25_col = col\n",
    "                        break\n",
    "                \n",
    "                if date_col is None:\n",
    "                    print(f\"      ❌ Could not find date column. Available columns: {list(df.columns)}\")\n",
    "                    continue\n",
    "                    \n",
    "                if pm25_col is None:\n",
    "                    print(f\"      ❌ Could not find PM2.5 column. Available columns: {list(df.columns)}\")\n",
    "                    continue\n",
    "                \n",
    "                print(f\"      Using date column: '{date_col}'\")\n",
    "                print(f\"      Using PM2.5 column: '{pm25_col}'\")\n",
    "                \n",
    "                # Convert Date column to datetime - try multiple formats\n",
    "                try:\n",
    "                    # Try different date formats\n",
    "                    if df[date_col].dtype == 'object':\n",
    "                        # Try to infer format automatically\n",
    "                        df[date_col] = pd.to_datetime(df[date_col], infer_datetime_format=True)\n",
    "                    else:\n",
    "                        # If it's already datetime or numeric, convert\n",
    "                        df[date_col] = pd.to_datetime(df[date_col])\n",
    "                    \n",
    "                    df['Year'] = df[date_col].dt.year\n",
    "                    print(f\"      Date conversion successful. Year range: {df['Year'].min()} - {df['Year'].max()}\")\n",
    "                    \n",
    "                except Exception as date_error:\n",
    "                    print(f\"      ❌ Date conversion failed: {date_error}\")\n",
    "                    print(f\"      Sample date values: {df[date_col].head()}\")\n",
    "                    continue\n",
    "                \n",
    "                # Filter for target years and calculate annual means\n",
    "                annual_means = {}\n",
    "                for year in target_years:\n",
    "                    year_data = df[df['Year'] == year][pm25_col]\n",
    "                    if len(year_data) > 0:\n",
    "                        # Remove any non-numeric values\n",
    "                        year_data = pd.to_numeric(year_data, errors='coerce').dropna()\n",
    "                        if len(year_data) > 0:\n",
    "                            annual_mean = year_data.mean()\n",
    "                            annual_means[year] = annual_mean\n",
    "                            print(f\"      {year}: {len(year_data)} valid days, Mean PM2.5: {annual_mean:.2f} μg/m³\")\n",
    "                        else:\n",
    "                            print(f\"      {year}: No valid numeric data\")\n",
    "                    else:\n",
    "                        print(f\"      {year}: No data available\")\n",
    "                \n",
    "                pm25_data[city] = annual_means\n",
    "                print(f\"    ✅ {city}: Successfully loaded data for {len(annual_means)} years\")\n",
    "                \n",
    "            else:\n",
    "                print(f\"    ⚠️  Sheet '{city}' not found in Excel file\")\n",
    "        \n",
    "        print(\"✅ PM2.5 Excel data loaded successfully\")\n",
    "        return pm25_data\n",
    "        \n",
    "    except Exception as e:\n",
    "        print(f\"❌ Error loading PM2.5 Excel data: {e}\")\n",
    "        import traceback\n",
    "        traceback.print_exc()\n",
    "        return {}\n",
    "\n",
    "def calculate_health_impacts(pm25, population):\n",
    "    \"\"\"Calculate health impacts based on PM2.5 and population\"\"\"\n",
    "    # Exposure concentration\n",
    "    EC = (pm25 * ET * EF * ED) / AT\n",
    "    \n",
    "    # Hazard quotient\n",
    "    HQ = EC / RfC\n",
    "    \n",
    "    # Relative risk\n",
    "    RR = np.exp(beta * np.maximum(0, pm25 - C_standard))\n",
    "    \n",
    "    # Calculate deaths by disease\n",
    "    deaths_by_disease = {}\n",
    "    total_deaths = 0\n",
    "    \n",
    "    for disease, y0 in mortality_rates.items():\n",
    "        delta_mort = y0 * ((RR - 1) / np.maximum(RR, 1)) * population\n",
    "        deaths_by_disease[disease] = delta_mort\n",
    "        total_deaths += delta_mort\n",
    "    \n",
    "    return {\n",
    "        'EC': EC,\n",
    "        'HQ': HQ,\n",
    "        'RR': RR,\n",
    "        'total_deaths': total_deaths,\n",
    "        'deaths_by_disease': deaths_by_disease\n",
    "    }\n",
    "\n",
    "# Initialize data storage\n",
    "city_data = {city: {year: {} for year in years} for city in cities.keys()}\n",
    "\n",
    "print(\"🔄 Loading population data...\")\n",
    "# Load population data\n",
    "ds_pop = nc.Dataset(pop_path)\n",
    "pop_var = list(ds_pop.variables.keys())[-1]\n",
    "pop_data_all = ds_pop.variables[pop_var][:5]  # 5 years of data\n",
    "lat_pop = ds_pop.variables['latitude'][:]\n",
    "lon_pop = ds_pop.variables['longitude'][:]\n",
    "lon_grid_pop, lat_grid_pop = np.meshgrid(lon_pop, lat_pop)\n",
    "ds_pop.close()\n",
    "print(\"✅ Population data loaded successfully\")\n",
    "print(f\"    Population data shape: {pop_data_all.shape}\")\n",
    "print(f\"    Lat range: {lat_pop.min():.3f} to {lat_pop.max():.3f}\")\n",
    "print(f\"    Lon range: {lon_pop.min():.3f} to {lon_pop.max():.3f}\")\n",
    "\n",
    "# Calculate grid resolution\n",
    "lat_resolution = np.abs(np.mean(np.diff(lat_pop)))\n",
    "lon_resolution = np.abs(np.mean(np.diff(lon_pop)))\n",
    "print(f\"    Grid resolution: {lat_resolution:.4f}° lat, {lon_resolution:.4f}° lon\")\n",
    "print(f\"    (approximately {lat_resolution*60:.2f} arcminutes)\")\n",
    "\n",
    "# Calculate grid sizes for each city based on their area\n",
    "print(\"\\n🔄 Calculating optimal grid sizes for each city...\")\n",
    "city_grid_info = {}\n",
    "for city, info in cities.items():\n",
    "    grid_count, grid_size = calculate_grid_count_from_area(info['area_km2'], resolution_deg=lat_resolution)\n",
    "    city_grid_info[city] = {'grid_count': grid_count, 'grid_size': grid_size}\n",
    "    actual_area_covered = grid_count * (lat_resolution * 111) ** 2\n",
    "    print(f\"  {city}: Area={info['area_km2']} km², Grid={grid_size}x{grid_size} ({grid_count} cells), \"\n",
    "          f\"Covers ≈{actual_area_covered:.0f} km²\")\n",
    "\n",
    "# Load PM2.5 data from Excel\n",
    "pm25_excel_data = load_pm25_excel_data(pm25_excel_path, years)\n",
    "\n",
    "# Process data for each year\n",
    "print(\"\\n🔄 Processing data year by year...\")\n",
    "\n",
    "for idx, year in enumerate(years):\n",
    "    print(f\"\\n📅 Processing {year}...\")\n",
    "    \n",
    "    # Extract data for each city\n",
    "    for city, info in cities.items():\n",
    "        try:\n",
    "            lat, lon = info['coords']\n",
    "            area_km2 = info['area_km2']\n",
    "            grid_count = city_grid_info[city]['grid_count']\n",
    "            grid_size = city_grid_info[city]['grid_size']\n",
    "            \n",
    "            print(f\"  Processing {city} (Lat: {lat:.3f}, Lon: {lon:.3f}, Area: {area_km2} km², Grid: {grid_size}x{grid_size})...\")\n",
    "            \n",
    "            # Get population data using city-specific grid\n",
    "            pop_data = pop_data_all[idx]\n",
    "            population = extract_city_data_from_netcdf(lat, lon, pop_data, lat_grid_pop, lon_grid_pop, grid_count)\n",
    "            \n",
    "            # Get PM2.5 data from Excel\n",
    "            pm25_value = None\n",
    "            pm25_source = \"Excel\"\n",
    "            \n",
    "            if city in pm25_excel_data and year in pm25_excel_data[city]:\n",
    "                pm25_value = pm25_excel_data[city][year]\n",
    "                print(f\"    PM2.5 from Excel: {pm25_value:.2f} μg/m³\")\n",
    "            else:\n",
    "                print(f\"    No PM2.5 data available for {city} in {year}\")\n",
    "            \n",
    "            # Store data if PM2.5 value is available\n",
    "            if pm25_value is not None and not np.isnan(pm25_value):\n",
    "                # Calculate health impacts\n",
    "                health_impacts = calculate_health_impacts(pm25_value, population)\n",
    "                \n",
    "                city_data[city][year] = {\n",
    "                    'population': population,\n",
    "                    'pm25': pm25_value,\n",
    "                    'pm25_source': pm25_source,\n",
    "                    'area_km2': area_km2,\n",
    "                    'grid_count': grid_count,\n",
    "                    'grid_size': grid_size,\n",
    "                    **health_impacts\n",
    "                }\n",
    "                print(f\"  ✅ {city}: PM2.5={pm25_value:.1f} μg/m³, Pop={population:,.0f}, Deaths={health_impacts['total_deaths']:.1f}\")\n",
    "            else:\n",
    "                print(f\"  ❌ {city}: No PM2.5 data available\")\n",
    "        \n",
    "        except Exception as e:\n",
    "            print(f\"  ❌ Error processing {city}: {e}\")\n",
    "            import traceback\n",
    "            traceback.print_exc()\n",
    "\n",
    "# Create comprehensive analysis DataFrame\n",
    "print(\"\\n📊 Creating analysis DataFrame...\")\n",
    "analysis_rows = []\n",
    "\n",
    "for city in cities.keys():\n",
    "    for year in years:\n",
    "        if city_data[city][year]:  # If data exists\n",
    "            data = city_data[city][year]\n",
    "            row = {\n",
    "                'City': city,\n",
    "                'Year': year,\n",
    "                'Area_km2': data['area_km2'],\n",
    "                'Grid_Size': data['grid_size'],\n",
    "                'Grid_Count': data['grid_count'],\n",
    "                'Population': data['population'],\n",
    "                'PM25': data['pm25'],\n",
    "                'PM25_Source': data['pm25_source'],\n",
    "                'Exposure_Concentration': data['EC'],\n",
    "                'Hazard_Quotient': data['HQ'],\n",
    "                'Relative_Risk': data['RR'],\n",
    "                'Total_Deaths': data['total_deaths'],\n",
    "                'LC_Deaths': data['deaths_by_disease']['LC'],\n",
    "                'CEV_Deaths': data['deaths_by_disease']['CEV'],\n",
    "                'COPD_Deaths': data['deaths_by_disease']['COPD'],\n",
    "                'IHD_Deaths': data['deaths_by_disease']['IHD']\n",
    "            }\n",
    "            analysis_rows.append(row)\n",
    "\n",
    "df_analysis = pd.DataFrame(analysis_rows)\n",
    "\n",
    "if len(df_analysis) > 0:\n",
    "    print(f\"✅ Analysis completed for {len(df_analysis)} city-year combinations\")\n",
    "    \n",
    "    # Calculate confidence intervals for total deaths\n",
    "    print(\"\\n📈 Calculating confidence intervals...\")\n",
    "    \n",
    "    def calculate_confidence_interval(data, confidence=0.95, n_bootstrap=1000):\n",
    "        \"\"\"Calculate confidence interval using bootstrap method\"\"\"\n",
    "        if len(data) < 2:\n",
    "            return np.nan, np.nan, np.nan\n",
    "        \n",
    "        bootstrap_means = []\n",
    "        for _ in range(n_bootstrap):\n",
    "            bootstrap_sample = np.random.choice(data, size=len(data), replace=True)\n",
    "            bootstrap_means.append(np.mean(bootstrap_sample))\n",
    "        \n",
    "        alpha = 1 - confidence\n",
    "        lower_percentile = (alpha/2) * 100\n",
    "        upper_percentile = (1 - alpha/2) * 100\n",
    "        \n",
    "        return (np.percentile(bootstrap_means, lower_percentile),\n",
    "                np.mean(data),\n",
    "                np.percentile(bootstrap_means, upper_percentile))\n",
    "    \n",
    "    # Calculate CIs for each city\n",
    "    ci_results = []\n",
    "    for city in cities.keys():\n",
    "        city_deaths = df_analysis[df_analysis['City'] == city]['Total_Deaths'].values\n",
    "        if len(city_deaths) > 0:\n",
    "            lower, mean, upper = calculate_confidence_interval(city_deaths)\n",
    "            ci_results.append({\n",
    "                'City': city,\n",
    "                'Mean_Deaths': mean,\n",
    "                'CI_Lower': lower,\n",
    "                'CI_Upper': upper,\n",
    "                'CI_Width': upper - lower if not np.isnan(upper) and not np.isnan(lower) else np.nan\n",
    "            })\n",
    "    \n",
    "    df_ci = pd.DataFrame(ci_results)\n",
    "    \n",
    "    # Enhanced professional plotting for research paper\n",
    "    print(\"\\n🎨 Creating professional research paper plots...\")\n",
    "    \n",
    "    # Color scheme for professional plots\n",
    "    colors = ['#2E86AB', '#A23B72', '#F18F01', '#C73E1D', '#3E885B', \n",
    "              '#5B6C5D', '#8F3985', '#07BEB8', '#A882DD', '#E84855']\n",
    "    \n",
    "    # 1. Enhanced Confidence Interval Plot with Professional Styling\n",
    "    if len(df_ci) > 0:\n",
    "        fig, ax = plt.subplots(figsize=(14, 10))\n",
    "        \n",
    "        # Sort cities by mean deaths for better visualization\n",
    "        df_ci_sorted = df_ci.sort_values('Mean_Deaths', ascending=True)\n",
    "        \n",
    "        # Calculate dynamic spacing\n",
    "        y_pos = np.arange(len(df_ci_sorted))\n",
    "        bar_height = 0.6\n",
    "        \n",
    "        # Plot confidence intervals as horizontal error bars\n",
    "        for i, (idx, row) in enumerate(df_ci_sorted.iterrows()):\n",
    "            if not np.isnan(row['Mean_Deaths']):\n",
    "                # Main bar for mean value\n",
    "                ax.barh(y_pos[i], row['Mean_Deaths'], height=bar_height, \n",
    "                       color=colors[i % len(colors)], alpha=0.7, edgecolor='black', linewidth=1)\n",
    "                \n",
    "                # Error bars for confidence intervals\n",
    "                ax.hlines(y_pos[i], row['CI_Lower'], row['CI_Upper'], \n",
    "                         color='black', linewidth=2.5, zorder=3)\n",
    "                \n",
    "                # Vertical caps for error bars\n",
    "                cap_width = bar_height * 0.3\n",
    "                ax.vlines(row['CI_Lower'], y_pos[i] - cap_width/2, y_pos[i] + cap_width/2, \n",
    "                         color='black', linewidth=2, zorder=3)\n",
    "                ax.vlines(row['CI_Upper'], y_pos[i] - cap_width/2, y_pos[i] + cap_width/2, \n",
    "                         color='black', linewidth=2, zorder=3)\n",
    "                \n",
    "                # Annotations for CI values\n",
    "                # Lower CI annotation (left of bar)\n",
    "                ax.text(row['CI_Lower'] - (row['CI_Upper'] - row['CI_Lower']) * 0.08, y_pos[i], \n",
    "                       f'{row[\"CI_Lower\"]:.0f}', \n",
    "                       ha='right', va='center', fontsize=11, fontweight='bold',\n",
    "                       bbox=dict(boxstyle='round,pad=0.3', facecolor='white', \n",
    "                               edgecolor=colors[i % len(colors)], linewidth=1.5))\n",
    "                \n",
    "                # Upper CI annotation (right of bar)\n",
    "                ax.text(row['CI_Upper'] + (row['CI_Upper'] - row['CI_Lower']) * 0.08, y_pos[i], \n",
    "                       f'{row[\"CI_Upper\"]:.0f}', \n",
    "                       ha='left', va='center', fontsize=11, fontweight='bold',\n",
    "                       bbox=dict(boxstyle='round,pad=0.3', facecolor='white', \n",
    "                               edgecolor=colors[i % len(colors)], linewidth=1.5))\n",
    "                \n",
    "                # Mean value annotation (on the bar) - IMPROVED VISIBILITY\n",
    "                ax.text(row['Mean_Deaths'], y_pos[i], \n",
    "                       f'{row[\"Mean_Deaths\"]:.0f}', \n",
    "                       ha='center', va='center', fontsize=12, fontweight='bold', \n",
    "                       color='black',  # Changed to black for better visibility\n",
    "                       bbox=dict(boxstyle='round,pad=0.3', facecolor='white', \n",
    "                               edgecolor='black', linewidth=1.5, alpha=0.9))\n",
    "        \n",
    "        # Customize the plot\n",
    "        ax.set_yticks(y_pos)\n",
    "        ax.set_yticklabels(df_ci_sorted['City'], fontweight='bold')\n",
    "        ax.set_xlabel('Estimated Annual Mortality (Deaths per Year)', fontweight='bold', fontsize=14)\n",
    "        ax.set_ylabel('Cities', fontweight='bold', fontsize=14)\n",
    "        ax.set_title('PM$_{2.5}$-Attributable Mortality with 95% Confidence Intervals\\nAcross Major Cities in Madhya Pradesh, India', \n",
    "                    fontweight='bold', fontsize=16, pad=20)\n",
    "        \n",
    "        # Add grid for better readability\n",
    "        ax.grid(True, axis='x', alpha=0.3, linestyle='--', linewidth=0.8)\n",
    "        ax.set_axisbelow(True)\n",
    "        \n",
    "        # Add some padding to x-axis\n",
    "        x_max = df_ci_sorted['CI_Upper'].max()\n",
    "        ax.set_xlim(0, x_max * 1.15)\n",
    "        \n",
    "        # Remove spines for cleaner look\n",
    "        ax.spines['top'].set_visible(False)\n",
    "        ax.spines['right'].set_visible(False)\n",
    "        ax.spines['left'].set_visible(True)\n",
    "        ax.spines['bottom'].set_visible(True)\n",
    "        \n",
    "        plt.tight_layout()\n",
    "        plt.savefig('PM25_Mortality_CI_Professional.png', dpi=300, bbox_inches='tight')\n",
    "        plt.show()\n",
    "    \n",
    "    # 2. Enhanced Time Series Plot\n",
    "    fig, ax = plt.subplots(figsize=(14, 8))\n",
    "    \n",
    "    for i, city in enumerate(cities.keys()):\n",
    "        city_data_df = df_analysis[df_analysis['City'] == city]\n",
    "        if len(city_data_df) > 0:\n",
    "            ax.plot(city_data_df['Year'], city_data_df['Total_Deaths'], \n",
    "                   marker='o', linewidth=2.5, markersize=8, label=city,\n",
    "                   color=colors[i % len(colors)], markerfacecolor='white', \n",
    "                   markeredgewidth=2, markeredgecolor=colors[i % len(colors)])\n",
    "    \n",
    "    ax.set_xlabel('Year', fontweight='bold', fontsize=14)\n",
    "    ax.set_ylabel('Estimated Annual Mortality (Deaths)', fontweight='bold', fontsize=14)\n",
    "    ax.set_title('Temporal Trends in PM$_{2.5}$-Attributable Mortality\\n(2000-2020)', \n",
    "                fontweight='bold', fontsize=16, pad=20)\n",
    "    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', frameon=True, fancybox=True, \n",
    "              shadow=True, framealpha=0.9)\n",
    "    ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.8)\n",
    "    ax.set_axisbelow(True)\n",
    "    \n",
    "    # Remove top and right spines\n",
    "    ax.spines['top'].set_visible(False)\n",
    "    ax.spines['right'].set_visible(False)\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.savefig('PM25_Mortality_Trends_Professional.png', dpi=300, bbox_inches='tight')\n",
    "    plt.show()\n",
    "    \n",
    "    # 3. PM2.5 Concentration Trends\n",
    "    fig, ax = plt.subplots(figsize=(14, 8))\n",
    "    \n",
    "    for i, city in enumerate(cities.keys()):\n",
    "        city_data_df = df_analysis[df_analysis['City'] == city]\n",
    "        if len(city_data_df) > 0:\n",
    "            ax.plot(city_data_df['Year'], city_data_df['PM25'], \n",
    "                   marker='s', linewidth=2.5, markersize=8, label=city,\n",
    "                   color=colors[i % len(colors)], markerfacecolor='white', \n",
    "                   markeredgewidth=2, markeredgecolor=colors[i % len(colors)])\n",
    "    \n",
    "    # Add WHO standard line\n",
    "    ax.axhline(y=40, color='red', linestyle='--', linewidth=2.5, \n",
    "               label='WHO Annual Guideline (40 μg/m³)', alpha=0.8)\n",
    "    \n",
    "    ax.set_xlabel('Year', fontweight='bold', fontsize=14)\n",
    "    ax.set_ylabel('PM$_{2.5}$ Concentration (μg/m³)', fontweight='bold', fontsize=14)\n",
    "    ax.set_title('Temporal Trends in Annual PM$_{2.5}$ Concentrations\\n(2000-2020)', \n",
    "                fontweight='bold', fontsize=16, pad=20)\n",
    "    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', frameon=True, fancybox=True, \n",
    "              shadow=True, framealpha=0.9)\n",
    "    ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.8)\n",
    "    ax.set_axisbelow(True)\n",
    "    \n",
    "    # Remove top and right spines\n",
    "    ax.spines['top'].set_visible(False)\n",
    "    ax.spines['right'].set_visible(False)\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.savefig('PM25_Concentration_Trends_Professional.png', dpi=300, bbox_inches='tight')\n",
    "    plt.show()\n",
    "    \n",
    "    # 4. Disease-Specific Mortality Breakdown (Stacked Bar Chart)\n",
    "    disease_cols = ['LC_Deaths', 'CEV_Deaths', 'COPD_Deaths', 'IHD_Deaths']\n",
    "    disease_names = ['Lung Cancer', 'Cerebrovascular', 'COPD', 'Ischemic Heart Disease']\n",
    "    disease_colors = ['#8E44AD', '#3498DB', '#27AE60', '#E74C3C']\n",
    "    \n",
    "    city_disease_avg = df_analysis.groupby('City')[disease_cols].mean()\n",
    "    \n",
    "    fig, ax = plt.subplots(figsize=(14, 10))\n",
    "    \n",
    "    # Create stacked bar chart\n",
    "    bottom_vals = np.zeros(len(city_disease_avg))\n",
    "    for i, (col, name, color) in enumerate(zip(disease_cols, disease_names, disease_colors)):\n",
    "        bars = ax.bar(city_disease_avg.index, city_disease_avg[col], bottom=bottom_vals, \n",
    "               label=name, color=color, alpha=0.8, edgecolor='black', linewidth=0.8)\n",
    "        bottom_vals += city_disease_avg[col]\n",
    "    \n",
    "    ax.set_xlabel('Cities', fontweight='bold', fontsize=14)\n",
    "    ax.set_ylabel('Average Annual Mortality (Deaths)', fontweight='bold', fontsize=14)\n",
    "    ax.set_title('Disease-Specific PM$_{2.5}$-Attributable Mortality Distribution\\n(Average 2000-2020)', \n",
    "                fontweight='bold', fontsize=16, pad=20)\n",
    "    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', frameon=True, fancybox=True, \n",
    "              shadow=True, framealpha=0.9)\n",
    "    \n",
    "    # Rotate x-axis labels for better readability\n",
    "    plt.xticks(rotation=45, ha='right')\n",
    "    \n",
    "    ax.grid(True, axis='y', alpha=0.3, linestyle='--', linewidth=0.8)\n",
    "    ax.set_axisbelow(True)\n",
    "    \n",
    "    # Remove top and right spines\n",
    "    ax.spines['top'].set_visible(False)\n",
    "    ax.spines['right'].set_visible(False)\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.savefig('Disease_Specific_Mortality_Professional.png', dpi=300, bbox_inches='tight')\n",
    "    plt.show()\n",
    "    \n",
    "    # 5. IMPROVED Correlation Analysis: PM2.5 vs Mortality\n",
    "    fig, ax = plt.subplots(figsize=(12, 8))\n",
    "    \n",
    "    # Create scatter plot with different colors for each city\n",
    "    scatter_plots = []\n",
    "    city_labels = []\n",
    "    \n",
    "    for i, city in enumerate(cities.keys()):\n",
    "        city_data = df_analysis[df_analysis['City'] == city]\n",
    "        if len(city_data) > 0:\n",
    "            # Calculate average values for this city\n",
    "            avg_pm25 = city_data['PM25'].mean()\n",
    "            avg_deaths = city_data['Total_Deaths'].mean()\n",
    "            \n",
    "            # Plot individual points\n",
    "            scatter = ax.scatter(city_data['PM25'], city_data['Total_Deaths'], \n",
    "                      s=60, alpha=0.6, color=colors[i % len(colors)],\n",
    "                      edgecolors='black', linewidth=0.5)\n",
    "            \n",
    "            # Plot average point with larger size and better visibility\n",
    "            avg_point = ax.scatter(avg_pm25, avg_deaths, \n",
    "                                 s=150, color=colors[i % len(colors)], \n",
    "                                 edgecolors='black', linewidth=2, \n",
    "                                 marker='D', label=city)  # Diamond shape for averages\n",
    "            \n",
    "            scatter_plots.append(avg_point)\n",
    "            city_labels.append(city)\n",
    "            \n",
    "            # Add city name label to average point\n",
    "            ax.annotate(city, (avg_pm25, avg_deaths), \n",
    "                       xytext=(10, 5), textcoords='offset points',\n",
    "                       fontsize=10, fontweight='bold',\n",
    "                       bbox=dict(boxstyle='round,pad=0.2', facecolor='white', \n",
    "                               alpha=0.8, edgecolor=colors[i % len(colors)]),\n",
    "                       arrowprops=dict(arrowstyle='->', color=colors[i % len(colors)], \n",
    "                                     alpha=0.7))\n",
    "    \n",
    "    # Add trend line for all data (not just averages)\n",
    "    z = np.polyfit(df_analysis['PM25'], df_analysis['Total_Deaths'], 1)\n",
    "    p = np.poly1d(z)\n",
    "    x_range = np.linspace(df_analysis['PM25'].min(), df_analysis['PM25'].max(), 100)\n",
    "    ax.plot(x_range, p(x_range), \"r--\", alpha=0.7, linewidth=2.5, \n",
    "            label=f'Linear trend (r = {np.corrcoef(df_analysis[\"PM25\"], df_analysis[\"Total_Deaths\"])[0,1]:.3f})')\n",
    "    \n",
    "    ax.set_xlabel('PM$_{2.5}$ Concentration (μg/m³)', fontweight='bold', fontsize=14)\n",
    "    ax.set_ylabel('Estimated Annual Mortality (Deaths)', fontweight='bold', fontsize=14)\n",
    "    ax.set_title('Relationship Between PM$_{2.5}$ Exposure and Attributable Mortality\\n(Individual Years and City Averages)', \n",
    "                fontweight='bold', fontsize=16, pad=20)\n",
    "    \n",
    "    # Create a simplified legend\n",
    "    from matplotlib.lines import Line2D\n",
    "    legend_elements = [\n",
    "        Line2D([0], [0], marker='D', color='w', markerfacecolor='gray', \n",
    "               markersize=10, markeredgecolor='black', label='City Average'),\n",
    "        Line2D([0], [0], marker='o', color='w', markerfacecolor='gray', \n",
    "               markersize=8, markeredgecolor='black', label='Individual Years'),\n",
    "        Line2D([0], [0], color='red', linestyle='--', linewidth=2.5, \n",
    "               label=f'Trend (r = {np.corrcoef(df_analysis[\"PM25\"], df_analysis[\"Total_Deaths\"])[0,1]:.3f})')\n",
    "    ]\n",
    "    \n",
    "    ax.legend(handles=legend_elements, loc='upper left', frameon=True, \n",
    "              fancybox=True, shadow=True, framealpha=0.9)\n",
    "    \n",
    "    ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.8)\n",
    "    ax.set_axisbelow(True)\n",
    "    \n",
    "    # Remove top and right spines\n",
    "    ax.spines['top'].set_visible(False)\n",
    "    ax.spines['right'].set_visible(False)\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.savefig('PM25_Mortality_Correlation_Professional.png', dpi=300, bbox_inches='tight')\n",
    "    plt.show()\n",
    "\n",
    "    # 6. Additional Plot: Population vs Mortality\n",
    "    fig, ax = plt.subplots(figsize=(12, 8))\n",
    "    \n",
    "    for i, city in enumerate(cities.keys()):\n",
    "        city_data = df_analysis[df_analysis['City'] == city]\n",
    "        if len(city_data) > 0:\n",
    "            avg_pop = city_data['Population'].mean()\n",
    "            avg_deaths = city_data['Total_Deaths'].mean()\n",
    "            \n",
    "            ax.scatter(avg_pop, avg_deaths, s=150, color=colors[i % len(colors)],\n",
    "                      edgecolors='black', linewidth=2, alpha=0.8, label=city)\n",
    "            \n",
    "            # Add city name label\n",
    "            ax.annotate(city, (avg_pop, avg_deaths), \n",
    "                       xytext=(10, 5), textcoords='offset points',\n",
    "                       fontsize=10, fontweight='bold',\n",
    "                       bbox=dict(boxstyle='round,pad=0.2', facecolor='white', \n",
    "                               alpha=0.8, edgecolor=colors[i % len(colors)]))\n",
    "    \n",
    "    # Add trend line\n",
    "    pop_deaths_corr = np.corrcoef(df_analysis.groupby('City')['Population'].mean(), \n",
    "                                 df_analysis.groupby('City')['Total_Deaths'].mean())[0,1]\n",
    "    z_pop = np.polyfit(df_analysis.groupby('City')['Population'].mean(), \n",
    "                      df_analysis.groupby('City')['Total_Deaths'].mean(), 1)\n",
    "    p_pop = np.poly1d(z_pop)\n",
    "    pop_range = np.linspace(df_analysis.groupby('City')['Population'].mean().min(), \n",
    "                           df_analysis.groupby('City')['Population'].mean().max(), 100)\n",
    "    ax.plot(pop_range, p_pop(pop_range), \"b--\", alpha=0.7, linewidth=2.5, \n",
    "            label=f'Trend (r = {pop_deaths_corr:.3f})')\n",
    "    \n",
    "    ax.set_xlabel('Population', fontweight='bold', fontsize=14)\n",
    "    ax.set_ylabel('Average Annual Mortality (Deaths)', fontweight='bold', fontsize=14)\n",
    "    ax.set_title('Relationship Between Population Size and PM$_{2.5}$-Attributable Mortality', \n",
    "                fontweight='bold', fontsize=16, pad=20)\n",
    "    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', frameon=True, \n",
    "              fancybox=True, shadow=True, framealpha=0.9)\n",
    "    ax.grid(True, alpha=0.3, linestyle='--', linewidth=0.8)\n",
    "    ax.set_axisbelow(True)\n",
    "    \n",
    "    # Format x-axis to show population in millions\n",
    "    from matplotlib.ticker import FuncFormatter\n",
    "    ax.xaxis.set_major_formatter(FuncFormatter(lambda x, p: f'{x/1e6:.1f}M'))\n",
    "    \n",
    "    # Remove top and right spines\n",
    "    ax.spines['top'].set_visible(False)\n",
    "    ax.spines['right'].set_visible(False)\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.savefig('Population_Mortality_Correlation_Professional.png', dpi=300, bbox_inches='tight')\n",
    "    plt.show()\n",
    "    \n",
    "    # Print comprehensive summary\n",
    "    print(\"\\n\" + \"=\"*100)\n",
    "    print(\"COMPREHENSIVE HEALTH RISK ANALYSIS SUMMARY\")\n",
    "    print(\"=\"*100)\n",
    "    \n",
    "    print(f\"\\n📊 DATA COVERAGE\")\n",
    "    print(f\"   • Cities analyzed: {df_analysis['City'].nunique()}\")\n",
    "    print(f\"   • Years covered: {sorted(df_analysis['Year'].unique())}\")\n",
    "    print(f\"   • Total observations: {len(df_analysis)} city-year combinations\")\n",
    "    \n",
    "    print(f\"\\n🏙️  CITY CHARACTERISTICS\")\n",
    "    for city in sorted(cities.keys()):\n",
    "        city_info = df_analysis[df_analysis['City'] == city].iloc[0] if len(df_analysis[df_analysis['City'] == city]) > 0 else None\n",
    "        if city_info is not None:\n",
    "            avg_pop = df_analysis[df_analysis['City'] == city]['Population'].mean()\n",
    "            print(f\"   • {city}: Area = {city_info['Area_km2']} km², \"\n",
    "                  f\"Avg Population = {avg_pop:,.0f}, \"\n",
    "                  f\"Grid = {city_info['Grid_Size']}×{city_info['Grid_Size']}\")\n",
    "    \n",
    "    print(f\"\\n🌫️  PM2.5 EXPOSURE ASSESSMENT\")\n",
    "    print(f\"   • Mean PM2.5: {df_analysis['PM25'].mean():.1f} μg/m³\")\n",
    "    print(f\"   • Median PM2.5: {df_analysis['PM25'].median():.1f} μg/m³\")\n",
    "    print(f\"   • Range: {df_analysis['PM25'].min():.1f} - {df_analysis['PM25'].max():.1f} μg/m³\")\n",
    "    exceed_count = (df_analysis['PM25'] > 40).sum()\n",
    "    exceed_percent = (exceed_count / len(df_analysis)) * 100\n",
    "    print(f\"   • WHO standard (40 μg/m³) exceedances: {exceed_count}/{len(df_analysis)} ({exceed_percent:.1f}%)\")\n",
    "    \n",
    "    print(f\"\\n💀 MORTALITY IMPACT ASSESSMENT\")\n",
    "    total_deaths_all = df_analysis['Total_Deaths'].sum()\n",
    "    avg_deaths_per_year = df_analysis.groupby('Year')['Total_Deaths'].sum().mean()\n",
    "    print(f\"   • Total estimated deaths (2000-2020): {total_deaths_all:.0f}\")\n",
    "    print(f\"   • Average annual deaths: {avg_deaths_per_year:.0f}\")\n",
    "    print(f\"   • Range per city-year: {df_analysis['Total_Deaths'].min():.1f} - {df_analysis['Total_Deaths'].max():.1f}\")\n",
    "    \n",
    "    print(f\"\\n🫀 DISEASE-SPECIFIC MORTALITY DISTRIBUTION\")\n",
    "    for disease, col in zip(disease_names, disease_cols):\n",
    "        total_deaths = df_analysis[col].sum()\n",
    "        percentage = (total_deaths / total_deaths_all) * 100\n",
    "        print(f\"   • {disease}: {total_deaths:.0f} deaths ({percentage:.1f}%)\")\n",
    "    \n",
    "    if len(df_ci) > 0:\n",
    "        print(f\"\\n📈 CONFIDENCE INTERVALS (95%)\")\n",
    "        df_ci_sorted_summary = df_ci.sort_values('Mean_Deaths', ascending=False)\n",
    "        for _, row in df_ci_sorted_summary.iterrows():\n",
    "            if not np.isnan(row['Mean_Deaths']):\n",
    "                print(f\"   • {row['City']}: {row['Mean_Deaths']:.0f} \"\n",
    "                      f\"[{row['CI_Lower']:.0f} - {row['CI_Upper']:.0f}] deaths/year\")\n",
    "    \n",
    "    print(f\"\\n🏆 HIGHEST RISK CITIES\")\n",
    "    top_cities = df_analysis.groupby('City')['Total_Deaths'].mean().sort_values(ascending=False).head(5)\n",
    "    for i, (city, deaths) in enumerate(top_cities.items(), 1):\n",
    "        avg_pm25 = df_analysis[df_analysis['City'] == city]['PM25'].mean()\n",
    "        avg_pop = df_analysis[df_analysis['City'] == city]['Population'].mean()\n",
    "        print(f\"   {i}. {city}: {deaths:.0f} deaths/year \"\n",
    "              f\"(PM2.5: {avg_pm25:.1f} μg/m³, Pop: {avg_pop:,.0f})\")\n",
    "    \n",
    "    print(f\"\\n📋 DATA QUALITY NOTES\")\n",
    "    print(f\"   • Population data source: Gridded Population of the World (GPW) v4\")\n",
    "    print(f\"   • PM2.5 data source: Ground monitoring and satellite observations\")\n",
    "    print(f\"   • Health impact function: Integrated Exposure Response (IER) function\")\n",
    "    print(f\"   • Confidence intervals calculated using bootstrap method (n=1000)\")\n",
    "    \n",
    "    # Export comprehensive results\n",
    "    df_analysis.to_csv('city_health_risk_analysis_comprehensive.csv', index=False)\n",
    "    if len(df_ci) > 0:\n",
    "        df_ci.to_csv('confidence_intervals_detailed.csv', index=False)\n",
    "    \n",
    "    # Export summary statistics\n",
    "    summary_stats = df_analysis.groupby('City').agg({\n",
    "        'PM25': ['mean', 'std', 'min', 'max', 'count'],\n",
    "        'Total_Deaths': ['mean', 'std', 'min', 'max'],\n",
    "        'Population': ['mean', 'std'],\n",
    "        'Area_km2': ['first'],\n",
    "        'Grid_Size': ['first']\n",
    "    }).round(2)\n",
    "    summary_stats.columns = ['_'.join(col).strip() for col in summary_stats.columns.values]\n",
    "    summary_stats.to_csv('city_summary_statistics_detailed.csv')\n",
    "    \n",
    "    print(f\"\\n💾 DATA EXPORTS\")\n",
    "    print(f\"   ✅ city_health_risk_analysis_comprehensive.csv\")\n",
    "    print(f\"   ✅ confidence_intervals_detailed.csv\")    \n",
    "    print(f\"   ✅ city_summary_statistics_detailed.csv\")\n",
    "    print(f\"   ✅ 6 high-quality figures for research publication\")\n",
    "    \n",
    "else:\n",
    "    print(\"❌ No data available for analysis. Please check file paths and data availability.\")\n",
    "\n",
    "print(\"\\n🎉 COMPREHENSIVE ANALYSIS SUCCESSFULLY COMPLETED!\")\n",
    "print(\"=\"*100)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
