#!/usr/bin/env python3
"""
REAL SPARC DATA TEST - NO SYNTHETIC DATA
=========================================

This script tests the LFM c_dyn formula against REAL rotation curve data
from the SPARC database (Lelli+ 2016).

Data source: http://astroweb.cwru.edu/SPARC/
175 galaxies with observed rotation curves and baryonic mass models.

The test:
- Compute V_bar from SPARC baryonic components (Vgas, Vdisk, Vbul)
- Apply LFM enhancement: V_LFM = V_bar * (1 + a0/a)^(1/4)
- Apply MOND formula: V_MOND via standard mu-function
- Compare both to V_obs

This is a REAL DATA test - no synthetic or simulated data.
"""

import numpy as np
from pathlib import Path
import json
from datetime import datetime

# Physical constants
G = 6.67430e-11  # m^3 kg^-1 s^-2
c = 2.998e8  # m/s
a0 = 1.2e-10  # m/s^2 (MOND/LFM acceleration scale)
kpc = 3.086e19  # m

SPARC_DIR = Path(__file__).parent.parent / "ceff_closure_experiment" / "data" / "sparc_rotcurves"


def load_sparc_galaxy(filename):
    """Load a single SPARC galaxy rotation curve."""
    data = {
        'rad': [],    # kpc
        'vobs': [],   # km/s
        'errv': [],   # km/s
        'vgas': [],   # km/s
        'vdisk': [],  # km/s
        'vbul': []    # km/s
    }
    
    distance = None
    
    with open(filename, 'r') as f:
        for line in f:
            line = line.strip()
            if line.startswith('# Distance'):
                distance = float(line.split('=')[1].replace('Mpc', '').strip())
            elif line.startswith('#') or not line:
                continue
            else:
                parts = line.split()
                if len(parts) >= 6:
                    data['rad'].append(float(parts[0]))
                    data['vobs'].append(float(parts[1]))
                    data['errv'].append(float(parts[2]))
                    data['vgas'].append(float(parts[3]))
                    data['vdisk'].append(float(parts[4]))
                    data['vbul'].append(float(parts[5]))
    
    for key in data:
        data[key] = np.array(data[key])
    
    data['distance_mpc'] = distance
    data['name'] = filename.stem.replace('_rotmod', '')
    
    return data


def compute_vbar(vgas, vdisk, vbul):
    """
    Compute baryonic velocity from components.
    V_bar² = V_gas² + V_disk² + V_bul²
    """
    return np.sqrt(vgas**2 + vdisk**2 + vbul**2)


def compute_acceleration(v_km_s, r_kpc):
    """Compute centripetal acceleration a = v²/r"""
    v_m_s = v_km_s * 1000  # km/s → m/s
    r_m = r_kpc * kpc      # kpc → m
    return v_m_s**2 / r_m


def lfm_enhancement(a):
    """
    LFM velocity enhancement factor.
    V_obs = V_bar * (1 + a0/a)^(1/4)
    
    Derived from c_dyn(a) = c * sqrt(a/(a+a0))
    """
    return (1 + a0/a)**0.25


def mond_mu(x):
    """Standard MOND interpolation function (simple form)."""
    return x / np.sqrt(1 + x**2)


def mond_velocity_rar(vbar, r_kpc):
    """
    MOND predicted velocity using the Radial Acceleration Relation.
    
    RAR (McGaugh+ 2016): g_obs = g_bar / (1 - exp(-sqrt(g_bar/a0)))
    
    This is the empirically validated form of MOND.
    """
    # Baryonic acceleration
    v_m = vbar * 1000  # km/s → m/s
    r_m = r_kpc * kpc  # kpc → m
    g_bar = v_m**2 / r_m
    
    # RAR formula
    g_obs = g_bar / (1 - np.exp(-np.sqrt(g_bar / a0)))
    
    # Observed velocity
    v_obs_m = np.sqrt(g_obs * r_m)
    return v_obs_m / 1000  # m/s → km/s


def mond_velocity(vbar, r_kpc):
    """
    MOND predicted velocity.
    
    In MOND: g_obs = g_bar / mu(g_bar/a0)
    For circular orbit: v²/r = g_obs
    """
    # Baryonic acceleration
    v_m = vbar * 1000  # km/s → m/s
    r_m = r_kpc * kpc  # kpc → m
    g_bar = v_m**2 / r_m
    
    # MOND acceleration
    x = g_bar / a0
    g_obs = g_bar / mond_mu(x)
    
    # Observed velocity
    v_obs_m = np.sqrt(g_obs * r_m)
    return v_obs_m / 1000  # m/s → km/s


def analyze_galaxy(data):
    """Analyze a single galaxy."""
    r = data['rad']
    vobs = data['vobs']
    verr = data['errv']
    
    # Compute baryonic velocity
    vbar = compute_vbar(data['vgas'], data['vdisk'], data['vbul'])
    
    # Skip points where vbar is zero or very small
    valid = vbar > 5  # km/s
    if np.sum(valid) < 3:
        return None
    
    r = r[valid]
    vobs = vobs[valid]
    verr = verr[valid]
    vbar = vbar[valid]
    
    # Compute acceleration using baryonic velocity
    a_bar = compute_acceleration(vbar, r)
    
    # LFM prediction
    v_lfm = vbar * lfm_enhancement(a_bar)
    
    # MOND prediction (using RAR - the correct formula)
    v_mond = mond_velocity_rar(vbar, r)
    
    # Compute chi-squared
    # Handle zero errors
    verr_safe = np.maximum(verr, 1.0)  # Minimum 1 km/s error
    
    chi2_newton = np.sum((vobs - vbar)**2 / verr_safe**2)
    chi2_lfm = np.sum((vobs - v_lfm)**2 / verr_safe**2)
    chi2_mond = np.sum((vobs - v_mond)**2 / verr_safe**2)
    
    n_points = len(r)
    
    return {
        'name': data['name'],
        'n_points': n_points,
        'chi2_newton': chi2_newton,
        'chi2_lfm': chi2_lfm,
        'chi2_mond': chi2_mond,
        'red_chi2_newton': chi2_newton / n_points,
        'red_chi2_lfm': chi2_lfm / n_points,
        'red_chi2_mond': chi2_mond / n_points
    }


def main():
    print("=" * 70)
    print("REAL SPARC DATA TEST")
    print("Testing LFM c_dyn formula against 175 real galaxies")
    print("=" * 70)
    print()
    
    # Find all SPARC files
    sparc_files = sorted(SPARC_DIR.glob("*_rotmod.dat"))
    print(f"Found {len(sparc_files)} SPARC galaxies")
    print()
    
    results = []
    failed = []
    
    for f in sparc_files:
        try:
            data = load_sparc_galaxy(f)
            result = analyze_galaxy(data)
            if result:
                results.append(result)
        except Exception as e:
            failed.append((f.name, str(e)))
    
    print(f"Successfully analyzed: {len(results)} galaxies")
    print(f"Failed/skipped: {len(failed)} galaxies")
    print()
    
    # Aggregate statistics
    total_chi2_newton = sum(r['chi2_newton'] for r in results)
    total_chi2_lfm = sum(r['chi2_lfm'] for r in results)
    total_chi2_mond = sum(r['chi2_mond'] for r in results)
    total_points = sum(r['n_points'] for r in results)
    
    print("=" * 70)
    print("AGGREGATE RESULTS (ALL GALAXIES)")
    print("=" * 70)
    print()
    print(f"Total data points: {total_points}")
    print()
    print(f"{'Model':<15} {'Total χ²':>15} {'Reduced χ²':>15}")
    print("-" * 50)
    print(f"{'Newtonian':<15} {total_chi2_newton:>15.1f} {total_chi2_newton/total_points:>15.2f}")
    print(f"{'MOND':<15} {total_chi2_mond:>15.1f} {total_chi2_mond/total_points:>15.2f}")
    print(f"{'LFM':<15} {total_chi2_lfm:>15.1f} {total_chi2_lfm/total_points:>15.2f}")
    print()
    
    # Which is better?
    delta_chi2 = total_chi2_mond - total_chi2_lfm
    sigma = np.sqrt(2 * abs(delta_chi2))
    
    print(f"Δχ² (MOND - LFM) = {delta_chi2:.1f}")
    if delta_chi2 > 0:
        print(f"LFM is better by {sigma:.1f}σ")
    else:
        print(f"MOND is better by {sigma:.1f}σ")
    
    # Count wins
    lfm_wins = sum(1 for r in results if r['chi2_lfm'] < r['chi2_mond'])
    mond_wins = sum(1 for r in results if r['chi2_mond'] < r['chi2_lfm'])
    ties = len(results) - lfm_wins - mond_wins
    
    print()
    print(f"Galaxy-by-galaxy comparison:")
    print(f"  LFM wins:  {lfm_wins} ({100*lfm_wins/len(results):.1f}%)")
    print(f"  MOND wins: {mond_wins} ({100*mond_wins/len(results):.1f}%)")
    print(f"  Ties:      {ties}")
    
    # Save results
    output = {
        'timestamp': datetime.now().isoformat(),
        'n_galaxies': len(results),
        'n_failed': len(failed),
        'total_points': total_points,
        'total_chi2': {
            'newtonian': total_chi2_newton,
            'mond': total_chi2_mond,
            'lfm': total_chi2_lfm
        },
        'reduced_chi2': {
            'newtonian': total_chi2_newton / total_points,
            'mond': total_chi2_mond / total_points,
            'lfm': total_chi2_lfm / total_points
        },
        'delta_chi2_mond_minus_lfm': delta_chi2,
        'sigma': sigma,
        'winner': 'LFM' if delta_chi2 > 0 else 'MOND',
        'galaxy_wins': {
            'lfm': lfm_wins,
            'mond': mond_wins,
            'tie': ties
        },
        'per_galaxy': results
    }
    
    results_dir = Path(__file__).parent / "results"
    results_dir.mkdir(exist_ok=True)
    
    with open(results_dir / "real_sparc_test.json", 'w') as f:
        json.dump(output, f, indent=2)
    
    print()
    print(f"Results saved to: {results_dir / 'real_sparc_test.json'}")
    
    return output


if __name__ == "__main__":
    main()
