#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
==========================================================================
Inverse-tuple Brannen parametrization: scale dependence under 4-loop
MS-bar QCD running.

Generates Figure 2 of the manuscript (saved as Figure_2.png).

What is computed
----------------
For each scale mu in the range [1, 1000] GeV, we evaluate the 4-loop
MS-bar running masses m_d(mu), m_s(mu), m_b(mu) using the same machinery
as Script 1 (Appendix A), then form the inverse square-root tuple
1 / sqrt(m_dn(mu)) and fit it to the Brannen-type form

    1/sqrt(m_dn(mu)) = A_inv(mu) * (1 + 2 T_inv(mu) cos(phi_inv(mu)/3 + 2 k pi/3)),
    k = 4 - n,   n = 1, 2, 3.

Compared with the direct mass formula cos(phi_d/3 + 2 n pi/3), this inverse
form keeps the cosine in the same Brannen shape but reverses the index by
k = 4 - n. This single discrete relabelling reflects the inversion of the
underlying mass ordering: sqrt(m_dn) is monotonically increasing in n,
whereas 1/sqrt(m_dn) is monotonically decreasing, so that the Brannen
reference (the dominant branch at angle 0) is taken at index 3 in both
cases: n = 3 (bottom) for the direct form and k = 3 (i.e. n = 1, down)
for the inverse form. In Brannen's geometric reading, the inverse-tuple
parametrization corresponds to the same spherical triangle traversed in
the opposite orientation.

The fit is closed-form (3 unknowns, 3 data points). Under flavour-universal
QCD running the parameters T_inv and phi_inv are scale-independent by
construction; only the overall amplitude A_inv(mu) carries the
multiplicative dressing.

Output
------
    Figure_2.png     : top panel = 1/sqrt(m_dn(mu)) vs mu (log-log);
                      bottom panel = T_inv(mu) with the reference
                      line 1/sqrt(2) ~= 0.7071 marked.

Imports the running-mass machinery from Script_S1.py (Appendix A).
==========================================================================
"""

from __future__ import annotations

import math
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

from Script_S1 import (
    running_mass,
    MD_INPUT,
    MS_INPUT,
    MB_INPUT,
    MU_D_INPUT,
    MB_MB,
)

PI = math.pi


def inverse_brannen_fit(m1: float, m2: float, m3: float):
    """
    Solve

        1/sqrt(m_n) = A_inv * (1 + 2 T_inv cos(phi_inv/3 + 2 k pi / 3)),
        k = 4 - n,   n = 1, 2, 3,

    for (A_inv, T_inv, phi_inv) given the three masses (m1, m2, m3).

    The inversion is closed-form. With theta = phi_inv/3 and
    u_n = (1/sqrt(m_n) / A_inv - 1) / 2 we have

        n = 1 (down):    k = 3,    u_1 = T_inv cos(theta + 2 pi)
                                       = T_inv cos(theta),
        n = 2 (strange): k = 2,    u_2 = T_inv cos(theta + 4 pi/3),
        n = 3 (bottom):  k = 1,    u_3 = T_inv cos(theta + 2 pi/3),

    so the closed-form inversion is

        T_inv cos(theta) = u_1,
        T_inv sin(theta) = (u_2 - u_3) / sqrt(3),

    which is identical (modulo the period-2 pi equivalence
    cos(theta + 4 pi/3) = cos(theta - 2 pi/3)) to the relations of any
    Brannen-type three-mass inversion.
    """
    inv1 = 1.0 / np.sqrt(m1)
    inv2 = 1.0 / np.sqrt(m2)
    inv3 = 1.0 / np.sqrt(m3)
    A = (inv1 + inv2 + inv3) / 3.0

    u1 = (inv1 / A - 1.0) / 2.0
    u2 = (inv2 / A - 1.0) / 2.0
    u3 = (inv3 / A - 1.0) / 2.0

    cos_theta_T = u1
    sin_theta_T = (u2 - u3) / np.sqrt(3.0)
    T = np.hypot(cos_theta_T, sin_theta_T)

    cos_theta = cos_theta_T / T
    sin_theta = sin_theta_T / T
    theta = np.arctan2(sin_theta, cos_theta)
    phi = 3.0 * theta
    return A, T, phi


def main():
    mu_grid = np.logspace(0.0, 3.0, 200)  # 1 GeV ... 1000 GeV

    inv_d1 = np.zeros_like(mu_grid)
    inv_d2 = np.zeros_like(mu_grid)
    inv_d3 = np.zeros_like(mu_grid)
    A_inv = np.zeros_like(mu_grid)
    T_inv = np.zeros_like(mu_grid)
    phi_inv = np.zeros_like(mu_grid)

    print("Computing 4-loop MS-bar running masses and inverse Brannen fit ...")
    for i, mu in enumerate(mu_grid):
        md = running_mass(MD_INPUT, MU_D_INPUT, float(mu))
        ms = running_mass(MS_INPUT, MU_D_INPUT, float(mu))
        mb = running_mass(MB_INPUT, MB_MB, float(mu))
        inv_d1[i] = 1.0 / np.sqrt(md)
        inv_d2[i] = 1.0 / np.sqrt(ms)
        inv_d3[i] = 1.0 / np.sqrt(mb)
        A, T, phi = inverse_brannen_fit(md, ms, mb)
        A_inv[i] = A
        T_inv[i] = T
        phi_inv[i] = phi
        if (i + 1) % 50 == 0:
            print(f"  ... {i + 1}/{len(mu_grid)} points")

    # Diagnostic at a few representative scales
    for mu_target in [2.0, MB_MB, 91.1876]:
        i_t = int(np.argmin(np.abs(mu_grid - mu_target)))
        print(
            f"  mu = {mu_grid[i_t]:8.3f} GeV : "
            f"A_inv = {A_inv[i_t]:.6e} MeV^(-1/2), "
            f"T_inv = {T_inv[i_t]:.6f}, "
            f"phi_inv = {phi_inv[i_t]:.6f} rad"
        )

    # ---------------- plot (3 panels) ----------------
    fig, (ax1, ax2, ax3) = plt.subplots(
        3, 1, figsize=(9, 9.5), sharex=True,
        gridspec_kw={"height_ratios": [2.4, 1.2, 1.2], "hspace": 0.06},
    )

    color_d = "#0072B2"
    color_s = "#D55E00"
    color_b = "#CC79A7"
    color_T = "#009E73"
    color_phi = "#E69F00"

    # Top panel: 1/sqrt(m_dn) vs mu
    ax1.plot(mu_grid, inv_d1, lw=2.0, color=color_d, ls="-",
             label=r"$1/\sqrt{m_{d1}(\mu)}$ (down)")
    ax1.plot(mu_grid, inv_d2, lw=2.0, color=color_s, ls="--",
             label=r"$1/\sqrt{m_{d2}(\mu)}$ (strange)")
    ax1.plot(mu_grid, inv_d3, lw=2.0, color=color_b, ls="-.",
             label=r"$1/\sqrt{m_{d3}(\mu)}$ (bottom)")
    ax1.set_yscale("log")
    ax1.set_xscale("log")
    ax1.set_ylabel(r"$1/\sqrt{m_{dn}(\mu)}\;\;[\mathrm{MeV}^{-1/2}]$",
                   fontsize=12)
    ax1.legend(loc="best", fontsize=10)
    ax1.grid(True, which="both", alpha=0.3)
    ax1.set_title(
        r"Inverse-tuple Brannen parametrization under 4-loop "
        r"$\overline{\mathrm{MS}}$ QCD running",
        fontsize=12,
    )

    # Middle panel: T_inv(mu) and reference 1/sqrt(2)
    ax2.plot(mu_grid, T_inv, lw=2.0, color=color_T, ls="-",
             label=r"$T_{d\,\mathrm{inv}}^{\mathrm{fit}}(\mu)$")
    ax2.axhline(1.0 / np.sqrt(2.0), color="black", ls=":", lw=1.5,
                label=r"$1/\sqrt{2}\,\approx\,0.7071$")
    ax2.set_ylabel(r"$T_{d\,\mathrm{inv}}^{\mathrm{fit}}$", fontsize=12)
    ax2.set_xscale("log")
    band_T = max(0.001, 4 * (T_inv.max() - T_inv.min()))
    centre_T = 0.5 * (T_inv.max() + T_inv.min())
    ax2.set_ylim(min(centre_T - band_T, 1.0 / np.sqrt(2.0) - 0.0008),
                 max(centre_T + band_T, 1.0 / np.sqrt(2.0) + 0.0015))
    ax2.legend(loc="best", fontsize=10)
    ax2.grid(True, which="both", alpha=0.3)

    # Bottom panel: phi_inv(mu)
    ax3.plot(mu_grid, phi_inv, lw=2.0, color=color_phi, ls="-",
             label=r"$\phi_{d\,\mathrm{inv}}^{\mathrm{fit}}(\mu)$")
    ax3.set_xlabel(r"$\mu\;\;[\mathrm{GeV}]$", fontsize=12)
    ax3.set_ylabel(r"$\phi_{d\,\mathrm{inv}}^{\mathrm{fit}}\;\;[\mathrm{rad}]$",
                   fontsize=12)
    ax3.set_xscale("log")
    band_phi = max(0.005, 4 * (phi_inv.max() - phi_inv.min()))
    centre_phi = 0.5 * (phi_inv.max() + phi_inv.min())
    ax3.set_ylim(centre_phi - band_phi, centre_phi + band_phi)
    ax3.legend(loc="best", fontsize=10)
    ax3.grid(True, which="both", alpha=0.3)

    fig.tight_layout()
    fig.savefig("Figure_2.png", dpi=150, bbox_inches="tight")
    plt.close(fig)
    print("Saved Figure_2.png")


if __name__ == "__main__":
    main()
