#!/usr/bin/env python3
"""
formula_reference.py
════════════════════════════════════════════════════════════════════════════
Reference implementation of all equations from:

  Flynn, D.C. (2026)
  "Baryonic Validation of the Omega Kinematic Correction
   Across 84 SPARC Galaxies"
  Submitted to Astronomy & Astrophysics

Each function is a direct Python translation of the corresponding
numbered equation in the paper. All variable names match the paper's
notation. Self-tests verify boundary conditions and unit consistency.

Units throughout:
  Velocities  : km/s
  Radii       : kpc
  omega       : rad/Gyr  (1 rad/Gyr = 1.022 km/s/kpc)
  Acceleration: m/s^2
  Mass        : solar masses (M_sun)
════════════════════════════════════════════════════════════════════════════
"""

import numpy as np

# ── Physical constants ────────────────────────────────────────────────────
OMEGA_CONV  = 1.022          # rad/Gyr → km/s/kpc  (Eq 2 unit note)
G0_MOND     = 1.2e-10        # m/s^2   MOND acceleration scale (McGaugh+2016)
KPC_TO_M    = 3.0857e19      # kpc → metres
G_GRAV      = 4.3009e-6      # kpc (km/s)^2 M_sun^-1  (Corbelli & Salucci 2000)
EPS         = 1e-6           # (km/s)^2  numerical smoothing (Eq 5)
UPSILON_MIN = 0.1            # lower bound on mass-to-light ratio
UPSILON_MAX = 1.0            # upper bound (Maximum Disk cap)


# ── Equation 1: omega — two-boundary angular velocity offset ──────────────
def eq1_omega(V1, R1, V2, R2):
    """
    Eq (1): Two-boundary angular velocity offset.

    omega = V2/R2 - (V1/R1) * (R1/R2)^(3/2)

    Parameters
    ----------
    V1 : float  — observed velocity at innermost boundary R1  [km/s]
    R1 : float  — innermost measured radius                   [kpc]
    V2 : float  — observed velocity at outermost boundary R2  [km/s]
    R2 : float  — outermost measured radius                   [kpc]

    Returns
    -------
    omega : float  [rad/Gyr]
    """
    return V2/R2 - (V1/R1) * (R1/R2)**1.5


# ── Equation 2: V_Kep — Keplerian velocity from boundary (R1, V1) ─────────
def eq2_vkep(R, V1, R1):
    """
    Eq (2): Keplerian velocity anchored at boundary point (R1, V1).

    V_Kep(R) = V1 * (R1/R)^(1/2)

    Boundary condition: V_Kep(R1) = V1  [verified in self-test below]

    Parameters
    ----------
    R  : array  — radii at which to evaluate                  [kpc]
    V1 : float  — boundary velocity                           [km/s]
    R1 : float  — boundary radius                             [kpc]

    Returns
    -------
    V_Kep : array  [km/s]
    """
    R = np.asarray(R, dtype=float)
    return V1 * np.sqrt(R1 / np.maximum(R, 1e-10))


# ── Equation 3: V_adj — omega-corrected velocity ──────────────────────────
def eq3_vadj(Vobs, R, omega):
    """
    Eq (3): Omega-corrected velocity.

    V_adj(R) = V_obs(R) - R * omega

    where omega is in rad/Gyr and R in kpc, so R*omega has units
    of kpc * rad/Gyr = kpc * 1.022 km/s/kpc = km/s (after conversion).

    Parameters
    ----------
    Vobs  : array  — observed rotation curve                  [km/s]
    R     : array  — galactocentric radii                     [kpc]
    omega : float  — angular velocity offset  [rad/Gyr]
              converted to km/s/kpc via OMEGA_CONV = 1.022

    Returns
    -------
    V_adj : array  [km/s]
    """
    R     = np.asarray(R,    dtype=float)
    Vobs  = np.asarray(Vobs, dtype=float)
    return Vobs - R * omega * OMEGA_CONV


# ── Equation 4: V_bary — sign-preserving baryonic quadrature ──────────────
def eq4_vbary(Vgas, Vdisk, Vbul, upsilon):
    """
    Eq (4): Baryonic velocity via sign-preserving quadrature
    (Corbelli et al. 2014).

    V_bary = sqrt( sgn(Vg)*Vg^2 + Upsilon*Vd^2 + Upsilon*Vb^2 )

    The sign-preserving gas term handles inner radii where
    V_gas < 0 (outward thermal pressure support).

    Parameters
    ----------
    Vgas    : array  — gas rotation velocity (signed)         [km/s]
    Vdisk   : array  — stellar disk contribution              [km/s]
    Vbul    : array  — bulge contribution                     [km/s]
    upsilon : float  — stellar mass-to-light ratio            [dimensionless]

    Returns
    -------
    V_bary : array  [km/s]  (always >= 0)
    """
    Vgas  = np.asarray(Vgas,  dtype=float)
    Vdisk = np.asarray(Vdisk, dtype=float)
    Vbul  = np.asarray(Vbul,  dtype=float)
    gas_term = np.where(Vgas < 0, -Vgas**2, Vgas**2)
    inner    = gas_term + upsilon * Vdisk**2 + upsilon * Vbul**2
    return np.sqrt(np.maximum(inner, 0.0))


# ── Equation 5: Upsilon_max — Maximum Disk upper bound ────────────────────
def eq5_upsilon_max(Vobs, Vgas, Vdisk, Vbul):
    """
    Eq (5): Maximum Disk upper bound on stellar mass-to-light ratio
    (van Albada & Sancisi 1986).

    Upsilon_max = min_R [ (Vobs^2 - sgn(Vg)*Vg^2) /
                          (Vd^2 + Vb^2 + eps) ]  clipped to [0.1, 1.0]

    The smoothing parameter eps = 1e-6 (km/s)^2 prevents
    division-by-zero at inner radii where Vd and Vb are both near zero.

    Parameters
    ----------
    Vobs  : array  — observed velocity                        [km/s]
    Vgas  : array  — gas velocity (signed)                    [km/s]
    Vdisk : array  — stellar disk velocity                    [km/s]
    Vbul  : array  — bulge velocity                           [km/s]

    Returns
    -------
    upsilon_max : float  [dimensionless], clipped to [0.1, 1.0]
    """
    Vobs  = np.asarray(Vobs,  dtype=float)
    Vgas  = np.asarray(Vgas,  dtype=float)
    Vdisk = np.asarray(Vdisk, dtype=float)
    Vbul  = np.asarray(Vbul,  dtype=float)
    gas_term = np.where(Vgas < 0, -Vgas**2, Vgas**2)
    denom    = Vdisk**2 + Vbul**2 + EPS
    ratio    = (Vobs**2 - gas_term) / denom
    return float(np.clip(np.min(ratio), UPSILON_MIN, UPSILON_MAX))


# ── Equation 6: MOND RAR prediction ───────────────────────────────────────
def eq6_mond_vobs(R_kpc, Vgas, Vdisk, Vbul, upsilon=0.5):
    """
    Eq (6): MOND Radial Acceleration Relation prediction for V_obs
    (McGaugh, Lelli & Schombert 2016).

    g_obs = g_bar / (1 - exp(-sqrt(g_bar / g0)))
    g0    = 1.2e-10 m/s^2

    with Upsilon = 0.5 (published SPARC calibration, McGaugh+2016).

    Parameters
    ----------
    R_kpc   : array  — galactocentric radii                   [kpc]
    Vgas    : array  — gas velocity (signed)                  [km/s]
    Vdisk   : array  — stellar disk velocity                  [km/s]
    Vbul    : array  — bulge velocity                         [km/s]
    upsilon : float  — mass-to-light ratio (default 0.5)      [dimensionless]

    Returns
    -------
    V_MOND : array  [km/s]  — MOND predicted observed velocity
    """
    R_kpc = np.asarray(R_kpc, dtype=float)
    Vb    = eq4_vbary(Vgas, Vdisk, Vbul, upsilon)
    R_m   = R_kpc * KPC_TO_M
    gbar  = np.where(R_m > 0, (Vb * 1e3)**2 / R_m, 0.0)
    x     = np.sqrt(np.maximum(gbar / G0_MOND, 0.0))
    with np.errstate(invalid='ignore', divide='ignore'):
        denom = 1.0 - np.exp(-x)
        gobs  = np.where((gbar > 0) & (denom > 1e-10),
                         gbar / denom, 0.0)
    V_ms = np.sqrt(np.maximum(gobs * R_m, 0.0))
    return V_ms / 1e3   # m/s → km/s


# ── Uncertainty propagation ───────────────────────────────────────────────
def sigma_omega(sigma_V1, R1, sigma_V2, R2):
    """
    Uncertainty on omega from linear propagation through Eq (1).

    sigma_omega^2 = (sigma_V2/R2)^2 + (sigma_V1/R1)^2 * (R1/R2)^3

    Parameters
    ----------
    sigma_V1 : float  — uncertainty on V1                    [km/s]
    R1       : float  — innermost radius                      [kpc]
    sigma_V2 : float  — uncertainty on V2                    [km/s]
    R2       : float  — outermost radius                      [kpc]

    Returns
    -------
    sigma_omega : float  [rad/Gyr]
    """
    return np.sqrt((sigma_V2/R2)**2 + (sigma_V1/R1)**2 * (R1/R2)**3)


def sigma_vadj(sigma_Vobs, R, sig_omega):
    """
    Uncertainty on V_adj at each radius (Section 2.6).

    sigma_Vadj^2 = sigma_Vobs^2 + R^2 * sigma_omega^2

    Parameters
    ----------
    sigma_Vobs : array  — per-ring velocity uncertainties     [km/s]
    R          : array  — galactocentric radii                [kpc]
    sig_omega  : float  — uncertainty on omega                [rad/Gyr]

    Returns
    -------
    sigma_Vadj : array  [km/s]
    """
    R          = np.asarray(R,          dtype=float)
    sigma_Vobs = np.asarray(sigma_Vobs, dtype=float)
    return np.sqrt(sigma_Vobs**2 + (R * sig_omega * OMEGA_CONV)**2)


# ── Enclosed baryonic mass ────────────────────────────────────────────────
def enclosed_mass(Vbary, R_kpc):
    """
    Enclosed baryonic mass from V_bary (Section 2.5, Corbelli & Salucci 2000).

    M(<R) = V_bary^2 * R / G

    G = 4.3009e-6 kpc (km/s)^2 M_sun^-1

    Parameters
    ----------
    Vbary : array  — baryonic velocity                        [km/s]
    R_kpc : array  — galactocentric radii                     [kpc]

    Returns
    -------
    M : array  [M_sun]
    """
    return np.asarray(Vbary)**2 * np.asarray(R_kpc) / G_GRAV


# ── Self-tests ────────────────────────────────────────────────────────────
def run_self_tests():
    """Verify boundary conditions and basic consistency for all equations."""
    print("Running self-tests...\n")
    passed = 0; failed = 0

    # Eq 1: omega from M33 paper data
    V1, R1, V2, R2 = 37.3, 0.24, 119.6, 22.73
    om = eq1_omega(V1, R1, V2, R2)
    print(f"Eq 1  M33 omega: {om:.3f} rad/Gyr  (paper: 5.10)")
    assert abs(om - 5.10) < 0.05, f"FAIL: {om}"
    passed += 1

    # Eq 2: boundary condition V_Kep(R1) = V1
    V1, R1 = 50.0, 2.0
    R_test = np.array([R1, 5.0, 10.0])
    vk = eq2_vkep(R_test, V1, R1)
    assert abs(vk[0] - V1) < 1e-10, f"FAIL Eq2 BC: {vk[0]} != {V1}"
    print(f"Eq 2  V_Kep(R1)={vk[0]:.4f} = V1={V1:.4f}  ✓")
    passed += 1

    # Eq 3: V_adj at R=0 equals V_obs
    Vobs = np.array([80., 100., 120.])
    R    = np.array([0.0, 5.0,  10.0])
    om   = 5.0
    va   = eq3_vadj(Vobs, R, om)
    assert abs(va[0] - Vobs[0]) < 1e-10, f"FAIL Eq3: {va[0]}"
    print(f"Eq 3  V_adj(R=0)={va[0]:.4f} = V_obs={Vobs[0]:.4f}  ✓")
    passed += 1

    # Eq 4: gas-dominated galaxy (Vdisk=Vbul=0)
    # Positive gas → V_bary = Vgas; negative gas → V_bary = 0 (pressure support)
    Vgas  = np.array([50.,  30.])   # both positive for this test
    Vdisk = np.zeros(2); Vbul = np.zeros(2)
    vb = eq4_vbary(Vgas, Vdisk, Vbul, 0.5)
    assert np.allclose(vb, Vgas), f"FAIL Eq4: {vb}"
    # Negative gas test
    Vgas_neg = np.array([-10.])
    vb_neg   = eq4_vbary(Vgas_neg, np.zeros(1), np.zeros(1), 0.5)
    assert vb_neg[0] == 0.0, f"FAIL Eq4 neg gas: {vb_neg}"
    print(f"Eq 4  Gas-dominated: V_bary=Vgas (pos), V_bary=0 (neg pressure)  ✓")
    passed += 1

    # Eq 5: Upsilon_max clipped to [0.1, 1.0]
    Vobs  = np.array([100., 120., 110.])
    Vgas  = np.array([20.,  25.,  22.])
    Vdisk = np.array([80.,  90.,  85.])
    Vbul  = np.zeros(3)
    umax  = eq5_upsilon_max(Vobs, Vgas, Vdisk, Vbul)
    assert 0.1 <= umax <= 1.0, f"FAIL Eq5 clip: {umax}"
    print(f"Eq 5  Upsilon_max={umax:.3f} in [0.1,1.0]  ✓")
    passed += 1

    # Eq 6: MOND at very large R → V_MOND → 0 (mass converges)
    R    = np.array([50.0])
    Vg   = np.array([5.0])
    Vd   = np.array([10.0])
    Vb   = np.zeros(1)
    vm   = eq6_mond_vobs(R, Vg, Vd, Vb)
    assert vm[0] > 0, f"FAIL Eq6: {vm}"
    print(f"Eq 6  MOND V_obs at R=50 kpc: {vm[0]:.2f} km/s  ✓")
    passed += 1

    # Unit note: 1 rad/Gyr = 1.022 km/s/kpc
    # At R=1 kpc, omega=1 rad/Gyr → V_correction = 1.022 km/s
    va_unit = eq3_vadj(np.array([100.0]), np.array([1.0]), 1.0)
    correction = 100.0 - va_unit[0]
    assert abs(correction - OMEGA_CONV) < 1e-10, f"FAIL unit: {correction}"
    print(f"Unit  R=1kpc, omega=1 rad/Gyr → correction={correction:.4f} km/s = OMEGA_CONV  ✓")
    passed += 1

    print(f"\n{passed} tests passed, {failed} failed.")
    if failed == 0:
        print("All self-tests PASSED — formulas match paper equations.")
    return failed == 0


if __name__ == "__main__":
    run_self_tests()

    print("\n── Example: DDO161 ──────────────────────────────────────")
    # Sample values from Table 2
    omega = 4.69  # rad/Gyr
    R  = np.array([0.5, 1.0, 2.0, 5.0, 10.0, 13.0])
    V1, R1 = 20.0, 0.5  # approximate boundary

    VKep = eq2_vkep(R, V1, R1)
    Vobs_approx = np.array([35., 50., 62., 65., 66., 65.])
    Vadj = eq3_vadj(Vobs_approx, R, omega)

    print(f"  omega = {omega} rad/Gyr = {omega*OMEGA_CONV:.3f} km/s/kpc")
    print(f"  {'R':>6}  {'V_Kep':>8}  {'V_obs':>8}  {'V_adj':>8}")
    for i in range(len(R)):
        print(f"  {R[i]:6.1f}  {VKep[i]:8.2f}  {Vobs_approx[i]:8.2f}  {Vadj[i]:8.2f}")
