#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Comparison with competing parametrizations for the Supplemental Material.

For each of the five models below, fit to the PDG 2024 down-type masses
(m_d, m_s, m_b) and compute chi^2, RMS relative deviation, AIC, and BIC.
Generates Figure S.6 (saved as Figure_S6.png; AIC/BIC bar chart) and the
numerical contents of Table S.3 (model comparison).

Models (n = 1, 2, 3 indexes the down-type quark family):

  M0  Free Brannen form, three free parameters:
        sqrt(m_n) = A * (1 + 2 T cos(phi/3 + 2 n pi/3)),       k = 3.
  M1  Benchmark-fixed Brannen form (this paper):
        T  = T_pred = sqrt(1/2 + sqrt(2 - sqrt(2))/8) ~ 0.7717
        phi = 1/3
        Only the overall amplitude A is free,                    k = 1.
  M2  K_inv-symmetric Brannen form, phi_d = 1/3 frozen,
      T_d numerically determined by imposing K_inv = 2/3 exactly
      on the predicted spectrum (this gives T_d ~ 0.77197, very
      close to the benchmark T_d^pred ~ 0.77180), one free
      parameter:
        sqrt(m_n) = A * (1 + 2 T_d^Kinv cos(phi/3 + 2 n pi/3)),  k = 1.
  M3  Geometric progression in sqrt-mass space, two free:
        sqrt(m_n) = a * b^n,                                     k = 2.
  M4  Power-law plus offset in mass space, three free:
        m_n = a * n^p + c,                                       k = 3.

Fit metric:
  chi^2  = sum_n ((m_n^pred - m_n^PDG) / sigma_n)^2
  RMS    = sqrt((1/3) sum_n ((m_n^pred - m_n^PDG)/m_n^PDG)^2)
  AIC    = chi^2 + 2 k
  BIC    = chi^2 + k log N        (N = 3 data points)

This script is self-contained; no Script 1 dependency, alpha_s held fixed
at its central value.
"""
from __future__ import annotations

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import brentq, minimize


# =============================================================================
# Inputs
# =============================================================================

# PDG 2024 mixed-scale central values and 1-sigma uncertainties.
M_PDG    = np.array([   4.70,   93.5, 4183.0])  # MeV
SIGMA    = np.array([   0.07,    0.8,    7.0])  # MeV
N_DATA   = M_PDG.size
N_FIDX   = np.arange(1, N_DATA + 1)             # n = 1, 2, 3

# Benchmark Brannen parameters (Sections 3.2, 3.2.1)
PHI_BENCH = 1.0 / 3.0
COS_DELTA_BENCH = 0.5 * np.cos(3.0 * np.pi / 8.0)
T_BENCH = float(np.sqrt(0.5 + np.sqrt(2.0 - np.sqrt(2.0)) / 8.0))   # ~ 0.77180


# K_inv-symmetric T_d: numerically determined from K_inv(T, phi=1/3) = 2/3.
def _kinv_from_brannen(T, phi):
    n = np.arange(1, 4)
    k = 1.0 + 2.0 * T * np.cos(phi / 3.0 + 2.0 * np.pi * n / 3.0)
    return float(np.sum(1.0 / k ** 2) / np.sum(1.0 / k) ** 2)


T_KINV = float(brentq(
    lambda T: _kinv_from_brannen(T, PHI_BENCH) - 2.0 / 3.0,
    0.5, 0.9, xtol=1e-14))                                          # ~ 0.77197


# =============================================================================
# Common metric helpers
# =============================================================================


def chi2(m_pred):
    return float(np.sum(((m_pred - M_PDG) / SIGMA) ** 2))


def rms_rel(m_pred):
    return float(np.sqrt(np.mean(((m_pred - M_PDG) / M_PDG) ** 2)))


def aic_bic(chi2_val, k):
    return chi2_val + 2.0 * k, chi2_val + k * np.log(N_DATA)


# =============================================================================
# Brannen-type prediction (used by M0, M1, M4)
# =============================================================================


def brannen_pred(A, T, phi, n=N_FIDX):
    sqrt_m = A * (1.0 + 2.0 * T * np.cos(phi / 3.0 + 2.0 * np.pi * n / 3.0))
    return sqrt_m ** 2


# =============================================================================
# Per-model fitters
# =============================================================================


def fit_M0():
    """Free three-parameter Brannen form."""
    def obj(x):
        return chi2(brannen_pred(x[0], x[1], x[2]))
    res = minimize(obj, x0=[64.0, T_BENCH, PHI_BENCH], method="L-BFGS-B",
                   bounds=[(1.0, 200.0), (0.0, 0.99), (0.0, 2.0 * np.pi / 3.0)],
                   options={"ftol": 1e-14, "gtol": 1e-12, "maxiter": 5000})
    A, T, phi = (float(v) for v in res.x)
    m_pred = brannen_pred(A, T, phi)
    return _summary("M0", "free 3-param Brannen", 3, A, T, phi, m_pred,
                    extra_params={"A": A, "T": T, "phi": phi})


def fit_M1():
    """Benchmark-fixed Brannen: A free, T and phi frozen."""
    def obj(x):
        return chi2(brannen_pred(x[0], T_BENCH, PHI_BENCH))
    res = minimize(obj, x0=[64.0], method="L-BFGS-B",
                   bounds=[(1.0, 200.0)],
                   options={"ftol": 1e-14, "gtol": 1e-12})
    A = float(res.x[0])
    m_pred = brannen_pred(A, T_BENCH, PHI_BENCH)
    return _summary("M1", "benchmark-fixed Brannen", 1, A, T_BENCH, PHI_BENCH,
                    m_pred, extra_params={"A": A})


def fit_M2():
    """K_inv-symmetric Brannen: phi=1/3 fixed, T_d set so that K_inv = 2/3
    holds exactly on the predicted spectrum, only A free."""
    def obj(x):
        return chi2(brannen_pred(x[0], T_KINV, PHI_BENCH))
    res = minimize(obj, x0=[64.0], method="L-BFGS-B",
                   bounds=[(1.0, 200.0)],
                   options={"ftol": 1e-14, "gtol": 1e-12})
    A = float(res.x[0])
    m_pred = brannen_pred(A, T_KINV, PHI_BENCH)
    return _summary("M2", "K_inv-symmetric Brannen (K_inv=2/3, phi_d=1/3)", 1,
                    A, T_KINV, PHI_BENCH, m_pred,
                    extra_params={"A": A, "T_d": T_KINV})


def fit_M3():
    """Geometric progression in sqrt-mass space, sqrt(m_n) = a * b^n."""
    def obj(x):
        a, b = x
        sqrt_m = a * (b ** N_FIDX)
        return chi2(sqrt_m ** 2)
    log_sqrt = 0.5 * np.log(M_PDG)
    slope = float(np.polyfit(N_FIDX, log_sqrt, 1)[0])
    intercept = float(np.polyfit(N_FIDX, log_sqrt, 1)[1])
    a0 = float(np.exp(intercept))
    b0 = float(np.exp(slope))
    res = minimize(obj, x0=[a0, b0], method="L-BFGS-B",
                   bounds=[(1e-3, 1e3), (1.001, 100.0)],
                   options={"ftol": 1e-14, "gtol": 1e-12})
    a, b = (float(v) for v in res.x)
    m_pred = (a * (b ** N_FIDX)) ** 2
    return _summary("M3", "geometric progression sqrt(m_n) = a b^n", 2,
                    a, b, None, m_pred, extra_params={"a": a, "b": b})


def fit_M4():
    """Power-law plus offset in mass space, m_n = a * n^p + c."""
    def obj(x):
        a, p, c = x
        m_pred = a * (N_FIDX ** p) + c
        return chi2(m_pred)
    res = minimize(obj, x0=[1.0, 6.0, 0.0], method="Nelder-Mead",
                   options={"fatol": 1e-12, "xatol": 1e-9, "maxiter": 50000})
    a, p, c = (float(v) for v in res.x)
    m_pred = a * (N_FIDX ** p) + c
    return _summary("M4", "power-law + offset m_n = a n^p + c", 3,
                    a, p, c, m_pred, extra_params={"a": a, "p": p, "c": c})


def _summary(name, descr, k, p1, p2, p3, m_pred, extra_params):
    c2 = chi2(m_pred)
    rms = rms_rel(m_pred)
    aic, bic = aic_bic(c2, k)
    return {
        "name": name, "description": descr, "k": k,
        "params": (p1, p2, p3),
        "extra": extra_params,
        "m_pred": m_pred,
        "chi2": c2, "rms": rms, "aic": aic, "bic": bic,
    }


# =============================================================================
# Plot
# =============================================================================


COL_AIC = "#0072B2"
COL_BIC = "#CC79A7"
COL_HIGHLIGHT = "#D55E00"


def make_plot(rows, output_path="Figure_S6.png"):
    names = [r["name"] for r in rows]
    aics = np.array([r["aic"] for r in rows])
    bics = np.array([r["bic"] for r in rows])
    rmss = np.array([r["rms"] for r in rows]) * 100.0

    fig, axes = plt.subplots(1, 2, figsize=(12.0, 5.0),
                             constrained_layout=True)

    # Panel (a): AIC and BIC
    axA = axes[0]
    x = np.arange(len(names))
    width = 0.38
    barsA = axA.bar(x - width / 2, aics, width, label="AIC",
                    color=COL_AIC, edgecolor="white")
    barsB = axA.bar(x + width / 2, bics, width, label="BIC",
                    color=COL_BIC, edgecolor="white")
    # Highlight M1 (benchmark) bars
    for i, name in enumerate(names):
        if name == "M1":
            barsA[i].set_color(COL_HIGHLIGHT)
            barsB[i].set_color(COL_HIGHLIGHT)
            barsB[i].set_hatch("//")
    axA.set_xticks(x)
    axA.set_xticklabels(names)
    axA.set_ylabel("information criterion")
    axA.set_title(r"(a) AIC and BIC at PDG 2024 central values (log scale)")
    axA.set_yscale("log")
    axA.legend(loc="upper left", fontsize=10, framealpha=0.92)
    axA.grid(True, axis="y", ls=":", alpha=0.3, which="both")
    # Annotate bar heights
    for i, (a, b) in enumerate(zip(aics, bics)):
        axA.text(i - width / 2, a * 1.06, "%.2f" % a, ha="center",
                 va="bottom", fontsize=8)
        axA.text(i + width / 2, b * 1.06, "%.2f" % b, ha="center",
                 va="bottom", fontsize=8)
    axA.set_ylim(1.0, max(aics.max(), bics.max()) * 5.0)

    # Panel (b): RMS relative deviation
    axB = axes[1]
    bars = axB.bar(x, rmss, color=COL_AIC, edgecolor="white")
    for i, name in enumerate(names):
        if name == "M1":
            bars[i].set_color(COL_HIGHLIGHT)
    axB.set_xticks(x)
    axB.set_xticklabels(names)
    axB.set_ylabel(r"RMS relative deviation [\%]")
    axB.set_title(r"(b) RMS relative deviation at PDG 2024 central values")
    axB.set_yscale("log")
    axB.grid(True, axis="y", ls=":", alpha=0.3, which="both")
    for i, r in enumerate(rmss):
        axB.text(i, r * 1.06, "%.3f%%" % r if r >= 0.01 else "%.1e%%" % r,
                 ha="center", va="bottom", fontsize=8)

    fig.savefig(output_path, dpi=200)
    plt.close(fig)
    print("Saved " + output_path)


# =============================================================================
# Main
# =============================================================================


def main():
    print("===== Comparison with competing parametrizations =====")
    print("  PDG 2024 inputs (central +/- 1 sigma):")
    for n, m, s in zip(N_FIDX, M_PDG, SIGMA):
        print("    m_d%d = %8.3f +/- %5.3f MeV" % (n, m, s))
    print("  Benchmark T_pred ~ %.5f, phi_pred = %.5f"
          % (T_BENCH, PHI_BENCH))
    print("  K_inv-symmetric T_d (K_inv=2/3, phi_d=1/3) = %.10f"
          % T_KINV)
    print("  N_data = %d" % N_DATA)
    print()

    rows = [fit_M0(), fit_M1(), fit_M2(), fit_M3(), fit_M4()]

    print("===== Per-model fits =====")
    for r in rows:
        print("  %s : %s" % (r["name"], r["description"]))
        print("    k = %d, parameters = %s" % (r["k"], r["extra"]))
        print("    m_pred = [%.4f, %.4f, %.4f] MeV"
              % (r["m_pred"][0], r["m_pred"][1], r["m_pred"][2]))
        print("    delta  = [%+.3f%%, %+.3f%%, %+.3f%%]"
              % tuple((r["m_pred"][i] - M_PDG[i]) / M_PDG[i] * 100.0
                      for i in range(3)))
        print("    chi^2 = %.4f, RMS = %.4f%%, AIC = %.4f, BIC = %.4f"
              % (r["chi2"], r["rms"] * 100.0, r["aic"], r["bic"]))
        print()

    print("===== Table S.3 (model-comparison summary) =====")
    print("  %-3s %-44s %-3s %-10s %-10s %-10s %-10s"
          % ("M", "form", "k", "chi^2", "RMS [%]", "AIC", "BIC"))
    for r in rows:
        print("  %-3s %-44s %-3d %-10.4f %-10.4f %-10.4f %-10.4f"
              % (r["name"], r["description"], r["k"], r["chi2"],
                 r["rms"] * 100.0, r["aic"], r["bic"]))
    print()

    print("===== Drawing Figure S.6 (Figure_S6.png) =====")
    make_plot(rows, output_path="Figure_S6.png")

    # Highlight the headline comparison.
    aic_min = min(rows, key=lambda r: r["aic"])
    bic_min = min(rows, key=lambda r: r["bic"])
    print()
    print("  Lowest AIC: %s (%.4f)" % (aic_min["name"], aic_min["aic"]))
    print("  Lowest BIC: %s (%.4f)" % (bic_min["name"], bic_min["bic"]))
    rng = max(r["aic"] for r in rows) - min(r["aic"] for r in rows)
    print("  Spread of AIC across M0..M4: %.4f" % rng)


if __name__ == "__main__":
    main()
