import numpy as np
import matplotlib.pyplot as plt
import os

# ============================================================
# Synthetic Freeway CTM with PBC / ALINEA / Open-loop
# ============================================================

# ---------------------------
# Utilities
# ---------------------------

def gaussian_pulse(t, mu, sigma, amp):
    return amp * np.exp(-0.5 * ((t - mu) / sigma) ** 2)

def ema(x_s, x, a):
    return (1 - a) * x_s + a * x

def rolling_slope(y):
    if len(y) < 3:
        return 0.0
    y = np.asarray(y)
    x = np.arange(len(y))
    x = x - x.mean()
    y = y - y.mean()
    denom = np.sum(x * x)
    return np.sum(x * y) / denom if denom > 1e-12 else 0.0

# ---------------------------
# Demand models
# ---------------------------

def demand_mainline(t, p):
    diurnal = 0.5 * (1 + np.sin(2 * np.pi * (t / p["day_h"] - 0.25)))
    q = p["q_base"] + p["q_amp"] * diurnal
    for (mu, sigma, amp) in p["q_event_pulses"]:
        q += gaussian_pulse(t, mu, sigma, amp)
    return max(0.0, q)

def demand_ramp(t, p, rid):
    prof = p["ramp_profiles"][rid]
    diurnal = 0.5 * (1 + np.sin(2 * np.pi * (t / p["day_h"] - 0.25)))
    q = prof["base"] + prof["amp"] * diurnal
    for (mu, sigma, amp) in prof.get("pulses", []):
        q += gaussian_pulse(t, mu, sigma, amp)
    return max(0.0, q)

# ---------------------------
# Capacity modifiers
# ---------------------------

def incident_capacity_scale(n, t, inc):
    scale = np.ones(n)
    if inc["t0"] <= t <= inc["t1"]:
        idx = np.arange(n)
        bump = np.exp(-((idx - inc["cell"]) ** 2) / (2 * inc["width"] ** 2))
        scale *= 1.0 - inc["severity"] * bump
        scale = np.clip(scale, inc["min_scale"], 1.0)
    return scale

def latent_debuff_vector(n, center, width, severity, min_scale):
    idx = np.arange(n)
    bump = np.exp(-((idx - center) ** 2) / (2 * width ** 2))
    deb = 1.0 - severity * bump
    return np.clip(deb, min_scale, 1.0)

# ---------------------------
# CTM core
# ---------------------------

def ctm_step(rho, q_in, p, cap):
    n = p["n"]
    dt = p["dt_h"]
    L = p["L_km"]

    vf = p["vf_kmh"]
    w = p["w_kmh"]
    rho_j = p["rho_j"]
    qmax = p["qmax_vph"] * cap
    rho_c = qmax / vf

    S = np.minimum(vf * rho, qmax)
    R = np.minimum(qmax, w * (rho_j - rho))

    f = np.zeros(n + 1)
    f[0] = min(q_in, R[0])
    for i in range(n - 1):
        f[i + 1] = min(S[i], R[i + 1])
    f[n] = S[n - 1]

    rho_next = rho.copy()
    for i in range(n):
        rho_next[i] += (dt / L) * (f[i] - f[i + 1])
    rho_next = np.clip(rho_next, 0, rho_j)

    return rho_next, f[n]

# ---------------------------
# Main simulation
# ---------------------------

def run(seed, mode, p):
    rng = np.random.default_rng(seed)
    n = p["n"]
    dt = p["dt_h"]
    steps = int(p["T_h"] / dt)

    rho = np.ones(n) * p["rho_init"]
    rho_c_ref = p["qmax_vph"] / p["vf_kmh"]

    ts = np.zeros(steps)
    rho_max = np.zeros(steps)
    rho_mean = np.zeros(steps)
    q_out = np.zeros(steps)

    C_raw = np.zeros(steps)
    C_ema = np.zeros(steps)
    relax_gate = np.zeros(steps)
    slope_h = np.zeros(steps)

    latent = np.zeros(steps)
    events = np.zeros(steps)
    cap_min_exo = np.zeros(steps)
    cap_min_total = np.zeros(steps)

    C_ema_s = 0.0
    hist = []
    s_lat = 0.0
    latent_until = -1.0

    for k in range(steps):
        t = k * dt
        ts[k] = t

        q_in = demand_mainline(t, p)
        cap_exo = incident_capacity_scale(n, t, p["incident"])
        cap_min_exo[k] = cap_exo.min()

        C = np.mean(np.maximum(0.0, rho - rho_c_ref))
        C_raw[k] = C
        C_ema_s = ema(C_ema_s, C, p["relax_alpha"])
        C_ema[k] = C_ema_s

        hist.append(C_ema_s)
        if len(hist) > p["relax_W"]:
            hist.pop(0)

        slope = rolling_slope(hist) / dt
        slope_h[k] = slope
        relax_gate[k] = 1.0 if slope <= p["relax_eps_slope"] else 0.0

        cap = cap_exo * p["cap_hetero"]

        if p["enable_latent"]:
            s_lat = p["latent_alpha"] * s_lat + abs(slope)
            if s_lat > p["latent_thresh"] and t > latent_until:
                latent_until = t + p["latent_cap_duration_h"]
                events[k] = 1.0
                s_lat *= p["latent_release"]

            if t < latent_until:
                cap *= latent_debuff_vector(
                    n,
                    p["latent_cap_cell"],
                    p["latent_cap_width"],
                    p["latent_cap_severity"],
                    p["latent_cap_min_scale"],
                )

        latent[k] = s_lat
        cap_min_total[k] = cap.min()

        rho, qout = ctm_step(rho, q_in, p, cap)
        q_out[k] = qout

        rho_max[k] = rho.max()
        rho_mean[k] = rho.mean()

    return dict(
        ts=ts,
        rho_max=rho_max,
        rho_mean=rho_mean,
        q_out=q_out,
        C_raw=C_raw,
        C_ema=C_ema,
        relax_gate=relax_gate,
        slope_per_h=slope_h,
        latent=latent,
        events=events,
        cap_min_exo=cap_min_exo,
        cap_min_total=cap_min_total,
    )

# ---------------------------
# MAIN
# ---------------------------

def main():

    p = {
        "n": 80,
        "L_km": 0.4,
        "dt_h": 5 / 3600,
        "T_h": 1.5,
        "vf_kmh": 100.0,
        "w_kmh": 18.0,
        "rho_j": 180.0,
        "qmax_vph": 6000.0,
        "rho_init": 20.0,
        "day_h": 24.0,
        "q_base": 2600.0,
        "q_amp": 2400.0,
        "q_event_pulses": [(0.18, 0.03, 700), (0.55, 0.05, 500)],
        "relax_alpha": 0.08,
        "relax_W": 60,
        "relax_eps_slope": 0.02,
        "incident": {
            "t0": 0.33,
            "t1": 0.55,
            "cell": 40,
            "width": 4.0,
            "severity": 0.45,
            "min_scale": 0.25,
        },
        "cap_hetero": np.clip(1 + 0.03 * np.random.randn(80), 0.92, 1.08),
        "enable_latent": True,
        "latent_alpha": 0.995,
        "latent_thresh": 0.02,
        "latent_release": 0.5,
        "latent_cap_duration_h": 90 / 3600,
        "latent_cap_cell": 40,
        "latent_cap_width": 3.5,
        "latent_cap_severity": 0.55,
        "latent_cap_min_scale": 0.35,
    }

    ex = run(0, "patronpc", p)

    fig, ax = plt.subplots(3, 1, figsize=(12, 8), sharex=True)

    ax[0].plot(ex["ts"], ex["rho_max"])
    ax[0].set_ylabel("rho_max")

    ax[1].plot(ex["ts"], ex["C_raw"], label="C")
    ax[1].plot(ex["ts"], ex["C_ema"], label="C_ema")
    ax[1].legend()

    ax[2].plot(ex["ts"], ex["latent"], label="latent")
    ax[2].plot(ex["ts"], ex["events"], label="events")
    ax[2].plot(ex["ts"], ex["cap_min_exo"], label="cap_exo")
    ax[2].plot(ex["ts"], ex["cap_min_total"], label="cap_total")
    ax[2].legend()

    ax[2].set_xlabel("time (h)")

    outdir = os.path.join("paper", "figures")
    os.makedirs(outdir, exist_ok=True)
    outpath = os.path.join(outdir, "fig_realistic_ctm_timeseries.png")

    fig.tight_layout()
    fig.savefig(outpath, dpi=300, bbox_inches="tight")
    plt.close(fig)

    print(f"Saved: {outpath}")

if __name__ == "__main__":
    main()
