#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
RG prescription / threshold-matching robustness for the Supplemental Material.

Compares four QCD-running variants for the down-type spectrum:

  V1 : 4-loop QCD running + continuous threshold matching  (baseline,
       same as Script 1)
  V2 : 4-loop QCD running + finite LO decoupling corrections at thresholds
       (CKS1998-type matching of alpha_s and quark masses at mu = m_h(m_h))
  V3 : 3-loop QCD running + continuous threshold matching
  V4 : 3-loop QCD running + finite LO decoupling corrections at thresholds

For each variant the script

  (i)   scans K_inv(mu) over mu in [1, 1000] GeV;
  (ii)  determines the closed-form Brannen-type fit (A_d^fit, T_d^fit,
        phi_d^fit) at mu = m_b(m_b) from the running masses;
  (iii) reports the QCD dressing factor sqrt(Z_m) =
        sqrt(m_q(m_b)/m_q(2 GeV)) for the d quark and the s quark.

Generates Figure S.7 (saved as Figure_S7.png; K_inv(mu) curves, 4 lines on a
single panel) and the numerical contents of Table S.4. Self-contained (does not
import Script_S1.py);
chunked and cached (Script_S7_cache.npz) for resumable execution.

Decoupling-correction conventions (Chetyrkin, Kniehl, Steinhauser 1998,
arXiv:hep-ph/9708255). At the heavy-quark mass scale mu = m_h(m_h)
[MS-bar], the leading non-trivial corrections in a = alpha_s/pi are

  alpha_s^{(nf-1)}(mu) = alpha_s^{(nf)}(mu) * [1 - (11/72) a^2 + O(a^3)]
  m_q^{(nf-1)}(mu)     = m_q^{(nf)}(mu)     * [1 - (89/432) a^2 + O(a^3)]

Continuous matching corresponds to dropping the a^2 corrections (V1, V3).
"""
from __future__ import annotations

import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from scipy.optimize import brentq, minimize


# =============================================================================
# Constants (PDG 2024)
# =============================================================================

MZ = 91.1876
ALPHA_S_MZ = 0.1180
MC_MC = 1.273
MB_MB = 4.183
MT_MT = 162.5

MD_INPUT = 4.7
MS_INPUT = 93.5
MB_INPUT = 4183.0
MU_LIGHT = 2.0

PHI_BENCH = 1.0 / 3.0
COS_DELTA_BENCH = 0.5 * np.cos(3.0 * np.pi / 8.0)


# =============================================================================
# Beta-function and mass-anomalous-dimension coefficients
# =============================================================================


def beta_coeffs(nf, loop=4):
    """Return (b0, b1, b2, b3); higher-loop coefficients zero if loop < 4 etc.
    Convention: d a/d ln mu^2 = - sum_i b_i a^{i+2}, with a = alpha_s / pi."""
    zeta3 = 1.2020569031595942
    b0 = (11.0 - 2.0 * nf / 3.0) / 4.0
    b1 = (102.0 - 38.0 * nf / 3.0) / 16.0 if loop >= 2 else 0.0
    b2 = (2857.0 / 2.0 - 5033.0 * nf / 18.0
          + 325.0 * nf**2 / 54.0) / 64.0 if loop >= 3 else 0.0
    if loop >= 4:
        b3 = ((149753.0 / 6.0 + 3564.0 * zeta3)
              - (1078361.0 / 162.0 + 6508.0 * zeta3 / 27.0) * nf
              + (50065.0 / 162.0 + 6472.0 * zeta3 / 81.0) * nf**2
              + 1093.0 * nf**3 / 729.0) / 256.0
    else:
        b3 = 0.0
    return b0, b1, b2, b3


def gamma_m_coeffs(nf, loop=4):
    """Return (g0, g1, g2, g3) of the mass anomalous dimension."""
    zeta3 = 1.2020569031595942
    zeta4 = np.pi**4 / 90.0
    zeta5 = 1.0369277551433699
    g0 = 1.0
    g1 = (202.0 / 3.0 - 20.0 * nf / 9.0) / 16.0 if loop >= 2 else 0.0
    g2 = (1249.0 - (2216.0 / 27.0 + 160.0 * zeta3 / 3.0) * nf
          - 140.0 * nf**2 / 81.0) / 64.0 if loop >= 3 else 0.0
    if loop >= 4:
        g3 = ((4603055.0 / 162.0 + 135680.0 * zeta3 / 27.0
               - 8800.0 * zeta5)
              - (91723.0 / 27.0 + 34192.0 * zeta3 / 9.0
                 - 880.0 * zeta4 - 18400.0 * zeta5 / 9.0) * nf
              + (5242.0 / 243.0 + 800.0 * zeta3 / 9.0
                 - 160.0 * zeta4 / 3.0) * nf**2
              + (-332.0 / 243.0 + 64.0 * zeta3 / 27.0) * nf**3
              ) / 256.0
    else:
        g3 = 0.0
    return g0, g1, g2, g3


# =============================================================================
# Threshold matching (continuous vs CKS1998 LO finite decoupling)
# =============================================================================


def alpha_s_decouple(a_s, direction, scheme):
    """Match alpha_s/pi across a heavy-quark threshold at mu = m_h(m_h).
    direction: 'down' (high nf -> low nf) or 'up' (low nf -> high nf).
    scheme: 'continuous' or 'finite_LO'."""
    if scheme == "continuous":
        return a_s
    # finite_LO: alpha_s^{(nf-1)} = alpha_s^{(nf)} * (1 - 11/72 a^2)
    delta = 11.0 / 72.0 * a_s**2
    if direction == "down":
        return a_s * (1.0 - delta)
    else:  # up: invert
        return a_s * (1.0 + delta)  # leading-order inversion


def mass_decouple(m, a_s, direction, scheme):
    """Match a quark mass across a heavy-quark threshold at mu = m_h(m_h)."""
    if scheme == "continuous":
        return m
    delta = 89.0 / 432.0 * a_s**2
    if direction == "down":
        return m * (1.0 - delta)
    else:
        return m * (1.0 + delta)


# =============================================================================
# RGE integrators
# =============================================================================


def das_dt(a, nf, loop):
    b0, b1, b2, b3 = beta_coeffs(nf, loop=loop)
    return -(b0 * a**2 + b1 * a**3 + b2 * a**4 + b3 * a**5)


def run_alpha_s_segment(a_s_start, mu_start, mu_end, nf, loop):
    if abs(mu_end - mu_start) / max(mu_start, 1e-30) < 1e-14:
        return a_s_start
    t_end = np.log(mu_end**2 / mu_start**2)
    sol = solve_ivp(lambda t, y: [das_dt(y[0], nf, loop)],
                    [0.0, t_end], [a_s_start],
                    method="RK45", rtol=1e-10, atol=1e-14,
                    max_step=abs(t_end) / 5)
    if not sol.success:
        raise RuntimeError("alpha_s RGE failure: " + sol.message)
    return float(sol.y[0, -1])


def run_mass_segment(m_start, a_s_start, mu_start, mu_end, nf, loop):
    if abs(mu_end - mu_start) / max(mu_start, 1e-30) < 1e-14:
        return m_start, a_s_start
    t_end = np.log(mu_end**2 / mu_start**2)
    b0, b1, b2, b3 = beta_coeffs(nf, loop=loop)
    g0, g1, g2, g3 = gamma_m_coeffs(nf, loop=loop)

    def rhs(t, y):
        a, lnm = y
        da = -(b0 * a**2 + b1 * a**3 + b2 * a**4 + b3 * a**5)
        dlnm = -(g0 * a + g1 * a**2 + g2 * a**3 + g3 * a**4)
        return [da, dlnm]

    sol = solve_ivp(rhs, [0.0, t_end], [a_s_start, np.log(m_start)],
                    method="RK45", rtol=1e-10, atol=1e-14,
                    max_step=abs(t_end) / 5)
    if not sol.success:
        raise RuntimeError("mass RGE failure: " + sol.message)
    a_end = float(sol.y[0, -1])
    m_end = float(np.exp(sol.y[1, -1]))
    return m_end, a_end


def get_nf(mu):
    if mu >= MT_MT:
        return 6
    if mu >= MB_MB:
        return 5
    if mu >= MC_MC:
        return 4
    return 3


def alpha_s_at_mu(mu, loop, scheme):
    """Compute a_s = alpha_s/pi at scale mu, starting from alpha_s(m_Z)."""
    a = ALPHA_S_MZ / np.pi
    mu_cur = MZ
    thresholds_up = [MT_MT]
    thresholds_dn = [MB_MB, MC_MC]
    if mu >= mu_cur:
        # run up; cross m_t if needed
        if mu >= MT_MT:
            a = run_alpha_s_segment(a, mu_cur, MT_MT, nf=5, loop=loop)
            a = alpha_s_decouple(a, "up", scheme)
            return run_alpha_s_segment(a, MT_MT, mu, nf=6, loop=loop)
        return run_alpha_s_segment(a, mu_cur, mu, nf=5, loop=loop)
    # run down
    if mu >= MB_MB:
        return run_alpha_s_segment(a, mu_cur, mu, nf=5, loop=loop)
    a = run_alpha_s_segment(a, mu_cur, MB_MB, nf=5, loop=loop)
    a = alpha_s_decouple(a, "down", scheme)
    if mu >= MC_MC:
        return run_alpha_s_segment(a, MB_MB, mu, nf=4, loop=loop)
    a = run_alpha_s_segment(a, MB_MB, MC_MC, nf=4, loop=loop)
    a = alpha_s_decouple(a, "down", scheme)
    return run_alpha_s_segment(a, MC_MC, mu, nf=3, loop=loop)


def running_mass(m_ref, mu_ref, mu_target, loop, scheme):
    """Run a quark mass from mu_ref to mu_target, with thresholds handled
    according to (loop, scheme)."""
    a_ref = alpha_s_at_mu(mu_ref, loop=loop, scheme=scheme)
    thresholds = [MC_MC, MB_MB, MT_MT]
    mu_cur, m_cur, a_cur = mu_ref, m_ref, a_ref
    if mu_target > mu_ref:
        waypoints = [mu_ref] + [t for t in thresholds
                                if mu_ref < t < mu_target] + [mu_target]
        for i in range(len(waypoints) - 1):
            mu_a, mu_b = waypoints[i], waypoints[i + 1]
            nf = get_nf(0.5 * (mu_a + mu_b))
            m_cur, a_cur = run_mass_segment(m_cur, a_cur, mu_a, mu_b,
                                            nf=nf, loop=loop)
            # If we landed exactly at a threshold (and we are not at the end),
            # apply matching for the next segment.
            if i + 1 < len(waypoints) - 1 and mu_b in thresholds:
                a_cur = alpha_s_decouple(a_cur, "up", scheme)
                m_cur = mass_decouple(m_cur, a_cur, "up", scheme)
    else:
        waypoints = [mu_ref] + [t for t in reversed(thresholds)
                                if mu_target < t < mu_ref] + [mu_target]
        for i in range(len(waypoints) - 1):
            mu_a, mu_b = waypoints[i], waypoints[i + 1]
            nf = get_nf(0.5 * (mu_a + mu_b))
            m_cur, a_cur = run_mass_segment(m_cur, a_cur, mu_a, mu_b,
                                            nf=nf, loop=loop)
            if i + 1 < len(waypoints) - 1 and mu_b in thresholds:
                a_cur = alpha_s_decouple(a_cur, "down", scheme)
                m_cur = mass_decouple(m_cur, a_cur, "down", scheme)
    return m_cur


# =============================================================================
# K_inv and Brannen-fit machinery
# =============================================================================


def kinv_from_masses(md, ms, mb):
    inv_m = np.array([1.0 / md, 1.0 / ms, 1.0 / mb])
    inv_sqrt = np.array([1.0 / np.sqrt(md), 1.0 / np.sqrt(ms),
                         1.0 / np.sqrt(mb)])
    return float(np.sum(inv_m) / np.sum(inv_sqrt) ** 2)


def k_from_masses(md, ms, mb):
    sqs = np.array([np.sqrt(md), np.sqrt(ms), np.sqrt(mb)])
    return float(np.sum([md, ms, mb]) / np.sum(sqs) ** 2)


def brannen_fit(md, ms, mb):
    """Closed-form Brannen-type fit at a common scale.
    Returns (A_fit, T_fit, phi_fit) such that
        sqrt(m_dn) = A (1 + 2 T cos(phi/3 + 2 n pi/3)),  n = 1, 2, 3."""
    sqs = np.array([np.sqrt(md), np.sqrt(ms), np.sqrt(mb)])
    A = float(np.sum(sqs) / 3.0)
    s = sqs / A - 1.0  # 2T cos(phi/3 + 2n pi/3) for n=1,2,3
    T2 = float(np.sum(s ** 2) / 6.0)
    T = float(np.sqrt(T2))

    # phi from minimization (closed-form has sign ambiguities)
    def loss(phi):
        third = phi / 3.0
        c = np.array([np.cos(third + 2.0 * np.pi / 3.0),
                      np.cos(third + 4.0 * np.pi / 3.0),
                      np.cos(third)])
        return float(np.sum((s - 2.0 * T * c) ** 2))

    res = minimize(lambda x: loss(float(x[0])), x0=[PHI_BENCH],
                   method="L-BFGS-B", bounds=[(0.0, 2.0 * np.pi)])
    return A, T, float(res.x[0])


# =============================================================================
# Variant scan (cached)
# =============================================================================


VARIANTS = [
    ("V1", "4-loop + continuous matching", 4, "continuous"),
    ("V2", "4-loop + LO decoupling",        4, "finite_LO"),
    ("V3", "3-loop + continuous matching",  3, "continuous"),
    ("V4", "3-loop + LO decoupling",        3, "finite_LO"),
]


def scan_kinv(loop, scheme, mu_grid):
    md_arr = np.empty(mu_grid.size)
    ms_arr = np.empty(mu_grid.size)
    mb_arr = np.empty(mu_grid.size)
    for i, mu in enumerate(mu_grid):
        md_arr[i] = running_mass(MD_INPUT, MU_LIGHT, mu, loop=loop,
                                  scheme=scheme)
        ms_arr[i] = running_mass(MS_INPUT, MU_LIGHT, mu, loop=loop,
                                  scheme=scheme)
        mb_arr[i] = running_mass(MB_INPUT, MB_MB, mu, loop=loop,
                                  scheme=scheme)
    kinv_arr = np.array([kinv_from_masses(md_arr[i], ms_arr[i], mb_arr[i])
                         for i in range(mu_grid.size)])
    k_arr = np.array([k_from_masses(md_arr[i], ms_arr[i], mb_arr[i])
                      for i in range(mu_grid.size)])
    return md_arr, ms_arr, mb_arr, kinv_arr, k_arr


def run_variants(mu_grid, cache_path="Script_S7_cache.npz", verbose=False):
    cache = {}
    if cache_path and os.path.exists(cache_path):
        try:
            cache = dict(np.load(cache_path))
            if cache.get("mu_grid") is None or \
               cache["mu_grid"].size != mu_grid.size or \
               not np.allclose(cache["mu_grid"], mu_grid):
                cache = {}
        except (OSError, ValueError):
            cache = {}

    out = {}
    for label, descr, loop, scheme in VARIANTS:
        keys = ["md_" + label, "ms_" + label, "mb_" + label,
                "kinv_" + label, "k_" + label]
        if all(k in cache for k in keys):
            md, ms, mb = cache["md_" + label], cache["ms_" + label], cache["mb_" + label]
            kinv = cache["kinv_" + label]
            kk = cache["k_" + label]
            if verbose:
                print("  cached: " + label)
        else:
            if verbose:
                print("  computing: " + label + " (" + descr + ")")
            md, ms, mb, kinv, kk = scan_kinv(loop, scheme, mu_grid)
            cache["md_" + label] = md
            cache["ms_" + label] = ms
            cache["mb_" + label] = mb
            cache["kinv_" + label] = kinv
            cache["k_" + label] = kk
            cache["mu_grid"] = mu_grid
            if cache_path:
                np.savez(cache_path, **cache)
        out[label] = {
            "descr": descr, "loop": loop, "scheme": scheme,
            "md": md, "ms": ms, "mb": mb, "kinv": kinv, "k": kk,
        }
    return out


# =============================================================================
# Plot
# =============================================================================


COLOR_MAP = {"V1": "#0072B2", "V2": "#D55E00",
             "V3": "#009E73", "V4": "#CC79A7"}
LS_MAP = {"V1": "-", "V2": "--", "V3": "-.", "V4": ":"}


def make_plot(mu_grid, results, output_path="Figure_S7.png"):
    fig, axes = plt.subplots(1, 2, figsize=(12.0, 5.4),
                             constrained_layout=True)

    axA = axes[0]
    for label, r in results.items():
        axA.plot(mu_grid, r["kinv"], color=COLOR_MAP[label],
                 linestyle=LS_MAP[label], lw=2.0,
                 label=label + " (" + r["descr"] + ")")
    axA.axhline(2.0 / 3.0, color="black", lw=1.0, ls=":",
                label=r"symmetric $K_{\mathrm{inv}}=2/3$")
    axA.set_xscale("log")
    axA.set_xlabel(r"$\mu$ [GeV]")
    axA.set_ylabel(r"$K_{\mathrm{inv}}(\mu)$")
    axA.set_title(r"(a) $K_{\mathrm{inv}}(\mu)$ across RG-prescription variants")
    axA.legend(loc="lower right", fontsize=8, framealpha=0.92)
    axA.grid(True, which="both", ls=":", alpha=0.3)

    axB = axes[1]
    for label, r in results.items():
        axB.plot(mu_grid, r["k"], color=COLOR_MAP[label],
                 linestyle=LS_MAP[label], lw=2.0, label=label)
    axB.axhline(2.0 / 3.0, color="black", lw=0.8, ls=":",
                label=r"symmetric $K=2/3$")
    axB.set_xscale("log")
    axB.set_xlabel(r"$\mu$ [GeV]")
    axB.set_ylabel(r"$K(\mu)$")
    axB.set_title(r"(b) Direct Koide $K(\mu)$ across variants")
    axB.legend(loc="lower right", fontsize=8, framealpha=0.92)
    axB.grid(True, which="both", ls=":", alpha=0.3)

    fig.savefig(output_path, dpi=200)
    plt.close(fig)
    print("Saved " + output_path)


# =============================================================================
# Main
# =============================================================================


def main():
    print("===== RG-prescription robustness =====")
    print("  PDG 2024 inputs: m_d = %.2f MeV (2 GeV), m_s = %.2f MeV (2 GeV),"
          " m_b = %.1f MeV (m_b)" % (MD_INPUT, MS_INPUT, MB_INPUT))
    print("  Thresholds: m_c = %.3f GeV, m_b = %.3f GeV, m_t = %.1f GeV"
          % (MC_MC, MB_MB, MT_MT))
    print("  alpha_s(M_Z) = %.4f" % ALPHA_S_MZ)

    n_pts = 60
    mu_grid = np.logspace(np.log10(1.0), np.log10(1000.0), n_pts)

    print()
    print("===== Running variants (chunked + cached) =====")
    results = run_variants(mu_grid, cache_path="Script_S7_cache.npz",
                            verbose=True)

    print()
    print("===== Brannen fits at mu = m_b =====")
    print("  %-3s %-44s %12s %12s %12s %12s %12s"
          % ("V", "description", "T_d^fit", "phi_d^fit",
             "K_inv(m_b)", "K(m_b)", "sqrt(Z_m,d)"))
    summary = {}
    for label, r in results.items():
        # find index closest to m_b
        i_mb = int(np.argmin(np.abs(mu_grid - MB_MB)))
        # Replace with exact m_b values via direct calls
        md_at_mb = running_mass(MD_INPUT, MU_LIGHT, MB_MB,
                                 loop=r["loop"], scheme=r["scheme"])
        ms_at_mb = running_mass(MS_INPUT, MU_LIGHT, MB_MB,
                                 loop=r["loop"], scheme=r["scheme"])
        mb_at_mb = running_mass(MB_INPUT, MB_MB, MB_MB,
                                 loop=r["loop"], scheme=r["scheme"])
        A_fit, T_fit, phi_fit = brannen_fit(md_at_mb, ms_at_mb, mb_at_mb)
        kinv_mb = kinv_from_masses(md_at_mb, ms_at_mb, mb_at_mb)
        k_mb = k_from_masses(md_at_mb, ms_at_mb, mb_at_mb)
        # Dressing factor sqrt(Z_m,d) = sqrt(m_d(m_b)/m_d(2 GeV))
        zm_d = float(np.sqrt(md_at_mb / MD_INPUT))
        zm_s = float(np.sqrt(ms_at_mb / MS_INPUT))
        summary[label] = {
            "A_fit": A_fit, "T_fit": T_fit, "phi_fit": phi_fit,
            "kinv_mb": kinv_mb, "k_mb": k_mb,
            "zm_d": zm_d, "zm_s": zm_s,
            "md_mb": md_at_mb, "ms_mb": ms_at_mb, "mb_mb": mb_at_mb,
        }
        print("  %-3s %-44s %12.6f %12.6f %12.6f %12.6f %12.6f"
              % (label, r["descr"], T_fit, phi_fit, kinv_mb, k_mb, zm_d))

    print()
    print("===== K_inv variability across variants =====")
    kinv_at_mb = np.array([summary[v]["kinv_mb"] for v in summary])
    kinv_min = float(kinv_at_mb.min())
    kinv_max = float(kinv_at_mb.max())
    print("  min(K_inv(m_b)) = %.6f" % kinv_min)
    print("  max(K_inv(m_b)) = %.6f" % kinv_max)
    print("  spread = %.6e (%.4f%%)"
          % (kinv_max - kinv_min,
             (kinv_max - kinv_min) / 0.6667 * 100.0))

    print()
    print("===== Drawing Figure S.7 (Figure_S7.png) =====")
    make_plot(mu_grid, results, output_path="Figure_S7.png")

    return summary


if __name__ == "__main__":
    main()
