{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bbe1af3-83b9-4705-bce1-a3fe178c3bfd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import scipy\n",
    "import xarray as xr\n",
    "import os \n",
    "import dask\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt \n",
    "import matplotlib.colors as clr \n",
    "import matplotlib.gridspec as gridspec\n",
    "import matplotlib.lines as mlines\n",
    "import matplotlib.colors as mcolors\n",
    "import cartopy.crs as ccrs\n",
    "import cartopy.feature as cfeature\n",
    "import calendar\n",
    "\n",
    "#import h2_vd_func as vd\n",
    "from tqdm import tqdm\n",
    "from analysis_util import *\n",
    "\n",
    "plt.rcParams['axes.facecolor']='white'\n",
    "plt.rcParams['savefig.facecolor']='white'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dfd730c4-a8f6-4773-b20c-c5e77a3b663d",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Optimal/critical soil moisture analysis\n",
    "### Load soil texture data\n",
    "\n",
    "with xr.open_dataset('/net/fs01/data/cesm2/inputdata/lnd/clm2/surfdata_esmf/ctsm5.3.0/surfdata_0.9x1.25_hist_2000_78pfts_c240908.nc') as soil_data:\n",
    "    pct_sand = soil_data.PCT_SAND[0:3,...]\n",
    "    pct_clay = soil_data.PCT_CLAY[0:3,...]\n",
    "    pct_om = soil_data.ORGANIC[0:3,...]\n",
    "    z_vec = np.array([1,5,9])/100\n",
    "    \n",
    "### Calculate soil hydraulic data\n",
    "\n",
    "### OM_MAX = 130\n",
    "b_min = 2.91 + 0.159*pct_clay\n",
    "theta_sat_min = 0.498 - 0.00126*pct_sand\n",
    "psi_sat_min = -10*(10**(1.88-0.0131*pct_sand))\n",
    "\n",
    "theta_sat_om = np.maximum(0.93 - 0.1 * z_vec / 0.5, 0.83)\n",
    "b_om = np.minimum(2.7 +9.3 * z_vec / 0.5, 12.0)\n",
    "psi_sat_om = np.minimum(10.3 - 0.2 * z_vec / 0.5, 10.1) * -1\n",
    "\n",
    "om_frac = np.minimum(pct_om/130,1)\n",
    "\n",
    "b = avg_soil_hydraulic(b_min, b_om, om_frac)\n",
    "theta_sat = avg_soil_hydraulic(theta_sat_min, theta_sat_om, om_frac)\n",
    "psi_sat = avg_soil_hydraulic(psi_sat_min, psi_sat_om, om_frac) * 9.8e-6\n",
    "s_min = (-3/psi_sat)**(-1/b)\n",
    "s_opt = (-0.3/psi_sat)**(-1/b)\n",
    "beta_2 = 0.4*(1-s_opt)/(s_opt-s_min)\n",
    "b_combined = 2+3/b+beta_2\n",
    "\n",
    "s_opt_with_D = xr.apply_ufunc(\n",
    "    np.vectorize(find_min_s),\n",
    "    s_min,             # Input 1\n",
    "    s_opt,             # Input 2 (used as upper bound)\n",
    "    b_combined,        # Input 3\n",
    "    input_core_dims=[[], [], []],  # Matches the 3 inputs\n",
    "    vectorize=True,\n",
    "    dask=\"parallelized\",\n",
    "    output_dtypes=[float]\n",
    ")\n",
    "\n",
    "with xr.open_dataset('/home/ayhwong/fs12_cesm_storage/CTSM/rundir/archive/ACCESS-ESM1-5_hot_new/lnd/hist/output_shifted.nc') as reind_xr:\n",
    "    ref_xr = reind_xr.F_MET[0,...]\n",
    "    theta_opt_with_D = (s_opt_with_D*theta_sat).rename({'lsmlat':'lat', 'lsmlon':'lon'})\n",
    "    theta_opt_with_D.coords['lat'] = ref_xr.lat\n",
    "    theta_opt_with_D.coords['lon'] = ref_xr.lon\n",
    "\n",
    "    theta_min = (s_min*theta_sat).rename({'lsmlat':'lat', 'lsmlon':'lon'})\n",
    "    theta_min.coords['lat'] = ref_xr.lat\n",
    "    theta_min.coords['lon'] = ref_xr.lon\n",
    "\n",
    "sm_curve_india = soil_moisture_curve(21,52,0.32) #India\n",
    "sm_curve_sandy = soil_moisture_curve(91,5,0.04)# Southern Africa\n",
    "\n",
    "fig,ax = plt.subplots(figsize=(7, 5))\n",
    "ax.set_xlabel(r'$\\theta$ ($m^3$ $m^{-3}$)', fontsize = 16)\n",
    "ax.set_ylabel(r'Normalized $G_{soil}$ (unitless)', fontsize = 16)\n",
    "ax.plot(sm_curve_india['theta'],sm_curve_india['f_theta'], color = 'blue',label = 'Clay-rich soil')\n",
    "#ax.axvline(sm_curve_india['theta_opt_bio'], ymin = 0.05,ymax = sm_curve_india['f_theta_opt_bio'] - 0.02,\n",
    "#           color = 'blue', linestyle=':')\n",
    "ax.axvline(sm_curve_india['theta_opt'], ymin = 0.05, ymax = 0.95,\n",
    "           color = 'blue', linestyle='--')\n",
    "ax.axhline(0, color='black')\n",
    "\n",
    "ax.plot(sm_curve_sandy['theta'],sm_curve_sandy['f_theta'], color = 'red',label = 'Sandy soil')\n",
    "#ax.axvline(sm_curve_sandy['theta_opt_bio'], ymin = 0.05,ymax = sm_curve_sandy['f_theta_opt_bio'] - 0.035,\n",
    "#           color = 'red', linestyle=':', label = r\"$w_{opt}$ w/o $D_s$\")\n",
    "ax.axvline(sm_curve_sandy['theta_opt'], ymin = 0.05, ymax = 0.95,\n",
    "           color = 'red', linestyle='--')\n",
    "\n",
    "ax.set_xlim(0,0.7)\n",
    "\n",
    "\n",
    "ax.tick_params(axis='x', labelsize=12)\n",
    "ax.tick_params(axis='y', labelsize=12) \n",
    "\n",
    "\n",
    "ax.grid(True)\n",
    "\n",
    "# 2. Call the legend with these custom handles\n",
    "\n",
    "h_clay = mlines.Line2D([], [], color='blue', label='Clay-rich soil')\n",
    "h_sand = mlines.Line2D([], [], color='red', label='Sandy soil')\n",
    "h_opt = mlines.Line2D([], [], color='black', linestyle='--', label=r\"$\\theta$ = $\\theta_{opt}$ \")\n",
    "ax.legend(handles=[h_clay, h_sand, h_opt], fontsize=14)\n",
    "plt.tight_layout()\n",
    "\n",
    "fig.savefig('f_theta_curve.png',dpi = 300, format = 'png')\n",
    "\n",
    "sm_curve_india = soil_moisture_curve(21,52,0.32) #India\n",
    "sm_curve_india_drier = soil_moisture_curve(21,52,0.32, -100) #India\n",
    "\n",
    "#sm_curve_sandy = soil_moisture_curve(91,5,0.04)# Southern Africa\n",
    "\n",
    "fig,ax = plt.subplots(figsize=(7, 5))\n",
    "ax.set_xlabel(r'$\\theta$ ($m^3$ $m^{-3}$)', fontsize = 16)\n",
    "ax.set_ylabel(r'Normalized $G_{soil}$ (unitless)', fontsize = 16)\n",
    "ax.plot(sm_curve_india['theta'],sm_curve_india['f_theta'], color = 'blue',label = r'$\\Psi_{lb}$ = -3 MPa')\n",
    "#ax.axvline(sm_curve_india['theta_opt_bio'], ymin = 0.05,ymax = sm_curve_india['f_theta_opt_bio'] - 0.02,\n",
    "#           color = 'blue', linestyle=':')\n",
    "ax.axvline(sm_curve_india['theta_opt'], ymin = 0.05, ymax = 0.95,\n",
    "           color = 'blue', linestyle='--')\n",
    "ax.axhline(0, color='black')\n",
    "\n",
    "ax.plot(sm_curve_india_drier['theta'],sm_curve_india_drier['f_theta'], color = 'red',label = r'$\\Psi_{lb}$ = -100 MPa')\n",
    "#ax.axvline(sm_curve_sandy['theta_opt_bio'], ymin = 0.05,ymax = sm_curve_sandy['f_theta_opt_bio'] - 0.035,\n",
    "#           color = 'red', linestyle=':', label = r\"$w_{opt}$ w/o $D_s$\")\n",
    "ax.axvline(sm_curve_india_drier['theta_opt'], ymin = 0.05, ymax = 0.95,\n",
    "           color = 'red', linestyle='--')\n",
    "\n",
    "ax.set_xlim(0,0.7)\n",
    "\n",
    "\n",
    "ax.tick_params(axis='x', labelsize=12)\n",
    "ax.tick_params(axis='y', labelsize=12) \n",
    "\n",
    "\n",
    "ax.grid(True)\n",
    "\n",
    "# 2. Call the legend with these custom handles\n",
    "#ax.legend()\n",
    "h_clay = mlines.Line2D([], [], color='blue', label = r'$\\Psi_{lb}$ = -3 MPa')\n",
    "h_sand = mlines.Line2D([], [], color='red', label = r'$\\Psi_{lb}$ = -100 MPa')\n",
    "h_opt = mlines.Line2D([], [], color='black', linestyle='--', label=r\"$\\theta$ = $\\theta_{opt}$ \")\n",
    "ax.legend(handles=[h_clay, h_sand, h_opt], fontsize=14)\n",
    "plt.tight_layout()\n",
    "\n",
    "fig.savefig('f_theta_curve_psi_lb_shift.png',dpi = 300, format = 'png')\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f556571d-ab8f-4714-af78-2a9a1f3df897",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Plot 2000 - 2009 averages\n",
    "\n",
    "with xr.open_dataset('/home/ayhwong/fs12_cesm_storage/CTSM/rundir/archive/om_0.03_soilc_1000_psiopt_0.3/lnd/hist/output_shifted.nc') as out_2d_soc:\n",
    "    vd_2d_trunc = out_2d_soc.DRYDEPV_H2.sel(lat = slice(-60,75))\n",
    "\n",
    "sm_min_xr = xr.open_dataarray('theta_min.nc').sel(lat = slice(-60,75))\n",
    "sm_opt_xr = xr.open_dataarray('theta_opt_with_D.nc').sel(lat = slice(-60,75))\n",
    "sm_xr = xr.open_dataset('/home/ayhwong/fs12_cesm_storage/CTSM/rundir/archive/om_0.03_soilc_1000_psiopt_0.3/lnd/hist/soil_water_shifted.nc').H2OSOI.sel(lat = slice(-60,75))\n",
    "\n",
    "weight_depth = xr.DataArray(np.array([2,4,6]), coords=[sm_xr.levsoi[0:3]], dims = 'levsoi')\n",
    "weight_month = sm_xr.time.dt.days_in_month\n",
    "\n",
    "sm_weighted_xr = sm_xr.weighted(weight_depth*weight_month).mean(dim=['time','levsoi'])\n",
    "\n",
    "dry_xr = xr.where(sm_weighted_xr < sm_min_xr, 1, 0)\n",
    "non_optimal_xr = xr.where(sm_weighted_xr < sm_opt_xr, 1, 0)\n",
    "threatened_xr = xr.where(sm_weighted_xr < (sm_opt_xr + 0.05), 1, 0)\n",
    "non_dry_xr = xr.where(sm_weighted_xr > sm_min_xr, 1, 0)\n",
    "high_uptake_xr = vd_2d_trunc * non_dry_xr * threatened_xr\n",
    "\n",
    "vd_2d = vd_2d_trunc.weighted(vd_2d_trunc.time.dt.days_in_month).mean(dim='time')\n",
    "\n",
    "fig = plt.figure(figsize=(10, 12))\n",
    "\n",
    "gs = fig.add_gridspec(3, 1, height_ratios=[1.5,1,1])\n",
    "\n",
    "# First subplot with PlateCarree projection (GeoAxes)\n",
    "ax1 = fig.add_subplot(gs[0,:], projection=ccrs.Miller())\n",
    "im = vd_2d.plot(ax=ax1, vmax=0.1, transform=ccrs.PlateCarree(), add_colorbar=False)\n",
    "dry_xr.plot.contour(\n",
    "    ax=ax1, \n",
    "    transform=ccrs.PlateCarree(),\n",
    "    levels=[0.5], \n",
    "    colors='red', \n",
    "    linewidths=1.5\n",
    ")\n",
    "\n",
    "threatened_xr.plot.contour(\n",
    "    ax=ax1, \n",
    "    transform=ccrs.PlateCarree(),\n",
    "    levels=[0.5], \n",
    "    colors='orange', \n",
    "    linewidths=1.5\n",
    ")\n",
    "\n",
    "\n",
    "cbar = fig.colorbar(im, ax=ax1, shrink=0.9, pad=0.05,extend = 'max')\n",
    "cbar.set_label(r'$v_d$ (cm $s^{-1}$)', fontsize = 16)  # label fontsize\n",
    "cbar.ax.tick_params(labelsize=12)\n",
    "ax1.coastlines()\n",
    "ax1.gridlines()\n",
    "ax1.set_title(\"\")\n",
    "\n",
    "ax1.text(-0.09, -0.03, 'a)', transform=ax1.transAxes, \n",
    "        fontsize=22, fontweight='bold', va='bottom', ha='right')\n",
    "\n",
    "\n",
    "# Second subplot with regular axes (no projection)\n",
    "ax2 = fig.add_subplot(gs[1,:])\n",
    "lat_band = np.arange(-60, 75, 5)\n",
    "lat_banded_vd_land = []\n",
    "lat_banded_vd_all = []\n",
    "\n",
    "for i, lat_i in enumerate(lat_band):\n",
    "    lat_banded_vd_land.append(vd_2d.sel(lat=slice(int(lat_i), int(lat_i+5))).mean().values[()])\n",
    "    lat_banded_vd_all.append(vd_2d.fillna(0).sel(lat=slice(int(lat_i), int(lat_i+5))).mean().values[()])\n",
    "\n",
    "ax2.plot(lat_band, lat_banded_vd_land,color='black', label = 'Land only')\n",
    "#ax2.plot(lat_band, lat_banded_vd_all,color='blue', label = 'All area')\n",
    "ax2.set_xlabel('Latitude (degree north)',fontsize = 16)\n",
    "ax2.set_ylabel(r'$v_d$ (cm $s^{-1}$)', fontsize = 16)\n",
    "ax2.grid()\n",
    "ax2.tick_params(axis='both', labelsize=12)\n",
    "#ax2.legend(bbox_to_anchor=(1.05, 0.7), loc='upper left')\n",
    "ax2.set_ylim(0,0.06)\n",
    "\n",
    "ax2.text(-0.07, -0.03, 'b)', transform=ax2.transAxes, \n",
    "        fontsize=22, fontweight='bold', va='bottom', ha='right')\n",
    "\n",
    "\n",
    "vd_monthly = vd_2d_trunc.groupby(vd_2d_trunc.time.dt.month).mean(dim='time')\n",
    "ax3 = fig.add_subplot(gs[2,:])\n",
    "\n",
    "# Get short month names: ['Jan', 'Feb', ..., 'Dec']\n",
    "month_names = [calendar.month_name[m][:3] for m in range(1, 13)]\n",
    "\n",
    "lat_band_ub = [75,30,-30]\n",
    "lat_band_lb = [30,-30,-60]\n",
    "\n",
    "labels_band = [r'$30^{\\circ}$N - $75^{\\circ}$N',\n",
    "              r'$30^{\\circ}S$ - $30^{\\circ}$N',\n",
    "              r'$60^{\\circ}$S - $30^{\\circ}S$']\n",
    "ax3.plot(np.arange(1,13), vd_monthly.mean(dim=['lon','lat']).values[()],\n",
    "        label = 'Global', color = 'black')\n",
    "for i in range(len(lat_band_ub)):\n",
    "    ax3.plot(np.arange(1,13), \n",
    "             vd_monthly.sel(lat=slice(lat_band_lb[i],lat_band_ub[i])).mean(dim=['lon','lat']).values[()],\n",
    "            label = labels_band[i])\n",
    "    ax3.set_xlabel('Month',fontsize = 16)\n",
    "ax3.set_ylabel(r'$v_d$ (cm $s^{-1}$)', fontsize = 16)\n",
    "ax3.grid()\n",
    "ax3.tick_params(axis='both', labelsize=12)\n",
    "ax3.set_ylim(0,0.06)\n",
    "ax3.legend(bbox_to_anchor=(0.5, 0.5), loc='upper left',fontsize = 14)\n",
    "ax3.xaxis.set_ticks(np.arange(1,13))\n",
    "ax3.xaxis.set_ticklabels(month_names, fontsize=14)\n",
    "\n",
    "ax3.text(-0.07, -0.03, 'c)', transform=ax3.transAxes, \n",
    "             fontsize=22, fontweight='bold', va='bottom', ha='right')\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "fig.savefig('global_2000_2010.png',dpi =300, format = 'png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03443c39-c236-440a-9cfc-0522cccdfac3",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Model observation comparison\n",
    "\n",
    "harvard_data = pd.read_csv('harvard_vd.csv')\n",
    "\n",
    "sites_config = {\n",
    "    'Helsinki': {\n",
    "        'lat': 60, 'lon': 25,\n",
    "        'obs_mean': [0.0084, np.nan, 0.0012, 0.023, 0.041, 0.059, np.nan, 0.049, np.nan, 0.026, 0.028, 0.029],\n",
    "        'obs_std': None\n",
    "    },\n",
    "    'Mace Head': {\n",
    "        'lat': 53, 'lon': 350,\n",
    "        'obs_mean': [np.nan, np.nan, np.nan, 0.0578, 0.0575, 0.0516, 0.0469, 0.0554, 0.0541, np.nan, np.nan, np.nan],\n",
    "        'obs_std': [np.nan, np.nan, np.nan, 0.0252, 0.0189, 0.0187, 0.0094, 0.0125, 0.0150, np.nan, np.nan, np.nan]\n",
    "    },\n",
    "    'Heidelberg': {\n",
    "        'lat': 49.4, 'lon': 8.7,\n",
    "        'obs_mean': [0.024, 0.018, 0.031, 0.026, 0.024, 0.033, 0.028, 0.037, 0.036, 0.038, np.nan, 0.023],\n",
    "        'obs_std': None\n",
    "    },\n",
    "    'Harvard Forest': {\n",
    "        'lat': 42.5, 'lon': 287.5,\n",
    "        'obs_mean': harvard_data['mean_vd'].values, # Assumes harvard_data is pre-loaded\n",
    "        'obs_std': harvard_data['sd_vd'].values\n",
    "    },\n",
    "    'Tsubaka': {\n",
    "        'lat': 36, 'lon': 140,\n",
    "        'obs_mean': [0.053, 0.063, 0.060, np.nan, 0.022, np.nan, 0.05, np.nan, 0.046, 0.045, 0.027, 0.047],\n",
    "        'obs_std': [0.009, 0.011, 0.028, np.nan, 0.010, np.nan, 0.014, np.nan, 0.028, 0.016, 0.013, 0.01],\n",
    "        'secondary_obs_mean': [0.063, 0.058, 0.065, np.nan, 0.066, 0.049, 0.069, 0.067, 0.047, 0.062, 0.05, 0.058],\n",
    "        'secondary_obs_std': [0.011, 0.011, 0.01, np.nan, 0.010, 0.012, 0.008, 0.013, 0.01, 0.009, 0.008, 0.009],\n",
    "        'labels': ['Cropland', 'Forest']\n",
    "    },\n",
    "    'San Jacinto (Forest)': {\n",
    "        'lat': 33.81, 'lon': 360 - 116.79,\n",
    "        'obs_mean': [0.046, 0.055, 0.051, 0.073, np.nan, 0.056, 0.057, np.nan , 0.062, np.nan, 0.088, 0.069],\n",
    "        'obs_std': [0.026, 0.022, np.nan, 0.018, np.nan, 0.005, 0.028, np.nan, 0.014, np.nan, 0.009, 0.015]\n",
    "    },\n",
    "    'San Jacinto (Desert)': {\n",
    "        'lat': 34.15, 'lon': 360 - 116.45,\n",
    "        'obs_mean': [0.068, np.nan, 0.084, 0.034, np.nan, 0.026, np.nan, 0.029, np.nan, np.nan, np.nan, 0.046],\n",
    "        'obs_std': [0.011, np.nan, 0.005, 0.008, np.nan, 0.009, np.nan, 0.008, np.nan, np.nan, np.nan, 0.013]\n",
    "    },\n",
    "    'Central Scotland': {\n",
    "        'lat': 55.87, 'lon': 360 - 3.21,\n",
    "        'obs_mean': [np.nan, 0.071, np.nan, 0.075, 0.109, 0.086, 0.079, 0.126 , 0.088, np.nan, 0.070, np.nan],\n",
    "        'obs_std': [np.nan, 0.035, np.nan, 0.030, 0.038, 0.038, 0.03, 0.022, 0.034, np.nan, 0.030, np.nan],\n",
    "        'secondary_obs_mean': [0.001, 0.004, 0.003, 0.0045, 0.024, 0.028, 0.030, 0.002, 0.016, 0.009, 0.002, np.nan],\n",
    "        'secondary_obs_std': [0.003, 0.006, 0.001, 0.005, 0.013, 0.014, 0.012, 0.007, 0.008, 0.004, 0.002, np.nan],\n",
    "        'labels': ['Forest', 'Grass']\n",
    "    }\n",
    "}\n",
    "\n",
    "### Plotting\n",
    "# Get short month names: ['Jan', 'Feb', ..., 'Dec']\n",
    "month_names = [calendar.month_name[m][:3] for m in range(1, 13)]\n",
    "\n",
    "all_data_list = []\n",
    "fig, axes = plt.subplots(nrows=4, ncols=2, figsize=(14, 22))\n",
    "axes_flat = axes.flatten()\n",
    "\n",
    "months = np.arange(12) + 1\n",
    "\n",
    "for i, (site_name, config) in enumerate(sites_config.items()):\n",
    "    ax = axes_flat[i]\n",
    "    \n",
    "    # 1. Extract Modelled Data\n",
    "    mod_vd = out_2d_soc['DRYDEPV_H2'].sel(\n",
    "        lat=config['lat'], lon=config['lon'], method='nearest'\n",
    "    ).groupby('time.month').mean().values\n",
    "    \n",
    "    # Get labels if they exist, otherwise default to 'Observed'\n",
    "    labels = config.get('labels', ['Observed', None])\n",
    "    \n",
    "    # 2. Add Primary Data to list\n",
    "    all_data_list.append(pd.DataFrame({\n",
    "        'Site': site_name,\n",
    "        'Type': labels[0],\n",
    "        'Month': months,\n",
    "        'Modelled': mod_vd,\n",
    "        'Obs_Mean': config['obs_mean'],\n",
    "        'Obs_Std': config['obs_std'] if config['obs_std'] is not None else np.nan\n",
    "    }))\n",
    "    \n",
    "    # 3. Add Secondary Data to list (for Tsubaka/Scotland)\n",
    "    if 'secondary_obs_mean' in config:\n",
    "        all_data_list.append(pd.DataFrame({\n",
    "            'Site': site_name,\n",
    "            'Type': labels[1],\n",
    "            'Month': months,\n",
    "            'Modelled': mod_vd, # Modelled value is the same for the grid cell\n",
    "            'Obs_Mean': config['secondary_obs_mean'],\n",
    "            'Obs_Std': config['secondary_obs_std'] if config['secondary_obs_std'] is not None else np.nan\n",
    "        }))\n",
    "\n",
    "    # --- Plotting Logic ---\n",
    "    ax.plot(months, mod_vd, label='Modeled', color='tab:blue', lw=3, zorder=1)\n",
    "    \n",
    "    # Plot Primary Obs\n",
    "    ax.scatter(months, config['obs_mean'], color='black', label=labels[0], s=60, zorder=2)\n",
    "    if config['obs_std'] is not None:\n",
    "        ax.errorbar(months, config['obs_mean'], yerr=config['obs_std'], fmt='none', ecolor='black', alpha=0.5)\n",
    "    \n",
    "    # Plot Secondary Obs\n",
    "    if 'secondary_obs_mean' in config:\n",
    "        ax.scatter(months, config['secondary_obs_mean'], color='red', label=labels[1], s=60, zorder=2)\n",
    "        if config['secondary_obs_std'] is not None:\n",
    "            ax.errorbar(months, config['secondary_obs_mean'], yerr=config['secondary_obs_std'], fmt='none', ecolor='red', alpha=0.5)\n",
    "    \n",
    "    # Formatting\n",
    "    ax.set_title(site_name, fontsize=24, pad=10)\n",
    "    ax.set_ylim(0, 0.14)\n",
    "    ax.set_xticks(months)\n",
    "    ax.set_xticklabels(month_names, fontsize=14)\n",
    "    ax.tick_params(axis='y', labelsize=14)\n",
    "    ax.legend(fontsize=12, loc='upper right')\n",
    "\n",
    "# Global labels\n",
    "fig.supylabel(r'$v_{d}$ (cm $s^{-1}$)', fontsize=28)\n",
    "fig.supxlabel('Month', fontsize=28)\n",
    "plt.tight_layout(rect=[0.03, 0.03, 1, 0.98])\n",
    "\n",
    "# Final Centralized DataFrame\n",
    "central_df = pd.concat(all_data_list, ignore_index=True)\n",
    "plt.savefig('model_evaluation_om_0.03_soilc_1000.png',dpi = 300, format = 'png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7c9d26c-847d-41f1-9629-1ce046678c5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Plot San Jacinto site\n",
    "with xr.open_dataset('/home/ayhwong/fs12_cesm_storage/CTSM/rundir/archive/om_0.03_soilc_1000_psiopt_0.3/lnd/hist/soil_water_shifted.nc') as sm_xr:\n",
    "    h2osoi = sm_xr.H2OSOI\n",
    "    rain = sm_xr.RAIN_FROM_ATM\n",
    "    lon_vec = sm_xr.lon.values\n",
    "    lat_vec = sm_xr.lat.values\n",
    "\n",
    "ij_forest = get_ij(360-116.79,33.81,lon_vec,lat_vec)\n",
    "ij_desert = get_ij(360-116.45,34.15,lon_vec,lat_vec)\n",
    "\n",
    "weights = xr.DataArray(np.array([2,4,6]), coords=[h2osoi.levsoi[0:3]], dims = 'levsoi')\n",
    "\n",
    "h2osoi_forest = h2osoi[:,0:3,ij_forest['j'],ij_forest['i']].weighted(weights).mean(dim = 'levsoi').groupby(h2osoi.time.dt.month).mean(dim='time')\n",
    "h2osoi_desert = h2osoi[:,0:3,ij_desert['j'],ij_desert['i']].weighted(weights).mean(dim = 'levsoi').groupby(h2osoi.time.dt.month).mean(dim='time')\n",
    "\n",
    "rain_forest = rain[:,ij_forest['j'],ij_forest['i']].groupby(h2osoi.time.dt.month).mean(dim='time')\n",
    "rain_desert = rain[:,ij_desert['j'],ij_desert['i']].groupby(h2osoi.time.dt.month).mean(dim='time')\n",
    "\n",
    "days_in_month = np.array([calendar.monthrange(2001, m)[1] for m in range(1, 13)])\n",
    "seconds_per_month = days_in_month * 86400\n",
    "\n",
    "# 2. Convert rainfall from mm/s to total mm per month\n",
    "# This performs element-wise multiplication (Jan * Jan_seconds, etc.)\n",
    "rain_forest_mm = rain_forest * seconds_per_month\n",
    "rain_desert_mm = rain_desert * seconds_per_month\n",
    "\n",
    "# Setup month labels\n",
    "months = np.arange(12) + 1\n",
    "month_names = [calendar.month_name[m][:3] for m in range(1, 13)]\n",
    "\n",
    "fig, ax1 = plt.subplots(figsize=(10, 6))\n",
    "\n",
    "# --- Plot Soil Moisture (Primary Y-Axis) ---\n",
    "color_sm = 'tab:brown'\n",
    "ax1.set_xlabel('Month', fontsize=18)\n",
    "ax1.set_ylabel('Soil Moisture ($m^3 m^{-3}$)', color=color_sm, fontsize=18)\n",
    "\n",
    "ax1.plot(months, h2osoi_forest, color=color_sm, label='SM: Forest', lw=3, marker='o', markersize=8)\n",
    "ax1.plot(months, h2osoi_desert, color=color_sm, label='SM: Desert', lw=3, marker='s', markersize=8, ls='--')\n",
    "ax1.axhline(y = 0.103, color=color_sm, ls=':', label='SM Threshold')\n",
    "\n",
    "\n",
    "ax1.tick_params(axis='y', labelcolor=color_sm, labelsize=16)\n",
    "ax1.tick_params(axis='x', labelsize=16)\n",
    "ax1.set_ylim(0.05, 0.2)\n",
    "\n",
    "# --- Plot Rainfall (Secondary Y-Axis) ---\n",
    "ax2 = ax1.twinx() \n",
    "color_rain = 'tab:blue'\n",
    "ax2.set_ylabel('Monthly Rainfall (mm)', color=color_rain, fontsize=18)\n",
    "\n",
    "ax2.plot(months, rain_forest_mm, color=color_rain, label='Rain: Forest', lw=3, marker='o', markersize=8)\n",
    "ax2.plot(months, rain_desert_mm, color=color_rain, label='Rain: Desert', lw=3, marker='s', markersize=8, ls='--')\n",
    "\n",
    "ax2.tick_params(axis='y', labelcolor=color_rain, labelsize=16)\n",
    "ax2.set_ylim(0, max(rain_forest_mm.max(), rain_desert_mm.max()) * 1.2)\n",
    "\n",
    "# --- Formatting ---\n",
    "ax1.set_xticks(months)\n",
    "ax1.set_xticklabels(month_names)\n",
    "\n",
    "# Merge legends and remove frame\n",
    "lines1, labels1 = ax1.get_legend_handles_labels()\n",
    "lines2, labels2 = ax2.get_legend_handles_labels()\n",
    "ax1.legend(lines1 + lines2, labels1 + labels2, loc='upper right', fontsize=16, frameon=False)\n",
    "\n",
    "plt.tight_layout()\n",
    "#plt.show()\n",
    "\n",
    "plt.savefig('theta_calif.png', dpi = 300, format = 'png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75b553b6-bdbe-43d7-9fee-7ce8c1ec5087",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_dict = {\n",
    "    'access_hot':xr.open_dataset('/home/ayhwong/fs12_cesm_storage/CTSM/rundir/archive/ACCESS-ESM1-5_hot_new/lnd/hist/output_shifted.nc'),\n",
    "    'access_median':xr.open_dataset('/home/ayhwong/fs12_cesm_storage/CTSM/rundir/archive/ACCESS-ESM1-5_median_new/lnd/hist/output_shifted.nc'),\n",
    "    'giss_hot':xr.open_dataset('/home/ayhwong/fs12_cesm_storage/CTSM/rundir/archive/GISS-E2-1-G_hot_new/lnd/hist/output_shifted.nc'),\n",
    "    'giss_median':xr.open_dataset('/home/ayhwong/fs12_cesm_storage/CTSM/rundir/archive/GISS-E2-1-G_median_new/lnd/hist/output_shifted.nc'),\n",
    "    'inm_hot':xr.open_dataset('/home/ayhwong/fs12_cesm_storage/CTSM/rundir/archive/INM-CM4-8_hot_new/lnd/hist/output_shifted.nc'),\n",
    "    'inm_median':xr.open_dataset('/home/ayhwong/fs12_cesm_storage/CTSM/rundir/archive/INM-CM4-8_median_new/lnd/hist/output_shifted.nc'),\n",
    "    'sam_hot':xr.open_dataset('/home/ayhwong/fs12_cesm_storage/CTSM/rundir/archive/SAM0-UNICON_hot_new/lnd/hist/output_shifted.nc'),\n",
    "    'sam_median':xr.open_dataset('/home/ayhwong/fs12_cesm_storage/CTSM/rundir/archive/SAM0-UNICON_median_new/lnd/hist/output_shifted.nc')\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1eb14c9-fc74-47d1-9b53-db9e5fb5ae49",
   "metadata": {},
   "outputs": [],
   "source": [
    "vd_dict = {}\n",
    "for key in list(output_dict.keys()):\n",
    "    vd_dict[key] = cal_yearly_var(output_dict[key],'DRYDEPV_H2')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea4caccb-9574-4800-b84b-316dd8f4498e",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_trend = pd.DataFrame({\n",
    "        'Pattern':[],\n",
    "        'Climate_sensitivity':[],\n",
    "        'Trend (%/yr)':[],\n",
    "        'Trend_std':[],\n",
    "        'Temperature sensivitity (%/K)':[],\n",
    "        'Temperature sensivitity (cm/s/K)':[],\n",
    "        'Start_average':[],\n",
    "        'End_average':[]\n",
    "    }\n",
    ")\n",
    "\n",
    "window_size = 10\n",
    "\n",
    "window = np.ones(window_size) / window_size\n",
    "\n",
    "for key, ts in vd_dict.items():\n",
    "    ts_smooth = np.convolve(ts,window,mode='valid')\n",
    "    slope, intercept, r_value, p_value, stderr = scipy.stats.linregress(ts.year.values, ts.values)\n",
    "    if 'hot' in key:\n",
    "        delta_t = dt_hot.values[()]\n",
    "    else:\n",
    "        delta_t = dt_median.values[()]\n",
    "\n",
    "    df_trend.loc[len(df_trend)] = {\n",
    "        'Pattern':key.split('_')[0],\n",
    "        'Climate_sensitivity':key.split('_')[1],\n",
    "        'Trend (%/yr)':slope*100/ts_smooth[0],\n",
    "        'Trend_std':stderr*100/ts_smooth[0],\n",
    "        'Temperature sensivitity (%/K)':(slope*70)*100/ts_smooth[0]/delta_t,\n",
    "        'Temperature sensivitity (cm/s/K)':(slope*70)/delta_t,\n",
    "        'Start_average':ts_smooth[0],\n",
    "        'End_average':ts_smooth[-1]\n",
    "\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbc7aa2a-ec20-4be1-99fe-7054f103294b",
   "metadata": {},
   "outputs": [],
   "source": [
    "window_size = 10\n",
    "\n",
    "window = np.ones(window_size) / window_size\n",
    "\n",
    "fig,ax = plt.subplots(figsize = (9,6))\n",
    "\n",
    "for key, ts in vd_dict.items():\n",
    "    ax.plot(np.arange(2025,2096),np.convolve(ts,window,mode='valid'),label = key.upper())\n",
    "\n",
    "ax.set_xlabel('Year', fontsize = 20)\n",
    "ax.set_ylabel(r'$v_{d}$ (cm $s^{-1}$)', fontsize = 20)\n",
    "ax.tick_params(axis = 'both', labelsize = 16)\n",
    "\n",
    "ax.legend(title = 'Climate')\n",
    "plt.tight_layout()\n",
    "\n",
    "ax.grid(True)\n",
    "fig.savefig('global_2025_2095_running_mean.png',dpi =300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9943bdf6-acab-418b-8a02-ed5eb0a9224f",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Conduct Shapley Analysis\n",
    "\n",
    "shapley_dict = {}\n",
    "for key, value in output_dict.items():\n",
    "\n",
    "    ### Initialize\n",
    "    dvd_ra_annual = []\n",
    "    dvd_rb_annual = []\n",
    "    dvd_rs_annual = []\n",
    "    dvd_annual = []\n",
    "\n",
    "    ### prepare data\n",
    "    #clim = 'SAM0_hot'\n",
    "\n",
    "    shapley = shapley_analysis_resistance(value)\n",
    "    dvd = get_monthly_mean_year_range(value,'DRYDEPV_H2','2090-01-01','2100-01-01').mean(\"month\") - get_monthly_mean_year_range(value,'DRYDEPV_H2','2020-01-01','2030-01-01').mean(\"month\")\n",
    "\n",
    "    ### Band data by latitudes\n",
    "    for lat_band_start in np.arange(-60,70,10):\n",
    "        lat_slice = slice(lat_band_start,lat_band_start+10)\n",
    "        #print(shapley_vd_access_hot['dvd_ra'].sel(lat = lat_slice).mean([\"lat\",'lon','season']).values[()])\n",
    "        #print(lat_band_start)\n",
    "        \n",
    "        ### Filter out extreme values\n",
    "        dvd_ra_annual.append(filter_threshold(shapley['dvd_ra']).sel(lat = lat_slice).mean().values[()])\n",
    "        dvd_rb_annual.append(filter_threshold(shapley['dvd_rb']).sel(lat = lat_slice).mean().values[()])\n",
    "        dvd_rs_annual.append(filter_threshold(shapley['dvd_rs']).sel(lat = lat_slice).mean().values[()])\n",
    "        dvd_annual.append(dvd.sel(lat = lat_slice).mean().values[()])\n",
    "\n",
    "    out_df = pd.DataFrame({\n",
    "        'lat_band': np.arange(-55,75,10),\n",
    "        'dvd_ra':np.asarray(dvd_ra_annual),\n",
    "        'dvd_rb':np.asarray(dvd_rb_annual),\n",
    "        'dvd_rs':np.asarray(dvd_rs_annual),\n",
    "        'dvd_annual':np.asarray(dvd_annual)\n",
    "    })\n",
    "\n",
    "    shapley_dict[key] = out_df\n",
    "\n",
    "# Pivot and group by climate sensitivity categories\n",
    "plot_df_dict = {}\n",
    "for clim_sens in ['median', 'hot']:\n",
    "    dvd_ra_list, dvd_rb_list, dvd_rs_list, dvd_annual_list = [], [], [], []\n",
    "    \n",
    "    for key, df in shapley_dict.items():\n",
    "        if clim_sens in key:\n",
    "            dvd_ra_list.append(list(df['dvd_ra']))\n",
    "            dvd_rb_list.append(list(df['dvd_rb']))\n",
    "            dvd_rs_list.append(list(df['dvd_rs']))\n",
    "            dvd_annual_list.append(list(df['dvd_annual']))\n",
    "    \n",
    "    plot_df_dict[clim_sens] = pd.DataFrame({\n",
    "        'lat_band': np.arange(-55, 75, 10),\n",
    "        'dvd_ra_min': np.asarray(dvd_ra_list).min(axis=0),\n",
    "        'dvd_ra_max': np.asarray(dvd_ra_list).max(axis=0),\n",
    "        'dvd_ra_mean': np.asarray(dvd_ra_list).mean(axis=0),\n",
    "        'dvd_rb_min': np.asarray(dvd_rb_list).min(axis=0),\n",
    "        'dvd_rb_max': np.asarray(dvd_rb_list).max(axis=0),\n",
    "        'dvd_rb_mean': np.asarray(dvd_rb_list).mean(axis=0),\n",
    "        'dvd_rs_min': np.asarray(dvd_rs_list).min(axis=0),\n",
    "        'dvd_rs_max': np.asarray(dvd_rs_list).max(axis=0),\n",
    "        'dvd_rs_mean': np.asarray(dvd_rs_list).mean(axis=0),\n",
    "        'dvd_annual_min': np.asarray(dvd_annual_list).min(axis=0),\n",
    "        'dvd_annual_max': np.asarray(dvd_annual_list).max(axis=0),\n",
    "        'dvd_annual_mean': np.asarray(dvd_annual_list).mean(axis=0),\n",
    "    })\n",
    "\n",
    "# Calculate global area-weighted means\n",
    "dvd_ra_global = {}\n",
    "dvd_rb_global = {}\n",
    "dvd_rs_global = {}\n",
    "\n",
    "for key, val in data_dict.items():\n",
    "    shapley = shapley_analysis_resistance(val)\n",
    "    \n",
    "    weights = np.cos(np.deg2rad(shapley['dvd_ra'].lat))\n",
    "    weights = weights / weights.sum()\n",
    "    weights_broadcast = weights.broadcast_like(shapley['dvd_ra'])\n",
    "\n",
    "    dvd_ra_global[key] = filter_threshold(shapley['dvd_ra']).weighted(weights_broadcast).mean().values[()]\n",
    "    dvd_rb_global[key] = filter_threshold(shapley['dvd_rb']).weighted(weights_broadcast).mean().values[()]\n",
    "    dvd_rs_global[key] = filter_threshold(shapley['dvd_rs']).weighted(weights_broadcast).mean().values[()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c387110-ca7d-4a39-a38b-741bc3e9dbee",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Visualize Shapley analysis result\n",
    "\n",
    "# =====================================================================\n",
    "# 2. MASTER PLOT GENERATION (2 ROWS)\n",
    "# =====================================================================\n",
    "\n",
    "# Initialize a grid layout where row 2 handles the global bar chart\n",
    "fig = plt.figure(figsize=(15, 12))\n",
    "gs = fig.add_gridspec(nrows=2, ncols=2, height_ratios=[1, 1])\n",
    "\n",
    "ax_bar = fig.add_subplot(gs[0,0:2]) # Span across both columns\n",
    "ax_lat0 = fig.add_subplot(gs[1,0])\n",
    "ax_lat1 = fig.add_subplot(gs[1,1])\n",
    "#ax_bar = fig.add_subplot(gs[1,0:2]) # Span across both columns\n",
    "\n",
    "lat_axes = [ax_lat0, ax_lat1]\n",
    "colors = {'soil': 'red', 'snow': 'blue', 'canopy': 'green', 'total': 'black'}\n",
    "\n",
    "# --- Row 1: Latitudinal Band Plots ---\n",
    "for i, plot_df in enumerate(plot_df_dict.values()):\n",
    "    # Soil\n",
    "    lat_axes[i].plot(plot_df['lat_band'], plot_df['dvd_rb_mean'], label='Soil', color=colors['soil'])\n",
    "    lat_axes[i].fill_between(plot_df['lat_band'], plot_df['dvd_rb_min'], plot_df['dvd_rb_max'], alpha=0.3, color=colors['soil'])\n",
    "    \n",
    "    # Snow\n",
    "    lat_axes[i].plot(plot_df['lat_band'], plot_df['dvd_rs_mean'], label='Snow', color=colors['snow'])\n",
    "    lat_axes[i].fill_between(plot_df['lat_band'], plot_df['dvd_rs_min'], plot_df['dvd_rs_max'], alpha=0.3, color=colors['snow'])\n",
    "    \n",
    "    # Canopy\n",
    "    lat_axes[i].plot(plot_df['lat_band'], plot_df['dvd_ra_mean'], label='Canopy', color=colors['canopy'])\n",
    "    lat_axes[i].fill_between(plot_df['lat_band'], plot_df['dvd_ra_min'], plot_df['dvd_ra_max'], alpha=0.3, color=colors['canopy'])\n",
    "    \n",
    "    # Total Delta Vd\n",
    "    lat_axes[i].plot(plot_df['lat_band'], plot_df['dvd_annual_mean'], label=r'Total $\\Delta$$v_{d}$', color=colors['total'])\n",
    "    lat_axes[i].fill_between(plot_df['lat_band'], plot_df['dvd_annual_min'], plot_df['dvd_annual_max'], alpha=0.3, color=colors['total'])\n",
    "    \n",
    "    # Styling Row 1\n",
    "    lat_axes[i].axhline(y=0, color='black')\n",
    "    lat_axes[i].grid(True)\n",
    "    lat_axes[i].set_ylim(-0.002, 0.006)\n",
    "    lat_axes[i].tick_params(axis='both', labelsize=12)\n",
    "    lat_axes[i].set_xlabel('Latitude (degree north)', fontsize=14)\n",
    "\n",
    "ax_lat0.set_ylabel(r'Contribution to $\\Delta$$v_d$ (cm $s^{-1}$)', fontsize=14)\n",
    "ax_lat0.set_title('Median climate sensitivity', fontsize=20)\n",
    "ax_lat1.set_title('High climate sensitivity', fontsize=20)\n",
    "ax_lat1.legend(fontsize=12)\n",
    "\n",
    "\n",
    "# --- Row 2: Global Area-Weighted Mean Bar Chart ---\n",
    "x_indices = np.arange(len(dvd_rb_global))\n",
    "\n",
    "ax_bar.bar(x_indices - 0.25, list(dvd_rb_global.values()), width=0.25, label='Soil', color=colors['soil'])\n",
    "ax_bar.bar(x_indices, list(dvd_rs_global.values()), bottom=np.array(list(dvd_rb_global.values())), width=0.25, label='Snow', color=colors['snow'])\n",
    "ax_bar.bar(x_indices + 0.25, list(dvd_ra_global.values()), bottom=np.array(list(dvd_rb_global.values())) + np.array(list(dvd_rs_global.values())), width=0.25, label='Canopy', color=colors['canopy'])\n",
    "\n",
    "# Styling Row 2\n",
    "ax_bar.grid(True)\n",
    "ax_bar.set_xticks(x_indices)\n",
    "ax_bar.set_xticklabels(list(dvd_rb_global.keys()))\n",
    "ax_bar.tick_params(axis='x', labelsize=12, labelrotation=45, length=10)\n",
    "ax_bar.tick_params(axis='y', labelsize=12)\n",
    "ax_bar.set_xlabel('Climate', fontsize=16)\n",
    "ax_bar.set_ylabel(r'Contribution to $\\Delta$$v_d$ (cm $s^{-1}$)', fontsize=16)\n",
    "ax_bar.set_title('Global area-weighted mean', fontsize=20)\n",
    "ax_bar.legend(bbox_to_anchor=(1.01, 0.65), loc='upper left', fontsize=14)\n",
    "\n",
    "ax_bar.text(-0.05, -0.15, 'a)', transform=ax_bar.transAxes, \n",
    "             fontsize=22, fontweight='bold', va='bottom', ha='right')\n",
    "\n",
    "# Label for Row 2 (Placed on the leftmost subplot of the row)\n",
    "ax_lat0.text(-0.12, -0.15, 'b)', transform=ax_lat0.transAxes, \n",
    "             fontsize=22, fontweight='bold', va='bottom', ha='right')\n",
    "\n",
    "# Save Master Figure\n",
    "plt.tight_layout()\n",
    "plt.savefig('combined_sh_plots.png', dpi=300, format='png')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "311d37aa-4a20-4486-8aa2-cf512306aeba",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Look at CLM land surface variables\n",
    "### Load data\n",
    "\n",
    "clm_data_dict = {\n",
    "    'ACCESS_hot': xr.open_dataset('/home/ayhwong/fs12_cesm_storage/CTSM/rundir/archive/ACCESS-ESM1-5_hot_new/lnd/hist/output_clm_shifted.nc'),\n",
    "    'GISS_hot': xr.open_dataset('/home/ayhwong/fs12_cesm_storage/CTSM/rundir/archive/GISS-E2-1-G_hot_new/lnd/hist/output_clm_shifted.nc'),\n",
    "    'INM_hot': xr.open_dataset('/home/ayhwong/fs12_cesm_storage/CTSM/rundir/archive/INM-CM4-8_hot_new/lnd/hist/output_clm_shifted.nc'),\n",
    "    'SAM0_hot': xr.open_dataset('/home/ayhwong/fs12_cesm_storage/CTSM/rundir/archive/SAM0-UNICON_hot_new/lnd/hist/output_clm_shifted.nc'),\n",
    "    'ACCESS_median': xr.open_dataset('/home/ayhwong/fs12_cesm_storage/CTSM/rundir/archive/ACCESS-ESM1-5_median_new/lnd/hist/output_clm_shifted.nc'),\n",
    "    'GISS_median': xr.open_dataset('/home/ayhwong/fs12_cesm_storage/CTSM/rundir/archive/GISS-E2-1-G_median_new/lnd/hist/output_clm_shifted.nc'),    \n",
    "    'INM_median': xr.open_dataset('/home/ayhwong/fs12_cesm_storage/CTSM/rundir/archive/INM-CM4-8_median_new/lnd/hist/output_clm_shifted.nc'),    \n",
    "    'SAM0_median': xr.open_dataset('/home/ayhwong/fs12_cesm_storage/CTSM/rundir/archive/SAM0-UNICON_median_new/lnd/hist/output_clm_shifted.nc')}\n",
    "\n",
    "clm_df_dict = {}\n",
    "lat_band = np.arange(-55, 75, 10)\n",
    "\n",
    "for clim_sens in ['median', 'hot']:\n",
    "    dlai_list = []\n",
    "    dh2osoi_list = []\n",
    "    dtsoi_list = []\n",
    "    dsnow_list = []\n",
    "    dom_act_list = []\n",
    "    \n",
    "    for key, clm_xr in clm_data_dict.items():\n",
    "        # Check if the climate sensitivity substring matches the dictionary key\n",
    "        if clim_sens in key:\n",
    "            # 1. Compute delta values (forcing immediate evaluation or lightweight xarray graphs)\n",
    "            dlai = (get_seasonal_mean_year_range(clm_xr, 'ELAI', '2091-01-01', '2100-01-01').mean(dim='season') - \n",
    "                    get_seasonal_mean_year_range(clm_xr, 'ELAI', '2021-01-01', '2030-01-01').mean(dim='season'))\n",
    "            \n",
    "            dsnow = (get_seasonal_mean_year_range(clm_xr, 'SNOWDP', '2091-01-01', '2100-01-01').mean(dim='season') - \n",
    "                     get_seasonal_mean_year_range(clm_xr, 'SNOWDP', '2021-01-01', '2030-01-01').mean(dim='season'))\n",
    "            \n",
    "            dh2osoi = (get_seasonal_mean_year_range_soil_surface(clm_xr, 'H2OSOI', '2091-01-01', '2100-01-01').mean(dim='season') - \n",
    "                       get_seasonal_mean_year_range_soil_surface(clm_xr, 'H2OSOI', '2021-01-01', '2030-01-01').mean(dim='season'))\n",
    "            \n",
    "            dtsoi = (get_seasonal_mean_year_range_soil_surface(clm_xr, 'TSOI', '2091-01-01', '2100-01-01').mean(dim='season') - \n",
    "                     get_seasonal_mean_year_range_soil_surface(clm_xr, 'TSOI', '2021-01-01', '2030-01-01').mean(dim='season'))\n",
    "            \n",
    "            dom_act = (get_seasonal_mean_year_range_soil_surface(clm_xr, 'SOM_ACT_C_vr', '2091-01-01', '2100-01-01').mean(dim='season') - \n",
    "                       get_seasonal_mean_year_range_soil_surface(clm_xr, 'SOM_ACT_C_vr', '2021-01-01', '2030-01-01').mean(dim='season'))\n",
    "            \n",
    "            # Apply threshold filtering\n",
    "            dsnow = filter_threshold(dsnow, 5)\n",
    "            \n",
    "            # Temporary storage for this specific model/key run\n",
    "            current_dlai, current_dh2osoi, current_dtsoi, current_dsnow, current_dom_act = [], [], [], [], []\n",
    "            \n",
    "            # 2. Extract latitudinal values safely\n",
    "            for x in lat_band:\n",
    "                lat_slice = slice(x - 5, x + 5)\n",
    "                \n",
    "                # Using .values[()] extracts the raw scalar float safely from the xarray object\n",
    "                current_dlai.append(dlai.sel(lat=lat_slice).mean().values[()])\n",
    "                current_dh2osoi.append(dh2osoi.sel(lat=lat_slice).mean().values[()])\n",
    "                current_dtsoi.append(dtsoi.sel(lat=lat_slice).mean().values[()])\n",
    "                current_dsnow.append(dsnow.sel(lat=lat_slice).mean().values[()])\n",
    "                current_dom_act.append((dom_act.sel(lat=lat_slice).mean().values[()]) / 1000.0)\n",
    "            \n",
    "            # Append complete lists to Master Tracking Lists\n",
    "            dlai_list.append(current_dlai)\n",
    "            dh2osoi_list.append(current_dh2osoi)\n",
    "            dtsoi_list.append(current_dtsoi)\n",
    "            dsnow_list.append(current_dsnow)\n",
    "            dom_act_list.append(current_dom_act)\n",
    "            \n",
    "    # --- BUG GUARD ---\n",
    "    # If a sensitivity string ('median', 'hot', 'SSP') isn't found in clm_data_dict keys,\n",
    "    # the list stays empty. We check for this to prevent numpy from crashing on empty structures.\n",
    "    if len(dlai_list) == 0:\n",
    "        print(f\"Warning: No keys matched sensitivity criteria: '{clim_sens}'\")\n",
    "        continue\n",
    "\n",
    "    clm_df_dict[clim_sens] = pd.DataFrame({\n",
    "        'lat_band': lat_band,\n",
    "        'dlai_min': np.asarray(dlai_list).min(axis=0),\n",
    "        'dlai_max': np.asarray(dlai_list).max(axis=0),\n",
    "        'dlai_mean': np.asarray(dlai_list).mean(axis=0),\n",
    "        'dsnow_min': np.asarray(dsnow_list).min(axis=0),\n",
    "        'dsnow_max': np.asarray(dsnow_list).max(axis=0),\n",
    "        'dsnow_mean': np.asarray(dsnow_list).mean(axis=0),\n",
    "        'dh2osoi_min': np.asarray(dh2osoi_list).min(axis=0),\n",
    "        'dh2osoi_max': np.asarray(dh2osoi_list).max(axis=0),\n",
    "        'dh2osoi_mean': np.asarray(dh2osoi_list).mean(axis=0),\n",
    "        'dtsoi_min': np.asarray(dtsoi_list).min(axis=0),\n",
    "        'dtsoi_max': np.asarray(dtsoi_list).max(axis=0),\n",
    "        'dtsoi_mean': np.asarray(dtsoi_list).mean(axis=0),\n",
    "        'dom_act_min': np.asarray(dom_act_list).min(axis=0),\n",
    "        'dom_act_max': np.asarray(dom_act_list).max(axis=0),\n",
    "        'dom_act_mean': np.asarray(dom_act_list).mean(axis=0)\n",
    "    })\n",
    "\n",
    "# =====================================================================\n",
    "# MASTER PLOT SETUP (UNIFORM PANEL SIZES & CENTERED TOP ROW)\n",
    "# =====================================================================\n",
    "# We create a 2x6 grid framework.\n",
    "# To make them identical sizes, every single plot will span exactly 2 columns wide.\n",
    "# Row 1 (2 plots): Starts at column 1 (centered), skipping the outer edges.\n",
    "# Row 2 (3 plots): Spans cleanly from 0 to 6.\n",
    "fig = plt.figure(figsize=(18, 10))\n",
    "gs = gridspec.GridSpec(nrows=2, ncols=6, height_ratios=[1,1])\n",
    "\n",
    "row1_vars = {\n",
    "    'dlai': r'$\\Delta$LAI ($m^2$ $m^{-2}$)',\n",
    "    'dsnow': r'$\\Delta$$l$ (m)'\n",
    "}\n",
    "\n",
    "row2_vars = {\n",
    "    'dom_act': r'$\\Delta$$OM_{act}$ (kgC $m^{-3}$)',\n",
    "    'dh2osoi': r'$\\Delta$$\\theta$ ($m^3$ $m^{-3}$)',\n",
    "    'dtsoi': r'$\\Delta$$T_{soil}$ (K)'\n",
    "}\n",
    "\n",
    "style_dict = {\n",
    "    'hot': {'color': 'red', 'label': 'High climate sensitivity'},\n",
    "    'median': {'color': 'blue', 'label': 'Median climate sensitivity'},\n",
    "    'SSP': {'color': 'purple', 'label': 'SSP scenario'}\n",
    "}\n",
    "\n",
    "# --- ROW 1 GENERATION (2 Center-Aligned, Uniform Panels) ---\n",
    "for i, (var, label) in enumerate(row1_vars.items()):\n",
    "    # Every panel spans exactly 2 columns. \n",
    "    # By starting at column 1 instead of 0, the two panels occupy [1:3] and [3:5],\n",
    "    # leaving exactly 1 empty column on the left and right margins to center it!\n",
    "    ax = fig.add_subplot(gs[0, (i * 2) + 1 : (i * 2) + 3])\n",
    "    \n",
    "    if i == 0:\n",
    "        ax.text(-0.18, 0.05, 'a)', transform=ax.transAxes, \n",
    "                fontsize=22, fontweight='bold', va='bottom', ha='right')\n",
    "        \n",
    "    var_mean, var_min, var_max = f'{var}_mean', f'{var}_min', f'{var}_max'\n",
    "    \n",
    "    for clim_sens, plot_df in clm_df_dict.items():\n",
    "        style = style_dict.get(clim_sens, {'color': 'black', 'label': clim_sens})\n",
    "        ax.plot(plot_df['lat_band'], plot_df[var_mean], label=style['label'], color=style['color'])\n",
    "        ax.fill_between(plot_df['lat_band'], plot_df[var_min], plot_df[var_max], alpha=0.3, color=style['color'])\n",
    "        \n",
    "    ax.grid(True)\n",
    "    ax.axhline(y=0, color='black')\n",
    "    ax.tick_params(axis='both', labelsize=14)\n",
    "    ax.set_xlabel('Latitude (degree north)', fontsize=14)\n",
    "    ax.set_ylabel(label, fontsize=14)\n",
    "    \n",
    "# --- ROW 2 GENERATION (3 Uniform Panels) ---\n",
    "for i, (var, label) in enumerate(row2_vars.items()):\n",
    "    # Spans 2 columns each: [0:2], [2:4], and [4:6]\n",
    "    ax = fig.add_subplot(gs[1, i * 2 : (i + 1) * 2])\n",
    "    \n",
    "    if i == 0:\n",
    "        ax.text(-0.18, 0, 'b)', transform=ax.transAxes, \n",
    "                fontsize=22, fontweight='bold', va='bottom', ha='right')\n",
    "\n",
    "    var_mean, var_min, var_max = f'{var}_mean', f'{var}_min', f'{var}_max'\n",
    "    \n",
    "    for clim_sens, plot_df in clm_df_dict.items():\n",
    "        style = style_dict.get(clim_sens, {'color': 'black', 'label': clim_sens})\n",
    "        ax.plot(plot_df['lat_band'], plot_df[var_mean], label=style['label'], color=style['color'])\n",
    "        ax.fill_between(plot_df['lat_band'], plot_df[var_min], plot_df[var_max], alpha=0.3, color=style['color'])\n",
    "        \n",
    "    ax.grid(True)\n",
    "    ax.axhline(y=0, color='black')\n",
    "    ax.tick_params(axis='both', labelsize=14)\n",
    "    ax.set_xlabel('Latitude (degree north)', fontsize=14)\n",
    "    ax.set_ylabel(label, fontsize=14)\n",
    "    \n",
    "    if i == 2:\n",
    "        ax.legend(fontsize=12, loc='best')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('combined_environmental_variables_by_lat.png', dpi=300, format='png')\n",
    "plt.show()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33b4ad5b-e0ac-4cb4-8427-f9b6584e71cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Conduct Shapley Analysis for Rbio\n",
    "\n",
    "rbio_shapley_dict = {}\n",
    "#for i,keys in enumerate(data_dict):\n",
    "for key, value in output_dict.items():\n",
    "\n",
    "    ### Initialize\n",
    "    dgb_kb_annual = []\n",
    "    dgb_fm_annual = []\n",
    "    dgb_tot_est_annual = []\n",
    "    dgb_exact_annual = []\n",
    "\n",
    "    ### prepare data\n",
    "\n",
    "    shapley = shapley_analysis_monthly_rbio(value)\n",
    "\n",
    "    ### Band data by latitudes\n",
    "    for lat_band_start in np.arange(-60,70,10):\n",
    "        lat_slice = slice(lat_band_start,lat_band_start+10)\n",
    "        #print(shapley_vd_access_hot['dvd_ra'].sel(lat = lat_slice).mean([\"lat\",'lon','season']).values[()])\n",
    "        #print(lat_band_start)\n",
    "        \n",
    "        ### Filter out extreme values\n",
    "        dgb_kb_annual.append(filter_threshold(shapley['dgb_kb']).sel(lat = lat_slice).mean().values[()])\n",
    "        dgb_fm_annual.append(filter_threshold(shapley['dgb_fm']).sel(lat = lat_slice).mean().values[()])\n",
    "        dgb_tot_est_annual.append(filter_threshold(shapley['dgb_tot_est']).sel(lat = lat_slice).mean().values[()])\n",
    "        dgb_exact_annual.append(filter_threshold(shapley['dgb_exact']).sel(lat = lat_slice).mean().values[()])\n",
    "\n",
    "    out_df = pd.DataFrame({\n",
    "        'lat_band': np.arange(-55,75,10),\n",
    "        'dgb_kb_annual':np.asarray(dgb_kb_annual),\n",
    "        'dgb_fm_annual':np.asarray(dgb_fm_annual),\n",
    "        'dgb_tot_est_annual':np.asarray(dgb_tot_est_annual),\n",
    "        'dgb_exact_annual':np.asarray(dgb_exact_annual)\n",
    "    })\n",
    "\n",
    "    rbio_shapley_dict[key] = out_df\n",
    "\n",
    "plot_df_rbio_dict = {}\n",
    "\n",
    "for clim_sens in ['median','hot']:\n",
    "    dgb_kb_annual_list = []\n",
    "    dgb_fm_annual_list = []\n",
    "    dgb_tot_est_annual_list = []\n",
    "    dgb_exact_annual_list = []\n",
    "    \n",
    "    for key,df in rbio_shapley_dict.items():\n",
    "        if clim_sens in key:\n",
    "            dgb_kb_annual_list.append(list(df['dgb_kb_annual']* 100))\n",
    "            dgb_fm_annual_list.append(list(df['dgb_fm_annual']* 100))\n",
    "            dgb_tot_est_annual_list.append(list(df['dgb_tot_est_annual']* 100))\n",
    "            dgb_exact_annual_list.append(list(df['dgb_exact_annual']* 100) )\n",
    "    \n",
    "    plot_df_rbio_dict[clim_sens] = pd.DataFrame({\n",
    "        'lat_band': np.arange(-55,75,10),\n",
    "        'dgb_kb_min':np.asarray(dgb_kb_annual_list).min(axis = 0),\n",
    "        'dgb_kb_max':np.asarray(dgb_kb_annual_list).max(axis = 0),\n",
    "        'dgb_kb_mean':np.asarray(dgb_kb_annual_list).mean(axis = 0),\n",
    "        'dgb_fm_min':np.asarray(dgb_fm_annual_list).min(axis = 0),\n",
    "        'dgb_fm_max':np.asarray(dgb_fm_annual_list).max(axis = 0),\n",
    "        'dgb_fm_mean':np.asarray(dgb_fm_annual_list).mean(axis = 0),\n",
    "        'dgb_tot_min':np.asarray(dgb_tot_est_annual_list).min(axis = 0),\n",
    "        'dgb_tot_max':np.asarray(dgb_tot_est_annual_list).max(axis = 0),\n",
    "        'dgb_tot_mean':np.asarray(dgb_tot_est_annual_list).mean(axis = 0),\n",
    "        'dgb_exact_min':np.asarray(dgb_exact_annual_list).min(axis = 0),\n",
    "        'dgb_exact_max':np.asarray(dgb_exact_annual_list).max(axis = 0),\n",
    "        'dgb_exact_mean':np.asarray(dgb_exact_annual_list).mean(axis = 0),\n",
    "    })\n",
    "\n",
    "fig,axes = plt.subplots(figsize = (13,5), nrows = 1, ncols = 2)\n",
    "\n",
    "#plot_df = plot_df_dict['hot']\n",
    "for i, plot_df in enumerate(plot_df_rbio_dict.values()):\n",
    "    \n",
    "    axes[i].plot(plot_df['lat_band'],plot_df['dgb_kb_mean'], label = r'$\\Delta$$G_{soil,OM}$', color = 'brown')\n",
    "    axes[i].fill_between(plot_df['lat_band'], plot_df['dgb_kb_min'], plot_df['dgb_kb_max'], alpha = 0.3,\n",
    "                        color = 'brown')\n",
    "    axes[i].plot(plot_df['lat_band'],plot_df['dgb_fm_mean'], label = r'$\\Delta$$G_{soil,met}$', color = 'blue')\n",
    "    axes[i].fill_between(plot_df['lat_band'], plot_df['dgb_fm_min'], plot_df['dgb_fm_max'], alpha = 0.3, \n",
    "                         color = 'blue')\n",
    "    axes[i].axhline(y=0, color='black') \n",
    "\n",
    "    axes[i].grid(True)\n",
    "    axes[i].set_ylim(-4e-3,10e-3)\n",
    "    axes[i].tick_params(axis='x', labelsize=12)\n",
    "    axes[i].tick_params(axis='y', labelsize=12)\n",
    "    axes[i].set_xlabel('Latitude (degree north)', fontsize = 14)\n",
    "\n",
    "axes[0].set_ylabel(r'Contribution to $\\Delta$$G_{soil}$ (cm $s^{-1}$)', fontsize = 14)\n",
    "axes[0].set_title('Median climate sensitivity', fontsize = 24)\n",
    "axes[1].set_title('High climate sensitivity', fontsize = 24)\n",
    "axes[1].legend(#bbox_to_anchor=(1, 0.65),\n",
    "                 fontsize = 14)\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig('bio_vs_met.png',dpi = 300, format = 'png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddcf2c14-369e-4ce9-910e-648dd8033a1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Do soil dryness analysis\n",
    "\n",
    "results_list = []\n",
    "dist_maps_2025 = {}\n",
    "dist_maps_2095 = {}\n",
    "\n",
    "# Set the threshold for \"Threatened\" (distance to sm_opt is between 0 and 0.05)\n",
    "upper_limit = 0.05\n",
    "\n",
    "for key, item in tqdm(clm_data_dict.items()):\n",
    "    \n",
    "    # 1. Calculate seasonal means for both time periods\n",
    "    sm_2025 = get_seasonal_mean_year_range_soil_surface(item, 'H2OSOI', '2021-01-01', '2031-01-01')\n",
    "    sm_2095 = get_seasonal_mean_year_range_soil_surface(item, 'H2OSOI', '2090-01-01', '2100-01-01')\n",
    "    \n",
    "    # 2. Calculate distance from optimum (mean across seasons)\n",
    "    # Negative values = Dry; Positive values = Above optimum\n",
    "    dist_2025 = (sm_2025 - sm_opt).mean(dim='season')\n",
    "    dist_2095 = (sm_2095 - sm_opt).mean(dim='season')\n",
    "\n",
    "    # Store the 2D maps in dictionaries\n",
    "    dist_maps_2025[key] = dist_2025\n",
    "    dist_maps_2095[key] = dist_2095\n",
    "\n",
    "    # 3. Calculate Counts (using .item() to store as clean integers in the DataFrame)\n",
    "    # Dry: theta < sm_opt\n",
    "    dry_count_2025 = xr.where(dist_2025 < 0, 1, 0).sum().item()\n",
    "    dry_count_2095 = xr.where(dist_2095 < 0, 1, 0).sum().item()\n",
    "\n",
    "    # Threatened: 0 <= dist < upper_limit\n",
    "    # Note: Use '&' for element-wise 'and' logic in Xarray\n",
    "    threat_count_2025 = xr.where((dist_2025 >= 0) & (dist_2025 < upper_limit), 1, 0).sum().item()\n",
    "    threat_count_2095 = xr.where((dist_2095 >= 0) & (dist_2095 < upper_limit), 1, 0).sum().item()\n",
    "\n",
    "    # 4. Append results\n",
    "    results_list.append({\n",
    "        'Scenario': key,\n",
    "        'Dry_2025': dry_count_2025,\n",
    "        'Dry_2095': dry_count_2095,\n",
    "        'Threatened_2025': threat_count_2025,\n",
    "        'Threatened_2095': threat_count_2095\n",
    "    })\n",
    "\n",
    "# Create the final DataFrame\n",
    "df_counts = pd.DataFrame(results_list).set_index('Scenario')\n",
    "\n",
    "da_25 = xr.concat(dist_maps_2025.values(), dim='model')\n",
    "da_95 = xr.concat(dist_maps_2095.values(), dim='model')\n",
    "\n",
    "# 2. Apply logic across the whole 3D block (model, lat, lon)\n",
    "# 3. Sum across the 'model' dimension\n",
    "newly_dry_agreement = ((da_25 >= 0) & (da_95 < 0)).sum(dim='model')\n",
    "\n",
    "newly_threatened_mask = (da_25 >= 0.05) & (da_95 >= 0) & (da_95 < 0.05)\n",
    "\n",
    "# 3. Sum across models to get the agreement map\n",
    "agreement_newly_threatened = newly_threatened_mask.sum(dim='model')\n",
    "\n",
    "# 4. Apply the classification\n",
    "cat_dry = classify_confidence(newly_dry_agreement)\n",
    "cat_threat = classify_confidence(agreement_newly_threatened)\n",
    "\n",
    "# 5. plotting\n",
    "\n",
    "# 1. Updated Categorical Palette\n",
    "colors = ['#E0E0E0', '#FFD700', '#FF8C00', '#D80000']\n",
    "cmap = mcolors.ListedColormap(colors)\n",
    "bounds = [-0.5, 0.5, 1.5, 2.5, 3.5]\n",
    "norm = mcolors.BoundaryNorm(bounds, cmap.N)\n",
    "\n",
    "fig, axes = plt.subplots(2, 1, figsize=(12, 10), \n",
    "                         subplot_kw={'projection': ccrs.Robinson()})\n",
    "\n",
    "plot_configs = [\n",
    "    (cat_dry, r\"Drying to $\\theta$ < $\\theta_{opt}$\"),\n",
    "    (cat_threat, r\"Drying to $\\theta_{opt}$ < $\\theta$ < $\\theta_{opt}$ + 0.05\")\n",
    "]\n",
    "\n",
    "labels = ['a)', 'b)']\n",
    "\n",
    "for i, (data, title) in enumerate(plot_configs):\n",
    "    ax = axes[i]\n",
    "    im = data.sel(lat=slice(-60, 80)).plot(ax=ax, transform=ccrs.PlateCarree(),\n",
    "                                           cmap=cmap, norm=norm, add_colorbar=False)\n",
    "    ax.coastlines()\n",
    "    ax.set_title(title, fontsize=14, fontweight='bold')\n",
    "    \n",
    "    # Add 'a)' and 'b)' to the lower left\n",
    "    # x=0.02, y=0.05 places it slightly inside the bottom-left corner\n",
    "    ax.text(0.15, 0.05, labels[i], transform=ax.transAxes, \n",
    "            fontsize=18, fontweight='bold', va='bottom', ha='left',\n",
    "            bbox=dict(facecolor='white', alpha=0.5, edgecolor='none')) # Optional: adds a slight background for legibility\n",
    "\n",
    "# 2. Colorbar configuration\n",
    "# Applying tight_layout before manual colorbar/adjustments often helps prevent overlaps\n",
    "plt.tight_layout()\n",
    "\n",
    "cbar = fig.colorbar(im, ax=axes, orientation='horizontal', \n",
    "                    pad=0.08, shrink=0.6, aspect=40,\n",
    "                    ticks=[0,1,2,3], location='bottom')\n",
    "\n",
    "cbar.ax.set_xticklabels(['None (0)', 'Low (1-2)', 'Med (3-5)', 'High (6-8)'], \n",
    "                        fontsize=11)\n",
    "cbar.set_label('Confidence Tier (Number of Models)', fontsize=12, labelpad=10)\n",
    "\n",
    "# Adjust vertical spacing and bottom margin for the colorbar\n",
    "fig.subplots_adjust(bottom=0.15, hspace=0.1)\n",
    "\n",
    "plt.savefig('drying_transition.png', bbox_inches='tight', dpi=300, format='png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6db98f3-c7ae-4ef5-95c0-775390297010",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Now do SSP analysis\n",
    "\n",
    "data_dict_ssp = {\n",
    "    'full': xr.open_dataset('/home/ayhwong/fs12_cesm_storage/CTSM/rundir/archive/ssp_245_new/lnd/hist/output_shifted.nc'),\n",
    "    'no_lu': xr.open_dataset('/home/ayhwong/fs12_cesm_storage/CTSM/rundir/archive/ssp_245_no_lu_new//lnd/hist/output_shifted.nc'),\n",
    "    'no_co2': xr.open_dataset('/home/ayhwong/fs12_cesm_storage/CTSM/rundir/archive/ssp_245_lu_new//lnd/hist/output_shifted.nc')\n",
    "}\n",
    "\n",
    "### Time series plots\n",
    "\n",
    "vd_ts_dict = {}\n",
    "\n",
    "for key, value in data_dict_ssp.items():\n",
    "    vd_ts_dict[key] = cal_yearly_var(value, 'DRYDEPV_H2')\n",
    "\n",
    "landuse_ts_xr = xr.open_dataset('/net/fs01/data/cesm2/inputdata/lnd/clm2/surfdata_esmf/ctsm5.3.0/landuse.timeseries_0.9x1.25_SSP2-4.5_1850-2100_78pfts_c240908.nc')\n",
    "pct_crop = landuse_ts_xr.PCT_CROP.assign_coords({\n",
    "    'lsmlon':dvd.lon.values,\n",
    "    'lsmlat':dvd.lat.values   \n",
    "}).rename({\n",
    "    'lsmlon':'lon',\n",
    "    'lsmlat':'lat'\n",
    "})\n",
    "\n",
    "#delta_pct_crop_with_coord\n",
    "weights = np.cos(np.deg2rad(pct_crop.lat))\n",
    "weights.name = \"weights\"\n",
    "pct_crop_ts = pct_crop.weighted(weights).mean(dim = ['lon','lat']).sel(time=slice(2020,2100))\n",
    "\n",
    "\n",
    "pct_crop_2020 = pct_crop.sel(time = slice(2010,2020)).mean(dim='time')\n",
    "pct_crop_2090 = pct_crop.sel(time = slice(2090,2100)).mean(dim='time')\n",
    "delta_pct_crop = pct_crop_2090 - pct_crop_2020\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81cb427d-03c4-476a-8f96-615543a66c44",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Now do Shapley analysis\n",
    "\n",
    "shapley_ssp_dict = {}\n",
    "t_start='2090-01-01'\n",
    "t_end='2100-01-01'\n",
    "\n",
    "full_case = data_dict_ssp['full']\n",
    "\n",
    "for key, value in data_dict_ssp.items():\n",
    "\n",
    "    if key == 'full':\n",
    "        continue\n",
    "\n",
    "    ### Initialize\n",
    "    dvd_ra_annual = []\n",
    "    dvd_rb_annual = []\n",
    "    dvd_rs_annual = []\n",
    "    dvd_annual = []\n",
    "\n",
    "    ### prepare data\n",
    "    #clim = 'SAM0_hot'\n",
    "    shapley = shapley_analysis_datasets(full_case, value, t_start, t_end )\n",
    "    dvd = get_monthly_mean_year_range(value,'DRYDEPV_H2',t_start,t_end).mean(\"month\") - get_monthly_mean_year_range(full_case,'DRYDEPV_H2',t_start,t_end).mean(\"month\")\n",
    "    dvd_est = filter_threshold(shapley['dvd_ra'])\n",
    "    \n",
    "    ### Band data by latitudes\n",
    "    for lat_band_start in np.arange(-60,70,10):\n",
    "        lat_slice = slice(lat_band_start,lat_band_start+10)\n",
    "        #print(shapley_vd_access_hot['dvd_ra'].sel(lat = lat_slice).mean([\"lat\",'lon','season']).values[()])\n",
    "        #print(lat_band_start)\n",
    "        \n",
    "        ### Filter out extreme values\n",
    "        dvd_ra_annual.append(filter_threshold(shapley['dvd_ra']).sel(lat = lat_slice).mean().values[()])\n",
    "        dvd_rb_annual.append(filter_threshold(shapley['dvd_rb']).sel(lat = lat_slice).mean().values[()])\n",
    "        dvd_rs_annual.append(filter_threshold(shapley['dvd_rs']).sel(lat = lat_slice).mean().values[()])\n",
    "        dvd_annual.append(dvd.sel(lat = lat_slice).mean().values[()])\n",
    "\n",
    "    out_df = pd.DataFrame({\n",
    "        'lat_band': np.arange(-55,75,10),\n",
    "        'dvd_ra':np.asarray(dvd_ra_annual),\n",
    "        'dvd_rb':np.asarray(dvd_rb_annual),\n",
    "        'dvd_rs':np.asarray(dvd_rs_annual),\n",
    "        'dvd_annual':np.asarray(dvd_annual)\n",
    "    })\n",
    "\n",
    "    shapley_ssp_dict[key] = out_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b469fef3-fa1b-4ee1-967b-8dbd0d0e6002",
   "metadata": {},
   "outputs": [],
   "source": [
    "### Combine plots\n",
    "\n",
    "# =====================================================================\n",
    "# 1. DATA PREPARATION\n",
    "# =====================================================================\n",
    "dvd_lu_ts = (vd_ts_dict['full'] - vd_ts_dict['no_lu']).sel(year=slice(2020, 2100))\n",
    "plot_df = shapley_ssp_dict['no_lu']\n",
    "\n",
    "# =====================================================================\n",
    "# 2. MASTER PLOT SETUP (2 ROWS)\n",
    "# =====================================================================\n",
    "fig = plt.figure(figsize=(10, 10))\n",
    "gs = gridspec.GridSpec(nrows=2, ncols=1, height_ratios=[1,1])\n",
    "\n",
    "ax1 = fig.add_subplot(gs[0])  # Row 1: Time Series (Primary Axis)\n",
    "ax3 = fig.add_subplot(gs[1])  # Row 2: Latitudinal Plot\n",
    "\n",
    "# --- Row 1: Dual-Axis Time Series Plot ---\n",
    "color_blue = 'tab:blue'\n",
    "color_brown = 'tab:brown'\n",
    "\n",
    "# Primary axis (Left y-axis)\n",
    "ax1.plot(dvd_lu_ts.year.values, dvd_lu_ts.values, color=color_blue, linewidth=2)\n",
    "ax1.set_xlabel('Year', fontsize=14)\n",
    "ax1.set_ylabel(r'$\\Delta$$v_{d,LUC}$ (cm $s^{-1}$)', color=color_blue, fontsize=14)\n",
    "ax1.tick_params(axis='y', labelcolor=color_blue, labelsize=12)\n",
    "ax1.tick_params(axis='x', labelsize=12)\n",
    "ax1.grid(True, linestyle='--', alpha=0.5)\n",
    "\n",
    "# Twin axis (Right y-axis)\n",
    "ax2 = ax1.twinx() \n",
    "ax2.plot(pct_crop_ts.time.values, pct_crop_ts.values, color=color_brown, linewidth=2)\n",
    "ax2.set_ylabel('Global cropland coverage (%)', color=color_brown, fontsize=14)\n",
    "ax2.tick_params(axis='y', labelcolor=color_brown, labelsize=12)\n",
    "\n",
    "\n",
    "# --- Row 2: Latitudinal Band Plot ---\n",
    "ax3.plot(plot_df['lat_band'], -plot_df['dvd_rb'], label='Soil', color='red', linewidth=2)\n",
    "ax3.plot(plot_df['lat_band'], -plot_df['dvd_rs'], label='Snow', color='blue', linewidth=2)\n",
    "ax3.plot(plot_df['lat_band'], -plot_df['dvd_ra'], label='Canopy', color='green', linewidth=2)\n",
    "ax3.plot(plot_df['lat_band'], -plot_df['dvd_annual'], label=r'Total $\\Delta$$v_{d}$', color='black', linewidth=2)\n",
    "\n",
    "# Styling Row 2\n",
    "ax3.axhline(y=0, color='black', linestyle='-', linewidth=1) \n",
    "ax3.grid(True, linestyle='--', alpha=0.5)\n",
    "ax3.set_ylim(-0.003, 0.001)\n",
    "ax3.tick_params(axis='both', labelsize=12)\n",
    "ax3.set_xlabel('Latitude (degree north)', fontsize=14)\n",
    "ax3.set_ylabel(r'Contribution to $\\Delta$$v_d$ (cm $s^{-1}$)', fontsize=14)\n",
    "ax3.legend(fontsize=12, loc='best')\n",
    "\n",
    "\n",
    "# =====================================================================\n",
    "# 3. ADD SUBFIGURE LABELS a) AND b)\n",
    "# =====================================================================\n",
    "# Adjusted coordinates ensure labels clear the left y-axis titles safely\n",
    "ax1.text(-0.12, 0, 'a)', transform=ax1.transAxes, \n",
    "         fontsize=20, fontweight='bold', va='bottom', ha='right')\n",
    "\n",
    "ax3.text(-0.12, 0, 'b)', transform=ax3.transAxes, \n",
    "         fontsize=20, fontweight='bold', va='bottom', ha='right')\n",
    "\n",
    "# Save and render output\n",
    "plt.tight_layout()\n",
    "plt.savefig('combined_time_and_lat_plots_lu.png', dpi=300, format='png')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "937378c1-481a-4838-a45c-3f66a443287b",
   "metadata": {},
   "outputs": [],
   "source": [
    "clm_xr_ssp_245 = xr.open_dataset('../CTSM/rundir/archive/ssp_245_new/lnd/hist/output_clm_shifted.nc')\n",
    "clm_xr_ssp_245_no_lu = xr.open_dataset('../CTSM/rundir/archive/ssp_245_no_lu_new/lnd/hist/output_clm_shifted.nc')\n",
    "\n",
    "dLAI = []\n",
    "dOM = []\n",
    "dcrop = []\n",
    "dtheta = []\n",
    "\n",
    "dLAI_2d = (get_seasonal_mean_year_range(clm_xr_ssp_245,'ELAI','2091-01-01','2100-01-01') - \n",
    "           get_seasonal_mean_year_range(clm_xr_ssp_245_no_lu,'ELAI','2091-01-01','2100-01-01')).mean(dim='season')\n",
    "dOM_2d = (get_seasonal_mean_year_range_soil_surface(clm_xr_ssp_245,'SOM_ACT_C_vr','2091-01-01','2100-01-01').mean(dim='season') - \n",
    "                       get_seasonal_mean_year_range_soil_surface(clm_xr_ssp_245_no_lu,'SOM_ACT_C_vr','2091-01-01','2100-01-01').mean(dim='season'))\n",
    "dtheta_2d = (get_seasonal_mean_year_range_soil_surface(clm_xr_ssp_245,'H2OSOI','2091-01-01','2100-01-01').mean(dim='season') - \n",
    "                       get_seasonal_mean_year_range_soil_surface(clm_xr_ssp_245_no_lu,'H2OSOI','2091-01-01','2100-01-01').mean(dim='season'))\n",
    "\n",
    "landuse_ts_xr = xr.open_dataset('/net/fs01/data/cesm2/inputdata/lnd/clm2/surfdata_esmf/ctsm5.3.0/landuse.timeseries_0.9x1.25_SSP2-4.5_1850-2100_78pfts_c240908.nc')\n",
    "pct_crop = landuse_ts_xr.PCT_CROP.assign_coords({\n",
    "    'lsmlon':dLAI_2d.lon.values,\n",
    "    'lsmlat':dLAI_2d.lat.values   \n",
    "}).rename({\n",
    "    'lsmlon':'lon',\n",
    "    'lsmlat':'lat'\n",
    "})\n",
    "\n",
    "pct_crop_2020 = pct_crop.sel(time = slice(2010,2020)).mean(dim='time')\n",
    "pct_crop_2090 = pct_crop.sel(time = slice(2090,2100)).mean(dim='time')\n",
    "delta_pct_crop = pct_crop_2090 - pct_crop_2020\n",
    "\n",
    "#delta_pct_crop\n",
    "\n",
    "for lat_band_start in np.arange(-60,70,10):\n",
    "    lat_slice = slice(lat_band_start,lat_band_start+10)\n",
    "    dLAI.append(dLAI_2d.sel(lat = lat_slice).mean().values[()])\n",
    "    dOM.append(dOM_2d.sel(lat = lat_slice).mean().values[()])\n",
    "    dcrop.append(delta_pct_crop.sel(lat = lat_slice).mean().values[()])\n",
    "    dtheta.append(dtheta_2d.sel(lat = lat_slice).mean().values[()])\n",
    "\n",
    "\n",
    "out_clm_df = pd.DataFrame({\n",
    "    'lat_band':np.arange(-55,75,10),\n",
    "    'dLAI':dLAI,\n",
    "    'dOM':dOM,\n",
    "    'dcrop':dcrop,\n",
    "    'dtheta':dtheta\n",
    "})\n",
    "\n",
    "fig, axes = plt.subplots(figsize=(12, 10), nrows = 2, ncols = 2)\n",
    "\n",
    "axes.flat[0].plot(out_clm_df['lat_band'],out_clm_df['dcrop'])\n",
    "axes.flat[0].set_ylabel(r'$\\Delta$Crop Fraction (%)', fontsize = 16)\n",
    "axes.flat[1].plot(out_clm_df['lat_band'],out_clm_df['dOM']/1e3)\n",
    "axes.flat[1].set_ylabel(r'$\\Delta$$OM_{act}$ (kgC $m^{-2}$)', fontsize = 16)\n",
    "axes.flat[2].plot(out_clm_df['lat_band'],out_clm_df['dLAI'])\n",
    "axes.flat[2].set_ylabel(r'$\\Delta$LAI ($m^2$ $m^{-2}$)', fontsize = 16)\n",
    "axes.flat[3].plot(out_clm_df['lat_band'],out_clm_df['dtheta'])\n",
    "axes.flat[3].set_ylabel(r'$\\Delta$$\\theta$ ($m^3$ $m^{-3}$)', fontsize = 16)\n",
    "\n",
    "for ax in axes.flat:\n",
    "    ax.tick_params(axis='x', labelsize=14)\n",
    "    ax.tick_params(axis='y', labelsize=14)\n",
    "    ax.axhline(y=0, color='black') \n",
    "    ax.grid(True)\n",
    "\n",
    "fig.supxlabel('Latitude (degree north)', fontsize = 16)\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.savefig('clm_var_lu.png',dpi = 300,format = 'png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3298776a-e823-4536-8cde-46ba14f518e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import cartopy.crs as ccrs\n",
    "import matplotlib.ticker as mticker\n",
    "\n",
    "fig, axes = plt.subplots(nrows = 2, ncols = 1, figsize = (8,10),\n",
    "                         subplot_kw={'projection': ccrs.PlateCarree()}\n",
    "                         )\n",
    "cbar_shrink = 0.5\n",
    "delta_pct_crop.sel(lat = slice(-60,70)).plot(ax = axes[0],robust=True,\n",
    "                                            cbar_kwargs={\n",
    "                                                'shrink':cbar_shrink,\n",
    "                                                'label':'Changes in cropland coverage (%)'}\n",
    "                                             )\n",
    "\n",
    "dvd_lu.sel(lat = slice(-60,70)).plot(ax = axes[1],robust=True,\n",
    "                                            cbar_kwargs={\n",
    "                                                'shrink':cbar_shrink,\n",
    "                                                'label':r'$\\Delta$$v_{d,lu}$ (cm $s^{-1}$)'}\n",
    "                                             )\n",
    "\n",
    "\n",
    "for i,ax in enumerate(axes):\n",
    "\n",
    "    if i == 5:\n",
    "        ax.set_visible(False)\n",
    "        continue\n",
    "        \n",
    "    ax.coastlines()\n",
    "    \n",
    "    # 1. Initialize gridlines\n",
    "    gl = ax.gridlines(draw_labels=True, linewidth=0.5, color='gray', alpha=0.5, linestyle='--')\n",
    "    \n",
    "    # 2. Set 10-degree intervals using FixedLocator\n",
    "    gl.xlocator = mticker.FixedLocator(range(-180, 181, 10))\n",
    "    gl.ylocator = mticker.FixedLocator(range(-90, 91, 10))\n",
    "    \n",
    "    # 3. Toggle labels (Standard Cartopy handles the \"E/W/N/S\" automatically now)\n",
    "    gl.top_labels = False\n",
    "    gl.right_labels = False\n",
    "    \n",
    "    # 4. If you want to be extra sure they look like \"10°N\":\n",
    "    gl.xpadding = 5 # Adds a little breathing room between map and label\n",
    "    gl.xlabel_style = {'size': 10}\n",
    "    gl.ylabel_style = {'size': 10}\n",
    "\n",
    "plt.subplots_adjust(hspace=0)\n",
    "plt.tight_layout()\n",
    "\n",
    "fig.savefig('dlu_dcrop_maps.png',dpi = 300,format= 'png',bbox_inches='tight')\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "default_2026",
   "language": "python",
   "name": "default_2026"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
