#!/usr/bin/env python
"""
ULTIMATE ΔC-L TEST SUITE
========================
Multi-Database Cosmological Analysis

Databases:
- Pantheon+ (SNe Ia)
- CosmicFlows-4 (distances + velocities)
- 2M++ (density field)
- GWTC-3 (gravitational waves) - OPTIONAL

Tests:
1. H₀ vs. Density (3 environments)
2. H₀ Directional Dependence (cos²θ)
3. H₀ vs. z per environment
4. Field-Depression Model
5. Large-scale anisotropy
"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.interpolate import RegularGridInterpolator
from scipy.stats import chi2
from scipy.optimize import curve_fit
from astropy.coordinates import SkyCoord
from astropy import units as u
import emcee
import corner
import urllib.request
import warnings
import os
warnings.filterwarnings('ignore')

np.random.seed(42)
c = 299792.458  # km/s

print("="*80)
print("🌌 ULTIMATE ΔC-L TEST SUITE")
print("="*80)

# ============================================================================
# CONFIGURATION
# ============================================================================
CONFIG = {
    'use_pantheon': True,
    'use_cosmicflows4': True,
    'use_2mpp': True,
    'use_gwtc3': False,
    
    'test_density_bins': True,
    'test_direction': True,
    'test_z_evolution': True,
    'test_field_depression': True,
    'test_anisotropy': True,
    
    'mcmc_steps': 5000,
    'mcmc_burn': 2000,
    'mcmc_thin': 50,
}

# ============================================================================
# 1. DATA LOADERS
# ============================================================================

def load_pantheon_plus():
    """Lädt Pantheon+ SNe Ia Daten"""
    print("\n📥 [1/4] Lade Pantheon+ SNe Ia...")
    
    url = "https://raw.githubusercontent.com/PantheonPlusSH0ES/DataRelease/main/Pantheon%2B_Data/4_DISTANCES_AND_COVAR/Pantheon%2BSH0ES.dat"
    
    try:
        response = urllib.request.urlopen(url, timeout=30)
        data_text = response.read().decode('utf-8')
        lines = [l for l in data_text.split('\n') if l and not l.startswith('#')]
        header = lines[0].split()
        
        idx_name = 0
        idx_zcmb = header.index('zCMB')
        idx_mu = header.index('MU_SH0ES')
        idx_mu_err = header.index('MU_SH0ES_ERR_DIAG')
        idx_ra = header.index('RA')
        idx_dec = header.index('DEC')
        
        data_list = []
        for line in lines[1:]:
            parts = line.split()
            if len(parts) >= 12:
                try:
                    data_list.append({
                        'name': parts[idx_name],
                        'z': float(parts[idx_zcmb]),
                        'mu': float(parts[idx_mu]),
                        'mu_err': float(parts[idx_mu_err]),
                        'ra': float(parts[idx_ra]),
                        'dec': float(parts[idx_dec]),
                        'source': 'Pantheon+'
                    })
                except (ValueError, IndexError):
                    continue
        
        df = pd.DataFrame(data_list)
        print(f"   ✅ {len(df)} SNe Ia geladen")
        return df
    
    except Exception as e:
        print(f"   ❌ Fehler: {e}")
        return None

def load_cosmicflows4():
    """Lädt CosmicFlows-4 Daten (falls verfügbar)"""
    print("\n📥 [2/4] Versuche CosmicFlows-4...")
    
    cf4_files = [
        'CF4_distances.txt',
        'CF4_TF.dat',
        'cosmicflows4.csv'
    ]
    
    for fname in cf4_files:
        if os.path.exists(fname):
            try:
                df = pd.read_csv(fname, delim_whitespace=True)
                print(f"   ✅ {len(df)} CF4 Objekte aus {fname}")
                return df
            except:
                continue
    
    print("   ⚠️  Keine CF4-Daten gefunden (optional)")
    print("   💡 Download: https://edd.ifa.hawaii.edu/CF4data/")
    return None

def load_2mpp_density():
    """Lädt 2M++ Dichtefeld"""
    print("\n📥 [3/4] Lade 2M++ Dichtefeld...")
    
    try:
        rho_data = np.load('twompp_density.npy')
        print(f"   ✅ 2M++ geladen: {rho_data.shape}")
        
        n = 257
        grid_size = 400
        coords = np.linspace(-grid_size/2, grid_size/2, n)
        
        interpolator = RegularGridInterpolator(
            (coords, coords, coords),
            rho_data,
            method='linear',
            bounds_error=False,
            fill_value=0.0
        )
        return interpolator
    
    except FileNotFoundError:
        print("   ❌ twompp_density.npy nicht gefunden!")
        print("   💡 wget https://cosmicflows.iap.fr/assets/data/twompp_density.npy")
        return None

def load_gwtc3():
    """Lädt GWTC-3 Gravitationswellen-Daten (optional)"""
    print("\n📥 [4/4] Versuche GWTC-3 (GW-Ereignisse)...")
    
    if not CONFIG['use_gwtc3']:
        print("   ⏭️  Übersprungen (CONFIG)")
        return None
    
    print("   ⚠️  GWTC-3 Loader noch nicht implementiert")
    print("   💡 Siehe: https://www.gw-openscience.org/eventapi/html/GWTC/")
    return None

# ============================================================================
# 2. COORDINATE TRANSFORMATIONS
# ============================================================================

def z_to_distance(z, H0=70.0):
    """Rotverschiebung → komovierende Distanz"""
    return (c * z / H0)

def get_supergalactic_coords(ra, dec, dist):
    """RA/Dec → Supergalaktische kartesische Koordinaten"""
    coords = SkyCoord(ra=ra*u.deg, dec=dec*u.deg, 
                      distance=dist*u.Mpc, frame='icrs')
    sg = coords.supergalactic
    return sg.cartesian.x.value, sg.cartesian.y.value, sg.cartesian.z.value

def get_local_density(ra, dec, z, interpolator):
    """Interpoliert lokale Dichte aus 2M++"""
    if z < 0.001 or z > 2.5 or interpolator is None:
        return 0.0
    
    dist = z_to_distance(z)
    x, y, z_coord = get_supergalactic_coords(ra, dec, dist)
    
    point = np.array([x, y, z_coord])
    delta_rho = interpolator(point)
    return 1.0 + float(delta_rho)

def get_sg_latitude(ra, dec):
    """Winkel zur supergalaktischen Ebene"""
    coords = SkyCoord(ra=ra*u.deg, dec=dec*u.deg, frame='icrs')
    sg = coords.supergalactic
    return sg.sgb.degree

# ============================================================================
# 3. DATA PROCESSING
# ============================================================================

def process_all_data(df_pantheon, interpolator):
    """Verarbeitet alle Daten"""
    print("\n🔬 Verarbeite Daten...")
    
    if df_pantheon is None:
        return None
    
    print("   📍 Berechne lokale Dichten...")
    densities = []
    sg_lats = []
    
    for idx, row in df_pantheon.iterrows():
        rho = get_local_density(row['ra'], row['dec'], row['z'], interpolator)
        sg_lat = get_sg_latitude(row['ra'], row['dec'])
        densities.append(rho)
        sg_lats.append(sg_lat)
        
        if (idx + 1) % 500 == 0:
            print(f"      ... {idx+1}/{len(df_pantheon)}")
    
    df_pantheon['rho'] = densities
    df_pantheon['sg_lat'] = sg_lats
    
    df_pantheon['env'] = pd.cut(
        df_pantheon['rho'],
        bins=[0, 0.5, 1.5, np.inf],
        labels=['void', 'field', 'cluster']
    )
    
    d_L = 10 ** ((df_pantheon['mu'] - 25) / 5.0)
    df_pantheon['H0'] = c * df_pantheon['z'] / d_L
    df_pantheon['H0_err'] = df_pantheon['H0'] * (np.log(10)/5.0) * df_pantheon['mu_err']
    
    print(f"   ✅ Fertig!")
    print(f"\n   📊 Umgebungsverteilung:")
    print(df_pantheon['env'].value_counts())
    
    return df_pantheon

# ============================================================================
# 4. ANALYSIS TESTS
# ============================================================================

def test_density_bins(df):
    """Test 1: H₀ vs. Density"""
    if not CONFIG['test_density_bins']:
        return None
    
    print("\n" + "="*80)
    print("📊 TEST 1: H₀ vs. Dichte (Void/Field/Cluster)")
    print("="*80)
    
    results = []
    for env, grp in df.groupby('env', observed=True):
        if len(grp) < 10:
            continue
        
        grp = grp[grp['z'] > 0.01].copy()
        if len(grp) < 10:
            continue
        
        w = 1 / grp['H0_err'] ** 2
        w = w / w.sum()
        
        H0_mean = np.average(grp['H0'], weights=w)
        H0_std = np.sqrt(1 / np.sum(1 / grp['H0_err'] ** 2))
        
        results.append({
            'env': env,
            'H0': H0_mean,
            'H0_err': H0_std,
            'rho_mean': grp['rho'].mean(),
            'n': len(grp)
        })
    
    df_res = pd.DataFrame(results)
    print("\n✨ Ergebnisse:")
    print(df_res.to_string(index=False))
    
    if len(df_res) == 3:
        h_void = df_res[df_res['env'] == 'void']['H0'].values[0]
        h_field = df_res[df_res['env'] == 'field']['H0'].values[0]
        h_cluster = df_res[df_res['env'] == 'cluster']['H0'].values[0]
        
        depression = (h_void + h_cluster)/2 - h_field
        print(f"\n🔍 Field-Depression: {depression:.2f} km/s/Mpc")
        
        if depression > 5:
            print("   ⚠️  SIGNIFIKANTE FIELD-ANOMALIE DETEKTIERT!")
            print("   → Konsistent mit D5-Schatten-Hypothese (κ₅→₄ ≈ 0.1)")
    
    return df_res

def test_directional_dependence(df):
    """Test 2: H₀ vs. Winkel zur SG-Ebene"""
    if not CONFIG['test_direction']:
        return None
    
    print("\n" + "="*80)
    print("📊 TEST 2: Richtungsabhängigkeit (cos²θ)")
    print("="*80)
    
    df_field = df[df['env'] == 'field'].copy()
    
    if len(df_field) < 50:
        print("   ⚠️  Zu wenige Field-SNe")
        return None
    
    df_field['sg_lat_bin'] = pd.cut(
        np.abs(df_field['sg_lat']),
        bins=[0, 15, 30, 45, 60, 90],
        labels=['0-15°', '15-30°', '30-45°', '45-60°', '60-90°']
    )
    
    results = []
    for lat_bin, grp in df_field.groupby('sg_lat_bin', observed=True):
        if len(grp) < 10:
            continue
        
        w = 1 / grp['H0_err'] ** 2
        H0_mean = np.average(grp['H0'], weights=w)
        H0_std = np.sqrt(1 / np.sum(w))
        
        results.append({
            'lat_bin': lat_bin,
            'H0': H0_mean,
            'H0_err': H0_std,
            'n': len(grp)
        })
    
    if len(results) > 0:
        df_dir = pd.DataFrame(results)
        print("\n✨ H₀ vs. SG-Latitude:")
        print(df_dir.to_string(index=False))
        
        if len(df_dir) >= 3:
            h_vals = df_dir['H0'].values
            trend = h_vals[-1] - h_vals[0]
            print(f"\n🔍 Trend: {trend:.2f} km/s/Mpc")
            
            if abs(trend) > 2:
                print("   ⚠️  RICHTUNGSABHÄNGIGKEIT DETEKTIERT!")
        
        return df_dir
    
    return None

def test_z_evolution(df):
    """Test 3: H₀(z) pro Umgebung"""
    if not CONFIG['test_z_evolution']:
        return None
    
    print("\n" + "="*80)
    print("📊 TEST 3: H₀-Evolution mit z")
    print("="*80)
    
    results = []
    
    for env in ['void', 'field', 'cluster']:
        df_env = df[df['env'] == env].copy()
        
        if len(df_env) < 50:
            continue
        
        df_env['z_bin'] = pd.cut(
            df_env['z'],
            bins=[0, 0.05, 0.15, 0.5, 2.5],
            labels=['0-0.05', '0.05-0.15', '0.15-0.5', '0.5+']
        )
        
        for z_bin, grp in df_env.groupby('z_bin', observed=True):
            if len(grp) < 5:
                continue
            
            w = 1 / grp['H0_err'] ** 2
            H0_mean = np.average(grp['H0'], weights=w)
            H0_std = np.sqrt(1 / np.sum(w))
            
            results.append({
                'env': env,
                'z_bin': z_bin,
                'H0': H0_mean,
                'H0_err': H0_std,
                'n': len(grp)
            })
    
    if len(results) > 0:
        df_z = pd.DataFrame(results)
        print("\n✨ H₀(z) pro Umgebung:")
        for env in df_z['env'].unique():
            print(f"\n{env.upper()}:")
            print(df_z[df_z['env'] == env].to_string(index=False))
        
        return df_z
    
    return None

def test_field_depression_model(df_bins):
    """Test 4: Fit Field-Depression-Modell"""
    if not CONFIG['test_field_depression'] or df_bins is None:
        return None
    
    print("\n" + "="*80)
    print("📊 TEST 4: Field-Depression Modell")
    print("="*80)
    
    def model(rho, H0_base, A_dep, sigma):
        depression = A_dep * np.exp(-(rho - 1)**2 / (2*sigma**2))
        return H0_base - depression
    
    rho_vals = df_bins['rho_mean'].values
    H0_vals = df_bins['H0'].values
    H0_errs = df_bins['H0_err'].values
    
    try:
        popt, pcov = curve_fit(
            model, rho_vals, H0_vals, 
            sigma=H0_errs,
            p0=[70, 8, 0.5],
            bounds=([65, 0, 0.1], [75, 15, 2.0])
        )
        
        perr = np.sqrt(np.diag(pcov))
        
        print("\n✨ Fit-Parameter:")
        print(f"   H0_base    = {popt[0]:.2f} ± {perr[0]:.2f} km/s/Mpc")
        print(f"   Depression = {popt[1]:.2f} ± {perr[1]:.2f} km/s/Mpc")
        print(f"   Sigma      = {popt[2]:.2f} ± {perr[2]:.2f}")
        
        kappa_5to4_obs = popt[1] / popt[0]
        print(f"\n🔍 κ₅→₄ (obs): {kappa_5to4_obs:.3f}")
        print(f"   Theorie: ~0.10")
        
        if 0.05 < kappa_5to4_obs < 0.15:
            print("   ✅ KONSISTENT mit D5-Schatten-Modell!")
        
        return popt, pcov
    
    except Exception as e:
        print(f"   ❌ Fit fehlgeschlagen: {e}")
        return None

def test_large_scale_anisotropy(df):
    """Test 5: Großräumige Anisotropie"""
    if not CONFIG['test_anisotropy']:
        return None
    
    print("\n" + "="*80)
    print("📊 TEST 5: Großräumige Anisotropie")
    print("="*80)
    
    df['hemisphere'] = 'North'
    df.loc[df['dec'] < 0, 'hemisphere'] = 'South'
    
    results = []
    for hem in ['North', 'South']:
        grp = df[df['hemisphere'] == hem]
        
        if len(grp) < 50:
            continue
        
        w = 1 / grp['H0_err'] ** 2
        H0_mean = np.average(grp['H0'], weights=w)
        H0_std = np.sqrt(1 / np.sum(w))
        
        results.append({
            'hemisphere': hem,
            'H0': H0_mean,
            'H0_err': H0_std,
            'n': len(grp)
        })
    
    if len(results) == 2:
        df_hem = pd.DataFrame(results)
        print("\n✨ Nord vs. Süd:")
        print(df_hem.to_string(index=False))
        
        delta_H0 = abs(df_hem.iloc[0]['H0'] - df_hem.iloc[1]['H0'])
        print(f"\n🔍 Differenz: {delta_H0:.2f} km/s/Mpc")
        
        if delta_H0 > 2:
            print("   ⚠️  ANISOTROPIE DETEKTIERT!")
        
        return df_hem
    
    return None

# ============================================================================
# 5. VISUALIZATION
# ============================================================================

def create_summary_plot(df, df_bins, results_dict):
    """Erstellt Summary-Plot"""
    print("\n📊 Erstelle Summary-Plot...")
    
    fig = plt.figure(figsize=(18, 12))
    gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)
    
    # Plot 1: H0 vs Density
    if df_bins is not None:
        ax1 = fig.add_subplot(gs[0, :2])
        ax1.errorbar(df_bins['rho_mean'], df_bins['H0'], 
                     yerr=df_bins['H0_err'],
                     fmt='o', capsize=5, markersize=12, 
                     color='darkblue', label='Beobachtung')
        
        if len(df_bins) == 3:
            field_idx = df_bins[df_bins['env'] == 'field'].index[0]
            ax1.plot(df_bins.loc[field_idx, 'rho_mean'],
                    df_bins.loc[field_idx, 'H0'],
                    'r*', markersize=20, label='Field-Depression')
        
        ax1.set_xscale('log')
        ax1.set_xlabel(r'$\rho/\bar{\rho}$ (2M++)', fontsize=12)
        ax1.set_ylabel(r'$H_0$ [km/s/Mpc]', fontsize=12)
        ax1.set_title('TEST 1: H₀ vs. Dichte', fontweight='bold', fontsize=14)
        ax1.legend()
        ax1.grid(True, alpha=0.4)
    
    # Plot 2: Directional
    if results_dict.get('directional') is not None:
        ax2 = fig.add_subplot(gs[0, 2])
        df_dir = results_dict['directional']
        x_pos = range(len(df_dir))
        ax2.errorbar(x_pos, df_dir['H0'], yerr=df_dir['H0_err'],
                     fmt='o-', capsize=5, color='green')
        ax2.set_xticks(x_pos)
        ax2.set_xticklabels(df_dir['lat_bin'], rotation=45, fontsize=8)
        ax2.set_ylabel(r'$H_0$ [km/s/Mpc]', fontsize=10)
        ax2.set_title('TEST 2: Richtung', fontweight='bold')
        ax2.grid(True, alpha=0.4)
    
    # Plot 3: z-Evolution
    if results_dict.get('z_evolution') is not None:
        ax3 = fig.add_subplot(gs[1, :2])
        df_z = results_dict['z_evolution']
        for env in df_z['env'].unique():
            data_env = df_z[df_z['env'] == env]
            x_pos = range(len(data_env))
            ax3.errorbar(x_pos, data_env['H0'], yerr=data_env['H0_err'],
                        fmt='o-', capsize=5, label=env, alpha=0.7)
        ax3.set_ylabel(r'$H_0$ [km/s/Mpc]', fontsize=12)
        ax3.set_title('TEST 3: H₀(z) Evolution', fontweight='bold', fontsize=14)
        ax3.legend()
        ax3.grid(True, alpha=0.4)
    
    # Plot 4: Dichte-Histogramm
    ax4 = fig.add_subplot(gs[1, 2])
    ax4.hist(df['rho'], bins=50, alpha=0.7, color='steelblue', edgecolor='black')
    ax4.axvline(0.5, color='orange', linestyle='--', label='Void/Field')
    ax4.axvline(1.5, color='red', linestyle='--', label='Field/Cluster')
    ax4.set_xlabel(r'$\rho/\bar{\rho}$', fontsize=10)
    ax4.set_ylabel('Anzahl SNe', fontsize=10)
    ax4.set_title('Dichte-Verteilung', fontweight='bold')
    ax4.legend(fontsize=8)
    ax4.set_yscale('log')
    
    # Plot 5: Sky Distribution
    ax5 = fig.add_subplot(gs[2, :], projection='mollweide')
    ra_rad = np.deg2rad(df['ra'] - 180)
    dec_rad = np.deg2rad(df['dec'])
    scatter = ax5.scatter(ra_rad, dec_rad, c=df['H0'], s=1, 
                         cmap='RdYlBu_r', alpha=0.5, vmin=60, vmax=75)
    ax5.set_xlabel('RA', fontsize=12)
    ax5.set_ylabel('Dec', fontsize=12)
    ax5.set_title('TEST 5: H₀ am Himmel', fontweight='bold', fontsize=14)
    ax5.grid(True, alpha=0.3)
    plt.colorbar(scatter, ax=ax5, label=r'$H_0$ [km/s/Mpc]', 
                 orientation='horizontal', pad=0.05, shrink=0.8)
    
    plt.suptitle('ΔC-L MULTI-TEST SUMMARY', fontsize=16, fontweight='bold', y=0.995)
    
    plt.savefig('ULTIMATE_DCL_test_summary.png', dpi=300, bbox_inches='tight')
    print("   ✅ Gespeichert: ULTIMATE_DCL_test_summary.png")

# ============================================================================
# 6. MAIN EXECUTION
# ============================================================================

def main():
    """Hauptausführung"""
    
    # Load data
    df_pantheon = load_pantheon_plus()
    df_cf4 = load_cosmicflows4()
    interpolator_2mpp = load_2mpp_density()
    df_gw = load_gwtc3()
    
    if df_pantheon is None:
        print("\n❌ Keine Daten geladen! Abbruch.")
        return
    
    # Process
    df = process_all_data(df_pantheon, interpolator_2mpp)
    
    if df is None:
        print("\n❌ Datenverarbeitung fehlgeschlagen!")
        return
    
    # Run tests
    results = {}
    
    results['density_bins'] = test_density_bins(df)
    results['directional'] = test_directional_dependence(df)
    results['z_evolution'] = test_z_evolution(df)
    results['field_model'] = test_field_depression_model(results['density_bins'])
    results['anisotropy'] = test_large_scale_anisotropy(df)
    
    # Visualize
    create_summary_plot(df, results['density_bins'], results)
    
    # Final Summary
    print("\n" + "="*80)
    print("🏆 FINALE ZUSAMMENFASSUNG")
    print("="*80)
    
    print("\n📊 Tests durchgeführt:")
    for test_name, result in results.items():
        status = "✅" if result is not None else "⏭️"
        print(f"   {status} {test_name}")
    
    print("\n🔬 Hauptergebnisse:")
    
    if results['density_bins'] is not None:
        df_bins = results['density_bins']
        if len(df_bins) == 3:
            h_void = df_bins[df_bins['env'] == 'void']['H0'].values[0]
            h_field = df_bins[df_bins['env'] == 'field']['H0'].values[0]
            h_cluster = df_bins[df_bins['env'] == 'cluster']['H0'].values[0]
            
            depression = (h_void + h_cluster)/2 - h_field
            
            print(f"\n   Field-Depression: {depression:.2f} km/s/Mpc")
            
            if depression > 5:
                print("   ⚠️  SIGNIFIKANT! Konsistent mit κ₅→₄ ≈ 0.1")
                print("   → D5-Schatten-Hypothese UNTERSTÜTZT")
            else:
                print("   → Keine signifikante Depression")
    
    print("\n" + "="*80)
    print("✅ ANALYSE ABGESCHLOSSEN!")
    print("="*80)
    
    # Save data
    df.to_csv('processed_data_DCL.csv', index=False)
    print("\n💾 Daten gespeichert: processed_data_DCL.csv")

if __name__ == "__main__":
    main()