#!/usr/bin/env python3
"""
==========================================================================
Generate a two-dimensional log-log plot based on the down-type quark mass
formula
==========================================================================

Problem setup
-------------
  Horizontal axis: mu_d3 [GeV] (evaluation scale of the bottom quark)
  Vertical axis:   mu_d1 [GeV], mu_d2 [GeV]

  The script computes the predicted masses from the down-type quark mass
  formula and then determines the running-mass scales that reproduce those
  masses by root finding.

Formula and algorithm
---------------------
  sqrt(m_dn) = A(mu_d3) * (1 + sqrt(2 + sqrt(2 - sqrt(2))/2)
                             * cos(1/9 + 2 n pi / 3))  [MeV^(1/2)]
  n = 1: down, n = 2: strange, n = 3: bottom

  A(mu_d3) is determined from m_b^(MS-bar)(mu_d3):
    sqrt(m_b(mu_d3)) = A(mu_d3) * k_3
  therefore
    A(mu_d3) = sqrt(m_b(mu_d3) [MeV]) / k_3

  predicted mass: m_dn^pred = (A(mu_d3) * k_n)^2  [MeV]

  mu_d1 and mu_d2 are defined by
    m_d^(MS-bar)(mu_d1) = m_d1^pred(mu_d3)
    m_s^(MS-bar)(mu_d2) = m_d2^pred(mu_d3)

Running-mass and matching assumptions
-------------------------------------
  - MS-bar scheme, 4-loop QCD running
  - alpha_s(m_Z^2) = 0.1180  (PDG 2024)
  - m_Z = 91.1876 GeV        (PDG 2024)
  - Quark thresholds (MS-bar masses):
      m_c(m_c) = 1.273 GeV   (PDG 2024, charm)
      m_b(m_b) = 4.183 GeV   (PDG 2024, bottom)
      m_t(m_t) = 162.5 GeV   (PDG 2024, top)
  - Threshold treatment: n_f -> n_f - 1 (or the reverse) at mu = m_q(m_q)
    alpha_s matching: continuous at mu = m_q(m_q)
    mass matching:    continuous at mu = m_q(m_q)
    Finite higher-order matching corrections from decoupling constants are
    not implemented in this script.
  - PDG 2024 input values:
      m_d = 4.7 MeV   at mu = 2.0 GeV
      m_s = 93.5 MeV  at mu = 2.0 GeV
      m_b = 4183 MeV  at mu = 4.183 GeV (= m_b(m_b))

How to run
----------
  $ python down_quark_mass.py

Required libraries
------------------
  numpy, scipy, matplotlib
==========================================================================
"""

import numpy as np
from scipy.integrate import solve_ivp
from scipy.optimize import brentq
import matplotlib.pyplot as plt
import warnings

# =========================================================================
# Constants
# =========================================================================
MZ = 91.1876       # Z boson mass [GeV] (PDG 2024)
ALPHA_S_MZ = 0.1180  # alpha_s(m_Z^2) (PDG 2024)

# Quark thresholds: MS-bar masses m_q(m_q) [GeV]  (PDG 2024)
MC_MC = 1.273      # charm:  m_c(m_c) = 1.273 GeV
MB_MB = 4.183      # bottom: m_b(m_b) = 4.183 GeV
MT_MT = 162.5      # top:    m_t(m_t) = 162.5 GeV

# Input quark masses [MeV] at given scales [GeV]  (PDG 2024 central values)
MD_INPUT = 4.7       # m_d at mu=2 GeV [MeV]
MS_INPUT = 93.5      # m_s at mu=2 GeV [MeV]
MB_INPUT = 4183.0    # m_b at mu=m_b   [MeV]
MU_D_INPUT = 2.0     # evaluation scale for m_d, m_s [GeV]

# =========================================================================
# QCD beta-function coefficients (MS-bar scheme)
# beta(alpha_s) = -beta_0 * (alpha_s/pi)^2 - beta_1 * (alpha_s/pi)^3 - ...
# Using standard normalization: d(alpha_s/pi)/d(ln mu^2) = - beta_i (alpha_s/pi)^{i+2}
# =========================================================================
def beta_coeffs(nf):
    """
    4-loop beta-function coefficients for n_f active flavours.
    beta_0, beta_1, beta_2, beta_3 in the convention:
      d a_s / d ln mu^2 = - beta_0 a_s^2 - beta_1 a_s^3 - beta_2 a_s^4 - beta_3 a_s^5
    where a_s = alpha_s / pi.
    """
    b0 = (11.0 - 2.0 * nf / 3.0) / 4.0
    b1 = (102.0 - 38.0 * nf / 3.0) / 16.0
    b2 = (2857.0 / 2.0 - 5033.0 * nf / 18.0 + 325.0 * nf**2 / 54.0) / 64.0
    # 4-loop coefficient (van Ritbergen, Vermaseren, Larin)
    b3 = ((149753.0 / 6.0 + 3564.0 * 1.2020569031595942)
          - (1078361.0 / 162.0 + 6508.0 * 1.2020569031595942 / 27.0) * nf
          + (50065.0 / 162.0 + 6472.0 * 1.2020569031595942 / 81.0) * nf**2
          + 1093.0 * nf**3 / 729.0) / 256.0
    return b0, b1, b2, b3


# =========================================================================
# Quark mass anomalous dimension coefficients (MS-bar)
# gamma_m(alpha_s) = gamma_0 a_s + gamma_1 a_s^2 + gamma_2 a_s^3 + gamma_3 a_s^4
# d ln m / d ln mu^2 = -gamma_m
# =========================================================================
def gamma_m_coeffs(nf):
    """
    4-loop mass anomalous dimension coefficients.
    gamma_0, gamma_1, gamma_2, gamma_3 in the convention:
      d ln m / d ln mu^2 = -gamma_0 a_s - gamma_1 a_s^2 - gamma_2 a_s^3 - gamma_3 a_s^4
    where a_s = alpha_s / pi.
    References: Chetyrkin, Vermaseren, Larin (1997); Baikov et al.
    """
    zeta3 = 1.2020569031595942
    zeta4 = np.pi**4 / 90.0
    zeta5 = 1.0369277551433699

    g0 = 1.0
    g1 = (202.0 / 3.0 - 20.0 * nf / 9.0) / 16.0
    g2 = (1249.0 - (2216.0 / 27.0 + 160.0 * zeta3 / 3.0) * nf
           - 140.0 * nf**2 / 81.0) / 64.0
    g3 = ((4603055.0 / 162.0 + 135680.0 * zeta3 / 27.0
           - 8800.0 * zeta5)
          - (91723.0 / 27.0 + 34192.0 * zeta3 / 9.0
             - 880.0 * zeta4 - 18400.0 * zeta5 / 9.0) * nf
          + (5242.0 / 243.0 + 800.0 * zeta3 / 9.0
             - 160.0 * zeta4 / 3.0) * nf**2
          + (-332.0 / 243.0 + 64.0 * zeta3 / 27.0) * nf**3
         ) / 256.0
    return g0, g1, g2, g3


# =========================================================================
# alpha_s threshold matching at mu = m_q(m_q)
# =========================================================================
def alpha_s_matching_up(a_s_low, nf_low):
    """
    Match alpha_s from n_f = nf_low to n_f = nf_low + 1 at a quark threshold.
    a_s = alpha_s/pi.

    In this script finite decoupling constants are not included;
    alpha_s is matched continuously at mu = m_q(m_q).
    """
    return a_s_low  # continuous matching at mu = m_q(m_q)


def alpha_s_matching_down(a_s_high, nf_high):
    """Match alpha_s from n_f = nf_high to n_f = nf_high - 1.

    In this script finite decoupling constants are not included;
    alpha_s is matched continuously at mu = m_q(m_q).
    """
    return a_s_high  # continuous matching at mu = m_q(m_q)


# =========================================================================
# RGE for a_s = alpha_s/pi as function of t = ln(mu^2/mu_0^2)
# =========================================================================
def das_dt(t, a_s, nf):
    """
    d(a_s)/dt where t = ln(mu^2/mu_0^2), a_s = alpha_s/pi.
    4-loop beta-function.
    """
    b0, b1, b2, b3 = beta_coeffs(nf)
    return -(b0 * a_s**2 + b1 * a_s**3 + b2 * a_s**4 + b3 * a_s**5)


def run_alpha_s_segment(a_s_start, mu_start, mu_end, nf, rtol=1e-12):
    """
    Run alpha_s/pi from mu_start to mu_end with n_f fixed flavours (4-loop).
    Returns a_s at mu_end.
    """
    if abs(mu_end - mu_start) / mu_start < 1e-14:
        return a_s_start
    t_start = 0.0
    t_end = np.log(mu_end**2 / mu_start**2)
    sol = solve_ivp(lambda t, y: das_dt(t, y[0], nf),
                    [t_start, t_end], [a_s_start],
                    method='RK45', rtol=rtol, atol=1e-15,
                    max_step=abs(t_end - t_start) / 10)
    if not sol.success:
        raise RuntimeError(f"ODE solver failed for alpha_s running: {sol.message}")
    return sol.y[0, -1]


# =========================================================================
# Complete alpha_s running from m_Z to arbitrary mu, with threshold matching
# =========================================================================
def alpha_s_at_mu(mu, a_s_mz=ALPHA_S_MZ / np.pi, mu0=MZ):
    """
    Compute a_s = alpha_s/pi at scale mu [GeV], starting from alpha_s(m_Z)/pi.
    Includes 4-loop running with threshold crossings at m_c, m_b, m_t.
    Finite decoupling constants are not included; alpha_s is matched continuously.

    Threshold structure:
      mu < m_c:           n_f = 3
      m_c <= mu < m_b:    n_f = 4
      m_b <= mu < m_t:    n_f = 5
      mu >= m_t:          n_f = 6

    Matching at mu = m_q(m_q) is continuous in this implementation.
    """
    a_s_current = a_s_mz
    mu_current = mu0

    if mu >= mu0:
        # Run upward from m_Z
        # First check if we cross m_t
        if mu < MT_MT:
            return run_alpha_s_segment(a_s_current, mu_current, mu, nf=5)
        else:
            # Run to m_t, match, then run to mu
            a_s_current = run_alpha_s_segment(a_s_current, mu_current, MT_MT, nf=5)
            a_s_current = alpha_s_matching_up(a_s_current, 5)
            return run_alpha_s_segment(a_s_current, MT_MT, mu, nf=6)
    else:
        # Run downward from m_Z
        if mu >= MB_MB:
            return run_alpha_s_segment(a_s_current, mu_current, mu, nf=5)
        else:
            # Run to m_b
            a_s_current = run_alpha_s_segment(a_s_current, mu_current, MB_MB, nf=5)
            a_s_current = alpha_s_matching_down(a_s_current, 5)
            if mu >= MC_MC:
                return run_alpha_s_segment(a_s_current, MB_MB, mu, nf=4)
            else:
                # Run to m_c
                a_s_current = run_alpha_s_segment(a_s_current, MB_MB, MC_MC, nf=4)
                a_s_current = alpha_s_matching_down(a_s_current, 4)
                return run_alpha_s_segment(a_s_current, MC_MC, mu, nf=3)


# =========================================================================
# Quark mass running (MS-bar, 4-loop)
# =========================================================================
def run_mass_segment(m_start, a_s_start, mu_start, mu_end, nf, rtol=1e-12):
    """
    Run quark mass from mu_start to mu_end with n_f fixed flavours (4-loop).

    The coupled system:
      d(a_s)/dt = -beta_0 a_s^2 - beta_1 a_s^3 - beta_2 a_s^4 - beta_3 a_s^5
      d(ln m)/dt = -gamma_0 a_s - gamma_1 a_s^2 - gamma_2 a_s^3 - gamma_3 a_s^4
    where t = ln(mu^2/mu_0^2).

    Returns (m_end, a_s_end) at mu_end.
    """
    if abs(mu_end - mu_start) / max(mu_start, 1e-30) < 1e-14:
        return m_start, a_s_start

    t_start = 0.0
    t_end = np.log(mu_end**2 / mu_start**2)

    b0, b1, b2, b3 = beta_coeffs(nf)
    g0, g1, g2, g3 = gamma_m_coeffs(nf)

    def rhs(t, y):
        a_s = y[0]
        ln_m = y[1]
        da = -(b0 * a_s**2 + b1 * a_s**3 + b2 * a_s**4 + b3 * a_s**5)
        dlnm = -(g0 * a_s + g1 * a_s**2 + g2 * a_s**3 + g3 * a_s**4)
        return [da, dlnm]

    sol = solve_ivp(rhs, [t_start, t_end],
                    [a_s_start, np.log(m_start)],
                    method='RK45', rtol=rtol, atol=1e-15,
                    max_step=abs(t_end - t_start) / 10)
    if not sol.success:
        raise RuntimeError(f"ODE solver failed for mass running: {sol.message}")

    a_s_end = sol.y[0, -1]
    m_end = np.exp(sol.y[1, -1])
    return m_end, a_s_end


def running_mass(m_ref, mu_ref, mu_target):
    """
    Compute MS-bar running mass m(mu_target) given m(mu_ref).

    Mass matching at heavy-quark thresholds is treated as continuous
    in this implementation.

    Strategy: run alpha_s and mass together through thresholds.

    Parameters
    ----------
    m_ref : float
        Quark mass at mu_ref [MeV]
    mu_ref : float
        Reference scale [GeV]
    mu_target : float
        Target scale [GeV]

    Returns
    -------
    float : m(mu_target) [MeV]
    """
    # Get a_s at mu_ref
    a_s_ref = alpha_s_at_mu(mu_ref)

    # Thresholds
    thresholds = sorted([MC_MC, MB_MB, MT_MT])

    def get_nf(mu):
        """Number of active flavours at scale mu."""
        nf = 3
        if mu >= MC_MC: nf = 4
        if mu >= MB_MB: nf = 5
        if mu >= MT_MT: nf = 6
        return nf

    mu_current = mu_ref
    m_current = m_ref
    a_s_current = a_s_ref

    if mu_target > mu_ref:
        # Run upward: build waypoints [mu_ref, thr1, thr2, ..., mu_target]
        waypoints = [mu_ref]
        for thr in thresholds:
            if mu_ref < thr < mu_target:
                waypoints.append(thr)
        waypoints.append(mu_target)

        for i in range(len(waypoints) - 1):
            mu_from = waypoints[i]
            mu_to = waypoints[i + 1]
            # Determine n_f using midpoint (avoids boundary ambiguity)
            nf = get_nf((mu_from + mu_to) / 2.0)
            m_current, a_s_current = run_mass_segment(
                m_current, a_s_current, mu_from, mu_to, nf)
            mu_current = mu_to
    else:
        # Run downward: build waypoints [mu_ref, thr_high, ..., thr_low, mu_target]
        waypoints = [mu_ref]
        for thr in reversed(thresholds):
            if mu_target < thr < mu_ref:
                waypoints.append(thr)
        waypoints.append(mu_target)

        for i in range(len(waypoints) - 1):
            mu_from = waypoints[i]
            mu_to = waypoints[i + 1]
            # Determine n_f using midpoint (avoids boundary ambiguity)
            nf = get_nf((mu_from + mu_to) / 2.0)
            m_current, a_s_current = run_mass_segment(
                m_current, a_s_current, mu_from, mu_to, nf)
            mu_current = mu_to

    return m_current


# =========================================================================
# Coefficients of the down-type quark mass formula
# =========================================================================
def mass_formula_k(n):
    """
    Bracket coefficient k_n of the mass formula:
      k_n = 1 + sqrt(2 + sqrt(2 - sqrt(2))/2) * cos(1/9 + 2 n pi / 3)
    n = 1 (down), 2 (strange), 3 (bottom)
    Angle in radians.
    """
    prefactor = np.sqrt(2.0 + np.sqrt(2.0 - np.sqrt(2.0)) / 2.0)
    angle = 1.0 / 9.0 + 2.0 * n * np.pi / 3.0  # [radians]
    return 1.0 + prefactor * np.cos(angle)


# Precompute the coefficients
K1 = mass_formula_k(1)  # down
K2 = mass_formula_k(2)  # strange
K3 = mass_formula_k(3)  # bottom


def compute_A(mu_d3):
    """
    Compute the amplitude A(mu_d3).

    A(mu_d3) = sqrt(m_b(mu_d3) [MeV]) / k_3

    m_b(mu_d3) is the bottom-quark MS-bar running mass evaluated at the
    scale mu_d3. Reference: m_b(m_b) = 4183 MeV.
    """
    # Obtain m_b(mu_d3) via RG running [MeV]
    mb_at_mud3 = running_mass(MB_INPUT, MB_MB, mu_d3)
    A = np.sqrt(mb_at_mud3) / K3
    return A


def predicted_mass(A, n):
    """
    Compute the predicted mass from the mass formula [MeV].
    m_dn^pred = (A * k_n)^2
    """
    kn = mass_formula_k(n)
    return (A * kn) ** 2


# =========================================================================
# Numerical determination of mu_d1, mu_d2 via root finding
# =========================================================================
def find_mu_scale(m_pred, m_ref, mu_ref, mu_low=0.65, mu_high=1000.0):
    """
    Find the scale mu that satisfies m_q^(MS-bar)(mu) = m_pred using brentq.

    The running mass is a decreasing function of mu (asymptotic freedom),
    so we solve f(mu) = running_mass(m_ref, mu_ref, mu) - m_pred = 0.

    Near the lower bound alpha_s becomes large and perturbation theory is
    unreliable; the code therefore searches dynamically for an effective
    lower bound. Following the user's specification, tentative solutions
    below 1 GeV are also allowed for mu_d1 and mu_d2.

    Parameters
    ----------
    m_pred : float
        Target mass [MeV]
    m_ref : float
        Reference mass [MeV] at mu_ref
    mu_ref : float
        Reference scale [GeV]
    mu_low, mu_high : float
        Search bounds [GeV]

    Returns
    -------
    float : mu [GeV] or NaN if no solution
    """
    def func(mu):
        return running_mass(m_ref, mu_ref, mu) - m_pred

    try:
        # Dynamically determine an effective lower bound in case the ODE
        # diverges at the nominal lower bound.
        f_low = None
        effective_mu_low = mu_low
        # Try the nominal lower bound first; raise it step by step on failure.
        test_points = [mu_low, mu_low * 1.1, mu_low * 1.3, mu_low * 1.5,
                       mu_low * 2.0]
        if 1.0 < mu_high:
            test_points.append(1.0)
        for test_mu in test_points:
            if test_mu >= mu_high:
                break
            try:
                f_low = func(test_mu)
                effective_mu_low = test_mu
                break
            except (RuntimeError, ValueError):
                continue

        if f_low is None:
            return np.nan

        f_high = func(mu_high)

        if f_low * f_high > 0:
            # No sign change: no solution in range
            return np.nan

        mu_sol = brentq(func, effective_mu_low, mu_high, xtol=1e-8, rtol=1e-10)
        return mu_sol
    except (ValueError, RuntimeError):
        return np.nan


# =========================================================================
# Main calculation
# =========================================================================
def main():
    print("=" * 70)
    print("Down-type quark mass formula: generation of a 2D log-log plot")
    print("=" * 70)

    # ------------------------------------------------------------------
    # 1. Print the mass-formula coefficients
    # ------------------------------------------------------------------
    print(f"\n* Mass-formula coefficients k_n:")
    print(f"  k_1 (down)    = {K1:.10f}")
    print(f"  k_2 (strange) = {K2:.10f}")
    print(f"  k_3 (bottom)  = {K3:.10f}")

    # ------------------------------------------------------------------
    # 2. Consistency check for A(4.183 GeV)
    # ------------------------------------------------------------------
    A_check = compute_A(MB_MB)
    print(f"\n* A(4.183 GeV) consistency check:")
    print(f"  m_b(m_b)  = {MB_INPUT:.1f} MeV")
    print(f"  k_3       = {K3:.10f}")
    print(f"  A(4.183)  = sqrt({MB_INPUT:.1f}) / {K3:.10f}")
    print(f"           = {A_check:.6f}")
    print(f"  Expected  ~= 25.522566")
    print(f"  Difference = {abs(A_check - 25.522566):.6e}")
    assert abs(A_check - 25.522566) < 0.01, \
        f"A(4.183 GeV) = {A_check}, expected ~= 25.522566"
    print("  -> OK: consistency check passed")

    # ------------------------------------------------------------------
    # 3. Predicted masses at the reference point
    # ------------------------------------------------------------------
    print(f"\n* Predicted masses at the reference point (mu_d3 = 4.183 GeV):")
    A_ref = A_check
    m_d_pred_ref = predicted_mass(A_ref, 1)
    m_s_pred_ref = predicted_mass(A_ref, 2)
    m_b_pred_ref = predicted_mass(A_ref, 3)
    print(f"  m_d^pred = {m_d_pred_ref:.4f} MeV  (input: {MD_INPUT} MeV at 2 GeV)")
    print(f"  m_s^pred = {m_s_pred_ref:.4f} MeV  (input: {MS_INPUT} MeV at 2 GeV)")
    print(f"  m_b^pred = {m_b_pred_ref:.4f} MeV  (input: {MB_INPUT} MeV at m_b)")

    # m_d at 2 GeV from running
    m_d_at_2 = running_mass(MD_INPUT, MU_D_INPUT, 2.0)
    m_s_at_2 = running_mass(MS_INPUT, MU_D_INPUT, 2.0)
    print(f"\n* Running-mass check:")
    print(f"  m_d(2 GeV) = {m_d_at_2:.4f} MeV  (should be {MD_INPUT})")
    print(f"  m_s(2 GeV) = {m_s_at_2:.4f} MeV  (should be {MS_INPUT})")

    # Additional internal consistency check: run to a scale and come back.
    mu_test = 10.0
    mb_at_10 = running_mass(MB_INPUT, MB_MB, mu_test)
    mb_back = running_mass(mb_at_10, mu_test, MB_MB)
    print(f"\n* Running-mass round-trip check:")
    print(f"  m_b(10 GeV)         = {mb_at_10:.4f} MeV")
    print(f"  m_b(10 GeV -> m_b)  = {mb_back:.4f} MeV")
    print(f"  round-trip error    = {abs(mb_back - MB_INPUT):.6e} MeV")

    # ------------------------------------------------------------------
    # 4. Scan mu_d3 and generate plot data
    # ------------------------------------------------------------------
    print(f"\n* mu_d3 scan range: 1.0 - 1000.0 GeV (log-spaced)")
    N_POINTS = 100  # number of grid points
    mu_d3_array = np.logspace(np.log10(1.0), np.log10(1000.0), N_POINTS)

    mu_d1_array = np.full(N_POINTS, np.nan)
    mu_d2_array = np.full(N_POINTS, np.nan)

    print("  Computing...")
    for i, mu_d3 in enumerate(mu_d3_array):
        try:
            # Step 1: Obtain the running bottom mass m_b(mu_d3).
            # Step 2: Determine A(mu_d3).
            A_val = compute_A(mu_d3)

            # Step 3: Compute the predicted mass [MeV].
            m_d1_pred = predicted_mass(A_val, 1)  # down
            m_d2_pred = predicted_mass(A_val, 2)  # strange

            # Step 4: Determine mu_d1, mu_d2 via root finding.
            # m_d^(MS-bar)(mu_d1) = m_d1_pred
            # -> running_mass(MD_INPUT, 2.0 GeV, mu_d1) = m_d1_pred
            if m_d1_pred > 0:
                mu_d1_array[i] = find_mu_scale(
                    m_d1_pred, MD_INPUT, MU_D_INPUT, mu_low=0.65, mu_high=1000.0)

            # m_s^(MS-bar)(mu_d2) = m_d2_pred
            if m_d2_pred > 0:
                mu_d2_array[i] = find_mu_scale(
                    m_d2_pred, MS_INPUT, MU_D_INPUT, mu_low=0.65, mu_high=1000.0)

        except Exception as e:
            # Leave as NaN if no solution exists.
            warnings.warn(f"mu_d3 = {mu_d3:.3f} GeV: {e}")
            continue

        if (i + 1) % 20 == 0:
            print(f"    ... {i + 1}/{N_POINTS} points completed")

    print(f"  All {N_POINTS} points computed")

    # Count valid data points.
    n_valid_d1 = np.sum(~np.isnan(mu_d1_array))
    n_valid_d2 = np.sum(~np.isnan(mu_d2_array))
    print(f"  Valid data points: mu_d1 = {n_valid_d1}, mu_d2 = {n_valid_d2}")

    # ------------------------------------------------------------------
    # 5. Produce the plot
    # ------------------------------------------------------------------
    print("\n* Generating plot...")
    fig, ax = plt.subplots(1, 1, figsize=(10, 7))

    # Color-blind-safe palette (Okabe-Ito family) chosen to avoid the
    # red/green confusion common in deuteranopia/protanopia.
    # We also separate the series by line style so the curves remain
    # distinguishable in grayscale printouts.
    color_d1 = "#0072B2"  # blue
    color_d2 = "#D55E00"  # vermillion
    color_ref = "#999999"  # neutral gray for the reference marker
    ls_d1 = "-"
    ls_d2 = "--"
    ls_ref = "-."

    # mu_d1 curve (down quark)
    mask1 = ~np.isnan(mu_d1_array)
    ax.plot(mu_d3_array[mask1], mu_d1_array[mask1],
            color=color_d1, linestyle=ls_d1, linewidth=2.2,
            label=r'$\mu_{d1}$ (down)')

    # mu_d2 curve (strange quark)
    mask2 = ~np.isnan(mu_d2_array)
    ax.plot(mu_d3_array[mask2], mu_d2_array[mask2],
            color=color_d2, linestyle=ls_d2, linewidth=2.2,
            label=r'$\mu_{d2}$ (strange)')

    # Reference-point marker
    ax.axvline(x=MB_MB, color=color_ref, linestyle=ls_ref, alpha=0.7,
               linewidth=1.5,
               label=r'$\mu_{d3} = m_b(m_b) = 4.183$ GeV')

    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_xlabel(r'$\mu_{d3}$ [GeV]', fontsize=14)
    ax.set_ylabel(r'$\mu_{d1},\, \mu_{d2}$ [GeV]', fontsize=14)
    ax.set_title(
        r'Down-type quark mass formula: $\mu_{d1}(\mu_{d3})$ and $\mu_{d2}(\mu_{d3})$'
        '\n(4-loop QCD running, continuous threshold matching, $\\overline{\\mathrm{MS}}$ scheme)',
        fontsize=13)
    ax.legend(fontsize=12, loc='best')
    ax.grid(True, which='both', alpha=0.3)
    ax.tick_params(labelsize=12)

    plt.tight_layout()

    # Save figure
    output_path = 'down_quark_mass_plot_fixed.png'
    fig.savefig(output_path, dpi=150, bbox_inches='tight')
    print(f"  Plot saved: {output_path}")

    plt.close(fig)

    # ------------------------------------------------------------------
    # 6. Print numerical notes
    # ------------------------------------------------------------------
    print("\n" + "=" * 70)
    print("* Numerical notes")
    print("=" * 70)
    print("""
  1. Uses the 4-loop beta-function and 4-loop mass anomalous dimension.
  2. At each threshold n_f is switched and the RG is solved piecewise.
     Finite higher-order decoupling corrections are not included;
     alpha_s and the light-quark masses are matched continuously at
     mu = m_q(m_q).
  3. Root finding uses brentq (bracketing method, numerically stable).
     Following the user's specification, the search range is
     [0.65 GeV, 1000 GeV], and tentative solutions below 1 GeV
     for mu_{d1}, mu_{d2} are accepted.
  4. ODE solver: RK45 (relative tolerance 1e-12).
  5. Units: masses in MeV, scales in GeV. Conversions are explicit.
  6. PDG 2024 values used:
       alpha_s(m_Z) = 0.1180
       m_Z          = 91.1876 GeV
       m_c(m_c)     = 1.273 GeV
       m_b(m_b)     = 4.183 GeV
       m_t(m_t)     = 162.5 GeV
       m_d(2 GeV)   = 4.7 MeV
       m_s(2 GeV)   = 93.5 MeV
""")


if __name__ == '__main__':
    main()
