#!/usr/bin/env python3
"""
Black Hole Simulator - Main Entry Point
Powered by the Fractal Correction Engine (FCE)

Usage:
    python run_simulation.py                    # Default: stellar BH, circular orbit
    python run_simulation.py --preset m87       # M87* supermassive BH
    python run_simulation.py --preset stellar --orbit eccentric
    python run_simulation.py --validate         # Run validation suite
    python run_simulation.py --demo             # Full demonstration
"""

import argparse
import os
import sys
import time
import numpy as np

# Add parent directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

from blackhole_sim.config import BlackHoleConfig, PRESETS
from blackhole_sim.metrics import create_metric
from blackhole_sim.geodesics.geodesic_engine import GeodesicIntegrator
from blackhole_sim.geodesics.timelike_geodesics import TimelikeGeodesic
from blackhole_sim.geodesics.null_geodesics import NullGeodesic
from blackhole_sim.physics.horizons import HorizonCalculator
from blackhole_sim.physics.frame_dragging import FrameDragging
from blackhole_sim.physics.tidal_forces import TidalForces
from blackhole_sim.physics.orbits import OrbitalPhysics
from blackhole_sim.fce_integration.fce_adapter import BlackHoleFCEAdapter
from blackhole_sim.fce_integration.geodesic_corrector import GeodesicFCECorrector
from blackhole_sim.fce_integration.curvature_mapper import CurvatureMapper
from blackhole_sim.fce_integration.trajectory_predictor import TrajectoryPredictor
from blackhole_sim.compute.gpu_engine import GPUEngine
from blackhole_sim.compute.adaptive_scheduler import AdaptiveScheduler
from blackhole_sim import constants


def print_header(config):
    """Print simulation header with BH properties."""
    print("=" * 70)
    print("     BLACK HOLE SIMULATOR - Fractal Correction Engine")
    print("=" * 70)
    print(f"  Metric:      {config.metric_type.upper()}")
    print(f"  Mass:        {config.mass_solar} M_sun")
    print(f"  Spin:        a* = {config.spin}")
    print(f"  Charge:      Q* = {config.charge}")
    print(f"  Lifecycle:   {config.lifecycle_stage}")
    print(f"  FCE:         {'Enabled' if config.use_fce else 'Disabled'}")
    print(f"  CPU Workers: {config.n_cpu_workers}")
    print(f"  GPU:         {'Enabled' if config.use_gpu else 'Disabled'}")
    print("-" * 70)


def print_bh_properties(metric, config):
    """Print detailed black hole properties."""
    hc = HorizonCalculator(metric.M, metric.a, metric.Q)
    info = hc.summary()

    print("\n  BLACK HOLE PROPERTIES")
    print(f"  Event horizon:     r_+ = {info['r_plus']:.6f} M")
    if info['r_minus'] > 0:
        print(f"  Cauchy horizon:    r_- = {info['r_minus']:.6f} M")
    print(f"  Surface gravity:   kappa = {info['surface_gravity']:.6f}")
    print(f"  Hawking temp:      T_H = {info['T_hawking_natural']:.6e} (natural)")
    print(f"  Horizon area:      A = {info['horizon_area']:.4f} M^2")
    print(f"  BH entropy:        S = {info['horizon_area']/4:.4f}")

    if config.spin > 0:
        print(f"  Omega_horizon:     {info['Omega_H']:.6f}")
        print(f"  Penrose E_max:     {info['penrose_energy_max']:.4f} M")

    # ISCO
    try:
        r_isco = metric.isco_radius(True)
        print(f"  ISCO (prograde):   r = {r_isco:.4f} M")
        if config.spin > 0:
            r_isco_r = metric.isco_radius(False)
            print(f"  ISCO (retrograde): r = {r_isco_r:.4f} M")
    except (TypeError, AttributeError):
        r_isco = metric.isco_radius()
        print(f"  ISCO:              r = {r_isco:.4f} M")

    # Physical units
    r_s_m = constants.schwarzschild_radius(config.mass_solar * constants.M_sun)
    T_H_K = constants.hawking_temperature_si(config.mass_solar * constants.M_sun)
    print(f"\n  SI Units:")
    print(f"  Schwarzschild radius: {r_s_m:.2e} m ({r_s_m/1e3:.2f} km)")
    print(f"  Hawking temperature:  {T_H_K:.3e} K")
    print(f"  BH entropy:          {constants.bekenstein_hawking_entropy(config.mass_solar * constants.M_sun):.3e} k_B")


def run_geodesic_simulation(config, orbit_type='circular', r_orbit=10.0, tau_max=1000.0):
    """Run a geodesic simulation with FCE correction."""
    print(f"\n  GEODESIC SIMULATION: {orbit_type} orbit at r={r_orbit}M")
    print("-" * 70)

    metric = create_metric(config)
    t_start = time.time()

    # Initialize FCE (lazy -- FCE backends load on first use)
    fce = None
    corrector = None
    if config.use_fce:
        fce = BlackHoleFCEAdapter()
        corrector = GeodesicFCECorrector(
            fce,
            correction_strength=config.fce_correction_strength,
        )
        print(f"  FCE: enabled (lazy loading -- backends init on first correction)")

    # Create geodesic integrator
    tl = TimelikeGeodesic(metric, fce_corrector=corrector, fce_interval=config.fce_correction_interval)

    # Run orbit
    if orbit_type == 'circular':
        trajectory = tl.circular_orbit(r_orbit, tau_max=tau_max, max_step=config.dt * 10)
    elif orbit_type == 'radial':
        trajectory = tl.radial_infall(r_orbit, tau_max=tau_max, max_step=config.dt * 10)
    elif orbit_type == 'plunge':
        L = metric.circular_orbit_angular_momentum(r_orbit) * 0.95 if hasattr(metric, 'circular_orbit_angular_momentum') else 3.5
        trajectory = tl.plunge_orbit(r_orbit, L, tau_max=tau_max, max_step=config.dt * 10)
    else:
        trajectory = tl.circular_orbit(r_orbit, tau_max=tau_max, max_step=config.dt * 10)

    t_elapsed = time.time() - t_start

    # Report
    E = trajectory.energy
    L = trajectory.angular_momentum
    r = trajectory.coordinates[:, 1]

    print(f"  Integration time: {t_elapsed:.2f}s")
    print(f"  Points: {len(trajectory.tau)}")
    print(f"  Proper time: {trajectory.tau[-1]:.1f}M")
    print(f"  Energy: E = {E[0]:.8f}, dE/E = {np.std(E)/abs(np.mean(E)):.2e}")
    print(f"  Ang. Mom: L = {L[0]:.8f}, dL/L = {np.std(L)/abs(np.mean(L)):.2e}")
    print(f"  Radius: [{np.min(r):.4f}, {np.max(r):.4f}]M")
    print(f"  4-vel norm: {trajectory.norm[0]:.10f} (expected: -1)")

    if corrector:
        print(f"  FCE corrections: {len(corrector.correction_log)}")

    return trajectory, metric


def run_hawking_analysis(config):
    """Compute Hawking radiation properties."""
    print("\n  HAWKING RADIATION ANALYSIS")
    print("-" * 70)

    from blackhole_sim.hawking.temperature import hawking_temperature_natural, hawking_temperature_si
    from blackhole_sim.hawking.entropy import bekenstein_hawking_entropy, entropy_natural
    from blackhole_sim.hawking.evaporation import evaporation_lifetime_si, evaporation_rate_natural
    from blackhole_sim.hawking.spectrum import hawking_spectrum, total_luminosity

    T_nat = hawking_temperature_natural(1.0, config.spin, config.charge)
    T_si = hawking_temperature_si(config.mass_solar)
    S = bekenstein_hawking_entropy(config.mass_solar, config.spin, config.charge)
    t_evap_s, t_evap_yr = evaporation_lifetime_si(config.mass_solar)
    L = total_luminosity(1.0, config.spin)

    print(f"  Temperature: {T_nat:.6e} (natural), {T_si:.3e} K")
    print(f"  Entropy: {S:.3e} k_B")
    print(f"  Luminosity: {L:.6e} (natural units)")
    print(f"  Evaporation lifetime: {t_evap_yr:.3e} years")

    # Spectrum
    omega, dN, power = hawking_spectrum(1.0, config.spin, n_freq=200)
    print(f"  Spectrum: {len(omega)} frequency points, peak at omega/T = {omega[np.argmax(power)]/T_nat:.2f}")

    return omega, dN, power, T_nat


def run_accretion_analysis(config):
    """Compute accretion disk properties."""
    if config.accretion_rate_eddington <= 0:
        print("\n  ACCRETION: No accretion (M_dot = 0)")
        return None

    print("\n  ACCRETION DISK ANALYSIS")
    print("-" * 70)

    from blackhole_sim.accretion.thin_disk import ThinAccretionDisk
    from blackhole_sim.accretion.eddington import eddington_luminosity

    disk = ThinAccretionDisk(config.mass_solar, config.spin, config.accretion_rate_eddington)
    summary = disk.summary()

    print(f"  ISCO: {summary['r_isco']:.4f} r_g")
    print(f"  Efficiency: eta = {summary['eta']*100:.1f}%")
    print(f"  M_dot: {summary['M_dot_kg_s']:.3e} kg/s ({summary['M_dot_eddington_fraction']*100:.1f}% Eddington)")
    print(f"  Luminosity: {summary['L_total_watts']:.3e} W ({summary['L_eddington_fraction']*100:.1f}% L_Edd)")
    print(f"  Peak temp: {summary['T_peak_kelvin']:.0f} K")
    print(f"  Peak band: {summary['peak_band']} (lambda = {summary['peak_wavelength_m']:.2e} m)")

    return disk


def run_gw_analysis(m1_solar=30.0, m2_solar=30.0, distance_mpc=400.0):
    """Compute gravitational wave properties for a binary merger."""
    print("\n  GRAVITATIONAL WAVE ANALYSIS")
    print("-" * 70)

    from blackhole_sim.gravitational_waves.quadrupole import (
        chirp_mass, gw_strain_amplitude, time_to_merger, gw_luminosity
    )
    from blackhole_sim.gravitational_waves.ringdown import qnm_frequency_damping, ringdown_waveform
    from blackhole_sim.gravitational_waves.inspiral import inspiral_waveform

    Mc = chirp_mass(m1_solar, m2_solar)
    print(f"  Binary: {m1_solar} + {m2_solar} M_sun at {distance_mpc} Mpc")
    print(f"  Chirp mass: {Mc/constants.M_sun:.2f} M_sun")

    # At separation 10 r_s
    h = gw_strain_amplitude(m1_solar, m2_solar, distance_mpc, 100.0)
    t_merge = time_to_merger(m1_solar, m2_solar, 10.0)
    print(f"  Strain at f=100Hz: h ~ {h:.2e}")
    print(f"  Time to merger (10 r_s): {t_merge:.2f} s")

    # Ringdown
    M_final = (m1_solar + m2_solar) * 0.95  # ~5% radiated as GW
    a_final = 0.69  # Typical for equal-mass merger
    f_qnm, tau, Q = qnm_frequency_damping(M_final, a_final)
    print(f"  Ringdown: f_QNM = {f_qnm:.1f} Hz, tau = {tau*1000:.2f} ms, Q = {Q:.1f}")

    # Generate waveform
    t_array = np.linspace(-1.0, 0.05, 10000)
    h_plus, h_cross, f_gw, phase = inspiral_waveform(
        m1_solar, m2_solar, distance_mpc, t_array, t_merge=0.0
    )

    # Ringdown
    t_ring = np.linspace(0, 0.1, 5000)
    h_ring, _, _ = ringdown_waveform(M_final, a_final, t_ring, amplitude=np.max(np.abs(h_plus)))

    return t_array, h_plus, h_cross, f_gw


def run_binary_simulation(m1=30.0, m2=30.0, a1=0.0, a2=0.0, separation=50.0):
    """Run binary BH inspiral simulation."""
    print("\n  BINARY BLACK HOLE INSPIRAL")
    print("-" * 70)

    from blackhole_sim.physics.binary_dynamics import BinaryState, BinaryEvolution

    state = BinaryState(m1=m1, m2=m2, a1=a1, a2=a2, separation=separation)
    binary = BinaryEvolution(state)

    print(f"  Binary: {m1} + {m2} M_sun (q={state.mass_ratio:.3f})")
    print(f"  Eta: {state.eta:.4f}, Chirp mass: {state.chirp_mass:.2f} M_sun")
    print(f"  Chi_eff: {state.chi_eff:.3f}")
    print(f"  Initial separation: {separation} M")
    print(f"  Peters timescale: {binary.peters_timescale():.2e} M")

    # Final state from NR fitting
    merger = binary.final_state()
    print(f"\n  MERGER RESULT (NR fitting formulae):")
    print(f"  Final mass: {merger.final_mass:.4f} M ({merger.final_mass * state.total_mass:.2f} M_sun)")
    print(f"  Final spin: a_f = {merger.final_spin:.4f}")
    print(f"  Energy radiated: {merger.energy_radiated_fraction * 100:.2f}%")
    print(f"  Kick velocity: {merger.kick_velocity_kms:.1f} km/s")
    print(f"  Peak luminosity: {merger.peak_luminosity:.4f} c^5/G")

    # Run inspiral
    print(f"\n  Integrating 2.5PN inspiral...")
    t0 = time.time()
    timeline = binary.evolve()
    elapsed = time.time() - t0
    print(f"  Done in {elapsed:.1f}s ({len(timeline.time)} points)")
    print(f"  Final separation: {timeline.separation[-1]:.4f} M")
    print(f"  Peak GW frequency: {np.max(timeline.gw_frequency):.6f} /M")

    return timeline, merger


def run_back_reaction(config):
    """Run Hawking radiation back-reaction evolution."""
    print("\n  HAWKING BACK-REACTION EVOLUTION")
    print("-" * 70)

    from blackhole_sim.hawking.back_reaction import HawkingBackReaction

    br = HawkingBackReaction(
        M_initial=1.0,
        a_initial=config.spin,
        Q_initial=config.charge,
        n_species=1,
    )

    summary = br.summary()
    print(f"  Initial mass: {summary['M_initial']:.4f} (natural units)")
    print(f"  Initial spin: a* = {config.spin}")
    print(f"  Hawking temperature: {summary['T_initial']:.6e}")
    print(f"  BH entropy: {summary['S_initial']:.4f}")
    print(f"  Hawking luminosity: {summary['L_initial']:.6e}")
    print(f"  Page spin factor: {summary['page_factor']:.3f}")
    print(f"  Evaporation lifetime: {summary['evaporation_lifetime']:.2e} (natural)")

    # Run evolution (10% of lifetime)
    t_evap = br.evaporation_lifetime()
    print(f"\n  Evolving for 10% of evaporation lifetime...")
    t0 = time.time()
    timeline = br.evolve(t_max_natural=t_evap * 0.1, n_steps=1000)
    elapsed = time.time() - t0
    print(f"  Done in {elapsed:.2f}s")
    print(f"  Mass: {timeline.mass[0]:.6f} -> {timeline.mass[-1]:.6f}")
    print(f"  Spin: {timeline.spin[0]:.4f} -> {timeline.spin[-1]:.4f}")
    print(f"  Temperature: {timeline.temperature[0]:.6e} -> {timeline.temperature[-1]:.6e}")

    return timeline


def run_jet_analysis(config):
    """Run accretion + jet analysis."""
    print("\n  ACCRETION + JET ANALYSIS")
    print("-" * 70)

    from blackhole_sim.accretion.mhd_disk import MHDDiskModel

    mdot = config.accretion_rate_eddington if config.accretion_rate_eddington > 0 else 0.01
    mhd = MHDDiskModel(
        mass_solar=config.mass_solar,
        spin=config.spin,
        mdot_eddington=mdot,
        include_jet=config.spin > 0,
    )

    result = mhd.unified_result()
    summary = mhd.summary()

    print(f"  Disk type: {result.disk_type}")
    print(f"  mdot/mdot_Edd: {result.mdot_eddington}")
    print(f"  mdot_critical: {result.mdot_critical:.4f}")
    print(f"  Radiative efficiency: {result.radiative_efficiency:.4f}")
    print(f"  Disk luminosity: {result.luminosity_erg_s:.3e} erg/s")
    print(f"  Jet power (BZ): {result.jet_power_erg_s:.3e} erg/s")
    print(f"  Total power: {result.total_power_erg_s:.3e} erg/s")
    print(f"  ISCO: {result.isco_radius:.4f} M")
    print(f"  Jet dominated: {summary['jet_dominated']}")

    return mhd


def run_shadow(config, resolution=64):
    """Render black hole shadow image."""
    print("\n  BLACK HOLE SHADOW RENDERING")
    print("-" * 70)

    from blackhole_sim.visualization.shadow_renderer import ShadowRenderer

    metric = create_metric(config)
    renderer = ShadowRenderer(
        metric, r_observer=500.0, theta_observer=np.pi / 4,
        resolution=(resolution, resolution), fov_M=15.0,
    )

    print(f"  Metric: {config.metric_type}, a*={config.spin}")
    print(f"  Observer: r={renderer.r_observer}M, theta={np.degrees(renderer.theta_observer):.0f}deg")
    print(f"  Resolution: {resolution}x{resolution}")

    # Trace rays
    shadow = renderer.trace_all_rays(
        lambda_max=1500.0, disk_inner=None, disk_outer=50.0
    )

    # Print summary
    summary = renderer.summary(shadow)
    print(f"\n  Results:")
    print(f"  Captured: {summary['n_captured']} ({summary['capture_fraction']*100:.1f}%)")
    print(f"  Disk hits: {summary['n_disk_hit']}")
    print(f"  Scattered: {summary['n_scattered']}")
    if summary['n_disk_hit'] > 0:
        print(f"  Mean disk redshift: {summary['mean_redshift_disk']:.4f}")
    print(f"  Max half-orbits: {summary['max_n_orbits']:.1f}")

    # Save image
    output_dir = config.output_dir
    os.makedirs(output_dir, exist_ok=True)
    save_path = os.path.join(output_dir, 'shadow.png')
    renderer.render_image(shadow, save_path=save_path)

    return shadow


def run_plot_all(config):
    """Generate all visualizations from all simulation modes."""
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt

    output_dir = config.output_dir
    os.makedirs(output_dir, exist_ok=True)

    print("\n  GENERATING ALL VISUALIZATIONS")
    print("=" * 70)

    # 1. Geodesic orbit + dashboard
    print("\n  [1/6] Geodesic simulation...")
    trajectory, metric = run_geodesic_simulation(config, 'circular', r_orbit=10.0, tau_max=500.0)

    from blackhole_sim.visualization.geodesic_plots import (
        plot_orbit_2d, plot_conservation_laws, plot_orbit_3d, plot_effective_potential
    )
    from blackhole_sim.visualization.dashboard import (
        simulation_dashboard, back_reaction_dashboard, binary_dashboard, shadow_dashboard
    )

    plot_orbit_2d(trajectory, metric, save_path=os.path.join(output_dir, 'orbit.png'))
    print(f"    Saved: orbit.png")
    plot_orbit_3d(trajectory, metric, save_path=os.path.join(output_dir, 'orbit_3d.png'))
    print(f"    Saved: orbit_3d.png")
    plot_conservation_laws(trajectory, save_path=os.path.join(output_dir, 'conservation.png'))
    print(f"    Saved: conservation.png")
    simulation_dashboard(trajectory, metric, config, save_path=os.path.join(output_dir, 'dashboard.png'))
    print(f"    Saved: dashboard.png")
    plot_effective_potential(metric, [3.0, 3.464, 4.0, 5.0],
                            save_path=os.path.join(output_dir, 'effective_potential.png'))
    print(f"    Saved: effective_potential.png")
    plt.close('all')

    # 2. Binary inspiral
    print("\n  [2/6] Binary inspiral...")
    from blackhole_sim.visualization.waveform_plots import plot_binary_orbit
    timeline, merger = run_binary_simulation(30.0, 30.0)
    binary_dashboard(timeline, merger, save_path=os.path.join(output_dir, 'binary_dashboard.png'))
    print(f"    Saved: binary_dashboard.png")
    plot_binary_orbit(timeline, merger, save_path=os.path.join(output_dir, 'binary_orbit.png'))
    print(f"    Saved: binary_orbit.png")
    plt.close('all')

    # 3. IMR waveform
    print("\n  [3/6] IMR waveform...")
    from blackhole_sim.gravitational_waves.imr_waveform import IMRWaveform
    from blackhole_sim.visualization.waveform_plots import plot_imr_waveform, plot_spectrogram
    imr = IMRWaveform(36.0, 29.0, distance_mpc=410.0)
    result_td = imr.generate_time_domain()
    plot_imr_waveform(result_td, save_path=os.path.join(output_dir, 'imr_waveform.png'))
    print(f"    Saved: imr_waveform.png")
    if len(result_td.h_plus) > 512:
        dt_imr = result_td.time[1] - result_td.time[0]
        plot_spectrogram(result_td.h_plus, dt_imr,
                         save_path=os.path.join(output_dir, 'spectrogram.png'))
        print(f"    Saved: spectrogram.png")
    plt.close('all')

    # 4. Back-reaction
    print("\n  [4/6] Hawking back-reaction...")
    br_timeline = run_back_reaction(config)
    back_reaction_dashboard(br_timeline, save_path=os.path.join(output_dir, 'back_reaction.png'))
    print(f"    Saved: back_reaction.png")
    plt.close('all')

    # 5. Accretion + jet
    print("\n  [5/6] Accretion + jet...")
    if config.spin > 0:
        mhd = run_jet_analysis(config)
        from blackhole_sim.visualization.accretion_render import (
            plot_unified_accretion, plot_jet_structure, plot_disk_cross_section
        )
        plot_unified_accretion(mhd, save_path=os.path.join(output_dir, 'unified_accretion.png'))
        print(f"    Saved: unified_accretion.png")

        if mhd.include_jet and mhd.spin > 0:
            jet = mhd._ensure_jet()
            if jet is not None:
                plot_jet_structure(jet, save_path=os.path.join(output_dir, 'jet_structure.png'))
                print(f"    Saved: jet_structure.png")

        result = mhd.unified_result()
        if result.disk_type == 'adaf':
            thick = mhd._ensure_thick_disk()
            plot_disk_cross_section(thick, save_path=os.path.join(output_dir, 'disk_cross_section.png'))
            print(f"    Saved: disk_cross_section.png")
        plt.close('all')
    else:
        print("    Skipping jet (a*=0)")

    # 6. Shadow (small resolution for speed)
    print("\n  [6/6] Black hole shadow (32x32)...")
    shadow = run_shadow(config, resolution=32)
    shadow_dashboard(shadow, save_path=os.path.join(output_dir, 'shadow_dashboard.png'))
    print(f"    Saved: shadow_dashboard.png")
    plt.close('all')

    print(f"\n{'='*70}")
    print(f"  ALL VISUALIZATIONS SAVED to {output_dir}/")
    print(f"{'='*70}")


def run_compute_status():
    """Report compute infrastructure status."""
    print("\n  COMPUTE STATUS")
    print("-" * 70)

    import os
    print(f"  CPU cores: {os.cpu_count()}")
    print(f"  Workers: {max(os.cpu_count() // 2, 1)}")

    gpu = GPUEngine()
    gpu_status = gpu.status()
    if gpu_status.get('gpu_available'):
        mem = gpu.memory_info()
        print(f"  GPU: {mem.get('device_name', 'Unknown')}")
        print(f"  VRAM: {mem.get('total_gb', 0):.1f} GB total, {mem.get('free_gb', 0):.1f} GB free")
    else:
        print("  GPU: Not available (CPU-only mode)")

    scheduler = AdaptiveScheduler(gpu_engine=gpu)
    print(f"  GPU batch threshold: {scheduler.GPU_BATCH_THRESHOLD}")
    print(f"  Numba CUDA: {gpu_status.get('numba_available', False)}")


def run_validation():
    """Run the full validation suite."""
    print("\n  VALIDATION SUITE")
    print("=" * 70)
    from blackhole_sim.validation.analytic_tests import run_all_validations
    return run_all_validations(verbose=True)


def run_demo(config=None):
    """Run a comprehensive demonstration."""
    if config is None:
        config = PRESETS['stellar']()

    print_header(config)
    metric = create_metric(config)
    print_bh_properties(metric, config)

    # Geodesic simulation
    trajectory, metric = run_geodesic_simulation(config, 'circular', r_orbit=10.0, tau_max=500.0)

    # Hawking radiation
    run_hawking_analysis(config)

    # Accretion (if applicable)
    if config.accretion_rate_eddington > 0:
        run_accretion_analysis(config)

    # GW analysis
    run_gw_analysis()

    # Jet analysis
    if config.spin > 0:
        run_jet_analysis(config)

    # Back-reaction summary
    run_back_reaction(config)

    # Compute status
    run_compute_status()

    # Visualization
    print("\n  GENERATING VISUALIZATIONS")
    print("-" * 70)

    output_dir = config.output_dir
    os.makedirs(output_dir, exist_ok=True)

    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt

    from blackhole_sim.visualization.geodesic_plots import plot_orbit_2d, plot_conservation_laws, plot_orbit_3d
    from blackhole_sim.visualization.dashboard import simulation_dashboard

    plot_orbit_2d(trajectory, metric, save_path=os.path.join(output_dir, 'orbit.png'))
    print(f"  Saved: {output_dir}/orbit.png")

    plot_conservation_laws(trajectory, save_path=os.path.join(output_dir, 'conservation.png'))
    print(f"  Saved: {output_dir}/conservation.png")

    simulation_dashboard(trajectory, metric, config, save_path=os.path.join(output_dir, 'dashboard.png'))
    print(f"  Saved: {output_dir}/dashboard.png")

    plot_orbit_3d(trajectory, metric, save_path=os.path.join(output_dir, 'orbit_3d.png'))
    print(f"  Saved: {output_dir}/orbit_3d.png")

    plt.close('all')

    print(f"\n{'='*70}")
    print("  SIMULATION COMPLETE")
    print(f"{'='*70}")


def main():
    parser = argparse.ArgumentParser(description='Black Hole Simulator (FCE-powered)')
    parser.add_argument('--preset', type=str, default='stellar',
                        choices=list(PRESETS.keys()),
                        help='Black hole preset')
    parser.add_argument('--mass', type=float, default=None, help='Mass in solar masses')
    parser.add_argument('--spin', type=float, default=None, help='Spin parameter a*')
    parser.add_argument('--charge', type=float, default=None, help='Charge parameter Q*')
    parser.add_argument('--orbit', type=str, default='circular',
                        choices=['circular', 'radial', 'plunge'],
                        help='Orbit type')
    parser.add_argument('--r-orbit', type=float, default=10.0, help='Orbital radius (M)')
    parser.add_argument('--tau-max', type=float, default=500.0, help='Max proper time (M)')
    parser.add_argument('--no-fce', action='store_true', help='Disable FCE correction')
    parser.add_argument('--no-gpu', action='store_true', help='Disable GPU')
    parser.add_argument('--validate', action='store_true', help='Run validation suite')
    parser.add_argument('--demo', action='store_true', help='Full demonstration')
    parser.add_argument('--binary', action='store_true', help='Run binary BH inspiral')
    parser.add_argument('--binary-m1', type=float, default=30.0, help='Primary mass (M_sun)')
    parser.add_argument('--binary-m2', type=float, default=30.0, help='Secondary mass (M_sun)')
    parser.add_argument('--binary-sep', type=float, default=50.0, help='Initial separation (M)')
    parser.add_argument('--back-reaction', action='store_true', help='Run Hawking back-reaction')
    parser.add_argument('--jet', action='store_true', help='Run accretion + jet analysis')
    parser.add_argument('--shadow', action='store_true', help='Render BH shadow image')
    parser.add_argument('--shadow-res', type=int, default=64, help='Shadow resolution (NxN)')
    parser.add_argument('--plot-all', action='store_true', help='Generate all visualizations')
    parser.add_argument('--output', type=str, default='results', help='Output directory')

    args = parser.parse_args()

    if args.validate:
        success = run_validation()
        sys.exit(0 if success else 1)

    # Build config
    overrides = {'output_dir': args.output}
    if args.mass is not None:
        overrides['mass_solar'] = args.mass
    if args.spin is not None:
        overrides['spin'] = args.spin
    if args.charge is not None:
        overrides['charge'] = args.charge
    if args.no_fce:
        overrides['use_fce'] = False
    if args.no_gpu:
        overrides['use_gpu'] = False

    config = PRESETS[args.preset](**overrides)

    if args.plot_all:
        print_header(config)
        run_plot_all(config)
    elif args.demo:
        run_demo(config)
    elif args.binary:
        print_header(config)
        run_binary_simulation(args.binary_m1, args.binary_m2, separation=args.binary_sep)
    elif args.back_reaction:
        print_header(config)
        run_back_reaction(config)
    elif args.jet:
        print_header(config)
        run_jet_analysis(config)
    elif args.shadow:
        print_header(config)
        run_shadow(config, resolution=args.shadow_res)
    else:
        print_header(config)
        metric = create_metric(config)
        print_bh_properties(metric, config)
        run_geodesic_simulation(config, args.orbit, args.r_orbit, args.tau_max)


if __name__ == '__main__':
    main()
