#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Log-x / log-y plot of the down-type running masses reconstructed from the
mass formula

    sqrt(m_dn) = A_d(mu) * [1 + 2 T_d(mu) cos(phi_d/3 + 2 pi n / 3)]

with n = 1, 2, 3,

where A_d(mu), T_d(mu), phi_d(mu) are extracted from the running masses
m_d1(mu), m_d2(mu), m_d3(mu) in MSbar QCD.

What is plotted
---------------
Top panel:
  - masses reconstructed from the Brannen-type formula evaluated with the fit
    parameters (A_d^fit(mu), T_d^fit, phi_d^fit),
      m_d1^fit(mu), m_d2^fit(mu), m_d3^fit(mu),
    with logarithmic y-axis.
Bottom panel:
  - absolute relative difference between the fit-reconstructed masses and the
    PDG-2024 running masses m_{dn}(mu) obtained by 4-loop MS-bar evolution,
      |Delta m^fit| / m * 100 [%],
      Delta m^fit(mu) = m_{dn}^fit(mu) - m_{dn}(mu),
    also with logarithmic y-axis.

Why percentage difference is recommended
----------------------------------------
The three down-type masses span orders of magnitude from a few MeV to several GeV.
Therefore a raw Delta[MeV] plot is dominated by the heaviest state and is less useful
for comparing reconstruction quality across all three curves. For that reason the
bottom panel shows

    |Delta m| / m * 100 [%]

on a logarithmic scale.

What is saved internally
------------------------
The CSV/NPZ files contain:
  - mu, alpha_s, running masses m_d1, m_d2, m_d3
  - sqrt(m_dn), 1/sqrt(m_dn)
  - K, K_inv
  - A_d(mu), T_d(mu), phi_d(mu)
  - masses reconstructed from the mass formula
  - signed differences Delta m [MeV]
  - signed relative differences [%]
  - absolute relative differences [%]

Input reference values (given by the user)
------------------------------------------
  m_d1(2.000 GeV)   = 4.7   MeV
  m_d2(2.000 GeV)   = 93.5  MeV
  m_d3(4.183 GeV)   = 4183  MeV
  m_u2(1.273 GeV)   = 1273  MeV   -> used here as the charm threshold scale
  m_u3(162.500 GeV) = 162500 MeV  -> used here as the top threshold scale
  alpha_s(M_Z^2)    = 0.1180

Physics assumptions implemented here
------------------------------------
1) MSbar scheme for quark masses and alpha_s.
2) 4-loop QCD running for both alpha_s and the quark-mass anomalous dimension.
3) Piecewise running with flavour thresholds at
       mu_c = m_c(m_c) = 1.273 GeV,
       mu_b = m_b(m_b) = 4.183 GeV,
       mu_t = m_t(m_t) = 162.5 GeV.
4) "Continuous matching" baseline, as requested:
       alpha_s^(nf)(mu_th) = alpha_s^(nf+1)(mu_th)
       m_q^(nf)(mu_th)     = m_q^(nf+1)(mu_th)
   i.e. finite MSbar decoupling constants at the thresholds are NOT applied.
   This is the simplest continuous prescription.

Important note on matching
--------------------------
In strict EFT matching inside MSbar, finite threshold corrections start beyond leading order.
So the present script is intentionally a continuous-matching approximation, not a full
4-loop decoupling implementation. If you want the strict EFT treatment, replace the two
functions
    continuous_match_alpha(...)
    continuous_match_mass(...)
by the appropriate decoupling factors.

Important note on the bottom contribution below mu_b
---------------------------------------------------
The d3 input is interpreted as the bottom quark mass m_b(mu). To compute single continuous
curves over the full range 1--1000 GeV, this script continues m_b(mu) below mu_b using the
same continuous-matching convention. For a stricter EFT-oriented treatment, set
INCLUDE_HEAVY_BELOW_OWN_THRESHOLD = False. In that case all three-mass derived quantities
are set to NaN for mu < mu_b because m_d3(mu) is then excluded from the three-mass formula.
"""

from __future__ import annotations

import csv
import math
from dataclasses import dataclass
from typing import Callable, List, Tuple

import matplotlib.pyplot as plt
import numpy as np
from scipy.integrate import solve_ivp

# -----------------------------------------------------------------------------
# User-facing switches
# -----------------------------------------------------------------------------
INCLUDE_HEAVY_BELOW_OWN_THRESHOLD = True
SAVEFIG = True
SAVE_INTERNAL_DATA = True
OUTPUT_FIG = "Figure_S3.png"
OUTPUT_CSV = "quark_running_mass_formula_compare_v5_colorblind_data.csv"
OUTPUT_NPZ = "quark_running_mass_formula_compare_v5_colorblind_data.npz"
N_GRID = 800

# -----------------------------------------------------------------------------
# Input parameters
# -----------------------------------------------------------------------------
MZ_GEV = 91.1876            # Z-boson pole mass [GeV]
ALPHA_S_MZ = 0.1180         # alpha_s(M_Z^2)

MU_MIN = 1.0
MU_MAX = 1000.0

# Quark thresholds used for piecewise running.
# These are the assumptions requested to be stated explicitly in the code.
MU_C = 1.273                # charm threshold scale  mu_c = m_c(m_c) [GeV]
MU_B = 4.183                # bottom threshold scale mu_b = m_b(m_b) [GeV]
MU_T = 162.5                # top threshold scale    mu_t = m_t(m_t) [GeV]

M_D1_REF = 4.7              # MeV at mu = 2 GeV
MU_D1_REF = 2.0             # GeV
M_D2_REF = 93.5             # MeV at mu = 2 GeV
MU_D2_REF = 2.0             # GeV
M_D3_REF = 4183.0           # MeV at mu = 4.183 GeV  (interpreted as bottom)
MU_D3_REF = MU_B            # GeV

# Mathematical constants
PI = math.pi
ZETA3 = 1.2020569031595942854
ZETA4 = PI**4 / 90.0
ZETA5 = 1.0369277551433699263


# -----------------------------------------------------------------------------
# 4-loop MSbar beta function and mass anomalous dimension
# Convention: a = alpha_s / (4*pi), and t = ln(mu^2)
# Then
#   da/dt      = - sum_{n=0}^3 beta_n  a^(n+2)
#   d ln m /dt = - sum_{n=0}^3 gamma_n a^(n+1)
# -----------------------------------------------------------------------------
def beta_coeffs(nf: int) -> Tuple[float, float, float, float]:
    beta0 = 11.0 - 2.0 / 3.0 * nf
    beta1 = 102.0 - 38.0 / 3.0 * nf
    beta2 = 2857.0 / 2.0 - 5033.0 / 18.0 * nf + 325.0 / 54.0 * nf**2
    beta3 = (
        149753.0 / 6.0
        + 3564.0 * ZETA3
        - (1078361.0 / 162.0 + 6508.0 / 27.0 * ZETA3) * nf
        + (50065.0 / 162.0 + 6472.0 / 81.0 * ZETA3) * nf**2
        + 1093.0 / 729.0 * nf**3
    )
    return beta0, beta1, beta2, beta3



def gamma_coeffs(nf: int) -> Tuple[float, float, float, float]:
    gamma0 = 4.0
    gamma1 = 202.0 / 3.0 - 20.0 / 9.0 * nf
    gamma2 = 1249.0 - (2216.0 / 27.0 + 160.0 / 3.0 * ZETA3) * nf - 140.0 / 81.0 * nf**2
    gamma3 = (
        4603055.0 / 162.0
        + 135680.0 / 27.0 * ZETA3
        - 8800.0 * ZETA5
        + (-91723.0 / 27.0 - 34192.0 / 9.0 * ZETA3 + 880.0 * ZETA4 + 18400.0 / 9.0 * ZETA5) * nf
        + (5242.0 / 243.0 + 800.0 / 9.0 * ZETA3 - 160.0 / 3.0 * ZETA4) * nf**2
        + (-332.0 / 243.0 + 64.0 / 27.0 * ZETA3) * nf**3
    )
    return gamma0, gamma1, gamma2, gamma3



def beta_a(a: float, nf: int) -> float:
    beta0, beta1, beta2, beta3 = beta_coeffs(nf)
    return -(beta0 * a**2 + beta1 * a**3 + beta2 * a**4 + beta3 * a**5)



def gamma_m(a: float, nf: int) -> float:
    gamma0, gamma1, gamma2, gamma3 = gamma_coeffs(nf)
    return gamma0 * a + gamma1 * a**2 + gamma2 * a**3 + gamma3 * a**4


# -----------------------------------------------------------------------------
# Threshold / matching assumptions
# -----------------------------------------------------------------------------
def active_nf(mu: float) -> int:
    """Number of active flavours used for the EFT running at the scale mu."""
    if mu < MU_C:
        return 3
    if mu < MU_B:
        return 4
    if mu < MU_T:
        return 5
    return 6



def continuous_match_alpha(alpha_in: float, nf_from: int, nf_to: int, mu_th: float) -> float:
    """
    Continuous matching baseline requested by the user.

    This enforces continuity of alpha_s across the threshold and omits the finite
    MSbar decoupling corrections.

    To switch to strict EFT matching, replace the return value by the appropriate
    decoupling relation alpha_s^(nf_to)(mu_th) = zeta_g^2 * alpha_s^(nf_from)(mu_th).
    """
    _ = (nf_from, nf_to, mu_th)
    return alpha_in



def continuous_match_mass(m_in: float, nf_from: int, nf_to: int, mu_th: float, heavy_name: str | None = None) -> float:
    """
    Continuous matching baseline requested by the user.

    For light quarks this keeps m_q continuous across the threshold.
    For the bottom contribution (heavy_name='b') we also keep continuity so that
    the three-mass formula can be evaluated continuously over the full x-range
    when desired.

    To switch to strict EFT matching, replace the return value by the appropriate
    MSbar mass decoupling factor.
    """
    _ = (nf_from, nf_to, mu_th, heavy_name)
    return m_in


# -----------------------------------------------------------------------------
# Dense piecewise solutions for alpha_s(mu)
# -----------------------------------------------------------------------------
@dataclass
class AlphaSegment:
    mu_lo: float
    mu_hi: float
    nf: int
    dense: Callable[[float], np.ndarray]



def solve_alpha_segment(mu_start: float, mu_end: float, alpha_start: float, nf: int) -> Tuple[AlphaSegment, float]:
    a0 = alpha_start / (4.0 * PI)
    t_start = math.log(mu_start**2)
    t_end = math.log(mu_end**2)

    sol = solve_ivp(
        fun=lambda t, y: [beta_a(y[0], nf)],
        t_span=(t_start, t_end),
        y0=[a0],
        method="DOP853",
        rtol=1e-10,
        atol=1e-12,
        dense_output=True,
        max_step=0.05,
    )
    if not sol.success:
        raise RuntimeError(f"alpha_s segment solve failed between {mu_start} and {mu_end} GeV")

    alpha_end = float(sol.y[0, -1]) * 4.0 * PI
    seg = AlphaSegment(mu_lo=min(mu_start, mu_end), mu_hi=max(mu_start, mu_end), nf=nf, dense=sol.sol)
    return seg, alpha_end



def build_alpha_segments(mu_min: float, mu_max: float, mu_anchor: float, alpha_anchor: float) -> List[AlphaSegment]:
    segments: List[AlphaSegment] = []

    # Downward from MZ.
    alpha_curr = alpha_anchor
    mu_curr = mu_anchor
    for mu_th, nf_from, nf_to in [(MU_B, 5, 4), (MU_C, 4, 3)]:
        if mu_min < mu_th < mu_curr:
            seg, alpha_at_th = solve_alpha_segment(mu_curr, mu_th, alpha_curr, nf_from)
            segments.append(seg)
            alpha_curr = continuous_match_alpha(alpha_at_th, nf_from=nf_from, nf_to=nf_to, mu_th=mu_th)
            mu_curr = mu_th
    if mu_min < mu_curr:
        seg, _ = solve_alpha_segment(mu_curr, mu_min, alpha_curr, active_nf(math.sqrt(mu_curr * mu_min)))
        segments.append(seg)

    # Upward from MZ.
    alpha_curr = alpha_anchor
    mu_curr = mu_anchor
    for mu_th, nf_from, nf_to in [(MU_T, 5, 6)]:
        if mu_curr < mu_th < mu_max:
            seg, alpha_at_th = solve_alpha_segment(mu_curr, mu_th, alpha_curr, nf_from)
            segments.append(seg)
            alpha_curr = continuous_match_alpha(alpha_at_th, nf_from=nf_from, nf_to=nf_to, mu_th=mu_th)
            mu_curr = mu_th
    if mu_curr < mu_max:
        seg, _ = solve_alpha_segment(mu_curr, mu_max, alpha_curr, active_nf(math.sqrt(mu_curr * mu_max)))
        segments.append(seg)

    return segments



def alpha_of_mu(mu: float, segments: List[AlphaSegment]) -> float:
    for seg in segments:
        if seg.mu_lo - 1e-14 <= mu <= seg.mu_hi + 1e-14:
            t = math.log(mu**2)
            a = float(seg.dense(t)[0])
            return 4.0 * PI * a
    raise ValueError(f"mu={mu} GeV is outside the solved range")


# -----------------------------------------------------------------------------
# Dense piecewise solutions for m_q(mu)
# -----------------------------------------------------------------------------
@dataclass
class MassSegment:
    mu_lo: float
    mu_hi: float
    nf: int
    dense: Callable[[float], np.ndarray]



def solve_mass_segment(
    mu_start: float,
    mu_end: float,
    m_start: float,
    nf: int,
    alpha_fun: Callable[[float], float],
) -> Tuple[MassSegment, float]:
    t_start = math.log(mu_start**2)
    t_end = math.log(mu_end**2)

    def rhs(t: float, y: np.ndarray) -> List[float]:
        mu = math.exp(0.5 * t)
        a = alpha_fun(mu) / (4.0 * PI)
        return [-gamma_m(a, nf) * y[0]]

    sol = solve_ivp(
        fun=rhs,
        t_span=(t_start, t_end),
        y0=[m_start],
        method="DOP853",
        rtol=1e-10,
        atol=1e-12,
        dense_output=True,
        max_step=0.05,
    )
    if not sol.success:
        raise RuntimeError(f"mass segment solve failed between {mu_start} and {mu_end} GeV")

    m_end = float(sol.y[0, -1])
    seg = MassSegment(mu_lo=min(mu_start, mu_end), mu_hi=max(mu_start, mu_end), nf=nf, dense=sol.sol)
    return seg, m_end



def thresholds_between(mu_a: float, mu_b: float) -> List[float]:
    lo, hi = sorted((mu_a, mu_b))
    th = [mu for mu in (MU_C, MU_B, MU_T) if lo < mu < hi]
    return th if mu_a < mu_b else th[::-1]



def build_mass_segments(
    mu_ref: float,
    m_ref: float,
    mu_min: float,
    mu_max: float,
    alpha_fun: Callable[[float], float],
    heavy_name: str | None = None,
) -> List[MassSegment]:
    segments: List[MassSegment] = []

    # Downward from reference scale.
    m_curr = m_ref
    mu_curr = mu_ref
    for mu_th in thresholds_between(mu_ref, mu_min):
        nf_from = active_nf(math.sqrt(mu_curr * (mu_th * (1.0 + 1e-12))))
        nf_to = active_nf(mu_th * (1.0 - 1e-12))
        seg, m_at_th = solve_mass_segment(mu_curr, mu_th, m_curr, nf_from, alpha_fun)
        segments.append(seg)
        m_curr = continuous_match_mass(m_at_th, nf_from=nf_from, nf_to=nf_to, mu_th=mu_th, heavy_name=heavy_name)
        mu_curr = mu_th
    if mu_min < mu_curr:
        nf = active_nf(math.sqrt(mu_curr * mu_min))
        seg, _ = solve_mass_segment(mu_curr, mu_min, m_curr, nf, alpha_fun)
        segments.append(seg)

    # Upward from reference scale.
    m_curr = m_ref
    mu_curr = mu_ref
    for mu_th in thresholds_between(mu_ref, mu_max):
        nf_from = active_nf(math.sqrt(mu_curr * (mu_th / (1.0 + 1e-12))))
        nf_to = active_nf(mu_th * (1.0 + 1e-12))
        seg, m_at_th = solve_mass_segment(mu_curr, mu_th, m_curr, nf_from, alpha_fun)
        segments.append(seg)
        m_curr = continuous_match_mass(m_at_th, nf_from=nf_from, nf_to=nf_to, mu_th=mu_th, heavy_name=heavy_name)
        mu_curr = mu_th
    if mu_curr < mu_max:
        nf = active_nf(math.sqrt(mu_curr * mu_max))
        seg, _ = solve_mass_segment(mu_curr, mu_max, m_curr, nf, alpha_fun)
        segments.append(seg)

    return segments



def mass_of_mu(mu: float, segments: List[MassSegment]) -> float:
    for seg in segments:
        if seg.mu_lo - 1e-14 <= mu <= seg.mu_hi + 1e-14:
            t = math.log(mu**2)
            return float(seg.dense(t)[0])
    raise ValueError(f"mu={mu} GeV is outside the mass solution range")


# -----------------------------------------------------------------------------
# Koide-type quantities and the A_d, T_d, phi_d parametrization
# -----------------------------------------------------------------------------
def koide_k_from_x(x1: np.ndarray, x2: np.ndarray, x3: np.ndarray) -> np.ndarray:
    denom = (x1 + x2 + x3) ** 2
    return (x1**2 + x2**2 + x3**2) / denom



def koide_k(m1: np.ndarray, m2: np.ndarray, m3: np.ndarray) -> np.ndarray:
    s1 = np.sqrt(m1)
    s2 = np.sqrt(m2)
    s3 = np.sqrt(m3)
    return koide_k_from_x(s1, s2, s3)



def koide_k_inv(m1: np.ndarray, m2: np.ndarray, m3: np.ndarray) -> np.ndarray:
    invs1 = 1.0 / np.sqrt(m1)
    invs2 = 1.0 / np.sqrt(m2)
    invs3 = 1.0 / np.sqrt(m3)
    return koide_k_from_x(invs1, invs2, invs3)



def koide_ad_t_phi_from_sqrt(x1: np.ndarray, x2: np.ndarray, x3: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Compute A_d(mu), T_d(mu), phi_d(mu) from

        x_n = sqrt(m_dn) = A_d * [1 + 2 T_d cos(phi_d/3 + 2 pi n / 3)],

    with n = 1, 2, 3.

    Writing theta = phi_d / 3 and

        u_n = (x_n / A_d - 1) / (2 T_d),

    the three equations become

        u_1 = cos(theta + 2 pi / 3),
        u_2 = cos(theta + 4 pi / 3),
        u_3 = cos(theta).

    Hence

        cos(theta) = u_3,
        sin(theta) = (u_2 - u_1) / sqrt(3),

    and a stable inversion is

        theta = atan2((u_2 - u_1)/sqrt(3), u_3),
        phi_d = 3 * theta.
    """
    a_d = (x1 + x2 + x3) / 3.0
    k = koide_k_from_x(x1, x2, x3)
    t_d = np.sqrt((3.0 * k - 1.0) / 2.0)

    u1 = (x1 / a_d - 1.0) / (2.0 * t_d)
    u2 = (x2 / a_d - 1.0) / (2.0 * t_d)
    u3 = (x3 / a_d - 1.0) / (2.0 * t_d)

    cos_theta = u3
    sin_theta = (u2 - u1) / np.sqrt(3.0)

    norm = np.hypot(cos_theta, sin_theta)
    cos_theta = cos_theta / norm
    sin_theta = sin_theta / norm

    theta = np.arctan2(sin_theta, cos_theta)
    phi_d = 3.0 * theta
    return a_d, t_d, phi_d



def masses_from_formula(a_d: np.ndarray, t_d: np.ndarray, phi_d: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Reconstruct m_d1, m_d2, m_d3 from

        sqrt(m_dn) = A_d * [1 + 2 T_d cos(phi_d/3 + 2 pi n / 3)]

    using n = 1, 2, 3 exactly as requested.
    """
    theta = phi_d / 3.0
    x1 = a_d * (1.0 + 2.0 * t_d * np.cos(theta + 2.0 * PI / 3.0))
    x2 = a_d * (1.0 + 2.0 * t_d * np.cos(theta + 4.0 * PI / 3.0))
    x3 = a_d * (1.0 + 2.0 * t_d * np.cos(theta + 6.0 * PI / 3.0))
    return x1**2, x2**2, x3**2


# -----------------------------------------------------------------------------
# Internal data export
# -----------------------------------------------------------------------------
def save_internal_data(
    mu_grid: np.ndarray,
    alpha_values: np.ndarray,
    m_d1: np.ndarray,
    m_d2: np.ndarray,
    m_d3: np.ndarray,
    sqrt_d1: np.ndarray,
    sqrt_d2: np.ndarray,
    sqrt_d3: np.ndarray,
    invsqrt_d1: np.ndarray,
    invsqrt_d2: np.ndarray,
    invsqrt_d3: np.ndarray,
    k_values: np.ndarray,
    k_inv_values: np.ndarray,
    a_d_values: np.ndarray,
    t_d_values: np.ndarray,
    phi_d_values: np.ndarray,
    m_d1_formula: np.ndarray,
    m_d2_formula: np.ndarray,
    m_d3_formula: np.ndarray,
    delta_m_d1: np.ndarray,
    delta_m_d2: np.ndarray,
    delta_m_d3: np.ndarray,
    rel_pct_d1: np.ndarray,
    rel_pct_d2: np.ndarray,
    rel_pct_d3: np.ndarray,
    abs_rel_pct_d1: np.ndarray,
    abs_rel_pct_d2: np.ndarray,
    abs_rel_pct_d3: np.ndarray,
) -> None:
    with open(OUTPUT_CSV, "w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerow([
            "mu_GeV",
            "alpha_s",
            "m_d1_MeV",
            "m_d2_MeV",
            "m_d3_MeV",
            "sqrt_m_d1_MeV_half",
            "sqrt_m_d2_MeV_half",
            "sqrt_m_d3_MeV_half",
            "inv_sqrt_m_d1_MeV_minus_half",
            "inv_sqrt_m_d2_MeV_minus_half",
            "inv_sqrt_m_d3_MeV_minus_half",
            "K",
            "K_inv",
            "A_d_MeV_half",
            "T_d",
            "phi_d",
            "m_d1_formula_MeV",
            "m_d2_formula_MeV",
            "m_d3_formula_MeV",
            "delta_m_d1_MeV",
            "delta_m_d2_MeV",
            "delta_m_d3_MeV",
            "delta_m_d1_percent",
            "delta_m_d2_percent",
            "delta_m_d3_percent",
            "abs_delta_m_d1_percent",
            "abs_delta_m_d2_percent",
            "abs_delta_m_d3_percent",
        ])
        for values in zip(
            mu_grid,
            alpha_values,
            m_d1,
            m_d2,
            m_d3,
            sqrt_d1,
            sqrt_d2,
            sqrt_d3,
            invsqrt_d1,
            invsqrt_d2,
            invsqrt_d3,
            k_values,
            k_inv_values,
            a_d_values,
            t_d_values,
            phi_d_values,
            m_d1_formula,
            m_d2_formula,
            m_d3_formula,
            delta_m_d1,
            delta_m_d2,
            delta_m_d3,
            rel_pct_d1,
            rel_pct_d2,
            rel_pct_d3,
            abs_rel_pct_d1,
            abs_rel_pct_d2,
            abs_rel_pct_d3,
        ):
            writer.writerow([f"{x:.16e}" for x in values])

    np.savez(
        OUTPUT_NPZ,
        mu_GeV=mu_grid,
        alpha_s=alpha_values,
        m_d1_MeV=m_d1,
        m_d2_MeV=m_d2,
        m_d3_MeV=m_d3,
        sqrt_m_d1_MeV_half=sqrt_d1,
        sqrt_m_d2_MeV_half=sqrt_d2,
        sqrt_m_d3_MeV_half=sqrt_d3,
        inv_sqrt_m_d1_MeV_minus_half=invsqrt_d1,
        inv_sqrt_m_d2_MeV_minus_half=invsqrt_d2,
        inv_sqrt_m_d3_MeV_minus_half=invsqrt_d3,
        K=k_values,
        K_inv=k_inv_values,
        A_d_MeV_half=a_d_values,
        T_d=t_d_values,
        phi_d=phi_d_values,
        m_d1_formula_MeV=m_d1_formula,
        m_d2_formula_MeV=m_d2_formula,
        m_d3_formula_MeV=m_d3_formula,
        delta_m_d1_MeV=delta_m_d1,
        delta_m_d2_MeV=delta_m_d2,
        delta_m_d3_MeV=delta_m_d3,
        delta_m_d1_percent=rel_pct_d1,
        delta_m_d2_percent=rel_pct_d2,
        delta_m_d3_percent=rel_pct_d3,
        abs_delta_m_d1_percent=abs_rel_pct_d1,
        abs_delta_m_d2_percent=abs_rel_pct_d2,
        abs_delta_m_d3_percent=abs_rel_pct_d3,
    )


# -----------------------------------------------------------------------------
# mu-grid construction
# -----------------------------------------------------------------------------
def build_mu_grid() -> np.ndarray:
    """
    Build the plotting/export grid and force exact inclusion of important scales.

    In particular, the user requested that the reference scale mu = 4.183 GeV
    be present explicitly in the CSV output. A pure logspace grid does not
    generally hit that value exactly, so we merge the logspace grid with the
    exact reference/threshold scales and remove duplicates.
    """
    base_grid = np.logspace(math.log10(MU_MIN), math.log10(MU_MAX), N_GRID)
    forced_points = np.array(sorted({MU_D1_REF, MU_D2_REF, MU_D3_REF, MU_C, MU_B, MU_T}), dtype=float)
    mu_grid = np.unique(np.concatenate([base_grid, forced_points]))
    mu_grid.sort()
    return mu_grid



def finite_positive_floor(*arrays: np.ndarray) -> float:
    positive = np.concatenate([arr[np.isfinite(arr) & (arr > 0.0)] for arr in arrays])
    if positive.size == 0:
        return 1e-30
    return max(float(np.min(positive)) * 0.5, 1e-30)



def add_threshold_lines(ax: plt.Axes) -> None:
    ax.axvline(MU_C, ls="--", lw=1, color="gray", alpha=0.6)
    ax.axvline(MU_B, ls="--", lw=1, color="gray", alpha=0.6)
    ax.axvline(MU_T, ls="--", lw=1, color="gray", alpha=0.6)


# -----------------------------------------------------------------------------
# Main execution
# -----------------------------------------------------------------------------
def main() -> None:
    mu_grid = build_mu_grid()

    alpha_segments = build_alpha_segments(
        mu_min=MU_MIN,
        mu_max=MU_MAX,
        mu_anchor=MZ_GEV,
        alpha_anchor=ALPHA_S_MZ,
    )
    alpha_fun = lambda mu: alpha_of_mu(mu, alpha_segments)

    d1_segments = build_mass_segments(MU_D1_REF, M_D1_REF, MU_MIN, MU_MAX, alpha_fun, heavy_name=None)
    d2_segments = build_mass_segments(MU_D2_REF, M_D2_REF, MU_MIN, MU_MAX, alpha_fun, heavy_name=None)
    d3_segments = build_mass_segments(MU_D3_REF, M_D3_REF, MU_MIN, MU_MAX, alpha_fun, heavy_name="b")

    alpha_values = np.array([alpha_fun(mu) for mu in mu_grid])
    m_d1 = np.array([mass_of_mu(mu, d1_segments) for mu in mu_grid])
    m_d2 = np.array([mass_of_mu(mu, d2_segments) for mu in mu_grid])
    m_d3 = np.array([mass_of_mu(mu, d3_segments) for mu in mu_grid])

    if not INCLUDE_HEAVY_BELOW_OWN_THRESHOLD:
        mask = mu_grid < MU_B
        m_d3 = np.where(mask, np.nan, m_d3)

    sqrt_d1 = np.sqrt(m_d1)
    sqrt_d2 = np.sqrt(m_d2)
    sqrt_d3 = np.sqrt(m_d3)
    invsqrt_d1 = 1.0 / sqrt_d1
    invsqrt_d2 = 1.0 / sqrt_d2
    invsqrt_d3 = 1.0 / sqrt_d3

    k_values = koide_k(m_d1, m_d2, m_d3)
    k_inv_values = koide_k_inv(m_d1, m_d2, m_d3)
    a_d_values, t_d_values, phi_d_values = koide_ad_t_phi_from_sqrt(sqrt_d1, sqrt_d2, sqrt_d3)

    m_d1_formula, m_d2_formula, m_d3_formula = masses_from_formula(a_d_values, t_d_values, phi_d_values)

    delta_m_d1 = m_d1_formula - m_d1
    delta_m_d2 = m_d2_formula - m_d2
    delta_m_d3 = m_d3_formula - m_d3

    rel_pct_d1 = 100.0 * delta_m_d1 / m_d1
    rel_pct_d2 = 100.0 * delta_m_d2 / m_d2
    rel_pct_d3 = 100.0 * delta_m_d3 / m_d3

    abs_rel_pct_d1 = np.abs(rel_pct_d1)
    abs_rel_pct_d2 = np.abs(rel_pct_d2)
    abs_rel_pct_d3 = np.abs(rel_pct_d3)

    if SAVE_INTERNAL_DATA:
        save_internal_data(
            mu_grid,
            alpha_values,
            m_d1,
            m_d2,
            m_d3,
            sqrt_d1,
            sqrt_d2,
            sqrt_d3,
            invsqrt_d1,
            invsqrt_d2,
            invsqrt_d3,
            k_values,
            k_inv_values,
            a_d_values,
            t_d_values,
            phi_d_values,
            m_d1_formula,
            m_d2_formula,
            m_d3_formula,
            delta_m_d1,
            delta_m_d2,
            delta_m_d3,
            rel_pct_d1,
            rel_pct_d2,
            rel_pct_d3,
            abs_rel_pct_d1,
            abs_rel_pct_d2,
            abs_rel_pct_d3,
        )
        print(f"Saved internal data to: {OUTPUT_CSV}")
        print(f"Saved internal data to: {OUTPUT_NPZ}")

    fig, (ax1, ax2) = plt.subplots(
        2,
        1,
        figsize=(9, 8),
        sharex=True,
        gridspec_kw={"height_ratios": [3.0, 1.8], "hspace": 0.08},
    )

    # Color-blind-safe palette (Okabe-Ito family) chosen to avoid the
    # red/green confusion common in deuteranopia/protanopia.
    # We also separate the series by line style so the curves remain
    # distinguishable in grayscale printouts.
    color_d1 = "#0072B2"  # blue
    color_d2 = "#D55E00"  # vermillion
    color_d3 = "#CC79A7"  # reddish purple
    ls_d1 = "-"
    ls_d2 = "--"
    ls_d3 = "-."

    # Top panel: formula-reconstructed masses (evaluated with the fit parameters).
    ax1.plot(mu_grid, m_d1_formula, lw=2.2, color=color_d1, ls=ls_d1, label=r"$m_{d1}^\mathrm{fit}(\mu)$")
    ax1.plot(mu_grid, m_d2_formula, lw=2.2, color=color_d2, ls=ls_d2, label=r"$m_{d2}^\mathrm{fit}(\mu)$")
    ax1.plot(mu_grid, m_d3_formula, lw=2.2, color=color_d3, ls=ls_d3, label=r"$m_{d3}^\mathrm{fit}(\mu)$")
    add_threshold_lines(ax1)
    ax1.set_xscale("log")
    ax1.set_yscale("log")
    ax1.set_xlim(MU_MIN, MU_MAX)
    ax1.set_ylabel(r"Fit mass $m_{dn}^\mathrm{fit}(\mu)$ [MeV]")
    ax1.set_title(r"Down-type masses reconstructed from the fit in $\overline{\mathrm{MS}}$ QCD (4-loop, continuous matching)")
    ax1.grid(True, which="both", ls=":", alpha=0.5)
    ax1.legend(loc="best")

    # Bottom panel: recommended comparison metric = absolute relative difference [%].
    # Delta m^fit_{dn}(mu) = m^fit_{dn}(mu) - m_{dn}(mu), where m_{dn}(mu) is the
    # PDG-2024 running mass obtained by 4-loop MS-bar evolution from the quoted inputs.
    plot_abs_rel_d1 = np.maximum(abs_rel_pct_d1, finite_positive_floor(abs_rel_pct_d1, abs_rel_pct_d2, abs_rel_pct_d3))
    plot_abs_rel_d2 = np.maximum(abs_rel_pct_d2, finite_positive_floor(abs_rel_pct_d1, abs_rel_pct_d2, abs_rel_pct_d3))
    plot_abs_rel_d3 = np.maximum(abs_rel_pct_d3, finite_positive_floor(abs_rel_pct_d1, abs_rel_pct_d2, abs_rel_pct_d3))

    ax2.plot(mu_grid, plot_abs_rel_d1, lw=2.2, color=color_d1, ls=ls_d1, label=r"$|\Delta m_{d1}^\mathrm{fit}|/m_{d1}$")
    ax2.plot(mu_grid, plot_abs_rel_d2, lw=2.2, color=color_d2, ls=ls_d2, label=r"$|\Delta m_{d2}^\mathrm{fit}|/m_{d2}$")
    ax2.plot(mu_grid, plot_abs_rel_d3, lw=2.2, color=color_d3, ls=ls_d3, label=r"$|\Delta m_{d3}^\mathrm{fit}|/m_{d3}$")
    add_threshold_lines(ax2)
    ax2.set_xscale("log")
    ax2.set_yscale("log")
    ax2.set_xlim(MU_MIN, MU_MAX)
    ax2.set_xlabel(r"Energy scale $\mu$ [GeV]")
    ax2.set_ylabel(r"$|\Delta m^\mathrm{fit}|/m$ [%]")
    ax2.grid(True, which="both", ls=":", alpha=0.5)
    ax2.legend(loc="best")

    # Threshold labels on the upper panel only.
    y_top = np.nanmax(m_d3_formula[np.isfinite(m_d3_formula)])
    ax1.text(MU_C * 1.02, y_top / 1.6, r"$\mu_c$", va="top")
    ax1.text(MU_B * 1.02, y_top / 1.6, r"$\mu_b$", va="top")
    ax1.text(MU_T * 1.02, y_top / 1.6, r"$\mu_t$", va="top")

    fig.tight_layout()

    if SAVEFIG:
        fig.savefig(OUTPUT_FIG, dpi=180, bbox_inches="tight")
        print(f"Saved figure to: {OUTPUT_FIG}")

    # Diagnostics.
    print("\nDiagnostic alpha_s values:")
    for mu in [1.0, MU_C, 2.0, MU_B, MZ_GEV, MU_T, 1000.0]:
        print(f"  alpha_s({mu:8.3f} GeV) = {alpha_fun(mu):.6f}")

    print("\nReference-point check:")
    print(f"  m_d1({MU_D1_REF:.3f} GeV) = {mass_of_mu(MU_D1_REF, d1_segments):.6f} MeV")
    print(f"  m_d2({MU_D2_REF:.3f} GeV) = {mass_of_mu(MU_D2_REF, d2_segments):.6f} MeV")
    print(f"  m_d3({MU_D3_REF:.3f} GeV) = {mass_of_mu(MU_D3_REF, d3_segments):.6f} MeV")

    print("\nSample A_d(mu), T_d(mu), phi_d(mu) values:")
    for mu in [1.0, 2.0, MU_B, 10.0, MZ_GEV, MU_T, 1000.0]:
        i = int(np.argmin(np.abs(mu_grid - mu)))
        print(
            f"  A_d({mu_grid[i]:8.3f} GeV) = {a_d_values[i]:.10f}, "
            f"T_d({mu_grid[i]:8.3f} GeV) = {t_d_values[i]:.10f}, "
            f"phi_d({mu_grid[i]:8.3f} GeV) = {phi_d_values[i]:.10f}"
        )

    print("\nMaximum absolute relative reconstruction differences [%]:")
    print(f"  d1: {np.nanmax(abs_rel_pct_d1):.16e}")
    print(f"  d2: {np.nanmax(abs_rel_pct_d2):.16e}")
    print(f"  d3: {np.nanmax(abs_rel_pct_d3):.16e}")

    plt.show()


if __name__ == "__main__":
    main()
