
import numpy as np
import matplotlib.pyplot as plt
import csv
import os
import argparse
import concurrent.futures as cf
from dataclasses import dataclass

# ============================================================
# PBC / ALINEA / OPEN — Synthetic Freeway CTM
# + ROTATION SUITE:
#   1) temporal rotation (time-shift of demand + pulses + incident)
#   2) spatial rotation  (shift ramp cells + incident cell + heterogeneity + latent center)
#   3) module rotation   (offset the module partition)
#
# Outputs:
#   - CSV with per-run metrics
#   - printed summary (mean/std + tail risk)
#   - optional summary plots
# ============================================================

# ---------------------------
# Demand generator
# ---------------------------

def gaussian_pulse(t, mu, sigma, amp):
    return amp * np.exp(-0.5 * ((t - mu) / sigma) ** 2)

def _wrap(x, period):
    if period <= 0:
        return x
    return x % period

def demand_mainline(t, p, rng):
    """
    Mainline inflow:
    base + diurnal hump + scheduled pulses + multiplicative AR(1) (applied outside).
    Units: veh/h
    Rotation support:
      - p["time_shift_h"] shifts the effective time used for diurnal + pulses.
    """
    tshift = float(p.get("time_shift_h", 0.0))
    # diurnal is 24h-periodic
    t_day = _wrap(t + tshift, p["day_h"])
    diurnal = 0.5 * (1.0 + np.sin(2*np.pi*(t_day/p["day_h"] - 0.25)))
    q = p["q_base"] + p["q_amp"] * diurnal

    # scheduled pulses live on a T_h-periodic schedule for rotation convenience
    t_eff = _wrap(t + tshift, p["T_h"])
    for (mu, sigma, amp) in p["q_event_pulses"]:
        q += gaussian_pulse(t_eff, mu, sigma, amp)

    return max(0.0, q)

def demand_ramp(t, p, ramp_id):
    """
    Each ramp has its own profile: base + diurnal + (optional) pulse.
    Rotation support:
      - p["time_shift_h"] shifts the effective time used for diurnal + pulses.
    """
    tshift = float(p.get("time_shift_h", 0.0))
    t_day = _wrap(t + tshift, p["day_h"])
    diurnal = 0.5 * (1.0 + np.sin(2*np.pi*(t_day/p["day_h"] - 0.25)))

    prof = p["ramp_profiles"][ramp_id]
    base = prof["base"]
    amp  = prof["amp"]
    q = base + amp * diurnal

    t_eff = _wrap(t + tshift, p["T_h"])
    for (mu, sigma, a) in prof.get("pulses", []):
        q += gaussian_pulse(t_eff, mu, sigma, a)

    return max(0.0, q)

# ---------------------------
# Capacity modifiers
# ---------------------------

def incident_capacity_scale(n, t, inc):
    """
    Spatial gaussian capacity reduction during [t0,t1].
    """
    scale = np.ones(n)
    if inc["t0"] <= t <= inc["t1"]:
        idx = np.arange(n, dtype=float)
        dist = np.abs(idx - float(inc["cell"]))
        bump = np.exp(-(dist**2) / (2.0 * (float(inc["width"]) ** 2)))
        scale = 1.0 - float(inc["severity"]) * bump
        scale = np.clip(scale, inc.get("min_scale", 0.2), 1.0)
    return scale

def latent_debuff_vector(n, center, width, severity, min_scale):
    idx = np.arange(n, dtype=float)
    dist = np.abs(idx - float(center))
    bump = np.exp(-(dist**2) / (2.0 * (float(width) ** 2)))
    debuff = 1.0 - float(severity) * bump
    return np.clip(debuff, min_scale, 1.0)

# ---------------------------
# CTM core with merges + ramp queues
# ---------------------------

def ctm_step_with_merges(rho, q_in_main, ramps, params, cap_scale_vec):
    """
    rho: (n,) densities [veh/km]
    q_in_main: upstream sending to cell 0 [veh/h]
    ramps: list of dict per on-ramp:
        - cell: merge into this mainline cell
        - r_meter: metering rate (max inflow from queue) [veh/h]
        - q_queue: current ramp queue [veh]
        - q_dem: ramp arrival demand [veh/h]
    cap_scale_vec: (n,) capacity multipliers in (0,1]
    returns: rho_next, q_out, updated ramps, debug flows
    """
    p = params
    n = p["n"]
    dt = p["dt_h"]
    L  = p["L_km"]

    vf = p["vf_kmh"]
    w  = p["w_kmh"]
    rho_j = p["rho_j"]
    qmax_base = p["qmax_vph"]

    qmax = qmax_base * cap_scale_vec
    rho_c = qmax / max(vf, 1e-12)

    S = np.minimum(vf * rho, qmax)
    R = np.minimum(qmax, w * (rho_j - rho))

    f = np.zeros(n + 1)        # flows between cells
    ramp_in = np.zeros(n)      # actual ramp inflow to each cell

    ramps_by_cell = {}
    for rid, r in enumerate(ramps):
        ramps_by_cell.setdefault(r["cell"], []).append(rid)

    f[0] = min(q_in_main, R[0])

    for i in range(n - 1):
        cell_to = i + 1
        R_to = R[cell_to]

        ramp_supply = 0.0
        if cell_to in ramps_by_cell:
            for rid in ramps_by_cell[cell_to]:
                r = ramps[rid]
                r["q_queue"] += r["q_dem"] * dt
                max_out = r["r_meter"] * dt
                ramp_out = min(r["q_queue"], max_out)
                ramp_supply += ramp_out / dt if dt > 0 else 0.0

        main_demand = S[i]
        total_demand = main_demand + ramp_supply

        if total_demand <= R_to + 1e-9:
            f[i+1] = main_demand
            if cell_to in ramps_by_cell:
                for rid in ramps_by_cell[cell_to]:
                    r = ramps[rid]
                    max_out = r["r_meter"] * dt
                    ramp_out = min(r["q_queue"], max_out)
                    r["q_queue"] -= ramp_out
                    ramp_in[cell_to] += ramp_out / dt
        else:
            beta = p["merge_main_priority"]
            main_alloc = min(main_demand, beta * R_to)
            ramp_alloc_total = max(0.0, R_to - main_alloc)

            if ramp_supply < ramp_alloc_total and main_demand > main_alloc:
                extra = min(main_demand - main_alloc, ramp_alloc_total - ramp_supply)
                main_alloc += extra
                ramp_alloc_total -= extra

            f[i+1] = main_alloc

            if cell_to in ramps_by_cell and ramp_supply > 1e-9:
                for rid in ramps_by_cell[cell_to]:
                    r = ramps[rid]
                    max_out = r["r_meter"] * dt
                    ramp_out = min(r["q_queue"], max_out)
                    ramp_out_h = ramp_out / dt
                    share = ramp_out_h / ramp_supply
                    alloc_h = share * ramp_alloc_total
                    actual_h = min(alloc_h, ramp_out_h)
                    actual_veh = actual_h * dt
                    r["q_queue"] -= actual_veh
                    ramp_in[cell_to] += actual_h

    f[n] = S[n-1]

    rho_next = rho.copy()
    for i in range(n):
        inflow = f[i] + ramp_in[i]
        outflow = f[i+1]
        rho_next[i] = rho[i] + (dt / L) * (inflow - outflow)

    rho_next = np.clip(rho_next, 0.0, rho_j)
    q_out = f[n]

    debug = dict(f=f, ramp_in=ramp_in, S=S, R=R, rho_c=rho_c, qmax=qmax)
    return rho_next, q_out, ramps, debug

# ---------------------------
# Controllers helpers
# ---------------------------

def alinea_meter(r_prev, rho_meas, rho_target, K, rmin, rmax):
    return float(np.clip(r_prev + K * (rho_target - rho_meas), rmin, rmax))

def apply_slew(r_old, r_cmd, slew_vph_per_h, dt_h):
    max_delta = slew_vph_per_h * dt_h
    return float(r_old + np.clip(r_cmd - r_old, -max_delta, max_delta))

# ---------------------------
# PBC sensors + gates
# ---------------------------

def ema(x_s, x, a):
    return (1-a)*x_s + a*x

def rolling_slope(y):
    if len(y) < 3:
        return 0.0
    yy = np.asarray(y, float)
    x = np.arange(len(yy), dtype=float)
    x -= x.mean()
    yy -= yy.mean()
    denom = np.sum(x*x)
    return float(np.sum(x*yy)/denom) if denom > 1e-12 else 0.0

def module_sensors(rho, rho_c_ref, mods, L_km):
    M = len(mods)
    cong = np.zeros(M)
    grad = np.zeros(M)
    coh = np.zeros(M)

    dr = np.abs(np.diff(rho)) / max(L_km, 1e-12)
    grad_cell = np.zeros_like(rho)
    if len(dr) > 0:
        grad_cell[1:-1] = 0.5*(dr[:-1] + dr[1:])
        grad_cell[0] = dr[0]
        grad_cell[-1] = dr[-1]

    for m, idx in enumerate(mods):
        rr = np.maximum(0.0, rho[idx] - rho_c_ref)
        cong[m] = 0.7*float(np.mean(rr)) + 0.3*float(np.quantile(rr, 0.85))
        grad[m] = float(np.mean(grad_cell[idx]))
        coh[m]  = 1.0/(1.0 + float(np.std(rho[idx])))

    return cong, grad, coh

# ---------------------------
# Metrics
# ---------------------------

def C_global(rho, rho_c_ref):
    return float(np.mean(np.maximum(0.0, rho - rho_c_ref)))

def tau_relax_first(relax_gate, dt_h, sustain_steps, start_step=0):
    run = 0
    for k in range(start_step, len(relax_gate)):
        if relax_gate[k] >= 0.5:
            run += 1
            if run >= sustain_steps:
                return (k - sustain_steps + 1) * dt_h
        else:
            run = 0
    return np.nan

def compute_kpis(R, p):
    dt = p["dt_h"]
    n  = p["n"]
    L  = p["L_km"]
    road_km = n * L

    N_t = R["rho_mean"] * road_km
    TTT_veh_h = float(np.sum(N_t) * dt)
    N_out = float(np.sum(R["q_out"]) * dt)

    TT_mean_h = TTT_veh_h / max(N_out, 1e-9)
    TT_ff_h = (road_km / max(p["vf_kmh"], 1e-12))

    delay_total_min = max(0.0, (TTT_veh_h - N_out * TT_ff_h) * 60.0)
    delay_per_vehicle_min = delay_total_min / max(N_out, 1e-9)

    queue_int_veh_h = float(np.sum(R.get("queue", 0.0)) * dt)
    U2 = float(R.get("U2", 0.0))

    a, b, c = 1.0, 0.2, 50.0
    fuel_proxy = a * TTT_veh_h + b * queue_int_veh_h + c * U2

    return dict(
        TTT_veh_h=TTT_veh_h,
        N_out=N_out,
        TT_mean_h=TT_mean_h,
        TT_ff_h=TT_ff_h,
        delay_total_min=delay_total_min,
        delay_per_vehicle_min=delay_per_vehicle_min,
        queue_int_veh_h=queue_int_veh_h,
        U2=U2,
        fuel_proxy=fuel_proxy,
    )

# ---------------------------
# Simulation runner
# ---------------------------

def run(seed, mode, p):
    rng = np.random.default_rng(seed)
    n = p["n"]
    dt = p["dt_h"]
    T  = p["T_h"]
    steps = int(T/dt)

    rho = np.ones(n) * p["rho_init"]

    ramps = []
    for rid, cell in enumerate(p["ramp_cells"]):
        ramps.append(dict(
            cell=int(cell),
            r_meter=p["ramp_r_init"],
            q_queue=0.0,
            q_dem=0.0
        ))

    # modules partition with optional offset rotation
    M = p["M"]
    m_size = max(1, n//M)
    offset = int(p.get("module_offset", 0)) % n

    mods = []
    for m in range(M):
        start = (m*m_size + offset) % n
        # contiguous in circular sense; we unwrap by building indices explicitly
        idx = (np.arange(m_size) + start) % n
        mods.append(idx)

    rho_c_ref = p["qmax_vph"]/p["vf_kmh"]
    C_ema = 0.0
    hist = []

    s_lat = 0.0
    latent_active_until = -1.0
    discharge_events = 0

    noise = 0.0

    # logs
    ts = np.zeros(steps)
    rho_max = np.zeros(steps)
    rho_mean = np.zeros(steps)
    q_out = np.zeros(steps)
    q_in_main = np.zeros(steps)
    ramp_meter_mean = np.zeros(steps)
    ramp_queue_sum = np.zeros(steps)

    C_raw = np.zeros(steps)
    C_ema_log = np.zeros(steps)
    slope_per_h = np.zeros(steps)
    relax_gate = np.zeros(steps)
    open_frac = np.zeros(steps)

    latent_log = np.zeros(steps)
    events_log = np.zeros(steps)

    A_cong = 0.0
    U2 = 0.0
    throughput = 0.0

    r_prev = np.array([r["r_meter"] for r in ramps], float)

    # window schedule (kept for comparability; PBC alpha does not require binary windows)
    until = np.full(M, -1.0)
    cooldown = np.full(M, -1.0)

    for s in range(steps):
        t = s*dt
        ts[s] = t

        # AR(1) demand noise (log-mult)
        noise = p["noise_phi"]*noise + p["noise_sigma"]*rng.normal()
        mult = np.exp(noise)
        mult = np.clip(mult, 0.7, 1.4)

        q_main = demand_mainline(t, p, rng) * mult
        q_in_main[s] = q_main

        for rid in range(len(ramps)):
            ramps[rid]["q_dem"] = demand_ramp(t, p, rid) * mult

        # incident time already rotated by p["time_shift_h"] in rotation builder
        cap_exo = incident_capacity_scale(n, t, p["incident"])

        # congestion C(t)
        Cnow = C_global(rho, rho_c_ref)
        C_raw[s] = Cnow
        C_ema = ema(C_ema, Cnow, p["relax_alpha"])
        C_ema_log[s] = C_ema

        hist.append(C_ema)
        if len(hist) > p["relax_W"]:
            hist.pop(0)
        slope = rolling_slope(hist)  # per step
        slope_h = slope / max(dt, 1e-12)
        slope_per_h[s] = slope_h

        relax_ok = (slope <= p["relax_eps_slope"])
        relax_gate[s] = 1.0 if relax_ok else 0.0

        # PBC windows (kept as "event openings" / probes)
        if mode == "pbc":
            cong, grad, coh = module_sensors(rho, rho_c_ref, mods, p["L_km"])
            for m in range(M):
                sens = cong[m] + p["grad_w"]*grad[m]
                probe = ((s % p["probe_period_steps"]) == 0) and (sens > p["probe_min_sens"])
                can_open = (t >= until[m]) and (t >= cooldown[m])
                if can_open and (coh[m] > p["th_coh"]) and (sens > p["th_sens"]) and (relax_ok or probe):
                    until[m] = t + p["hold_h"]
                    cooldown[m] = t + p["refrac_h"]
            w_m = (t < until).astype(float)
            open_frac[s] = float(np.mean(w_m))
        else:
            w_m = np.zeros(M)
            open_frac[s] = 0.0

        # ---- Control: choose r_meter for each ramp
        for rid, r in enumerate(ramps):
            meas_cell = min(r["cell"], n-1)
            rho_meas = rho[meas_cell]

            if mode == "openloop":
                r_cmd = p["ramp_free_vph"]
            elif mode == "alinea":
                r_cmd = alinea_meter(r["r_meter"], rho_meas, p["rho_target"], p["alinea_K"], p["rmin"], p["rmax"])
            else:  # pbc (NON-BINARY alpha coupling)
                m = min(((meas_cell - offset) % n) // m_size, M-1)

                # Global receptivity from slope(C_ema): slope_h > 0 => low coupling
                s0 = p["alpha_slope_s0"]
                k  = p["alpha_slope_k"]
                alpha_global = 1.0 / (1.0 + np.exp(k * (slope_h - s0)))

                # Local coherence permission
                alpha_local = np.clip((coh[m] - p["th_coh"]) / (1.0 - p["th_coh"] + 1e-9), 0.0, 1.0)

                # Need (soft)
                sens = cong[m] + p["grad_w"] * grad[m]
                alpha_need = np.clip((sens - p["th_sens"]) / (p["th_sens"] + 1e-9), 0.0, 1.0)

                # Optional window modulation (kept weak): windows amplify coupling when open
                win_boost = 0.6 + 0.4 * w_m[m]   # in [0.6, 1.0]

                alpha = alpha_global * alpha_local * (0.3 + 0.7 * alpha_need) * win_boost
                alpha = float(np.clip(alpha, p["alpha_min"], 1.0))

                r_alinea = alinea_meter(r["r_meter"], rho_meas, p["rho_target"], p["alinea_K"], p["rmin"], p["rmax"])
                r_cmd = (1.0 - alpha) * r["r_meter"] + alpha * r_alinea

            # slew limit
            r_new = apply_slew(r["r_meter"], r_cmd, p["slew_vph_per_h"], dt)
            r["r_meter"] = float(np.clip(r_new, p["rmin"], p["rmax"]))

        # ---- Latent accumulation: penalize changing meters when "closed"
        if p["enable_latent"]:
            r_now = np.array([r["r_meter"] for r in ramps], float)
            dr = r_now - r_prev
            u_energy = float(np.sum(dr*dr))

            # closed if windows mostly closed (patron) else always closed for classical
            if mode == "pbc":
                closed = (open_frac[s] < p["open_threshold"])
            else:
                closed = True

            s_lat = p["latent_alpha"]*s_lat + (p["latent_gain"]*u_energy if closed else 0.0)

            if (s_lat > p["latent_thresh"]) and (t >= latent_active_until):
                latent_active_until = t + p["latent_cap_duration_h"]
                discharge_events += 1
                events_log[s] = 1.0
                s_lat *= p["latent_release"]
            else:
                events_log[s] = 0.0

            latent_log[s] = s_lat
            r_prev = r_now.copy()
            U2 += float(np.sum((dr / max(p["rmax"],1e-12))**2) * dt)
        else:
            latent_log[s] = 0.0
            events_log[s] = 0.0

        # ---- Total capacity scale
        cap_total = cap_exo.copy()
        cap_total *= p["cap_hetero"]

        if p["enable_latent"] and (t < latent_active_until):
            deb = latent_debuff_vector(n, p["latent_cap_cell"], p["latent_cap_width"],
                                       p["latent_cap_severity"], p["latent_cap_min_scale"])
            cap_total *= deb

        # ---- Plant step
        rho, qout, ramps, dbg = ctm_step_with_merges(rho, q_main, ramps, p, cap_total)
        q_out[s] = qout

        throughput += qout * dt
        A_cong += Cnow * dt
        ramp_meter_mean[s] = float(np.mean([r["r_meter"] for r in ramps])) / max(p["rmax"],1e-12)
        ramp_queue_sum[s] = float(np.sum([r["q_queue"] for r in ramps]))

        rho_mean[s] = float(np.mean(rho))
        rho_max[s] = float(np.max(rho))

    start_step = int(p["tau_relax_start_h"]/dt)
    tauR = tau_relax_first(relax_gate, dt, p["tau_relax_sustain_steps"], start_step=start_step)

    return dict(
        ts=ts,
        rho_max=rho_max, rho_mean=rho_mean,
        q_out=q_out, q_in=q_in_main,
        cum_out=np.cumsum(q_out)*dt,
        meter=ramp_meter_mean,
        queue=ramp_queue_sum,
        C_raw=C_raw, C_ema=C_ema_log,
        slope_per_h=slope_per_h,
        relax_gate=relax_gate,
        open_frac=open_frac,
        latent=latent_log,
        events=events_log,
        discharge_events=discharge_events,
        A_cong=A_cong,
        throughput=throughput,
        U2=U2,
        tau_relax=tauR
    )

# ---------------------------
# Rotations suite
# ---------------------------

@dataclass
class RotationSpec:
    time_shift_h: float
    space_shift_cells: int
    module_offset: int


def _auto_workers(user_workers: int = 0) -> int:
    """Pick a sensible default worker count for the current machine."""
    try:
        cpu = os.cpu_count() or 1
    except Exception:
        cpu = 1
    # Leave 1 core free by default to keep the machine responsive
    if user_workers and user_workers > 0:
        return max(1, int(user_workers))
    return max(1, cpu - 1)

def _run_rotation_job(args):
    """Run all controllers for a single (rotation, seed) job.

    Returns: list[dict] rows for CSV.
    """
    rot_i, rot, seed, p0, modes = args
    p = build_rotated_params(p0, rot)
    out = []
    for mode in modes:
        R = run(seed, mode, p)
        kpi = compute_kpis(R, p)
        out.append({
            "rot_id": rot_i,
            "seed": seed,
            "controller": mode,
            "mode": mode,
            "time_shift_h": rot.time_shift_h,
            "space_shift": rot.space_shift_cells,
            "module_offset": rot.module_offset,
            "throughput": R.get("throughput", np.nan),
            "A_cong": float(R.get("A_cong", np.nan)),
            "U2": float(R.get("U2", np.nan)),
            "discharges": float(R.get("discharges", np.nan)),
            "tau_relax_min": float(R.get("tau_relax_min", np.nan)),
            "delay_total_min": float(kpi.get("delay_total_min", np.nan)),
            "delay_mean_min": float(kpi.get("delay_mean_min", np.nan)),
            "tt_mean_min": float(kpi.get("tt_mean_min", np.nan)),
            "fuel_proxy": float(kpi.get("fuel_proxy", np.nan)),
        })
    return out
def build_rotated_params(p_base, rot: RotationSpec):
    p = dict(p_base)  # shallow copy
    p["time_shift_h"] = float(rot.time_shift_h)

    n = p["n"]
    k = int(rot.space_shift_cells) % n
    if k != 0:
        p["ramp_cells"] = [int((c + k) % n) for c in p["ramp_cells"]]
        p["cap_hetero"] = np.roll(p["cap_hetero"], k)

        inc = dict(p["incident"])
        inc["cell"] = int((inc["cell"] + k) % n)
        p["incident"] = inc

        p["latent_cap_cell"] = int((p["latent_cap_cell"] + k) % n)

    p["module_offset"] = int(rot.module_offset) % n
    return p

def summarize(arr):
    x = np.array(arr, float)
    if len(x) == 0:
        return dict(mean=np.nan, std=np.nan, p50=np.nan, p90=np.nan, p99=np.nan)
    return dict(
        mean=float(np.mean(x)),
        std=float(np.std(x, ddof=1)) if len(x) > 1 else 0.0,
        p50=float(np.quantile(x, 0.5)),
        p90=float(np.quantile(x, 0.9)),
        p99=float(np.quantile(x, 0.99)),
    )

def main():

    # ---------------------------
    # CLI options (optional)
    # ---------------------------
    ap = argparse.ArgumentParser(add_help=True)
    ap.add_argument("--workers", type=int, default=0,
                    help="Number of parallel worker processes (0 = auto).")
    ap.add_argument("--seeds", type=int, default=3,
                    help="Number of random seeds to run (default: 3).")
    ap.add_argument("--no-parallel", action="store_true",
                    help="Disable multiprocessing (debug).")
    args = ap.parse_args()

    # ---------------------------
    # Base parameters (paper-ish)
    # ---------------------------
    p0 = {
        "n": 80,
        "L_km": 0.4,
        "dt_h": 5.0/3600.0,
        "T_h": 1.5,

        "vf_kmh": 100.0,
        "w_kmh": 18.0,
        "rho_j": 180.0,
        "qmax_vph": 6000.0,

        "rho_init": 20.0,
        "merge_main_priority": 0.75,

        "day_h": 24.0,

        "q_base": 2600.0,
        "q_amp": 2400.0,
        "q_event_pulses": [
            (0.18, 0.03, 700.0),
            (0.55, 0.05, 500.0),
        ],

        "noise_phi": 0.95,
        "noise_sigma": 0.05,

        "ramp_cells": [12, 28, 46, 63],
        "ramp_free_vph": 1800.0,
        "ramp_r_init": 900.0,
        "rmin": 0.0,
        "rmax": 1800.0,

        "ramp_profiles": [
            {"base": 1100.0, "amp": 1400.0, "pulses": [(0.35, 0.04, 400.0)]},
            {"base":  900.0, "amp": 1300.0, "pulses": []},
            {"base": 1000.0, "amp": 1200.0, "pulses": [(0.62, 0.03, 500.0)]},
            {"base":  900.0, "amp": 1400.0, "pulses": []},
        ],

        "rho_target": 35.0,
        "alinea_K": 140.0,

        # PBC modules
        "M": 8,
        "grad_w": 0.25,
        "th_sens": 3.0,
        "th_coh": 0.03,
        "hold_h": 120.0/3600.0,
        "refrac_h": 40.0/3600.0,
        "probe_period_steps": 60,
        "probe_min_sens": 2.0,
        "open_threshold": 0.2,

        # PBC alpha coupling (non-binary)
        "alpha_slope_s0": 2.0,
        "alpha_slope_k": 0.9,
        "alpha_min": 0.02,

        # relax gate
        "relax_alpha": 0.08,
        "relax_W": 60,
        "relax_eps_slope": 0.020,

        # slew
        "slew_vph_per_h": 20000.0,

        # incident
        "incident": {
            "t0": 0.33, "t1": 0.55,
            "cell": 40, "width": 4.0,
            "severity": 0.45,
            "min_scale": 0.25
        },

        # heterogeneity (filled below)
        "cap_hetero": None,

        # latent mechanism
        "enable_latent": True,
        "latent_alpha": 0.995,
        "latent_gain": 5.0,
        "latent_thresh": 2.0e4,
        "latent_release": 0.45,
        "latent_cap_duration_h": 420.0/3600.0,
        "latent_cap_cell": 40,
        "latent_cap_width": 3.5,
        "latent_cap_severity": 0.55,
        "latent_cap_min_scale": 0.35,

        # decongestion metrics
        "C_target": 0.0,
        "tau_relax_sustain_steps": 10,
        "tau_relax_start_h": 0.55,
    }

    # fixed heterogeneity
    rng = np.random.default_rng(123)
    cap_hetero = 1.0 + 0.03 * rng.normal(size=p0["n"])
    p0["cap_hetero"] = np.clip(cap_hetero, 0.92, 1.08)

    # ---------------------------
    # Rotation grid
    # ---------------------------
    # temporal shifts: 0..55 min in 5-min steps (wrap on T_h)
    time_shifts = [m/60.0 for m in range(0, 60, 5)]
    # spatial shifts: 0..75 cells in steps of 5
    space_shifts = list(range(0, p0["n"], 5))
    # module offsets: 0..(m_size-1) in steps of 2
    m_size = max(1, p0["n"]//p0["M"])
    module_offsets = list(range(0, m_size, 2))

    rotations = []
    for tsh in time_shifts:
        for k in space_shifts:
            for mo in module_offsets:
                rotations.append(RotationSpec(time_shift_h=tsh, space_shift_cells=k, module_offset=mo))

    print(f"Rotations: time={len(time_shifts)} * space={len(space_shifts)} * module={len(module_offsets)} "
          f"= {len(rotations)} runs per controller per seed.")

    seeds = list(range(int(args.seeds)))  # number of random seeds
    modes = ["openloop", "alinea", "pbc"]

    # Display names for controllers (paper-facing)
    mode_label = {
        "openloop": "Open-loop",
        "alinea": "ALINEA",
        "pbc": "PBC",
    }

    rows = []

    # ------------------------------------------------------------
    # Parallel execution: each job runs all controllers for (rot,seed)
    # ------------------------------------------------------------
    jobs = [(rot_i, rot, seed, p0, modes) for rot_i, rot in enumerate(rotations) for seed in seeds]
    total_jobs = len(jobs)

    use_parallel = (not args.no_parallel) and (total_jobs > 1)
    workers = _auto_workers(args.workers)

    if use_parallel:
        print(f"Using multiprocessing with {workers} worker(s) on {total_jobs} jobs "
              f"(each job runs {len(modes)} controller(s)).")
        done = 0
        with cf.ProcessPoolExecutor(max_workers=workers) as ex:
            for out_rows in ex.map(_run_rotation_job, jobs, chunksize=1):
                rows.extend(out_rows)
                done += 1
                if done % 25 == 0 or done == total_jobs:
                    print(f"  completed {done}/{total_jobs} jobs...")
    else:
        # Sequential fallback (debug / small runs)
        done = 0
        for jb in jobs:
            rows.extend(_run_rotation_job(jb))
            done += 1
            if done % 25 == 0 or done == total_jobs:
                print(f"  completed {done}/{total_jobs} jobs...")

    # ---------------------------
    # Save CSV
    # ---------------------------
    csv_name = "pbc_rotations_results.csv"
    with open(csv_name, "w", newline="", encoding="utf-8") as f:
        w = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
        w.writeheader()
        w.writerows(rows)
    print(f"\nSaved: {csv_name}")

    # Paper figures (from aggregated CSV)
    try:
        make_paper_figures(csv_name, outdir="paper/figures")
    except Exception as e:
        print("Plotting skipped:", e)


    # ---------------------------
    # Summaries by mode
    # ---------------------------
    print("\n=== Summary across ALL rotations/seeds ===")
    for mode in modes:
        sub = [r for r in rows if r["controller"] == mode]
        def col(key): return [r[key] for r in sub]

        s_delay = summarize(col("delay_total_min"))
        s_A     = summarize(col("A_cong"))
        s_U2    = summarize(col("U2"))
        s_ev    = summarize(col("discharges"))

        bad_rate = 100.0 * np.mean([1.0 if r["discharges"] >= 1 else 0.0 for r in sub])

        print(f"\n{mode_label[mode]}")
        print(f"  Delay total (min): mean={s_delay['mean']:.1f} std={s_delay['std']:.1f} "
              f"P90={s_delay['p90']:.1f} P99={s_delay['p99']:.1f}")
        print(f"  A_cong (h*unit):   mean={s_A['mean']:.3f} std={s_A['std']:.3f} "
              f"P90={s_A['p90']:.3f} P99={s_A['p99']:.3f}")
        print(f"  U2:                mean={s_U2['mean']:.8f} std={s_U2['std']:.8f} "
              f"P90={s_U2['p90']:.8f} P99={s_U2['p99']:.8f}")
        print(f"  Discharges:        mean={s_ev['mean']:.3f} std={s_ev['std']:.3f} "
              f"P90={s_ev['p90']:.3f} P99={s_ev['p99']:.3f}")
        print(f"  Bad-event rate (>=1 discharge): {bad_rate:.1f}%")

    # ---------------------------


def make_paper_figures(csv_path: str, outdir: str = "paper/figures") -> None:
    """Create paper figures from the aggregated rotations CSV.

    Designed to be robust: it reads the CSV, infers controller names, and writes PNGs.
    """
    import os
    import csv as _csv
    import numpy as _np
    import matplotlib.pyplot as _plt

    os.makedirs(outdir, exist_ok=True)

    rows = []
    with open(csv_path, "r", newline="", encoding="utf-8") as f:
        reader = _csv.DictReader(f)
        for row in reader:
            rr = dict(row)
            for k, v in list(rr.items()):
                try:
                    rr[k] = float(v)
                except Exception:
                    rr[k] = v
            rows.append(rr)

    if not rows:
        print("No rows found in CSV; skipping plotting.")
        return

    desired = ["OPENLOOP", "ALINEA", "PBC"]
    controllers = [c for c in desired if any(r.get("controller") == c for r in rows)]
    for c in sorted(set(r.get("controller") for r in rows)):
        if c and c not in controllers:
            controllers.append(c)

    mode_label = {"OPENLOOP": "Open-loop", "ALINEA": "Alignment", "PBC": "PBC"}
    for c in controllers:
        mode_label.setdefault(c, str(c))

    # --- CDF of delay_total_min ---
    fig, ax = _plt.subplots(1, 1, figsize=(10, 5))
    for c in controllers:
        sub = sorted([r["delay_total_min"] for r in rows if r.get("controller") == c])
        if not sub:
            continue
        y = _np.linspace(0, 1, len(sub), endpoint=True)
        ax.plot(sub, y, label=mode_label[c])
    ax.set_xlabel("Total delay vs free-flow (minutes)")
    ax.set_ylabel("CDF")
    ax.set_title("CDF of total delay across rotations")
    ax.legend()
    _plt.tight_layout()
    _plt.savefig(os.path.join(outdir, "fig_delay_cdf_rotations.png"), dpi=300, bbox_inches="tight")
    _plt.close(fig)

    def _boxplot(metric_key: str, ylabel: str, fname: str) -> None:
        data = [[r[metric_key] for r in rows if r.get("controller") == c] for c in controllers]
        fig, ax = _plt.subplots(1, 1, figsize=(7, 4))
        ax.boxplot(data, tick_labels=[mode_label[c] for c in controllers], showfliers=True)
        ax.set_ylabel(ylabel)
        _plt.tight_layout()
        _plt.savefig(os.path.join(outdir, fname), dpi=300, bbox_inches="tight")
        _plt.close(fig)

    _boxplot("delay_total_min", "Total delay (min)", "fig_delay_boxplot.png")
    _boxplot("A_cong", "Integrated congestion (h·unit)", "fig_Acong_boxplot.png")
    _boxplot("U2", "Control energy (U2)", "fig_U2_boxplot.png")

    # --- Mean discharge events ---
    fig, ax = _plt.subplots(1, 1, figsize=(6, 4))
    means = []
    for c in controllers:
        vals = [r["discharges"] for r in rows if r.get("controller") == c]
        means.append(float(_np.mean(vals)) if vals else 0.0)
    ax.bar([mode_label[c] for c in controllers], means)
    ax.set_ylabel("Mean # extreme events")
    ax.set_title("Mean extreme-event rate across rotations")
    _plt.tight_layout()
    _plt.savefig(os.path.join(outdir, "fig_extreme_events_mean.png"), dpi=300, bbox_inches="tight")
    _plt.close(fig)

    # --- Bad-event probability ---
    fig, ax = _plt.subplots(1, 1, figsize=(6, 4))
    probs = []
    for c in controllers:
        vals = [r["discharges"] for r in rows if r.get("controller") == c]
        probs.append(float(_np.mean([1.0 if v >= 1 else 0.0 for v in vals])) if vals else 0.0)
    ax.bar([mode_label[c] for c in controllers], probs)
    ax.set_ylim(0, 1)
    ax.set_ylabel("P(discharge ≥ 1)")
    ax.set_title("Probability of bad events across rotations")
    _plt.tight_layout()
    _plt.savefig(os.path.join(outdir, "fig_bad_event_probability.png"), dpi=300, bbox_inches="tight")
    _plt.close(fig)

    print(f"Saved figures to {outdir}/")


if __name__ == "__main__":
    main()
