#!/usr/bin/env python3
"""
==========================================================================
Down-type quark mass formula: low-energy window with local linear
extrapolation near the left edge
==========================================================================

Purpose:
  Starting from the low-energy plot generated by down_quark_mass.py,
  approximate the finite points near the left edge by local straight lines,
  and extend them visually down to mu_d3 = 0.587 GeV.

Important:
  - This is not a physical prediction. It is only a visualization of a local
    linear extrapolation near the left edge of the scanned window.
  - Below about 0.59 GeV the calculation is already close to the Landau-pole
    region and perturbation theory is not reliable.
  - The figure is therefore only a visual aid for the apparent crossing trend
    at low energy, not evidence for an actual unification scale.

Model assumptions:
  - same 4-loop MS-bar running as in down_quark_mass.py
  - continuous matching at flavour thresholds
  - finite decoupling constants are not included
==========================================================================
"""

from __future__ import annotations

import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import warnings

from Script_S1 import (
    alpha_s_at_mu,
    compute_A,
    predicted_mass,
    find_mu_scale,
    MD_INPUT,
    MS_INPUT,
    MU_D_INPUT,
)

LANDAU_POLE_APPROX = 0.5870   # [GeV], visual marker only
MU_D3_MIN = 0.5900            # scanned lower edge [GeV]
MU_D3_MAX = 1.0000            # scanned upper edge [GeV]
MU_EXTRAP_MIN = 0.5870        # extrapolation stop [GeV]
N_POINTS = 80
N_FIT = 6                     # number of left-edge points used for local linear fit


def compute_scan(mu_min: float = MU_D3_MIN,
                 mu_max: float = MU_D3_MAX,
                 n_points: int = N_POINTS):
    """Scan mu_d3 and compute mu_d1(mu_d3), mu_d2(mu_d3)."""
    mu_d3_array = np.logspace(np.log10(mu_min), np.log10(mu_max), n_points)
    mu_d1_array = np.full(n_points, np.nan)
    mu_d2_array = np.full(n_points, np.nan)

    for i, mu_d3 in enumerate(mu_d3_array):
        if (i + 1) % 20 == 0:
            print(f"  ... {i + 1}/{n_points} points")
        try:
            A_val = compute_A(mu_d3)
            m_d1_pred = predicted_mass(A_val, 1)
            m_d2_pred = predicted_mass(A_val, 2)

            if m_d1_pred > 0.0:
                mu_d1_array[i] = find_mu_scale(
                    m_d1_pred,
                    MD_INPUT,
                    MU_D_INPUT,
                    mu_low=mu_min,
                    mu_high=1000.0,
                )

            if m_d2_pred > 0.0:
                mu_d2_array[i] = find_mu_scale(
                    m_d2_pred,
                    MS_INPUT,
                    MU_D_INPUT,
                    mu_low=mu_min,
                    mu_high=1000.0,
                )
        except Exception as exc:
            warnings.warn(f"mu_d3 = {mu_d3:.6f} GeV: {exc}")

    return mu_d3_array, mu_d1_array, mu_d2_array


def fit_local_line(x: np.ndarray, y: np.ndarray, n_fit: int = N_FIT):
    """Fit y ~= a x + b using the first n_fit finite points at the left edge."""
    mask = np.isfinite(x) & np.isfinite(y)
    xv = x[mask]
    yv = y[mask]
    if len(xv) < max(2, n_fit):
        raise ValueError("Not enough finite points for local linear extrapolation")

    xf = xv[:n_fit]
    yf = yv[:n_fit]
    a, b = np.polyfit(xf, yf, deg=1)
    return float(a), float(b), xf, yf


def line_crossing_with_diagonal(a: float, b: float):
    """Return x where a x + b = x, if well defined."""
    if abs(a - 1.0) < 1e-14:
        return np.nan
    return -b / (a - 1.0)


def line_crossing(a1: float, b1: float, a2: float, b2: float):
    """Return x where a1 x + b1 = a2 x + b2, if well defined."""
    if abs(a1 - a2) < 1e-14:
        return np.nan
    return -(b1 - b2) / (a1 - a2)


def local_extrapolation(mu_d3: np.ndarray, mu_d1: np.ndarray, mu_d2: np.ndarray):
    """Build local linear extrapolations from the left edge down to MU_EXTRAP_MIN."""
    a1, b1, xf1, yf1 = fit_local_line(mu_d3, mu_d1, N_FIT)
    a2, b2, xf2, yf2 = fit_local_line(mu_d3, mu_d2, N_FIT)

    x_ext = np.linspace(MU_EXTRAP_MIN, MU_D3_MIN, 60)
    y1_ext = a1 * x_ext + b1
    y2_ext = a2 * x_ext + b2
    y3_ext = x_ext  # diagonal

    cross_d1 = line_crossing_with_diagonal(a1, b1)
    cross_d2 = line_crossing_with_diagonal(a2, b2)
    cross_12 = line_crossing(a1, b1, a2, b2)

    return {
        "x_ext": x_ext,
        "y1_ext": y1_ext,
        "y2_ext": y2_ext,
        "y3_ext": y3_ext,
        "a1": a1,
        "b1": b1,
        "a2": a2,
        "b2": b2,
        "fit_x1": xf1,
        "fit_y1": yf1,
        "fit_x2": xf2,
        "fit_y2": yf2,
        "cross_d1": float(cross_d1),
        "cross_d2": float(cross_d2),
        "cross_12": float(cross_12),
    }


def closest_approach(mu_d3: np.ndarray, mu_d1: np.ndarray, mu_d2: np.ndarray):
    """Find the closest approach to mu_d1 = mu_d2 = mu_d3 in the scanned window."""
    mask = np.isfinite(mu_d3) & np.isfinite(mu_d1) & np.isfinite(mu_d2)
    if not np.any(mask):
        return None

    x = mu_d3[mask]
    y1 = mu_d1[mask]
    y2 = mu_d2[mask]
    spread = np.maximum.reduce([
        np.abs(y1 - x),
        np.abs(y2 - x),
        np.abs(y1 - y2),
    ])
    idx = np.argmin(spread)
    return {
        "mu_d3": float(x[idx]),
        "mu_d1": float(y1[idx]),
        "mu_d2": float(y2[idx]),
        "spread": float(spread[idx]),
    }


def make_plot(mu_d3_array: np.ndarray,
              mu_d1_array: np.ndarray,
              mu_d2_array: np.ndarray,
              closest: dict | None,
              extrap: dict,
              output_path: str):
    fig, ax = plt.subplots(figsize=(10, 8))

    # Color-blind-safe palette (Okabe-Ito family) chosen to avoid the
    # red/green confusion common in deuteranopia/protanopia.
    # We also separate the series by line style so the curves remain
    # distinguishable in grayscale printouts.
    color_d1 = "#0072B2"   # blue
    color_d2 = "#D55E00"   # vermillion
    color_d3 = "#CC79A7"   # reddish purple
    color_aux = "#999999"  # neutral gray for scan-edge / Landau guide
    ls_d1 = "-"
    ls_d2 = "--"
    ls_d3 = "-."

    m1 = np.isfinite(mu_d1_array)
    m2 = np.isfinite(mu_d2_array)

    # Main curves in the scanned window
    ax.plot(mu_d3_array[m1], mu_d1_array[m1], linewidth=2.5,
            color=color_d1, linestyle=ls_d1,
            label=r'$\mu_{d1}$ (down)', zorder=3)
    ax.plot(mu_d3_array[m2], mu_d2_array[m2], linewidth=2.5,
            color=color_d2, linestyle=ls_d2,
            label=r'$\mu_{d2}$ (strange)', zorder=3)
    ax.plot(mu_d3_array, mu_d3_array, linewidth=2.5,
            color=color_d3, linestyle=ls_d3,
            label=r'$\mu_{d3}$ (diagonal $y=x$)', zorder=2)

    # Extrapolation helper lines -- dotted to distinguish from scanned curves
    ax.plot(extrap["x_ext"], extrap["y1_ext"], linestyle=':', linewidth=2.0,
            color=color_d1, alpha=0.85,
            label=r'local linear extrapolation of $\mu_{d1}$', zorder=2)
    ax.plot(extrap["x_ext"], extrap["y2_ext"], linestyle=':', linewidth=2.0,
            color=color_d2, alpha=0.85,
            label=r'local linear extrapolation of $\mu_{d2}$', zorder=2)
    ax.plot(extrap["x_ext"], extrap["y3_ext"], linestyle=':', linewidth=2.0,
            color=color_d3, alpha=0.85,
            label=r'extended diagonal to 0.587 GeV', zorder=1)

    # Mark the scan lower edge and the Landau-pole guide
    ax.axvline(MU_D3_MIN, linestyle=(0, (1, 1)), alpha=0.75, linewidth=1.5,
               color=color_aux,
               label=rf'scan lower edge = {MU_D3_MIN:.3f} GeV')
    ax.axvline(LANDAU_POLE_APPROX, linestyle=(0, (3, 1, 1, 1)),
               alpha=0.9, linewidth=1.5, color='black',
               label=rf'Landau-pole guide = {LANDAU_POLE_APPROX:.3f} GeV')
    ax.axhline(LANDAU_POLE_APPROX, linestyle=(0, (3, 1, 1, 1)),
               alpha=0.6, linewidth=1.2, color='black')

    # Show the points used for the left-edge fit
    ax.scatter(extrap["fit_x1"], extrap["fit_y1"],
               color=color_d1, marker='o', s=20, alpha=0.9, zorder=4)
    ax.scatter(extrap["fit_x2"], extrap["fit_y2"],
               color=color_d2, marker='s', s=20, alpha=0.9, zorder=4)

    # Closest approach within the scanned window
    if closest is not None:
        x0 = closest["mu_d3"]
        y1 = closest["mu_d1"]
        y2 = closest["mu_d2"]
        ax.plot([x0], [x0], marker='D', color=color_d3,
                markersize=7, zorder=5, markeredgecolor='black')
        ax.plot([x0], [y1], marker='o', color=color_d1,
                markersize=7, zorder=5, markeredgecolor='black')
        ax.plot([x0], [y2], marker='s', color=color_d2,
                markersize=7, zorder=5, markeredgecolor='black')

    # Estimated extrapolated crossing if the three lines cluster in the extension window
    crosses = np.array([extrap["cross_d1"], extrap["cross_d2"], extrap["cross_12"]], dtype=float)
    finite_crosses = crosses[np.isfinite(crosses)]
    if len(finite_crosses) > 0:
        x_cross = float(np.mean(finite_crosses))
        if MU_EXTRAP_MIN * 0.995 <= x_cross <= MU_D3_MIN * 1.01:
            ax.plot([x_cross], [x_cross], marker='*', color='black',
                    markersize=18, markeredgecolor='white', zorder=6)
            ax.annotate(
                'local-linear crossing estimate\n'
                rf'$\mu \approx {x_cross:.4f}$ GeV',
                xy=(x_cross, x_cross),
                xytext=(x_cross * 1.045, x_cross * 0.90),
                fontsize=11,
                arrowprops=dict(arrowstyle='->', lw=1.4, color='black'),
                bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.9),
                color='black',
                zorder=7,
            )

    info_lines = [
        'Dashed segments: local linear extrapolation from left-edge points',
        rf'fit points used at left edge: {N_FIT}',
        rf'linear root of $\mu_{{d1}}=\mu_{{d3}}$: {extrap["cross_d1"]:.4f} GeV',
        rf'linear root of $\mu_{{d2}}=\mu_{{d3}}$: {extrap["cross_d2"]:.4f} GeV',
        rf'linear root of $\mu_{{d1}}=\mu_{{d2}}$: {extrap["cross_12"]:.4f} GeV',
        'Interpretation: visual aid only, not a perturbative result',
    ]
    ax.text(0.98, 0.03, '\n'.join(info_lines), transform=ax.transAxes,
            fontsize=9, ha='right', va='bottom',
            bbox=dict(boxstyle='round,pad=0.4', facecolor='white', alpha=0.92))

    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_xlim(MU_EXTRAP_MIN, MU_D3_MAX * 1.005)
    ax.set_ylim(MU_EXTRAP_MIN * 0.995, MU_D3_MAX * 1.02)
    ax.set_xlabel(r'$\mu_{d3}$ [GeV]', fontsize=14)
    ax.set_ylabel(r'$\mu_{d1},\,\mu_{d2},\,\mu_{d3}$ [GeV]', fontsize=14)
    ax.set_title(
        r'Down-type quark: low-energy window with local linear extrapolation'
        '\n(4-loop QCD running, $\\overline{\\mathrm{MS}}$, continuous matching)',
        fontsize=13,
    )
    ax.legend(fontsize=10, loc='upper left')
    ax.grid(True, which='both', alpha=0.3)
    ax.tick_params(labelsize=12)

    fig.tight_layout()
    fig.savefig(output_path, dpi=150, bbox_inches='tight')
    plt.close(fig)


def main():
    print('=' * 78)
    print('Low-energy window with local linear extrapolation near the left edge')
    print(f'  scanned window: mu_d3 = {MU_D3_MIN:.3f} -- {MU_D3_MAX:.3f} GeV')
    print(f'  helper extrapolation down to: {MU_EXTRAP_MIN:.3f} GeV')
    print(f'  left-edge fit points used: {N_FIT}')
    print('  note: below 0.59 GeV this is only a simple extrapolation into the nonperturbative region')
    print('=' * 78)

    mu_d3_array, mu_d1_array, mu_d2_array = compute_scan()
    closest = closest_approach(mu_d3_array, mu_d1_array, mu_d2_array)
    extrap = local_extrapolation(mu_d3_array, mu_d1_array, mu_d2_array)

    print('local linear fit results near the left edge:')
    print(f"  mu_d1 ~= ({extrap['a1']:.6f}) mu_d3 + ({extrap['b1']:.6f})")
    print(f"  mu_d2 ~= ({extrap['a2']:.6f}) mu_d3 + ({extrap['b2']:.6f})")
    print(f"  linear root mu_d1 = mu_d3: {extrap['cross_d1']:.6f} GeV")
    print(f"  linear root mu_d2 = mu_d3: {extrap['cross_d2']:.6f} GeV")
    print(f"  linear root mu_d1 = mu_d2: {extrap['cross_12']:.6f} GeV")

    probe_mu = MU_EXTRAP_MIN
    y1_probe = extrap['a1'] * probe_mu + extrap['b1']
    y2_probe = extrap['a2'] * probe_mu + extrap['b2']
    print(f"  extrapolated values at mu_d3={probe_mu:.3f} GeV:")
    print(f"    mu_d1 ~= {y1_probe:.6f} GeV")
    print(f"    mu_d2 ~= {y2_probe:.6f} GeV")
    print(f"    diagonal = {probe_mu:.6f} GeV")
    try:
        alpha_s_value = alpha_s_at_mu(probe_mu) * np.pi
        print(f"  alpha_s({probe_mu:.3f} GeV) ~= {alpha_s_value:.3f}")
        if alpha_s_value > 1.0:
            print('  warning: this probe point is deep in the nonperturbative regime')
    except Exception as exc:
        print(f'  alpha_s evaluation failed near the extrapolation edge: {exc}')

    output_path = '/mnt/data/unification_scale_with_extrapolation.png'
    make_plot(mu_d3_array, mu_d1_array, mu_d2_array, closest, extrap, output_path)
    print(f'plot saved to: {output_path}')


if __name__ == '__main__':
    main()
