#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Monte Carlo input uncertainty propagation for the b-anchored benchmark.

Generates Figure S.5 (saved as Figure_S5.png) and the numerical contents of Table S.2. Samples the PDG
2024 down-type input masses with their quoted Gaussian 1-sigma uncertainties
and propagates the perturbations to (i) the benchmark RMS deviation and
(ii) the unconstrained best-fit values (phi*, cos delta*) of the Brannen-type
benchmark form, b-normalized so that m_b is reproduced exactly per sample.

This script is self-contained: it does NOT import the 4-loop running engine
of Script 1. alpha_s is held fixed and only the three PDG quark-mass central
values are sampled. This calculation propagates only the quoted input-mass uncertainties; correlations between m_q and alpha_s
are therefore not included.
"""
from __future__ import annotations

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import minimize
from scipy.stats import chi2 as scipy_chi2


# =============================================================================
# Inputs
# =============================================================================

MD_CENTRAL = 4.70
MD_SIGMA = 0.07
MS_CENTRAL = 93.5
MS_SIGMA = 0.8
MB_CENTRAL = 4183.0
MB_SIGMA = 7.0

PHI_BENCH = 1.0 / 3.0
COS_DELTA_BENCH = 0.5 * np.cos(3.0 * np.pi / 8.0)

N_MC = 10000
RNG_SEED = 20260430


# =============================================================================
# Benchmark machinery
# =============================================================================


def benchmark_predictions(phi, cos_delta, mb):
    T_d = np.sqrt((1.0 + cos_delta) / 2.0)
    two_T = 2.0 * T_d
    third_phi = phi / 3.0
    k1 = 1.0 + two_T * np.cos(third_phi + 2.0 * np.pi / 3.0)
    k2 = 1.0 + two_T * np.cos(third_phi + 4.0 * np.pi / 3.0)
    k3 = 1.0 + two_T * np.cos(third_phi)
    A = np.sqrt(mb) / k3
    return (A * k1) ** 2, (A * k2) ** 2, (A * k3) ** 2


def rms_residual(phi, cos_delta, md, ms, mb):
    m1, m2, m3 = benchmark_predictions(phi, cos_delta, mb)
    r1 = (m1 - md) / md
    r2 = (m2 - ms) / ms
    r3 = (m3 - mb) / mb
    return np.sqrt((r1 ** 2 + r2 ** 2 + r3 ** 2) / 3.0)


def find_best_fit(md, ms, mb, x0=(PHI_BENCH, COS_DELTA_BENCH)):
    def obj(x):
        return float(rms_residual(float(x[0]), float(x[1]), md, ms, mb))
    res = minimize(obj, x0=x0, method="L-BFGS-B",
                   bounds=((0.0, 2.0 / 3.0), (0.0, 0.4)),
                   options={"ftol": 1e-12, "gtol": 1e-10})
    return float(res.x[0]), float(res.x[1]), float(res.fun)


# =============================================================================
# Monte Carlo (chunked + cached for resumable execution)
# =============================================================================


def run_monte_carlo(n_samples=N_MC, seed=RNG_SEED, cache_path=None,
                    chunk_size=5000, verbose=False):
    rng = np.random.default_rng(seed)
    md_all = rng.normal(MD_CENTRAL, MD_SIGMA, n_samples)
    ms_all = rng.normal(MS_CENTRAL, MS_SIGMA, n_samples)
    mb_all = rng.normal(MB_CENTRAL, MB_SIGMA, n_samples)

    rms_bench = np.empty(n_samples)
    phi_best = np.empty(n_samples)
    cd_best = np.empty(n_samples)
    rms_best = np.empty(n_samples)
    done_to = 0

    if cache_path is not None:
        try:
            cache = np.load(cache_path)
            cached_done = int(cache["done_to"])
            if (cached_done > 0
                and int(cache["n_samples"]) == n_samples
                and int(cache["seed"]) == seed):
                rms_bench[:cached_done] = cache["rms_bench"][:cached_done]
                phi_best[:cached_done] = cache["phi_best"][:cached_done]
                cd_best[:cached_done] = cache["cd_best"][:cached_done]
                rms_best[:cached_done] = cache["rms_best"][:cached_done]
                md_all = cache["md"]
                ms_all = cache["ms"]
                mb_all = cache["mb"]
                done_to = cached_done
                if verbose:
                    print("  resumed from cache: %d / %d done"
                          % (done_to, n_samples))
        except (FileNotFoundError, KeyError, OSError):
            pass

    while done_to < n_samples:
        end = min(done_to + chunk_size, n_samples)
        for i in range(done_to, end):
            rms_bench[i] = rms_residual(PHI_BENCH, COS_DELTA_BENCH,
                                        md_all[i], ms_all[i], mb_all[i])
            p, c, r = find_best_fit(md_all[i], ms_all[i], mb_all[i])
            phi_best[i] = p
            cd_best[i] = c
            rms_best[i] = r
        done_to = end
        if cache_path is not None:
            np.savez(cache_path, n_samples=n_samples, seed=seed,
                     done_to=done_to, rms_bench=rms_bench,
                     phi_best=phi_best, cd_best=cd_best, rms_best=rms_best,
                     md=md_all, ms=ms_all, mb=mb_all)
        if verbose:
            print("  chunk completed: %d / %d" % (done_to, n_samples))

    return {
        "md": md_all, "ms": ms_all, "mb": mb_all,
        "rms_bench": rms_bench,
        "phi_best": phi_best,
        "cd_best": cd_best,
        "rms_best": rms_best,
    }


# =============================================================================
# Confidence-region helpers (joint 2-D chi^2_2 quantiles)
# =============================================================================


def covariance_eigen(x, y):
    cov = np.cov(x, y)
    eigvals, eigvecs = np.linalg.eigh(cov)
    order = np.argsort(eigvals)[::-1]
    return eigvals[order], eigvecs[:, order], cov


def confidence_ellipse(x, y, level=0.6827):
    cx = float(np.mean(x))
    cy = float(np.mean(y))
    eigvals, eigvecs, _ = covariance_eigen(x, y)
    chi2_q = float(scipy_chi2.ppf(level, df=2))
    a = float(np.sqrt(chi2_q * eigvals[0]))
    b = float(np.sqrt(chi2_q * eigvals[1]))
    theta = float(np.arctan2(eigvecs[1, 0], eigvecs[0, 0]))
    return cx, cy, a, b, theta


def point_inside_ellipse(px, py, cx, cy, a, b, theta):
    dx = px - cx
    dy = py - cy
    co = np.cos(-theta)
    si = np.sin(-theta)
    xp = co * dx - si * dy
    yp = si * dx + co * dy
    return (xp / a) ** 2 + (yp / b) ** 2 <= 1.0


def mahalanobis_sigma(px, py, x, y):
    cx = float(np.mean(x))
    cy = float(np.mean(y))
    cov = np.cov(x, y)
    inv = np.linalg.inv(cov)
    d = np.array([px - cx, py - cy])
    return float(np.sqrt(d @ inv @ d))


# =============================================================================
# Plot
# =============================================================================


COL_BENCH = "#000000"
COL_BEST = "#D55E00"
COL_HIST = "#56B4E9"
COL_SCATTER = "#999999"
COL_E1 = "#0072B2"
COL_E2 = "#CC79A7"


def _draw_ellipse(ax, cx, cy, a, b, theta, color, lw, ls="-", label=None):
    t = np.linspace(0.0, 2.0 * np.pi, 360)
    co = np.cos(theta)
    si = np.sin(theta)
    xs = cx + a * np.cos(t) * co - b * np.sin(t) * si
    ys = cy + a * np.cos(t) * si + b * np.sin(t) * co
    ax.plot(xs, ys, color=color, lw=lw, ls=ls, label=label)


def make_plot(mc, output_path="Figure_S5.png"):
    rms_bench = mc["rms_bench"] * 100.0
    phi = mc["phi_best"]
    cd = mc["cd_best"]

    cx, cy, a1, b1, th1 = confidence_ellipse(phi, cd, level=0.6827)
    _, _, a2, b2, th2 = confidence_ellipse(phi, cd, level=0.95)
    bench_inside_1s = bool(point_inside_ellipse(
        PHI_BENCH, COS_DELTA_BENCH, cx, cy, a1, b1, th1))
    bench_inside_95 = bool(point_inside_ellipse(
        PHI_BENCH, COS_DELTA_BENCH, cx, cy, a2, b2, th2))
    bench_sigma = mahalanobis_sigma(PHI_BENCH, COS_DELTA_BENCH, phi, cd)

    rms_central = rms_residual(PHI_BENCH, COS_DELTA_BENCH,
                               MD_CENTRAL, MS_CENTRAL, MB_CENTRAL) * 100.0

    fig, axes = plt.subplots(1, 2, figsize=(12.0, 5.4),
                             constrained_layout=True)

    # Panel (a)
    axA = axes[0]
    n_bins = 60
    axA.hist(rms_bench, bins=n_bins, color=COL_HIST,
             edgecolor="white", linewidth=0.5)
    axA.axvline(rms_central, color=COL_BENCH, ls="--", lw=1.5,
                label=r"benchmark RMS at PDG central $\approx %.3f\%%$"
                       % rms_central)
    p16, p50, p84 = np.percentile(rms_bench, [15.865, 50.0, 84.135])
    axA.axvspan(p16, p84, alpha=0.18, color=COL_E1,
                label=r"central 68.27\%% $[%.3f, %.3f]\%%$" % (p16, p84))
    axA.axvline(p50, color=COL_E2, ls=":", lw=1.2,
                label=r"median $\approx %.3f\%%$" % p50)
    axA.set_xlabel(r"$\mathrm{RMS}_{\rm bench}$  [\%]")
    axA.set_ylabel(r"MC frequency")
    axA.set_title(r"(a) benchmark RMS distribution "
                  r"($N_{\rm MC} = %d$)" % len(rms_bench))
    axA.legend(loc="upper right", fontsize=9, framealpha=0.92)
    axA.grid(True, ls=":", alpha=0.3)

    # Panel (b)
    axB = axes[1]
    sub = slice(None, None, max(1, len(phi) // 2000))
    axB.scatter(phi[sub], cd[sub], s=4, color=COL_SCATTER, alpha=0.45,
                label=r"MC best fits ($N=%d$)" % len(phi))
    _draw_ellipse(axB, cx, cy, a1, b1, th1, color=COL_E1, lw=1.6,
                  label=r"$1\sigma$ joint (68.27\% CL)")
    _draw_ellipse(axB, cx, cy, a2, b2, th2, color=COL_E2, lw=1.6, ls="--",
                  label=r"95\% CL")
    axB.plot([cx], [cy], marker="^", ms=11, mfc=COL_BEST, mec="white",
             mew=1.2,
             label=(r"MC mean $(\bar\phi^{\ast}, \overline{\cos\delta^{\ast}})"
                    r" \approx (%.3f, %.3f)$" % (cx, cy)))
    label_b = (r"benchmark "
               r"$(\phi_d, \cos\delta_d) = (1/3,\ \frac{1}{2}\cos(\frac{3}{8}\pi))$"
               r"  [%s 1$\sigma$, %s 95\%%]"
               % ("inside" if bench_inside_1s else "outside",
                  "inside" if bench_inside_95 else "outside"))
    axB.plot([PHI_BENCH], [COS_DELTA_BENCH], marker="s", ms=10,
             mfc=COL_BENCH, mec="white", mew=1.2, label=label_b)
    axB.set_xlabel(r"$\phi^{\ast}$")
    axB.set_ylabel(r"$\cos\delta^{\ast}$")
    axB.set_title(r"(b) unconstrained best fits and confidence ellipses")
    half_w = max(1.5 * a2, 0.012)
    axB.set_xlim(cx - half_w, cx + half_w)
    axB.set_ylim(cy - half_w, cy + half_w)
    axB.legend(loc="upper right", fontsize=8, framealpha=0.92)
    axB.grid(True, ls=":", alpha=0.3)

    fig.savefig(output_path, dpi=200)
    plt.close(fig)
    print("Saved " + output_path)
    return {
        "ellipse_1s": (cx, cy, a1, b1, th1),
        "ellipse_95": (cx, cy, a2, b2, th2),
        "bench_inside_1s": bench_inside_1s,
        "bench_inside_95": bench_inside_95,
        "bench_sigma": bench_sigma,
        "rms_central_at_pdg": rms_central,
    }


# =============================================================================
# Main
# =============================================================================


def main():
    print("===== Monte Carlo input uncertainty propagation =====")
    print("  PDG inputs (central +/- 1 sigma):")
    print("    m_d(2 GeV) = %.3f +/- %.3f MeV" % (MD_CENTRAL, MD_SIGMA))
    print("    m_s(2 GeV) = %.3f +/- %.3f MeV" % (MS_CENTRAL, MS_SIGMA))
    print("    m_b(m_b)   = %.1f +/- %.1f MeV" % (MB_CENTRAL, MB_SIGMA))
    print("  Benchmark (phi_bench, cos delta_bench) = (%.6f, %.6f)"
          % (PHI_BENCH, COS_DELTA_BENCH))
    print("  N_MC = %d, seed = %d" % (N_MC, RNG_SEED))
    print()

    print("===== Running MC =====")
    mc = run_monte_carlo(cache_path="Script_S5_cache.npz",
                         chunk_size=5000, verbose=True)

    rms_bench_pct = mc["rms_bench"] * 100.0
    rms_best_pct = mc["rms_best"] * 100.0
    phi = mc["phi_best"]
    cd = mc["cd_best"]

    p2_5, p16, p50, p84, p97_5 = np.percentile(
        rms_bench_pct, [2.5, 15.865, 50.0, 84.135, 97.5])
    mean_rms = float(np.mean(rms_bench_pct))
    std_rms = float(np.std(rms_bench_pct, ddof=1))
    print()
    print("===== RMS_bench distribution (in %) =====")
    print("  mean +/- std    = %.4f +/- %.4f" % (mean_rms, std_rms))
    print("  median          = %.4f" % p50)
    print("  68.27%% interval = [%.4f, %.4f]" % (p16, p84))
    print("  95%% interval    = [%.4f, %.4f]" % (p2_5, p97_5))

    rb_mean = float(np.mean(rms_best_pct))
    rb_std = float(np.std(rms_best_pct, ddof=1))
    rb_med = float(np.median(rms_best_pct))
    rb16, rb84 = np.percentile(rms_best_pct, [15.865, 84.135])
    print()
    print("===== RMS_best distribution (in %) =====")
    print("  mean +/- std    = %.4f +/- %.4f" % (rb_mean, rb_std))
    print("  median          = %.4f" % rb_med)
    print("  68.27%% interval = [%.4f, %.4f]" % (rb16, rb84))

    phi_mean = float(np.mean(phi))
    phi_std = float(np.std(phi, ddof=1))
    cd_mean = float(np.mean(cd))
    cd_std = float(np.std(cd, ddof=1))
    rho = float(np.corrcoef(phi, cd)[0, 1])
    phi16, phi84 = np.percentile(phi, [15.865, 84.135])
    cd16, cd84 = np.percentile(cd, [15.865, 84.135])
    print()
    print("===== (phi*, cos delta*) marginals =====")
    print("  phi*       : mean = %.5f, std = %.5f, 68.27%% = [%.5f, %.5f]"
          % (phi_mean, phi_std, phi16, phi84))
    print("  cos delta* : mean = %.5f, std = %.5f, 68.27%% = [%.5f, %.5f]"
          % (cd_mean, cd_std, cd16, cd84))
    print("  correlation rho(phi*, cos delta*) = %+.4f" % rho)
    print("  benchmark:    (%.5f, %.5f)" % (PHI_BENCH, COS_DELTA_BENCH))
    print("  offset bench->mean: d(phi)=%+.5f, d(cos delta)=%+.5f"
          % (PHI_BENCH - phi_mean, COS_DELTA_BENCH - cd_mean))

    print()
    print("===== Drawing Figure S.5 (Figure_S5.png) =====")
    summary = make_plot(mc, output_path="Figure_S5.png")
    cx, cy, a1, b1, th1 = summary["ellipse_1s"]
    _, _, a2, b2, th2 = summary["ellipse_95"]
    print("  Joint 1-sigma ellipse (68.27%% CL, 2 dof):")
    print("    centre = (%.5f, %.5f)" % (cx, cy))
    print("    semi-axes (a, b) = (%.5f, %.5f), tilt = %+.4f rad"
          % (a1, b1, th1))
    print("  Joint 95%% CL ellipse:")
    print("    semi-axes (a, b) = (%.5f, %.5f)" % (a2, b2))
    print("  Benchmark Mahalanobis distance (in equivalent sigma) = %.3f"
          % summary["bench_sigma"])
    print("  Benchmark inside joint 1-sigma ellipse?   %s"
          % ("YES" if summary["bench_inside_1s"] else "NO"))
    print("  Benchmark inside joint 95%% CL ellipse?    %s"
          % ("YES" if summary["bench_inside_95"] else "NO"))


if __name__ == "__main__":
    main()
