#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Local sensitivity in (phi_d, cos delta_d) space; produces the corresponding figure of the Supplemental Material.
Generates Figure 4 of the manuscript (saved as Figure_4.png). RMS relative deviation of the benchmark predictions
from PDG 2024 inputs at mixed scales (m_d, m_s at 2 GeV; m_b at m_b),
b-normalized so that m_b is reproduced exactly.
"""
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import minimize

MD_INPUT = 4.70
MS_INPUT = 93.5
MB_INPUT = 4183.0
PHI_BENCH = 1.0 / 3.0
COS_DELTA_BENCH = 0.5 * np.cos(3.0 * np.pi / 8.0)


def benchmark_predictions(phi, cos_delta):
    T_d = np.sqrt((1.0 + cos_delta) / 2.0)
    two_T = 2.0 * T_d
    third_phi = phi / 3.0
    k1 = 1.0 + two_T * np.cos(third_phi + 2.0 * np.pi / 3.0)
    k2 = 1.0 + two_T * np.cos(third_phi + 4.0 * np.pi / 3.0)
    k3 = 1.0 + two_T * np.cos(third_phi)
    A = np.sqrt(MB_INPUT) / k3
    return (A * k1) ** 2, (A * k2) ** 2, (A * k3) ** 2


def rms_residual(phi, cos_delta):
    m1, m2, m3 = benchmark_predictions(phi, cos_delta)
    r1 = (m1 - MD_INPUT) / MD_INPUT
    r2 = (m2 - MS_INPUT) / MS_INPUT
    r3 = (m3 - MB_INPUT) / MB_INPUT
    return np.sqrt((r1 ** 2 + r2 ** 2 + r3 ** 2) / 3.0)


def find_best_fit():
    def obj(x):
        return float(rms_residual(float(x[0]), float(x[1])))
    res = minimize(obj, x0=(PHI_BENCH, COS_DELTA_BENCH),
                   method="L-BFGS-B",
                   bounds=((0.0, 2.0 / 3.0), (0.0, 0.4)),
                   options={"ftol": 1e-12, "gtol": 1e-10})
    return float(res.x[0]), float(res.x[1]), float(res.fun)


COL_BENCH = "#000000"
COL_BEST = "#D55E00"


def _draw_panel(ax, phi_grid, cd_grid, rms_arr, phi_best, cd_best,
                xlim, ylim, show_legend, vmax):
    pcm = ax.pcolormesh(phi_grid, cd_grid, rms_arr * 100.0,
                        cmap=plt.get_cmap("viridis_r"),
                        shading="auto", vmin=0.5, vmax=vmax)
    cs = ax.contour(phi_grid, cd_grid, rms_arr * 100.0,
                    levels=[0.5, 1.0, 2.0, 5.0],
                    colors="white", linewidths=1.2)
    ax.clabel(cs, inline=True, fmt=r"%.1f%%", fontsize=8)
    label_b = (r"adopted benchmark "
               r"$(\phi_d, \cos\delta_d) = (1/3,\, \frac{1}{2}\cos(\frac{3}{8}\pi))$")
    ax.plot([PHI_BENCH], [COS_DELTA_BENCH], marker="s", ms=11,
            mfc=COL_BENCH, mec="white", mew=1.2, label=label_b)
    label_best = (r"unconstrained best fit $(\phi^{\ast}, \cos\delta^{\ast})"
                  + r" \approx (%.3f,\, %.3f)$" % (phi_best, cd_best))
    ax.plot([phi_best], [cd_best], marker="^", ms=12,
            mfc=COL_BEST, mec="white", mew=1.2, label=label_best)
    ax.set_xlabel(r"$\phi_d$")
    ax.set_ylabel(r"$\cos\delta_d$")
    ax.set_xlim(*xlim)
    ax.set_ylim(*ylim)
    ax.grid(True, ls=":", alpha=0.3)
    if show_legend:
        ax.legend(loc="upper right", fontsize=9, framealpha=0.92)
    return pcm


def make_plot(phi_grid, cd_grid, rms_arr, phi_best, cd_best, rms_best,
              output_path="Figure_4.png"):
    fig, axes = plt.subplots(1, 2, figsize=(12.0, 5.4),
                             constrained_layout=True)
    pcm = _draw_panel(axes[0], phi_grid, cd_grid, rms_arr,
                      phi_best, cd_best,
                      xlim=(0.0, 2.0 / 3.0), ylim=(0.0, 0.4),
                      show_legend=True, vmax=10.0)
    axes[0].set_title(r"(a) full benchmark window")
    half_w = 0.05
    _draw_panel(axes[1], phi_grid, cd_grid, rms_arr,
                phi_best, cd_best,
                xlim=(PHI_BENCH - half_w, PHI_BENCH + half_w),
                ylim=(COS_DELTA_BENCH - half_w, COS_DELTA_BENCH + half_w),
                show_legend=False, vmax=5.0)
    axes[1].set_title(r"(b) zoom around the benchmark")
    cbar = fig.colorbar(pcm, ax=axes, location="right", shrink=0.85,
                        extend="both")
    cbar.set_label(r"RMS relative deviation [%]")
    fig.savefig(output_path, dpi=200)
    plt.close(fig)
    print("Saved " + output_path)


def main():
    print("===== Mixed-scale residual at benchmark =====")
    m1b, m2b, m3b = benchmark_predictions(PHI_BENCH, COS_DELTA_BENCH)
    rms_bench = rms_residual(PHI_BENCH, COS_DELTA_BENCH)
    print("  m_d1^pred(2 GeV) = %.4f MeV  (PDG: %.2f)" % (m1b, MD_INPUT))
    print("  m_d2^pred(2 GeV) = %.4f MeV  (PDG: %.2f)" % (m2b, MS_INPUT))
    print("  m_d3^pred(m_b)   = %.4f MeV  (PDG: %.2f)  [b-normalized]"
          % (m3b, MB_INPUT))
    print("  delta_d1 = %+.3f%%" % ((m1b - MD_INPUT) / MD_INPUT * 100.0))
    print("  delta_d2 = %+.3f%%" % ((m2b - MS_INPUT) / MS_INPUT * 100.0))
    print("  delta_d3 = %+.3f%%  (zero by normalization)"
          % ((m3b - MB_INPUT) / MB_INPUT * 100.0))
    print("  RMS at benchmark: %.3f%%" % (rms_bench * 100.0))
    print()
    print("===== 200 x 200 scan =====")
    n = 200
    phi_axis = np.linspace(0.0, 2.0 / 3.0, n)
    cd_axis = np.linspace(0.0, 0.4, n)
    phi_grid, cd_grid = np.meshgrid(phi_axis, cd_axis)
    rms_arr = rms_residual(phi_grid, cd_grid)
    print("  RMS range: [%.4f%%, %.4f%%]"
          % (rms_arr.min() * 100.0, rms_arr.max() * 100.0))
    phi_best, cd_best, rms_best = find_best_fit()
    print("  Best fit: phi*=%.6f, cos delta*=%.6f, RMS=%.5f%%"
          % (phi_best, cd_best, rms_best * 100.0))
    dphi = phi_best - PHI_BENCH
    dcd = cd_best - COS_DELTA_BENCH
    print("  Offset bench->best: d(phi)=%+.5f, d(cos delta)=%+.5f"
          % (dphi, dcd))
    mask_1pct = rms_arr < 0.01
    if mask_1pct.any():
        phi_in = phi_grid[mask_1pct]
        cd_in = cd_grid[mask_1pct]
        print("  Region RMS<1%%: phi in [%.4f, %.4f], cos delta in [%.4f, %.4f]"
              % (phi_in.min(), phi_in.max(), cd_in.min(), cd_in.max()))
    print()
    make_plot(phi_grid, cd_grid, rms_arr, phi_best, cd_best, rms_best)


if __name__ == "__main__":
    main()
