from __future__ import annotations

import sys
from dataclasses import dataclass, replace
from pathlib import Path
from string import Template
import json

import matplotlib
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation

if not hasattr(np, "product"):
    np.product = np.prod

# Support a local py4cats install (./.pydeps) to avoid relying on system site-packages
PROJECT_ROOT = Path(__file__).resolve().parent
LOCAL_SITE_PACKAGES = PROJECT_ROOT / ".pydeps"
FIGURE_PATH = PROJECT_ROOT / "o2_a_band_transit.png"
SNR_WAVE_PATH = PROJECT_ROOT / "snr_vs_wavelength.png"
TIME_MAP_PATH = PROJECT_ROOT / "snr_hours_map.png"
SNR_SCALING_PATH = PROJECT_ROOT / "snr_scaling.png"
DISTANCE_CURVE_PATH = PROJECT_ROOT / "snr_vs_distance.png"
SLANT_PATH_FIG = PROJECT_ROOT / "slant_paths_3d.png"
ABAND_ANIMATION = PROJECT_ROOT / "o2_band_animation.gif"
TRANSIT_ANIMATION = PROJECT_ROOT / "transit_slantpaths.gif"
DASHBOARD_HTML = PROJECT_ROOT / "dashboard.html"
VMR_CURVE_PATH = PROJECT_ROOT / "o2_vmr_threshold.png"
INSTRUMENT_COMPARE_PATH = PROJECT_ROOT / "instrument_compare.png"
if LOCAL_SITE_PACKAGES.exists():
    site_path = str(LOCAL_SITE_PACKAGES)
    if site_path not in sys.path:
        sys.path.insert(0, site_path)

from py4cats.lbl.lbl2xs import lbl2xs
from py4cats.lbl.lines import lineArray

# =========================
# 1. PHYSICAL CONSTANTS
# =========================

k_B = 1.380649e-16      # Boltzmann constant, erg/K (cgs)
g_p = 980.0             # surface gravity on the planet, cm/s^2 (≈ Earth)
R_earth = 6.371e8       # Earth radius, cm
R_sun   = 6.957e10      # Solar radius, cm
HITRAN_REF_TEMP = 296.0 # HITRAN reference temperature, K
HITRAN_REF_PRESSURE = 1.01325e6  # HITRAN reference pressure, dyn/cm^2
h_planck = 6.62607015e-27        # erg*s
c_light = 2.99792458e10          # cm/s
pc_to_cm = 3.085677581e18
CLOUD_TOP_KM = 15.0


@dataclass
class InstrumentConfig:
    name: str = "HWO"
    telescope_diameter_m: float = 6.0
    optical_throughput: float = 0.2
    spectral_resolution: float = 150.0
    exposure_time_s: float = 1000.0
    read_noise_e: float = 5.0
    dark_current_e_per_s: float = 0.001
    distance_pc: float = 10.0
    star_temperature_K: float = 5778.0
    stellar_radius_cm: float = R_sun
    noise_floor_ppm: float = 7.0
    noise_floor_reference_pc: float = 10.0
    noise_floor_exponent: float = 1.0

# Planet and star (tweak to match a preferred exoplanet)
R_p = R_earth           # planet radius
R_star = R_sun          # stellar radius

# Wavelength range (O2 A-band ~ 760 nm)
lambda_min_nm = 740.0
lambda_max_nm = 780.0


# =========================
# 2. EXO-EARTH ATMOSPHERE
# =========================

def build_exoearth_atmosphere(n_layers=80, z_top_km=120.0, cloud_top_km=None):
    """Construct a simple layered Earth-analog atmosphere.

    Parameters
    ----------
    n_layers : int
        Number of layers from the surface to ``z_top_km``.
    z_top_km : float
        Top of the model atmosphere (km).
    cloud_top_km : float or None
        If provided, O₂ is truncated above this altitude to mimic clouds/haze.

    Returns
    -------
    dict
        Dictionary with heights, pressure, temperature, total density, and O₂ density.
    """
    # layer mid-point heights, cm
    z_km = np.linspace(0.5 * z_top_km / n_layers,
                       z_top_km - 0.5 * z_top_km / n_layers,
                       n_layers)
    z = z_km * 1.0e5  # km -> cm

    # scale height H ~ 7.5 km
    H_cm = 7.5e5

    # pressure (dyn/cm^2), p0 ~ 1 atm
    p0 = 1.01325e6
    p = p0 * np.exp(-z / H_cm)

    # crude T(z): 288 K at the surface, minimum of 200 K
    T = 288.0 - 6.5 * z_km      # lapse rate 6.5 K/km
    T = np.clip(T, 200.0, 288.0)

    # total number density n = p / (k_B T)
    n_total = p / (k_B * T)

    # O2 volume mixing ratio ~ 0.21
    vmr_O2 = 0.21
    n_O2 = vmr_O2 * n_total

    if cloud_top_km is not None:
        n_O2 = np.where(z_km <= cloud_top_km, n_O2, 0.0)

    atm = {
        "z": z,             # cm
        "z_km": z_km,       # km
        "p": p,             # dyn/cm^2
        "T": T,             # K
        "n_total": n_total, # cm^-3
        "n_O2": n_O2        # cm^-3
    }
    return atm


# =========================
# 3. O2 LINE INGEST (HITRAN)
# =========================

def _slice_number(line, start, end, *, cast=float, default=None, field_name=""):
    """Return a number pulled from fixed-width HITRAN columns."""
    token = line[start:end]
    text = token.strip()
    if not text:
        if default is not None:
            return default
        raise ValueError(f"Failed to read {field_name or 'value'} (cols {start}:{end})")
    text = text.replace("D", "E").replace("d", "E")
    return cast(text)


def _read_hitran_extract(hitran_file):
    """Read the minimal subset of HITRAN .par columns without headers."""
    columns = "v S E a s n d".split()
    dtype = [(name, float) for name in columns]
    records = []

    with open(hitran_file, "r", encoding="ascii", errors="ignore") as infile:
        for row_id, raw_line in enumerate(infile, start=1):
            if not raw_line.strip():
                continue
            line = raw_line.rstrip("\r\n")
            if len(line) < 67:
                raise ValueError(f"Row {row_id} is too short for the HITRAN format")
            record = (
                _slice_number(line, 3, 15, field_name="position"),
                _slice_number(line, 15, 25, field_name="strength"),
                _slice_number(line, 45, 55, field_name="lower-state energy"),
                _slice_number(line, 35, 40, field_name="air-width"),
                _slice_number(line, 40, 45, field_name="self-width", default=0.0),
                _slice_number(line, 55, 59, field_name="n-air", default=0.0),
                _slice_number(line, 59, 67, field_name="pressure-shift", default=0.0),
            )
            records.append(record)

    if not records:
        raise ValueError("No HITRAN data rows found in the file")

    lines = np.array(records, dtype=dtype)
    order = np.argsort(lines["v"])
    return lines[order]


def load_o2_lines(hitran_file):
    """Load O2 lines from a header-free HITRAN extract file."""
    raw_lines = _read_hitran_extract(hitran_file)
    line = lineArray(raw_lines, p=HITRAN_REF_PRESSURE, t=HITRAN_REF_TEMP, molec="O2")
    print(f"Loaded {line.size} O2 lines from {hitran_file}")
    return line


def planck_lambda(lam_cm, T):
    """Planck spectral radiance B_lambda (erg / (s*cm^2*sr*cm))."""
    lam5 = lam_cm**5
    expo = h_planck * c_light / (lam_cm * k_B * T)
    intensity = (2.0 * h_planck * c_light**2) / lam5
    return intensity / (np.exp(expo) - 1.0)


def stellar_flux_at_earth(lam_cm, config: InstrumentConfig):
    """Stellar spectral flux at config.distance_pc (erg / (s*cm^2*cm))."""
    surface_flux = np.pi * planck_lambda(lam_cm, config.star_temperature_K)
    distance_cm = config.distance_pc * pc_to_cm
    geom_factor = (config.stellar_radius_cm / distance_cm)**2
    return surface_flux * geom_factor


def photon_rate_per_bin(lam_nm, config: InstrumentConfig):
    """Return the photon arrival rate per spectral bin (photons/s)."""
    lam_cm = lam_nm * 1e-7
    flux_lambda = stellar_flux_at_earth(lam_cm, config)
    area_cm2 = np.pi * (0.5 * config.telescope_diameter_m * 100.0)**2
    delta_lambda = lam_cm / config.spectral_resolution
    energy_per_photon = h_planck * c_light / lam_cm
    photons = flux_lambda * area_cm2 * config.optical_throughput * delta_lambda / energy_per_photon
    return photons


def snr_time_requirements(lam_nm, signal_depth, config: InstrumentConfig, snr_targets):
    photon_rate = photon_rate_per_bin(lam_nm, config)
    noise_norm = photon_rate + config.dark_current_e_per_s + (config.read_noise_e**2 / config.exposure_time_s)
    sigma_floor, _ = effective_noise_floor(config)
    snr_tables = {}
    for target in snr_targets:
        sigma_target_sq = np.where(target > 0, (signal_depth / target)**2, np.inf)
        requirement = sigma_target_sq - sigma_floor**2
        denom = photon_rate**2 * requirement
        valid = (denom > 0) & np.isfinite(denom) & (signal_depth > 0)
        t_seconds = np.where(valid, noise_norm / denom, np.inf)
        snr_tables[target] = t_seconds / 3600.0
    return photon_rate, noise_norm, snr_tables, sigma_floor


def snr_for_time(total_hours, depth, photon_rate, noise_norm, sigma_floor):
    t_seconds = total_hours * 3600.0
    sigma_ph = np.sqrt(noise_norm / np.maximum(t_seconds, 1e-30)) / np.maximum(photon_rate, 1e-30)
    sigma_tot = np.sqrt(sigma_ph**2 + sigma_floor**2)
    snr = depth / sigma_tot
    return snr


def summarize_snr(lam_nm, depth, signal_depth, photon_rate, noise_norm, snr_tables, config: InstrumentConfig, sigma_floor):
    depth_ppm = depth * 1e6
    lam_core_idx = int(np.argmax(depth_ppm))
    continuum_level = np.percentile(depth_ppm, 20)
    lam_cont_idx = int(np.argmin(np.abs(depth_ppm - continuum_level)))
    pivot_nm = 760.0
    lam_pivot_idx = int(np.argmin(np.abs(lam_nm - pivot_nm)))

    key_indices = [
        ("Line core", lam_core_idx),
        ("Continuum", lam_cont_idx),
        ("760 nm", lam_pivot_idx)
    ]

    headers = [
        "Feature", "λ (nm)", "Depth (ppm)", "Photon rate (s^-1)",
        "Hours SNR=3", "Hours SNR=5", "Hours SNR=10"
    ]
    print("\n=== Summary table for key spectral points ===")
    print(" | ".join(headers))
    print("-" * 110)
    table_rows = []
    for label, idx in key_indices:
        row = [
            label,
            f"{lam_nm[idx]:7.1f}",
            f"{depth_ppm[idx]:9.1f}",
            f"{photon_rate[idx]:12.3e}",
            f"{snr_tables[3][idx]:10.2f}",
            f"{snr_tables[5][idx]:10.2f}",
            f"{snr_tables[10][idx]:10.2f}"
        ]
        print(" | ".join(row))
        table_rows.append((label, lam_nm[idx], depth_ppm[idx], photon_rate[idx], snr_tables[3][idx], snr_tables[5][idx], snr_tables[10][idx]))

    hours_grid = np.logspace(-1, 2.7, 200)
    core_photon = photon_rate[lam_core_idx]
    core_noise = noise_norm[lam_core_idx]
    core_signal = signal_depth[lam_core_idx]
    snr_scaling = snr_for_time(hours_grid, core_signal, core_photon, core_noise, sigma_floor)

    plt.figure(figsize=(8, 4))
    plt.plot(lam_nm, snr_for_time(100.0, signal_depth, photon_rate, noise_norm, sigma_floor), color="#0072B2")
    plt.xlabel("Wavelength (nm)")
    plt.ylabel("SNR (100 h)")
    plt.title("SNR across the A-band (100 h integration)")
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(SNR_WAVE_PATH, dpi=200)

    plt.figure(figsize=(8, 4))
    plt.semilogy(lam_nm, snr_tables[5], color="#D55E00")
    plt.xlabel("Wavelength (nm)")
    plt.ylabel("Hours for SNR=5")
    plt.title("Integration time needed for SNR=5")
    plt.grid(True, which="both", alpha=0.3)
    plt.tight_layout()
    plt.savefig(TIME_MAP_PATH, dpi=200)

    plt.figure(figsize=(6, 4))
    plt.loglog(hours_grid, snr_scaling, color="#009E73")
    for target in (3, 5, 10):
        hours_needed = snr_tables[target][lam_core_idx]
        plt.scatter(hours_needed, target, color="#CC79A7")
        plt.text(hours_needed * 1.05, target * 1.05, f"SNR={target}\n{hours_needed:.1f} h")
    plt.xlabel("Integration time (hours)")
    plt.ylabel("SNR at O$_2$ core")
    plt.title("SNR growth at the O$_2$ line core")
    plt.grid(True, which="both", alpha=0.3)
    plt.tight_layout()
    plt.savefig(SNR_SCALING_PATH, dpi=200)

    print("\nSaved plots:")
    print(f"  - {FIGURE_PATH.name} (spectrum)")
    print(f"  - {SNR_WAVE_PATH.name} (SNR vs wavelength)")
    print(f"  - {TIME_MAP_PATH.name} (hours for SNR=5)")
    print(f"  - {SNR_SCALING_PATH.name} (SNR scaling)")

    return lam_core_idx, table_rows


def distance_sensitivity_analysis(lam_nm, signal_depth, core_idx, snr_target, base_config: InstrumentConfig):
    distances = np.array([5.0, 10.0, 20.0, 30.0])
    time_hours = []
    snr_100h = []
    for d in distances:
        cfg = replace(base_config, distance_pc=d)
        photon_rate, noise_norm, snr_tables, _ = snr_time_requirements(
            lam_nm,
            signal_depth,
            cfg,
            (snr_target,)
        )
        hours = snr_tables[snr_target][core_idx]
        time_hours.append(hours)
        sigma_floor_cfg, _ = effective_noise_floor(cfg)
        snr100 = snr_for_time(100.0, signal_depth[core_idx], photon_rate[core_idx], noise_norm[core_idx], sigma_floor_cfg)
        snr_100h.append(snr100)

    plt.figure(figsize=(6, 4))
    plt.loglog(distances, time_hours, marker="o")
    plt.xlabel("Distance (pc)")
    plt.ylabel(f"Hours for SNR={snr_target} (O$_2$ core)")
    plt.title("Integration time required to detect O$_2$")
    plt.grid(True, which="both", alpha=0.3)
    plt.tight_layout()
    plt.savefig(DISTANCE_CURVE_PATH, dpi=200)

    print("\n=== Observation time vs distance (SNR=5 at O₂ core) ===")
    print("Distance (pc) | Hours | SNR(100h)")
    print("----------------------------------")
    distance_rows = []
    for d, hours, snr100 in zip(distances, time_hours, snr_100h):
        print(f"{d:11.1f} | {hours:6.3f} | {snr100:9.1f}")
        distance_rows.append((d, hours, snr100))

    limiting_hours = 100.0
    feasible = distances[np.array(time_hours) <= limiting_hours]
    if feasible.size:
        max_feasible = feasible.max()
        print(
            f"\nWith a {limiting_hours:.0f}h budget, HWO reaches SNR={snr_target}"
            f" out to ≈{max_feasible:.0f} pc (photon-limited)."
        )
    else:
        print(
            f"\nEven nearby targets need >{limiting_hours:.0f}h for SNR={snr_target};"
            " more collecting area or throughput is required."
        )

    print(f"  - {DISTANCE_CURVE_PATH.name} (hours vs distance)")
    return distances, time_hours, snr_100h


# =========================
# 4. CROSS-SECTIONS σ(λ,p,T)
# =========================

def compute_cross_sections(line, p, T, *, nGrids=3, x_limits=None):
    """Compute line-by-line cross-sections via lbl2xs for a layered atmosphere.

    Parameters
    ----------
    line : lineArray | dict[lineArray]
        Line definition for lbl2xs.
    p, T : array-like
        Pressure and temperature arrays of length ``N_layers``.
    x_limits : tuple or None
        Optional (ν_min, ν_max) bounds in cm^-1.

    Returns
    -------
    wn : ndarray
        Shared wavenumber grid (cm^-1).
    xs_layers : list[xsArray]
        Cross-sections per layer (σ(ν) in cm^2/molecule).
    """
    xss = lbl2xs(line, p, T, nGrids=nGrids, xLimits=x_limits)

    # Depending on the line spec, lbl2xs returns either:
    #  - a list of xsArray objects (single molecule)
    #  - a dict {molecule: [xsArray, ...]}
    if isinstance(xss, list):
        xs_layers = xss
    elif isinstance(xss, dict):
        # grab the first molecule
        first_key = list(xss.keys())[0]
        print("Molecule in xss:", first_key)
        xs_layers = xss[first_key]
    else:
        raise TypeError(f"Unsupported lbl2xs result type: {type(xss)}")

    grids = [xs.grid() for xs in xs_layers]
    grid_min = max(grid[0] for grid in grids)
    grid_max = min(grid[-1] for grid in grids)
    if grid_min >= grid_max:
        raise ValueError("Layer grids do not overlap in wavenumber range")

    ref_idx = int(np.argmax([len(grid) for grid in grids]))
    ref_grid = grids[ref_idx]
    mask = (ref_grid >= grid_min) & (ref_grid <= grid_max)
    wn = ref_grid[mask]

    xs_common = []
    for idx, (grid, xs) in enumerate(zip(grids, xs_layers)):
        values = np.asarray(xs)
        if idx == ref_idx and mask.all():
            xs_common.append(values)
        elif idx == ref_idx:
            xs_common.append(values[mask])
        else:
            xs_common.append(np.interp(wn, grid, values))

    return wn, xs_common


# =========================
# 5. TRANSIT GEOMETRY
# =========================

def compute_slant_path_matrix(z, R_p):
    """Return a matrix of slant-path lengths Δsᵢⱼ for transit geometry."""
    n_layers = len(z)
    # layer boundaries
    # assume a uniform grid in z, so
    dz = np.diff(z)
    dz = np.append(dz, dz[-1])  # duplicate for the last layer

    r_mid = R_p + z
    r_in = r_mid - dz / 2.0
    r_out = r_mid + dz / 2.0

    # impact parameters sampled at layer centers
    b = R_p + z  # shape (n_layers,)

    # Δs matrix (layers along z, b-index along columns)
    ds = np.zeros((n_layers, n_layers))

    for j in range(n_layers):
        bj = b[j]
        for i in range(n_layers):
            if bj >= r_out[i]:
                # ray misses the layer entirely
                continue
            # inner boundary for the square root
            rin = max(bj, r_in[i])
            # if r_in < b < r_out the path clips only the upper portion
            s_out = np.sqrt(max(r_out[i]**2 - bj**2, 0.0))
            s_in = np.sqrt(max(rin**2    - bj**2, 0.0))
            ds[i, j] = 2.0 * (s_out - s_in)

    return b, ds


# =========================
# 6. TRANSMISSION SPECTRUM
# =========================

def compute_transmission_spectrum(wn, xs_layers, atm, R_p, R_star):
    """Build the transit spectrum from layer cross-sections and geometry."""
    z = atm["z"]
    n_O2 = atm["n_O2"]
    n_layers = len(z)

    # slant-path matrix
    b, ds = compute_slant_path_matrix(z, R_p)  # b[j], ds[i,j]

    # σ(ν, i) matrix
    n_points = len(wn)
    sigma = np.vstack([np.asarray(xs) for xs in xs_layers])

    # O2 concentrations per layer
    nO2 = n_O2[:, None]            # (n_layers, 1) for convenient broadcast
    # ds: (n_layers, n_layers_b)
    # τ(ν, b_j) = sum_i σ_i(ν) * nO2_i * ds_{i,j}
    alpha = sigma * nO2
    tau = np.einsum('lp,lb->pb', alpha, ds)

    # Transmission along each chord: T(ν, b) = exp(-τ)
    T_nu_b = np.exp(-tau)

    # Effective radius via the integral over impact parameter:
    # R_eff^2(ν) = R_p^2 + 2 ∫ (1 - T(ν, b)) b db
    # discretize over b:
    # sort b just in case
    sort_idx = np.argsort(b)
    b_sorted = b[sort_idx]
    T_sorted = T_nu_b[:, sort_idx]

    integrand = (1.0 - T_sorted) * b_sorted
    R_eff2 = R_p**2 + 2.0 * np.trapezoid(integrand, b_sorted, axis=1)

    R_eff = np.sqrt(np.clip(R_eff2, a_min=0.0, a_max=None))

    # Transit depth
    transit_depth = (R_eff / R_star)**2

    return R_eff, transit_depth


def transmission_from_tau(tau, b, R_p, R_star):
    T_nu_b = np.exp(-tau)
    sort_idx = np.argsort(b)
    b_sorted = b[sort_idx]
    T_sorted = T_nu_b[:, sort_idx]
    integrand = (1.0 - T_sorted) * b_sorted
    R_eff2 = R_p**2 + 2.0 * np.trapezoid(integrand, b_sorted, axis=1)
    R_eff = np.sqrt(np.clip(R_eff2, a_min=0.0, a_max=None))
    depth = (R_eff / R_star)**2
    return R_eff, depth


def feature_signal(depth):
    continuum = np.percentile(depth, 5)
    return np.clip(depth - continuum, 0.0, None)


def effective_noise_floor(config: InstrumentConfig) -> float:
    ref = max(config.noise_floor_reference_pc, 1e-6)
    scale = max(config.distance_pc / ref, 1.0)
    sigma_ppm = config.noise_floor_ppm * (scale ** config.noise_floor_exponent)
    return sigma_ppm * 1e-6, sigma_ppm


# =========================
# 7. HWO-CONVOLVED SPECTRUM
# =========================

def convolve_to_resolution(wavelength_nm, depth, R=150):
    """Gaussian convolution to mimic a given resolving power (default R~150)."""
    lam = wavelength_nm
    y = depth

    # sort by wavelength
    idx = np.argsort(lam)
    lam = lam[idx]
    y = y[idx]

    # new grid (≈200 points)
    n_out = 200
    lam_min, lam_max = lam[0], lam[-1]
    lam_out = np.linspace(lam_min, lam_max, n_out)

    y_out = np.zeros_like(lam_out)

    for i, l0 in enumerate(lam_out):
        fwhm = l0 / R
        sigma = fwhm / (2.0 * np.sqrt(2.0 * np.log(2.0)))
        # Gaussian kernel
        w = np.exp(-0.5 * ((lam - l0) / sigma)**2)
        w /= np.sum(w)
        y_out[i] = np.sum(y * w)

    return lam_out, y_out


def compute_optical_depth(xs_layers, atm, ds):
    sigma = np.vstack([np.asarray(xs) for xs in xs_layers])
    nO2 = atm["n_O2"][:, None]
    alpha = sigma * nO2
    tau = np.einsum('lp,lb->pb', alpha, ds)
    return tau


def visualize_atmosphere_slant_paths(atm, R_p, b, ds, tau_line, output_path):
    from mpl_toolkits.mplot3d import Axes3D  # noqa: F401

    cm_to_km = 1e-5
    radii_km = R_p * cm_to_km + atm["z_km"]
    planet_radius_km = R_p * cm_to_km
    top_radius_km = radii_km[-1]

    fig = plt.figure(figsize=(7, 7))
    ax = fig.add_subplot(111, projection='3d')
    theta = np.linspace(0, np.pi, 50)
    phi = np.linspace(0, 2 * np.pi, 50)
    theta, phi = np.meshgrid(theta, phi)

    for altitude in [0, 30, 60, 90, 120]:
        radius = planet_radius_km + altitude
        x = radius * np.sin(theta) * np.cos(phi)
        y = radius * np.sin(theta) * np.sin(phi)
        z = radius * np.cos(theta)
        ax.plot_surface(x, y, z, color='royalblue', alpha=0.03 + 0.002 * altitude, linewidth=0, shade=False)

    num_paths = 12
    azimuths = np.linspace(0, 2 * np.pi, num_paths, endpoint=False)
    selected_indices = np.linspace(0, len(b) - 1, num_paths, dtype=int)
    tau_norm = (tau_line[selected_indices] - tau_line.min()) / max(np.ptp(tau_line), 1e-30)

    for idx, phi0, weight in zip(selected_indices, azimuths, tau_norm):
        impact_km = (b[idx] * cm_to_km) - planet_radius_km
        radius = planet_radius_km + impact_km
        max_extent = np.sqrt(max(top_radius_km**2 - radius**2, 0.0))
        x_vals = np.linspace(-max_extent, max_extent, 200)
        y_vals = radius * np.cos(phi0) * np.ones_like(x_vals)
        z_vals = radius * np.sin(phi0) * np.ones_like(x_vals)
        color = plt.cm.inferno(0.3 + 0.7 * weight)
        ax.plot(x_vals, y_vals, z_vals, color=color, linewidth=2)

    ax.set_xlabel('x (km)')
    ax.set_ylabel('y (km)')
    ax.set_zlabel('z (km)')
    ax.set_title('Slant paths across the O$_2$ atmosphere')
    ax.set_box_aspect([1, 1, 1])
    ax.set_xlim(-top_radius_km, top_radius_km)
    ax.set_ylim(-top_radius_km, top_radius_km)
    ax.set_zlim(-top_radius_km, top_radius_km)
    plt.tight_layout()
    plt.savefig(output_path, dpi=200)
    plt.close(fig)


def animate_a_band(lam_nm, depth, R_eff, R_p, output_path, fps=20):
    R_eff_rel = R_eff / R_p
    lam = lam_nm
    depth_ppm = depth * 1e6

    fig, (ax_planet, ax_spec) = plt.subplots(1, 2, figsize=(10, 4))
    plt.tight_layout()

    circle = plt.Circle((0, 0), 1.0, color='#1f77b4', alpha=0.8)
    halo = plt.Circle((0, 0), R_eff_rel[0], color='#1f77b4', alpha=0.3)
    ax_planet.add_patch(circle)
    ax_planet.add_patch(halo)
    ax_planet.set_aspect('equal', 'box')
    ax_planet.set_xlim(-1.2, 1.2)
    ax_planet.set_ylim(-1.2, 1.2)
    ax_planet.set_xticks([])
    ax_planet.set_yticks([])
    ax_planet.set_title('Planet radius vs λ')

    ax_spec.plot(lam, depth_ppm, color='#d62728')
    ax_spec.set_xlabel('Wavelength (nm)')
    ax_spec.set_ylabel('Transit depth (ppm)')
    ax_spec.set_title('O$_2$ A-band profile')
    line = ax_spec.axvline(lam[0], color='k', linestyle='--')
    text = ax_spec.text(0.02, 0.9, '', transform=ax_spec.transAxes)

    def update(frame):
        halo.set_radius(R_eff_rel[frame])
        line.set_xdata([lam[frame], lam[frame]])
        text.set_text(
            f"λ={lam[frame]:.1f} nm\nR_eff={R_eff_rel[frame]*R_p/1e5:.0f} km\nDepth={depth_ppm[frame]:.1f} ppm"
        )
        return halo, line, text

    frames = np.linspace(0, len(lam) - 1, 120, dtype=int)
    anim = animation.FuncAnimation(fig, update, frames=frames, blit=False)
    anim.save(output_path, writer='pillow', fps=fps)
    plt.close(fig)


def animate_transit_geometry(atm, R_p, R_star, b, tau_grid, lam_nm, R_eff, output_path, fps=15):
    from mpl_toolkits.mplot3d import Axes3D  # noqa: F401

    scale = R_p
    planet_radius = 1.0
    atmos_radius = 1.0 + atm["z"][-1] / R_p
    R_eff_rel = R_eff / R_p

    # scale star radius for readability
    star_rel = min(R_star / R_p, 8.0)
    theta = np.linspace(0, np.pi, 40)
    phi = np.linspace(0, 2 * np.pi, 40)
    theta, phi = np.meshgrid(theta, phi)

    star_u = np.linspace(0, 1, 40)
    star_v = np.linspace(0, 2 * np.pi, 60)
    star_u, star_v = np.meshgrid(star_u, star_v)
    star_x = np.zeros_like(star_u)
    star_y = star_rel * star_u * np.cos(star_v)
    star_z = star_rel * star_u * np.sin(star_v)

    num_paths = 12
    selected_indices = np.linspace(0, len(b) - 1, num_paths, dtype=int)
    azimuths = np.linspace(0, 2 * np.pi, num_paths, endpoint=False)
    ray_coords = []

    for idx, phi0 in zip(selected_indices, azimuths):
        impact = (b[idx] - R_p) / R_p + planet_radius
        radius = impact
        max_extent = np.sqrt(max(atmos_radius**2 - radius**2, 0.0))
        x_vals = np.linspace(-max_extent, max_extent, 200)
        y_vals = radius * np.cos(phi0) * np.ones_like(x_vals)
        z_vals = radius * np.sin(phi0) * np.ones_like(x_vals)
        ray_coords.append((x_vals, y_vals, z_vals, idx))

    frames = np.linspace(0, len(lam_nm) - 1, 140, dtype=int)
    x_positions = np.linspace(star_rel + 2, -star_rel - 2, len(frames))

    fig = plt.figure(figsize=(7, 6))
    ax = fig.add_subplot(111, projection='3d')
    ax.plot_surface(star_x, star_y, star_z, color='#f7c74a', alpha=0.4, linewidth=0)
    ax.plot_surface(star_x, -star_y, -star_z, color='#f7c74a', alpha=0.4, linewidth=0)
    ax.set_xlim(-star_rel - 3, star_rel + 3)
    ax.set_ylim(-star_rel - 3, star_rel + 3)
    ax.set_zlim(-star_rel - 3, star_rel + 3)
    ax.set_xlabel('x (scaled)')
    ax.set_ylabel('y (scaled)')
    ax.set_zlabel('z (scaled)')
    ax.set_title('Transit geometry with slant paths')
    ax.set_box_aspect([1, 1, 1])

    planet_surface = [None]
    halo_surface = [None]
    lines = []
    for _ in selected_indices:
        line, = ax.plot([], [], [], linewidth=2)
        lines.append(line)

    text = ax.text2D(0.02, 0.92, '', transform=ax.transAxes)

    def draw_sphere(center_x, radius, color, alpha):
        x = center_x + radius * np.sin(theta) * np.cos(phi)
        y = radius * np.sin(theta) * np.sin(phi)
        z = radius * np.cos(theta)
        return ax.plot_surface(x, y, z, color=color, alpha=alpha, linewidth=0)

    def update(i):
        idx = frames[i]
        x0 = x_positions[i]
        if planet_surface[0] is not None:
            planet_surface[0].remove()
        if halo_surface[0] is not None:
            halo_surface[0].remove()

        planet_surface[0] = draw_sphere(x0, planet_radius, '#1f77b4', 0.9)
        halo_surface[0] = draw_sphere(x0, R_eff_rel[idx], '#1f77b4', 0.2)

        tau_line = tau_grid[idx]
        tau_norm = (tau_line[selected_indices] - tau_line.min()) / max(np.ptp(tau_line), 1e-30)
        for line, (x_vals, y_vals, z_vals, idx_ray), weight in zip(lines, ray_coords, tau_norm):
            line.set_data(x_vals + x0, y_vals)
            line.set_3d_properties(z_vals)
            color = plt.cm.inferno(0.3 + 0.7 * weight)
            line.set_color(color)
        text.set_text(f"λ={lam_nm[idx]:.1f} nm\nR_eff={R_eff_rel[idx]*R_p/1e5:.0f} km")
        return lines + [planet_surface[0], halo_surface[0], text]

    anim = animation.FuncAnimation(fig, update, frames=len(frames), blit=False)
    anim.save(output_path, writer='pillow', fps=fps)
    plt.close(fig)


def evaluate_vmr_threshold(factors, tau_grid, b, R_p, R_star, wavelength_nm, lam_smooth, core_idx, instrument, sigma_floor):
    results = []
    for f in factors:
        R_eff_f, depth_f = transmission_from_tau(tau_grid * f, b, R_p, R_star)
        lam_conv, depth_conv = convolve_to_resolution(wavelength_nm, depth_f, R=instrument.spectral_resolution)
        depth_interp = np.interp(lam_smooth, lam_conv, depth_conv)
        signal_interp = feature_signal(depth_interp)
        photon_rate_f, noise_norm_f, snr_tables_f, _ = snr_time_requirements(lam_smooth, signal_interp, instrument, (5,))
        hours = snr_tables_f[5][core_idx]
        snr100 = snr_for_time(100.0, signal_interp[core_idx], photon_rate_f[core_idx], noise_norm_f[core_idx], sigma_floor)
        results.append((f, depth_interp, signal_interp, photon_rate_f, noise_norm_f, hours, snr100))

    fractions = np.array([r[0] for r in results])
    snr_values = np.array([r[-1] for r in results])
    plt.figure(figsize=(6, 4))
    plt.plot(fractions * 21, snr_values, marker='o', color='#8a2be2')
    plt.axhline(5, color='gray', linestyle='--', label='SNR=5 target')
    plt.xlabel('O₂ volume mixing ratio (%)')
    plt.ylabel('SNR after 100 h')
    plt.title('Detectability of O₂ vs abundance')
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.savefig(VMR_CURVE_PATH, dpi=200)
    threshold = None
    below = snr_values < 5
    if np.any(below):
        idx = np.where(below)[0][0]
        threshold = fractions[idx] * 21
    vmr_rows = [
        (round(f * 21, 2), snr, hours)
        for f, _, _, _, _, hours, snr in results
    ]
    return vmr_rows, threshold


def instrument_comparison(lam_nm, signal_depth, core_idx, configs):
    rows = []
    for cfg in configs:
        photon_rate, noise_norm, snr_tables, _ = snr_time_requirements(lam_nm, signal_depth, cfg, (5,))
        hours = snr_tables[5][core_idx]
        sigma_floor, _ = effective_noise_floor(cfg)
        snr100 = snr_for_time(100.0, signal_depth[core_idx], photon_rate[core_idx], noise_norm[core_idx], sigma_floor)
        rows.append((cfg.name, hours, snr100))

    plt.figure(figsize=(6, 4))
    names = [r[0] for r in rows]
    hours = [r[1] for r in rows]
    plt.bar(names, hours, color=['#1f77b4', '#ff7f0e', '#2ca02c'])
    plt.ylabel('Hours for SNR=5 (O₂ core)')
    plt.title('Instrument comparison at 760 nm')
    plt.grid(True, axis='y', alpha=0.3)
    plt.tight_layout()
    plt.savefig(INSTRUMENT_COMPARE_PATH, dpi=200)
    return rows


def build_dashboard(lam_nm_high, depth_high, lam_nm_low, depth_low, summary_rows, distances, hours_needed, snr100_list,
                    instrument, r_eff_rel_high, r_eff_rel_low, vmr_rows, vmr_threshold, instrument_rows,
                    hero_metrics):
    def fmt_val(val, fmt="{:.3f}"):
        return "∞" if val is None or not np.isfinite(val) else fmt.format(val)

    def fmt_snr(val):
        return "<0.1" if np.isfinite(val) and val < 0.1 else ("∞" if not np.isfinite(val) else f"{val:.1f}")

    summary_html = []
    for label, lam, depth_val, photon, snr3, snr5, snr10 in summary_rows:
        summary_html.append(
            f"<tr><td>{label}</td><td>{lam:.1f}</td><td>{depth_val:.1f}</td><td>{photon:.3e}</td>"
            f"<td>{fmt_val(snr3, '{:.2f}')}</td><td>{fmt_val(snr5, '{:.2f}')}</td><td>{fmt_val(snr10, '{:.2f}')}</td></tr>"
        )
    summary_html_rows = ''.join(summary_html)

    distance_rows_html = ''.join(
        f"<tr><td>{d:.1f}</td><td>{fmt_val(t)}</td><td>{fmt_snr(snr)}</td></tr>"
        for d, t, snr in zip(distances, hours_needed, snr100_list)
    )
    vmr_rows_html = ''.join(
        f"<tr><td>{vmr:.3f}</td><td>{fmt_snr(snr)}</td><td>{fmt_val(hours, '{:.2f}')}</td></tr>"
        for vmr, snr, hours in vmr_rows
    )
    instrument_rows_html = ''.join(
        f"<tr><td>{name}</td><td>{fmt_val(hours)}</td><td>{fmt_snr(snr)}</td></tr>"
        for name, hours, snr in instrument_rows
    )

    lam_high_json = json.dumps(lam_nm_high.tolist())
    depth_high_json = json.dumps((depth_high * 1e6).tolist())
    lam_low_json = json.dumps(lam_nm_low.tolist())
    depth_low_json = json.dumps((depth_low * 1e6).tolist())
    r_eff_high_json = json.dumps((np.array(r_eff_rel_high)).tolist())
    r_eff_low_json = json.dumps((np.array(r_eff_rel_low)).tolist())
    slider_high = max(len(lam_nm_high) - 1, 0)
    slider_low = max(len(lam_nm_low) - 1, 0)
    instrument_desc = (
        f"HWO profile: D={instrument.telescope_diameter_m:.1f} m, efficiency {instrument.optical_throughput*100:.0f}%, "
        f"R≈{instrument.spectral_resolution:.0f}, Sun-like host at {instrument.distance_pc:.0f} pc."
    )
    vmr_text = (
        "Detectability drops once O₂ falls below ≈{:.2f}% (100 h, SNR=5)".format(vmr_threshold)
        if vmr_threshold is not None else "Detectability maintained across explored O₂ abundances"
    )

    hero_cards_html = f"""
        <div class=\"hero-card\">
            <p class=\"label\">Line-core depth</p>
            <p class=\"value\">{hero_metrics['depth_ppm']:.1f} ppm</p>
            <span class=\"chip\">{hero_metrics['photon_rate']:.2e} photons/s</span>
        </div>
        <div class=\"hero-card\">
            <p class=\"label\">Hours for SNR=5</p>
            <p class=\"value\">{fmt_val(hero_metrics['hours_snr5'], '{:.2f}')}</p>
            <span class=\"chip\">noise floor {hero_metrics['noise_floor_ppm']:.1f} ppm</span>
        </div>
        <div class=\"hero-card\">
            <p class=\"label\">SNR after 100 h</p>
            <p class=\"value\">{fmt_snr(hero_metrics['snr100'])}</p>
            <span class=\"chip\">cloud-top {CLOUD_TOP_KM:.0f} km</span>
        </div>
    """

    template = Template("""
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8" />
    <title>O₂ A-band Science Dashboard</title>
    <style>
        :root {
            color-scheme: dark;
            --bg:#04070d;
            --panel:#0f1624;
            --card:#141d2b;
            --accent:#66fcf1;
            --text:#d6e4ff;
            --muted:#8ba1c9;
        }
        * { box-sizing:border-box; }
        body { font-family:'Inter', 'Helvetica Neue', Arial, sans-serif; background:var(--bg); color:var(--text); margin:0; }
        header { padding:24px 48px; background:linear-gradient(120deg,#0b1b2b,#122b40); position:sticky; top:0; z-index:10; box-shadow:0 8px 20px rgba(0,0,0,0.5); }
        h1 { margin:0; font-size:32px; letter-spacing:0.02em; }
        header p.tagline { margin:6px 0 0; color:var(--muted); font-size:15px; }
        nav { margin-top:12px; display:flex; gap:16px; flex-wrap:wrap; }
        nav a { color:var(--accent); text-decoration:none; font-size:14px; border:1px solid rgba(102,252,241,0.3); padding:6px 12px; border-radius:999px; }
        main { padding: 24px 48px 80px; }
        .hero { display:grid; grid-template-columns:repeat(auto-fit,minmax(220px,1fr)); gap:18px; margin:30px 0; }
        .hero-card { background:linear-gradient(150deg,#1e2a44,#111a2b); border-radius:14px; padding:18px; box-shadow:0 18px 35px rgba(0,0,0,0.45); }
        .hero-card .label { font-size:13px; color:var(--muted); margin:0 0 6px; text-transform:uppercase; letter-spacing:0.08em; }
        .hero-card .value { font-size:26px; margin:0 0 8px; color:#fff; }
        .chip { display:inline-block; font-size:12px; background:rgba(102,252,241,0.15); color:var(--accent); padding:4px 10px; border-radius:999px; }
        details { background:var(--panel); margin-bottom:28px; border-radius:16px; padding:0 24px 24px; border:1px solid rgba(255,255,255,0.05); box-shadow:0 25px 60px rgba(0,0,0,0.35); }
        details summary { cursor:pointer; font-size:20px; color:var(--accent); font-weight:600; padding:22px 0; }
        .grid { display:flex; flex-wrap:wrap; gap:20px; }
        .card { background:var(--card); padding:20px; border-radius:14px; flex:1 1 280px; border:1px solid rgba(255,255,255,0.04); }
        table { width:100%; border-collapse: collapse; margin-top:15px; font-size:14px; }
        th, td { padding:9px 6px; border-bottom:1px solid rgba(255,255,255,0.06); text-align:center; }
        th { background:rgba(255,255,255,0.04); }
        img { width:100%; border-radius:12px; aspect-ratio:16/9; object-fit:contain; background:#050608; border:1px solid rgba(255,255,255,0.08); }
        .gallery img { margin-bottom:15px; box-shadow:0 12px 30px rgba(0,0,0,0.4); }
        .hidden { display:none; }
        button.play { background:var(--accent); border:none; color:#03121f; padding:10px 20px; border-radius:10px; cursor:pointer; font-weight:600; margin-top:10px; transition:transform 0.15s ease; }
        button.play:hover { transform:translateY(-2px); }
        #spectrumCanvas { background:#050a13; border-radius:14px; width:100%; height:380px; display:block; border:1px solid rgba(255,255,255,0.08); }
        #spectrumControls { margin-top:15px; display:flex; align-items:center; gap:18px; flex-wrap:wrap; }
        #spectrumSlider { width:240px; }
        #spectrumInfo { font-size:15px; color:var(--muted); }
        #tooltip { position:absolute; pointer-events:none; background:rgba(8,13,24,0.85); padding:6px 10px; border-radius:6px; font-size:13px; border:1px solid rgba(255,255,255,0.1); }
        .toggle label { margin-right:16px; font-weight:500; }
        .toggle input { margin-right:6px; }
        footer { text-align:center; padding:22px; background:#0b1524; color:#7c8faf; }
    </style>
</head>
<body>
    <header>
        <h1>Detecting Oxygen in Exoplanet Atmospheres</h1>
        <p class="tagline">Project lead & developer: Antonika Shapovalova — Habitable Worlds Observatory scenario</p>
        <p>$INSTRUMENT_DESC</p>
        <p class="tagline">Theme: “Detecting Oxygen in Exoplanet Atmospheres Using the Habitable Worlds Observatory”.</p>
        <nav>
            <a href="#sec-spectrum">Spectrum</a>
            <a href="#sec-snr">SNR</a>
            <a href="#sec-geometry">Geometry</a>
            <a href="#sec-interactive">Interactive</a>
            <a href="#sec-vmr">Retrieval</a>
            <a href="#sec-instruments">Instruments</a>
        </nav>
    </header>
    <main>
        <section class="hero" id="sec-spectrum">
            <div class="hero-card" style="grid-column: span 2; min-width:260px;">
                <p class="label">Mission statement</p>
                <p class="value" style="font-size:19px; line-height:1.4;">Habitable Worlds Observatory forward model + retrieval benchmark for the O₂ A-band. Dashboard tracks photon budgets, systematics, and detectability thresholds for Earth analogs at 5–30 pc.</p>
                <span class="chip">Theme: Detecting oxygen with HWO</span>
            </div>
            $HERO_CARDS
        </section>
        <details open>
            <summary>1. Spectrum &amp; key diagnostics</summary>
            <div class="grid">
                <div class="card">
                    <h3>Static spectrum</h3>
                    <img src="$FIGURE" alt="spectrum" />
                </div>
                <div class="card">
                    <h3>Key wavelengths</h3>
                    <table>
                        <thead>
                            <tr><th>Feature</th><th>λ (nm)</th><th>Depth (ppm)</th><th>Photon rate</th><th>hrs@SNR=3</th><th>hrs@SNR=5</th><th>hrs@SNR=10</th></tr>
                        </thead>
                        <tbody>
                            $SUMMARY_ROWS
                        </tbody>
                    </table>
                </div>
            </div>
        </details>

        <details open id="sec-snr">
            <summary>2. SNR diagnostics</summary>
            <div class="grid">
                <div class="card"><img src="$SNR_WAVE" alt="SNR vs wavelength" /></div>
                <div class="card"><img src="$TIME_MAP" alt="Hours for SNR=5" /></div>
                <div class="card"><img src="$SNR_SCALING" alt="SNR scaling" /></div>
                <div class="card"><img src="$DISTANCE_CURVE" alt="Distance curve" /></div>
            </div>
            <table>
                <thead><tr><th>Distance (pc)</th><th>Hours for SNR=5</th><th>SNR (100 h)</th></tr></thead>
                <tbody>$DISTANCE_ROWS</tbody>
            </table>
        </details>

        <details open id="sec-geometry">
            <summary>3. Atmosphere &amp; transit geometry</summary>
            <div class="gallery">
                <img src="$SLANT_PATH" alt="slant paths 3D" />
                <button class="play" id="btnTransit">Launch 3D transit visualization</button>
                <div id="transitPlayer" class="hidden">
                    <img src="$TRANSIT_ANIM" alt="transit animation" />
                </div>
                <button class="play" id="btnABand">Play A-band expansion inline</button>
                <div id="abandPlayer" class="hidden">
                    <img src="$ABAND_ANIM" alt="A-band animation" />
                </div>
            </div>
        </details>

        <details open id="sec-interactive">
            <summary>4. Interactive spectrum</summary>
            <div class="toggle">
                <label><input type="radio" name="resMode" value="high" checked> High-res</label>
                <label><input type="radio" name="resMode" value="hwo"> HWO-res (R≈150)</label>
            </div>
            <canvas id="spectrumCanvas"></canvas>
            <div id="spectrumControls">
                <label for="spectrumSlider">Wavelength selector</label>
                <input type="range" id="spectrumSlider" min="0" max="$SLIDER_HIGH" value="0" />
                <div id="spectrumInfo"></div>
            </div>
        </details>

        <details open id="sec-vmr">
            <summary>5. Retrieval-style O₂ threshold</summary>
            <p>$VMR_TEXT</p>
            <img src="$VMR_CURVE" alt="vmr curve" />
            <table>
                <thead><tr><th>O₂ VMR (%)</th><th>SNR (100 h)</th><th>Hours for SNR=5</th></tr></thead>
                <tbody>$VMR_ROWS</tbody>
            </table>
        </details>

        <details open id="sec-instruments">
            <summary>6. Instrument comparison</summary>
            <img src="$INSTRUMENT_CURVE" alt="instrument comparison" />
            <table>
                <thead><tr><th>Instrument</th><th>Hours for SNR=5</th><th>SNR (100 h)</th></tr></thead>
                <tbody>$INSTRUMENT_ROWS</tbody>
            </table>
        </details>
    </main>
    <footer>
        <p>Built from line-by-line physics (py4cats) and HWO science requirements.</p>
        <p>Designed & engineered by Antonika Shapovalova.</p>
    </footer>

    <div id="tooltip" class="hidden"></div>
    <script>
        const datasets = {
            high: { lam: $LAM_HIGH, depth: $DEPTH_HIGH, rEff: $REFF_HIGH, sliderMax: $SLIDER_HIGH },
            hwo: { lam: $LAM_LOW, depth: $DEPTH_LOW, rEff: $REFF_LOW, sliderMax: $SLIDER_LOW }
        };
        let currentMode = 'high';
        const canvas = document.getElementById('spectrumCanvas');
        const ctx = canvas.getContext('2d');
        const slider = document.getElementById('spectrumSlider');
        const info = document.getElementById('spectrumInfo');
        const tooltip = document.getElementById('tooltip');
        let highlightIdx = 0;

        document.querySelectorAll('input[name="resMode"]').forEach(r => {
            r.addEventListener('change', () => {
                currentMode = r.value;
                highlightIdx = 0;
                slider.max = datasets[currentMode].sliderMax;
                slider.value = 0;
                drawSpectrum();
            });
        });

        function resize() {
            canvas.width = canvas.clientWidth || canvas.parentElement.clientWidth || 800;
            canvas.height = canvas.clientHeight || 380;
            drawSpectrum();
        }
        window.addEventListener('resize', resize);
        slider.addEventListener('input', () => {
            highlightIdx = parseInt(slider.value, 10);
            drawSpectrum();
        });

        function drawSpectrum() {
            const lam = datasets[currentMode].lam;
            const depth = datasets[currentMode].depth;
            const rEff = datasets[currentMode].rEff;
            ctx.clearRect(0,0,canvas.width, canvas.height);
            const minLam = Math.min(...lam); const maxLam = Math.max(...lam);
            const maxDepth = Math.max(...depth)*1.05;
            ctx.beginPath(); ctx.strokeStyle = '#66fcf1'; ctx.lineWidth = 2;
            lam.forEach((val, i) => {
                const x = ((val - minLam) / (maxLam - minLam)) * canvas.width;
                const y = canvas.height - (depth[i] / maxDepth) * canvas.height;
                if(i===0) ctx.moveTo(x,y); else ctx.lineTo(x,y);
            });
            ctx.stroke();
            const highlightX = ((lam[highlightIdx]-minLam)/(maxLam-minLam))*canvas.width;
            ctx.strokeStyle = '#ff7f0e';
            ctx.beginPath();
            ctx.moveTo(highlightX, 0);
            ctx.lineTo(highlightX, canvas.height);
            ctx.stroke();
            const infoText = 'λ=' + lam[highlightIdx].toFixed(1) + ' nm | depth=' + depth[highlightIdx].toFixed(1) + ' ppm | R_eff=' + (rEff[highlightIdx]*6371).toFixed(0) + ' km';
            info.textContent = infoText;
        }
        resize();

        canvas.addEventListener('mousemove', (event) => {
            const rect = canvas.getBoundingClientRect();
            const x = event.clientX - rect.left;
            const lam = datasets[currentMode].lam;
            const minLam = Math.min(...lam); const maxLam = Math.max(...lam);
            const lambdaHover = minLam + (x / canvas.width) * (maxLam - minLam);
            const idx = lam.reduce((best, val, i, arr) => Math.abs(val - lambdaHover) < Math.abs(arr[best] - lambdaHover) ? i : best, 0);
            highlightIdx = idx;
            slider.value = idx;
            drawSpectrum();
            tooltip.classList.remove('hidden');
            tooltip.style.left = (event.clientX + 15) + 'px';
            tooltip.style.top = (event.clientY + 15) + 'px';
            tooltip.textContent = 'λ=' + lam[idx].toFixed(1) + ' nm';
        });
        canvas.addEventListener('mouseleave', () => {
            tooltip.classList.add('hidden');
        });

        const btnTransit = document.getElementById('btnTransit');
        const transitPlayer = document.getElementById('transitPlayer');
        if(btnTransit) {
            btnTransit.addEventListener('click', () => {
                transitPlayer.classList.toggle('hidden');
            });
        }
        const btnABand = document.getElementById('btnABand');
        const abandPlayer = document.getElementById('abandPlayer');
        if(btnABand) {
            btnABand.addEventListener('click', () => {
                abandPlayer.classList.toggle('hidden');
            });
        }
    </script>
</body>
</html>
""")

    html = template.substitute(
        INSTRUMENT_DESC=instrument_desc,
        FIGURE=FIGURE_PATH.name,
        SNR_WAVE=SNR_WAVE_PATH.name,
        TIME_MAP=TIME_MAP_PATH.name,
        SNR_SCALING=SNR_SCALING_PATH.name,
        DISTANCE_CURVE=DISTANCE_CURVE_PATH.name,
        SUMMARY_ROWS=summary_html_rows,
        DISTANCE_ROWS=distance_rows_html,
        SLANT_PATH=SLANT_PATH_FIG.name,
        TRANSIT_ANIM=TRANSIT_ANIMATION.name,
        ABAND_ANIM=ABAND_ANIMATION.name,
        SLIDER_HIGH=slider_high,
        SLIDER_LOW=slider_low,
        LAM_HIGH=lam_high_json,
        DEPTH_HIGH=depth_high_json,
        LAM_LOW=lam_low_json,
        DEPTH_LOW=depth_low_json,
        REFF_HIGH=r_eff_high_json,
        REFF_LOW=r_eff_low_json,
        VMR_TEXT=vmr_text,
        VMR_ROWS=vmr_rows_html,
        VMR_CURVE=VMR_CURVE_PATH.name,
        INSTRUMENT_CURVE=INSTRUMENT_COMPARE_PATH.name,
        INSTRUMENT_ROWS=instrument_rows_html,
        HERO_CARDS=hero_cards_html
    )

    DASHBOARD_HTML.write_text(html)


# =========================
# 8. MAIN SCRIPT
# =========================

def main():
    hitran_file = "69336a6b.par"   # HITRAN extract downloaded separately

    print("▶ Building the exo-Earth atmosphere...")
    atm = build_exoearth_atmosphere(n_layers=80, z_top_km=120.0, cloud_top_km=CLOUD_TOP_KM)

    print("▶ Reading O₂ line list from:", hitran_file)
    line = load_o2_lines(hitran_file)

    nu_bounds = (1e7 / lambda_max_nm, 1e7 / lambda_min_nm)

    print("▶ Computing slant-path geometry...")
    b_geo, ds_geo = compute_slant_path_matrix(atm["z"], R_p)

    print("▶ Computing σ(ν, p, T)...")
    wn, xs_layers = compute_cross_sections(line, atm["p"], atm["T"], nGrids=3, x_limits=nu_bounds)

    # Convert to wavelength, nm
    wavelength_nm = 1e7 / wn

    print("▶ Building the transit spectrum...")
    R_eff, depth = compute_transmission_spectrum(
        wn,
        xs_layers,
        atm,
        R_p,
        R_star
    )

    tau_grid = compute_optical_depth(xs_layers, atm, ds_geo)

    sort_idx = np.argsort(wavelength_nm)
    wavelength_nm = wavelength_nm[sort_idx]
    depth = depth[sort_idx]
    R_eff = R_eff[sort_idx]
    tau_grid = tau_grid[sort_idx]

    line_idx = int(np.argmax(depth))
    tau_line = tau_grid[line_idx]
    visualize_atmosphere_slant_paths(atm, R_p, b_geo, ds_geo, tau_line, SLANT_PATH_FIG)

    print("▶ Convolving to HWO spectral resolution...")
    lam_smooth, depth_smooth = convolve_to_resolution(wavelength_nm, depth, R=150)
    signal_high = depth
    signal_low = depth_smooth

    # ===== Plotting =====

    plt.figure(figsize=(8, 4))
    plt.plot(wavelength_nm, depth * 1e6, label="High-res")
    plt.plot(lam_smooth, depth_smooth * 1e6, label="HWO-like (R≈150)", linewidth=2)
    plt.gca().invert_xaxis()
    plt.xlabel("Wavelength (nm)")
    plt.ylabel("Transit depth (ppm)")   # parts per million
    plt.title("O$_2$ A-band transmission spectrum (exo-Earth, HWO)")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.savefig(FIGURE_PATH, dpi=200)

    backend = matplotlib.get_backend().lower()
    if "agg" not in backend and "backend_inline" not in backend:
        plt.show()
    else:
        print(f"Spectrum saved to {FIGURE_PATH}")

    print("▶ Evaluating SNR and observation requirements...")
    instrument = InstrumentConfig()
    print(
        "Using HWO configuration: D={:.1f} m, η={:.0f}%, R≈{}, d={:.0f} pc, T★={} K".format(
            instrument.telescope_diameter_m,
            instrument.optical_throughput * 100.0,
            int(instrument.spectral_resolution),
            instrument.distance_pc,
            int(instrument.star_temperature_K)
        )
    )
    snr_targets = (3, 5, 10)
    photon_rate, noise_norm, snr_tables, sigma_floor = snr_time_requirements(
        lam_smooth,
        signal_low,
        instrument,
        snr_targets
    )
    sigma_floor_ppm = sigma_floor * 1e6

    core_idx, summary_rows = summarize_snr(
        lam_smooth,
        depth_smooth,
        signal_low,
        photon_rate,
        noise_norm,
        snr_tables,
        instrument,
        sigma_floor
    )

    distances_pc, hours_snr5, snr100_list = distance_sensitivity_analysis(
        lam_smooth,
        signal_low,
        core_idx,
        snr_target=5,
        base_config=instrument
    )

    print("▶ Animating the A-band planetary inflation...")
    R_eff_rel_high = R_eff / R_p
    R_eff_smooth = np.interp(lam_smooth, wavelength_nm, R_eff)
    R_eff_rel_low = R_eff_smooth / R_p
    animate_a_band(lam_smooth, depth_smooth, R_eff_smooth, R_p, ABAND_ANIMATION)

    print("▶ Rendering 3D transit and slant-path animation...")
    animate_transit_geometry(
        atm,
        R_p,
        R_star,
        b_geo,
        tau_grid,
        wavelength_nm,
        R_eff,
        TRANSIT_ANIMATION
    )

    print("▶ Running retrieval-style O₂ abundance sweep...")
    vmr_factors = np.geomspace(1.0, 0.001, 14)
    vmr_rows, vmr_threshold = evaluate_vmr_threshold(
        vmr_factors,
        tau_grid,
        b_geo,
        R_p,
        R_star,
        wavelength_nm,
        lam_smooth,
        core_idx,
        instrument,
        sigma_floor
    )

    print("▶ Comparing HWO vs JWST/NIRSpec vs Roman...")
    instrument_suite = [
        instrument,
        InstrumentConfig(
            name="JWST/NIRSpec",
            telescope_diameter_m=6.5,
            optical_throughput=0.3,
            spectral_resolution=270.0,
            exposure_time_s=1000.0,
            read_noise_e=6.0,
            dark_current_e_per_s=0.005,
            distance_pc=instrument.distance_pc,
            star_temperature_K=instrument.star_temperature_K,
            stellar_radius_cm=instrument.stellar_radius_cm,
            noise_floor_ppm=10.0
        ),
        InstrumentConfig(
            name="Roman",
            telescope_diameter_m=2.4,
            optical_throughput=0.25,
            spectral_resolution=120.0,
            exposure_time_s=1000.0,
            read_noise_e=5.0,
            dark_current_e_per_s=0.002,
            distance_pc=instrument.distance_pc,
            star_temperature_K=instrument.star_temperature_K,
            stellar_radius_cm=instrument.stellar_radius_cm,
            noise_floor_ppm=12.0
        )
    ]
    instrument_rows = instrument_comparison(lam_smooth, signal_low, core_idx, instrument_suite)

    hero_metrics = {
        "depth_ppm": summary_rows[0][2],
        "photon_rate": summary_rows[0][3],
        "hours_snr5": summary_rows[0][5],
        "snr100": snr_for_time(100.0, signal_low[core_idx], photon_rate[core_idx], noise_norm[core_idx], sigma_floor),
        "noise_floor_ppm": sigma_floor_ppm
    }

    print("▶ Building interactive dashboard...")
    build_dashboard(
        wavelength_nm,
        depth,
        lam_smooth,
        depth_smooth,
        summary_rows,
        distances_pc,
        hours_snr5,
        snr100_list,
        instrument,
        R_eff_rel_high,
        R_eff_rel_low,
        vmr_rows,
        vmr_threshold,
        instrument_rows,
        hero_metrics
    )

    print("Done ✅")


if __name__ == "__main__":
    main()
