"""
GLP-1 Economic Impact Model
Art. 28 — Restrepo-Morales & Garcés Giraldo (2026)

Three-equation structural model:
  1. Bass Diffusion    — household adoption trajectory
  2. Demand Destruction — food expenditure reduction by category and horizon
  3. I-O Multiplier    — sectoral GDP impact via Leontief linkages

All parameters read from ../data/calibration/ CSV files.
Calibration anchors use HOUSEHOLD-level adoption (Gallup/Circana).
Short-run coefficients (Cornell, 6-month) and long-run coefficients
(KPMG, steady-state) are kept separate throughout.

t = 0 corresponds to Q2-2024 (first reliable household panel data point).
"""

import numpy as np
import pandas as pd
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from scipy.integrate import odeint
from scipy.optimize import minimize
from pathlib import Path

plt.rcParams.update({
    "font.family": "serif",
    "font.size": 11,
    "axes.spines.top": False,
    "axes.spines.right": False,
    "figure.dpi": 150,
})

# ── Paths ──────────────────────────────────────────────────────────────────
ROOT     = Path(__file__).parent.parent
DATA_DIR = ROOT / "data" / "calibration"
FIG_DIR  = ROOT / "figures"
FIG_DIR.mkdir(parents=True, exist_ok=True)

# ── Load calibration data ──────────────────────────────────────────────────
anchors_usa   = pd.read_csv(DATA_DIR / "bass_anchors_usa.csv")
anchors_latam = pd.read_csv(DATA_DIR / "bass_anchors_latam.csv")
e0_df         = pd.read_csv(DATA_DIR / "e0_baseline_usa.csv",  index_col="category")
alpha_df      = pd.read_csv(DATA_DIR / "demand_coefficients.csv", index_col="category")
io_df         = pd.read_csv(DATA_DIR / "io_multipliers.csv",   index_col="sector")

# ── Market potential and abandonment ──────────────────────────────────────
M_USA        = 0.35   # Gallup/JP Morgan: 35% of US households by 2030
M_LATAM      = 0.11   # Kantar: 11% use or plan to use (ceiling)
ABANDON_RATE = 1 / 3  # Cornell (2025) + Gleason et al. (2024): ~32-33% persist at 12 months
HH_USA       = 130.0  # millions of US households (Census 2023)

# ── Extract calibration anchors ────────────────────────────────────────────
t_obs_usa  = anchors_usa["t_years"].values         # [0.00, 1.25]
N_obs_usa  = anchors_usa["N_households"].values    # [0.12, 0.23]  — household %

t_obs_latam = anchors_latam["t_years"].values      # [0.00, 1.50]
N_obs_latam = anchors_latam["N_households"].values  # [0.03, 0.06]

# ── Extract E0 baseline expenditures (USD billions) ────────────────────────
E0 = e0_df["e0_billion_usd"].to_dict()
# E0["food_total"]    = 2095  BEA PCE 2023: food at home + food away
# E0["supermarket"]   =  800  grocery stores (USDA/NRF 2023)
# E0["fast_food"]     =  370  QSR (NRA 2023)
# E0["ultra_processed"] = 500  authors' estimate

# ── Extract demand reduction coefficients ─────────────────────────────────
ALPHA = alpha_df["alpha"].to_dict()
# Short-run (Cornell, 6-month):  food_total=0.053, fast_food=0.080, ultra_processed=0.100
# Long-run   (KPMG, steady-state): supermarket_lt=0.310
# LatAm survey (Kantar):         sugary_drinks=0.600, fatty_foods=0.550, high_sugar=0.510

# ── Extract I-O parameters ─────────────────────────────────────────────────
THETA = io_df["theta_share"].to_dict()
MU    = io_df["mu_multiplier"].to_dict()
weighted_mu = sum(THETA[s] * MU[s] for s in THETA)

# ══════════════════════════════════════════════════════════════════════════
# 1. BASS DIFFUSION MODEL
# ══════════════════════════════════════════════════════════════════════════

def bass_ode(N, t, p, q, M):
    """Bass (1969) differential equation: dN/dt = (p + q*N/M)*(M - N)."""
    return (p + q * N / M) * (M - N)


def calibrate_bass(t_obs, N_obs, M, p0=0.01, q0=0.3):
    """
    Calibrate innovation (p) and imitation (q) coefficients via
    Nelder-Mead minimisation of squared residuals against observed
    household adoption data.
    """
    def loss(params):
        p, q = params
        if p <= 0 or q <= 0:
            return 1e9
        t_grid = np.linspace(0, max(t_obs) + 0.1, 1000)
        N_sim  = odeint(bass_ode, N_obs[0], t_grid, args=(p, q, M)).flatten()
        N_pred = np.interp(t_obs[1:], t_grid, N_sim)   # skip t=0 (initial condition)
        return np.sum((N_pred - N_obs[1:]) ** 2)

    res = minimize(loss, [p0, q0], method="Nelder-Mead",
                   options={"xatol": 1e-10, "fatol": 1e-10, "maxiter": 20_000})
    return res.x


def simulate_bass(p, q, M, N_init, years=10, dt=0.01):
    t = np.arange(0, years + dt, dt)
    N = odeint(bass_ode, N_init, t, args=(p, q, M)).flatten()
    return t, np.clip(N, 0, M)


# Calibrate USA (household data, t=0 at Q2-2024)
p_usa, q_usa = calibrate_bass(t_obs_usa, N_obs_usa, M_USA)
print(f"Bass USA   — p={p_usa:.4f}  q={q_usa:.4f}  M={M_USA:.2f}  q/p={q_usa/p_usa:.1f}")

# Calibrate LatAm (Kantar survey estimates)
p_latam, q_latam = calibrate_bass(t_obs_latam, N_obs_latam, M_LATAM,
                                   p0=0.005, q0=0.20)
print(f"Bass LatAm — p={p_latam:.4f}  q={q_latam:.4f}  M={M_LATAM:.2f}  q/p={q_latam/p_latam:.1f}")

t_sim, N_usa   = simulate_bass(p_usa,   q_usa,   M_USA,   N_obs_usa[0],   years=10)
_,     N_latam = simulate_bass(p_latam, q_latam, M_LATAM, N_obs_latam[0], years=10)

# ══════════════════════════════════════════════════════════════════════════
# 2. DEMAND DESTRUCTION MODEL
# ══════════════════════════════════════════════════════════════════════════

def effective_adoption(N_path, t_path, abandon_rate=ABANDON_RATE):
    """
    Net active users after first-year abandonment correction.
    After year 1, a steady-state fraction (1 - abandon_rate) of ever-adopters
    remain on treatment at any given time.
    """
    return N_path * (1.0 - abandon_rate * np.minimum(t_path, 1.0))


def demand_destruction(N_eff, alpha, E0_bn):
    """
    Annual demand destruction in USD billions.
      N_eff  : effective adoption fraction (0-1) of all households
      alpha  : demand reduction coefficient for the category
      E0_bn  : total baseline market expenditure (USD billions)
    """
    return N_eff * alpha * E0_bn


N_eff = effective_adoption(N_usa, t_sim)

# Short-run (Cornell, 6-month coefficients)
dd_total_sr  = demand_destruction(N_eff, ALPHA["food_total"],      E0["food_total"])
dd_fastfood  = demand_destruction(N_eff, ALPHA["fast_food"],       E0["fast_food"])
dd_ultra     = demand_destruction(N_eff, ALPHA["ultra_processed"],  E0["ultra_processed"])

# Long-run (KPMG steady-state coefficient — supermarket channel)
dd_super_lr  = demand_destruction(N_eff, ALPHA["supermarket_lt"],  E0["supermarket"])

# Benchmark check at t ≈ 6 years (2030, index = 600)
idx_2030 = int(6.0 / 0.01)   # t=6 in the 0.01-step grid
print(f"\nDemand destruction at 2030 (USD bn):")
print(f"  Short-run total food   (Cornell a=0.053): {dd_total_sr[idx_2030]:.1f}")
print(f"  Fast food              (Cornell a=0.080): {dd_fastfood[idx_2030]:.1f}")
print(f"  Ultra-processed        (Cornell a=0.100): {dd_ultra[idx_2030]:.1f}")
print(f"  Long-run supermarket   (KPMG   a=0.310): {dd_super_lr[idx_2030]:.1f}  <-- benchmark ~$55B")
print(f"  KPMG benchmark: $55B")

# ══════════════════════════════════════════════════════════════════════════
# 3. INPUT-OUTPUT GDP MULTIPLIER
# ══════════════════════════════════════════════════════════════════════════

def gdp_impact(dd_total_bn):
    """
    Distribute total demand shock across food sectors and apply
    Leontief output multipliers.  Returns sector dict and total GDP loss.
    """
    gdp = {s: dd_total_bn * THETA[s] * MU[s] for s in THETA}
    return gdp, sum(gdp.values())


# GDP path uses the long-run supermarket channel as the primary demand shock
gdp_path_lr = np.array([gdp_impact(dd_super_lr[i])[1] for i in range(len(t_sim))])
gdp_2030, gdp_total_2030 = gdp_impact(dd_super_lr[idx_2030])

print(f"\nGDP impact at 2030 via I-O (USD bn):")
for s, v in gdp_2030.items():
    print(f"  {s:<25} {v:.1f}")
print(f"  {'TOTAL':<25} {gdp_total_2030:.1f}")
print(f"  Weighted multiplier: {weighted_mu:.3f}")

# ══════════════════════════════════════════════════════════════════════════
# SAVE MODEL OUTPUT TO CSV
# ══════════════════════════════════════════════════════════════════════════

years_label = 2024.5 + t_sim   # t=0 → Q2-2024

output_df = pd.DataFrame({
    "year":               np.round(years_label, 2),
    "N_usa_gross":        np.round(N_usa, 4),
    "N_eff_usa":          np.round(N_eff, 4),
    "N_latam_gross":      np.round(N_latam, 4),
    "dd_total_sr_bn":     np.round(dd_total_sr, 2),
    "dd_fastfood_bn":     np.round(dd_fastfood, 2),
    "dd_ultra_bn":        np.round(dd_ultra, 2),
    "dd_supermarket_lr_bn": np.round(dd_super_lr, 2),
    "gdp_impact_lr_bn":   np.round(gdp_path_lr, 2),
})
output_path = ROOT / "data" / "calibration" / "model_output_timeseries.csv"
output_df.to_csv(output_path, index=False)
print(f"\nTime series saved to {output_path}")

# Demand destruction scenarios (replaces old modelo_destruccion_demanda.csv)
scenarios = pd.DataFrame([
    {"scenario": "Short-run 2027 (Cornell)",  "year": 2027,
     "alpha_used": 0.053, "E0_bn": E0["food_total"],
     "N_eff": round(N_eff[int(2.5/0.01)], 3),
     "dd_bn": round(dd_total_sr[int(2.5/0.01)], 1),
     "gdp_bn": round(gdp_path_lr[int(2.5/0.01)], 1)},
    {"scenario": "Long-run 2030 (KPMG)",      "year": 2030,
     "alpha_used": 0.310, "E0_bn": E0["supermarket"],
     "N_eff": round(N_eff[idx_2030], 3),
     "dd_bn": round(dd_super_lr[idx_2030], 1),
     "gdp_bn": round(gdp_total_2030, 1)},
    {"scenario": "Long-run 2034 (KPMG)",      "year": 2034,
     "alpha_used": 0.310, "E0_bn": E0["supermarket"],
     "N_eff": round(N_eff[int(9.5/0.01)], 3),
     "dd_bn": round(dd_super_lr[int(9.5/0.01)], 1),
     "gdp_bn": round(gdp_path_lr[int(9.5/0.01)], 1)},
])
scenarios.to_csv(ROOT / "data" / "calibration" / "modelo_destruccion_demanda.csv", index=False)

# ══════════════════════════════════════════════════════════════════════════
# FIGURES
# ══════════════════════════════════════════════════════════════════════════

BLUE  = "#1f4e79"
GREEN = "#1a6b3c"
RED   = "#8b0000"
GOLD  = "#b8860b"
GRAY  = "#555555"

fig = plt.figure(figsize=(14, 10))
gs  = gridspec.GridSpec(2, 2, figure=fig, hspace=0.42, wspace=0.36)

# ── Panel A: Adoption trajectories ─────────────────────────────────────
ax1 = fig.add_subplot(gs[0, 0])
ax1.plot(years_label, N_usa * 100,   color=BLUE,  lw=2,   label="USA (households)")
ax1.plot(years_label, N_latam * 100, color=GREEN, lw=2,
         linestyle="--", label="Latin America (households)")
ax1.axhline(M_USA * 100,   color=BLUE,  lw=0.8, linestyle=":", alpha=0.5)
ax1.axhline(M_LATAM * 100, color=GREEN, lw=0.8, linestyle=":", alpha=0.5)
# Calibration anchors (household data)
ax1.scatter(t_obs_usa   + 2024.5, N_obs_usa   * 100,
            color=BLUE,  zorder=5, s=55, label="Gallup/Circana observed")
ax1.scatter(t_obs_latam + 2024.5, N_obs_latam * 100,
            color=GREEN, zorder=5, s=55, marker="s", label="Kantar observed")
ax1.set_title("A — GLP-1 Adoption: Bass Diffusion", fontweight="bold")
ax1.set_ylabel("% of households")
ax1.set_xlabel("Year")
ax1.legend(fontsize=8, loc="upper left")
ax1.set_ylim(0, 42)
ax1.text(0.97, 0.30, f"USA: p={p_usa:.4f}, q={q_usa:.4f}\nq/p = {q_usa/p_usa:.0f}",
         transform=ax1.transAxes, ha="right", fontsize=7.5,
         color=BLUE, style="italic")

# ── Panel B: Demand destruction — SR vs LR ────────────────────────────
ax2 = fig.add_subplot(gs[0, 1])
ax2.plot(years_label, dd_super_lr,   color=RED,   lw=2,
         label=f"Supermarket LR (KPMG a={ALPHA['supermarket_lt']})")
ax2.plot(years_label, dd_total_sr,   color=BLUE,  lw=2,
         linestyle="--", label=f"Total food SR (Cornell a={ALPHA['food_total']})")
ax2.plot(years_label, dd_fastfood,   color=GOLD,  lw=1.5,
         linestyle=":",  label=f"Fast food SR (Cornell a={ALPHA['fast_food']})")
ax2.plot(years_label, dd_ultra,      color=GREEN, lw=1.5,
         linestyle="-.", label=f"Ultra-proc. SR (Cornell a={ALPHA['ultra_processed']})")
ax2.axhline(55, color=GRAY, lw=1, linestyle=":", alpha=0.8)
ax2.text(2025.5, 57, "KPMG benchmark $55B", fontsize=8, color=GRAY)
ax2.set_title("B — Demand Destruction USA (USD bn/yr)", fontweight="bold")
ax2.set_ylabel("USD billion")
ax2.set_xlabel("Year")
ax2.legend(fontsize=7.5, loc="upper left")

# ── Panel C: GDP impact ───────────────────────────────────────────────
ax3 = fig.add_subplot(gs[1, 0])
ax3.fill_between(years_label, gdp_path_lr, alpha=0.20, color=RED)
ax3.plot(years_label, gdp_path_lr, color=RED, lw=2,
         label=f"GDP loss (mu={weighted_mu:.3f})")
ax3.axvline(2030, color=GRAY, lw=0.8, linestyle=":")
ax3.text(2030.2, gdp_path_lr[idx_2030] * 0.6,
         f"2030: ${gdp_total_2030:.0f}B", fontsize=8, color=GRAY)
ax3.set_title("C — GDP Impact via I-O Multipliers (USD bn/yr)", fontweight="bold")
ax3.set_ylabel("USD billion")
ax3.set_xlabel("Year")
ax3.legend(fontsize=9)

# ── Panel D: Sectoral decomposition at 2030 ───────────────────────────
ax4 = fig.add_subplot(gs[1, 1])
sectors_labels = [s.replace("_", "\n") for s in gdp_2030]
values = list(gdp_2030.values())
colors = [BLUE, GREEN, GOLD, RED]
bars = ax4.barh(sectors_labels, values, color=colors, alpha=0.85)
ax4.bar_label(bars, fmt="$%.1fB", padding=4, fontsize=8.5)
ax4.set_title("D — GDP Loss by Sector at 2030 (USD bn)", fontweight="bold")
ax4.set_xlabel("USD billion")
ax4.set_xlim(0, max(values) * 1.35)

plt.suptitle(
    "GLP-1 Macroeconomic Impact Model — USA & Latin America\n"
    "Restrepo-Morales & Garcés Giraldo (2026) · t=0: Q2-2024",
    fontsize=11, fontweight="bold", y=1.01
)

plt.savefig(FIG_DIR / "glp1_model_results.pdf", bbox_inches="tight")
plt.savefig(FIG_DIR / "glp1_model_results.png", bbox_inches="tight")
plt.close()
print(f"\nFigures saved to {FIG_DIR}/")
print("Done.")
