import numpy as np
import matplotlib.pyplot as plt

# ============================================================
# CTM: Cell Transmission Model (triangular FD)
# ============================================================

def ctm_step(rho, r_main_in, r_ramp, params, cap_scale_vec=None):
    n = params["n"]
    L = params["L_km"]
    dt_h = params["dt_h"]

    vf = params["vf_kmh"]
    w = params["w_kmh"]
    rho_j = params["rho_j"]
    qmax = params["qmax_vph"]

    if cap_scale_vec is None:
        cap_scale_vec = np.ones(n, dtype=float)
    cap_scale_vec = np.asarray(cap_scale_vec, dtype=float)

    qmax_cell = qmax * cap_scale_vec
    rho_c_cell = qmax_cell / np.maximum(vf, 1e-12)

    S = np.minimum(vf * rho, qmax_cell)
    R = np.minimum(qmax_cell, w * (rho_j - rho))

    f = np.zeros(n + 1)
    f[0] = min(r_main_in, R[0])
    for i in range(n - 1):
        f[i + 1] = min(S[i], R[i + 1])
    f[n] = S[n - 1]  # salida libre

    rho_next = rho.copy()
    for i in range(n):
        inflow = f[i] + r_ramp[i]
        outflow = f[i + 1]
        rho_next[i] = rho[i] + (dt_h / L) * (inflow - outflow)

    rho_next = np.clip(rho_next, 0.0, rho_j)
    throughput_out = f[n]
    return rho_next, f, throughput_out, rho_c_cell


# ============================================================
# Demanda y incidente
# ============================================================

def demand_profile(t, base=3200.0, amp=1400.0, period=1.0):
    x = 0.5 * (1.0 + np.sin(2 * np.pi * (t / period - 0.25)))
    return base + amp * x

def incident_capacity_scale(n, t, t0, t1, cell_center, width, severity):
    scale = np.ones(n, dtype=float)
    if t0 <= t <= t1:
        idx = np.arange(n, dtype=float)
        dist = np.abs(idx - float(cell_center))
        bump = np.exp(-(dist**2) / (2.0 * (float(width) ** 2)))
        scale = 1.0 - float(severity) * bump
        scale = np.clip(scale, 0.2, 1.0)
    return scale


# ============================================================
# ALINEA
# ============================================================

def alinea_control(ramp_rate, rho_meas, rho_target, K=80.0, rmin=0.0, rmax=1800.0):
    r_new = ramp_rate + K * (rho_target - rho_meas)
    return float(np.clip(r_new, rmin, rmax))


# ============================================================
# Helpers
# ============================================================

def ema_update(x_s, x, alpha):
    return (1.0 - alpha) * x_s + alpha * x

def apply_slew(r_old, r_cmd, slew_vph_per_h, dt_h):
    max_delta = slew_vph_per_h * dt_h
    delta = np.clip(r_cmd - r_old, -max_delta, max_delta)
    return r_old + delta

def module_sensors(rho, rho_c_ref, mods, L_km):
    M = len(mods)
    cong = np.zeros(M)
    grad = np.zeros(M)
    coh = np.zeros(M)

    dr = np.abs(np.diff(rho)) / max(L_km, 1e-12)
    grad_cell = np.zeros_like(rho)
    if len(dr) > 0:
        grad_cell[1:-1] = 0.5 * (dr[0:-1] + dr[1:])
        grad_cell[0] = dr[0]
        grad_cell[-1] = dr[-1]

    for m, idx in enumerate(mods):
        rr = np.maximum(0.0, rho[idx] - rho_c_ref)
        cong[m] = 0.7 * float(np.mean(rr)) + 0.3 * float(np.quantile(rr, 0.85))
        grad[m] = float(np.mean(grad_cell[idx]))
        coh[m]  = 1.0 / (1.0 + float(np.std(rho[idx])))

    return cong, grad, coh

def relax_gate_from_history(hist, eps_slope):
    if len(hist) < 3:
        return False, 0.0
    y = np.array(hist, dtype=float)
    x = np.arange(len(y), dtype=float)
    x -= x.mean()
    y -= y.mean()
    denom = np.sum(x * x)
    slope = float(np.sum(x * y) / denom) if denom > 1e-12 else 0.0
    return (slope <= eps_slope), slope


# ============================================================
# Métricas de descongestión
# ============================================================

def compute_congestion_C(rho, rho_c_ref):
    return float(np.mean(np.maximum(0.0, rho - rho_c_ref)))

def find_tau_relax(relax_gate, dt_h, sustain_steps=20, start_step=0):
    g = np.asarray(relax_gate, dtype=float)
    n = len(g)
    s = max(0, int(start_step))
    run = 0
    for k in range(s, n):
        if g[k] >= 0.5:
            run += 1
            if run >= sustain_steps:
                idx0 = k - sustain_steps + 1
                return idx0 * dt_h
        else:
            run = 0
    return np.nan


# ============================================================
# Latente + descarga como PÉRDIDA TEMPORAL DE CAPACIDAD
# ============================================================

def latent_update(s_latent, u_energy, is_closed, alpha, closed_gain):
    return float(alpha * s_latent + (closed_gain * u_energy if is_closed else 0.0))

def latent_cap_debuff_vector(n, center, width, severity):
    """
    Devuelve multiplicador de capacidad por celda (n,):
    1 - severity * bump, clamp [min_cap, 1]
    """
    idx = np.arange(n, dtype=float)
    dist = np.abs(idx - float(center))
    bump = np.exp(-(dist**2) / (2.0 * (float(width) ** 2)))
    debuff = 1.0 - float(severity) * bump
    return debuff

def maybe_trigger_discharge_capacity(s_latent, t, state, p):
    """
    state: dict con info de evento latente activo
      - active_until
      - event_count
    Si s_latent cruza umbral y no hay evento activo -> activar debuff temporal.
    """
    if (s_latent <= p["latent_thresh"]) or (t < state["active_until"]):
        return s_latent, 0  # no new event

    # activar evento
    state["active_until"] = t + p["latent_cap_duration_h"]
    state["event_count"] += 1

    # descarga parcial de energía latente
    s_latent = float(s_latent * p["latent_release"])

    return s_latent, 1


# ============================================================
# Simulación
# ============================================================

def run_sim(seed=0, mode="alinea", p=None):
    rng = np.random.default_rng(seed)

    n = p["n"]
    dt_h = p["dt_h"]
    steps = int(p["T_h"] / dt_h)

    rho = np.zeros(n) + p["rho_init"]

    ramp_mask = np.zeros(n, dtype=bool)
    ramp_mask[p["ramp_cells"]] = True

    ramp_demand = np.zeros(n)
    ramp_demand[ramp_mask] = p["ramp_demand_vph"]

    ramp_rate = np.zeros(n)
    ramp_rate[ramp_mask] = p["ramp_init_vph"]
    ramp_prev = ramp_rate.copy()

    # módulos
    M = p["M"]
    m_size = max(1, n // M)
    mods = [np.arange(i * m_size, min((i + 1) * m_size, n)) for i in range(M)]
    covered = np.concatenate(mods)
    if len(np.unique(covered)) < n:
        missing = np.setdiff1d(np.arange(n), covered)
        mods[-1] = np.unique(np.concatenate([mods[-1], missing]))

    until = np.full(M, -1.0)
    cooldown = np.full(M, -1.0)

    # métricas agregadas
    TTT = 0.0
    throughput = 0.0
    U2 = 0.0
    Uabs = 0.0

    # tensión observable + gate
    rho_c_ref = (p["qmax_vph"] / p["vf_kmh"])
    tension_s = 0.0
    tension_hist = []

    # latente
    s_latent = 0.0
    discharge_events = 0
    discharge_state = {"active_until": -1.0, "event_count": 0}

    # logs
    ts = np.zeros(steps)
    rho_mean = np.zeros(steps)
    rho_max = np.zeros(steps)
    outflow = np.zeros(steps)
    ctrl_abs = np.zeros(steps)
    open_frac = np.zeros(steps)

    C_raw = np.zeros(steps)
    C_ema = np.zeros(steps)
    relax_gate = np.zeros(steps)
    slope_per_h = np.zeros(steps)
    decong_rate = np.zeros(steps)
    tau_clear = np.zeros(steps)

    latent_log = np.zeros(steps)
    events_log = np.zeros(steps)

    cap_min = np.zeros(steps)
    cap_min_total = np.zeros(steps)  # con debuff latente aplicado

    for s in range(steps):
        t = s * dt_h
        ts[s] = t

        # demanda
        main_in = demand_profile(t, base=p["main_base"], amp=p["main_amp"], period=p["main_period"])
        main_in = max(0.0, main_in * (1.0 + p["demand_noise"] * rng.normal()))

        # incidente exógeno
        cap_inc = incident_capacity_scale(
            n, t,
            p["inc_t0"], p["inc_t1"],
            p["inc_cell"], p["inc_width"], p["inc_severity"]
        )
        cap_min[s] = float(np.min(cap_inc))

        # congestión global
        C_now = compute_congestion_C(rho, rho_c_ref)
        C_raw[s] = C_now
        tension_s = ema_update(tension_s, C_now, p["relax_alpha"])
        C_ema[s] = tension_s

        tension_hist.append(tension_s)
        if len(tension_hist) > p["relax_W"]:
            tension_hist.pop(0)

        relax_ok, slope = relax_gate_from_history(tension_hist, p["relax_eps_slope"])
        relax_gate[s] = 1.0 if relax_ok else 0.0
        slope_per_h[s] = slope / max(dt_h, 1e-12)

        # tasa de descongestión estimada
        decong_rate[s] = max(0.0, -slope_per_h[s])

        eps = 1e-9
        tau_clear[s] = max(0.0, (C_now - p["C_target"]) / (decong_rate[s] + eps))

        # sensores por módulo
        cong, grad, coh = module_sensors(rho, rho_c_ref, mods, p["L_km"])

        # ventanas Patron-PC
        if mode == "patronpc":
            for m in range(M):
                sens = cong[m] + p["grad_w"] * grad[m]
                probe = ((s % p["probe_period_steps"]) == 0) and (sens > p["probe_min_sens"])
                can_open = (t >= until[m]) and (t >= cooldown[m])
                if can_open and (coh[m] > p["th_coh"]) and (sens > p["th_sens"]) and (relax_ok or probe):
                    until[m] = t + p["hold_local_h"]
                    cooldown[m] = t + p["refrac_h"]
            w_m = (t < until).astype(float)
            open_frac[s] = float(np.mean(w_m))
        else:
            w_m = np.zeros(M)
            open_frac[s] = 0.0

        # ramp inflow
        r_ramp = np.zeros(n)

        if mode == "openloop":
            r_ramp[ramp_mask] = np.minimum(ramp_demand[ramp_mask], p["ramp_free_cap_vph"])
        else:
            for i in np.where(ramp_mask)[0]:
                meas_cell = min(i + 1, n - 1)
                rho_meas = rho[meas_cell]

                act = True
                if mode == "patronpc":
                    m = min(i // m_size, M - 1)
                    act = (w_m[m] > 0.5)

                if act:
                    r_cmd = alinea_control(
                        ramp_rate[i], rho_meas, p["rho_target"],
                        K=p["alinea_K"], rmin=p["rmin"], rmax=p["rmax"]
                    )
                else:
                    r_cmd = float(ramp_rate[i] + p["relax_rate"] * (p["r_cruise"] - ramp_rate[i]))

                r_slewed = apply_slew(ramp_rate[i], r_cmd, p["slew_vph_per_h"], dt_h)
                ramp_rate[i] = float(np.clip(r_slewed, p["rmin"], p["rmax"]))
                r_ramp[i] = min(ramp_demand[i], ramp_rate[i])

        # ================
        # Latente: acumula cuando no hay receptividad (o cuando ignoras ventanas)
        # ================
        if p["enable_latent"]:
            du_vec = (ramp_rate - ramp_prev)
            u_energy = float(np.sum((du_vec[ramp_mask]) ** 2))

            if mode == "patronpc":
                is_closed = (open_frac[s] < p["open_is_open_threshold"])
            else:
                is_closed = True

            s_latent = latent_update(
                s_latent=s_latent,
                u_energy=u_energy,
                is_closed=is_closed,
                alpha=p["latent_alpha"],
                closed_gain=p["latent_closed_gain"]
            )

            # disparo de descarga -> activa debuff de capacidad temporal
            s_latent, ev = maybe_trigger_discharge_capacity(s_latent, t, discharge_state, p)
            discharge_events += ev
            events_log[s] = ev
            latent_log[s] = s_latent
        else:
            events_log[s] = 0.0
            latent_log[s] = 0.0

        # ================
        # Capacidad total = incidente exógeno * debuff latente (si está activo)
        # ================
        cap_total = cap_inc.copy()

        if p["enable_latent"] and (t < discharge_state["active_until"]):
            # aplicar debuff localizado
            debuff = latent_cap_debuff_vector(
                n=n,
                center=p["latent_cap_cell"],
                width=p["latent_cap_width"],
                severity=p["latent_cap_severity"]
            )
            debuff = np.clip(debuff, p["latent_cap_min_scale"], 1.0)
            cap_total *= debuff

        cap_min_total[s] = float(np.min(cap_total))

        # planta
        rho, flows, y_out, rho_c_cell = ctm_step(rho, main_in, r_ramp, p, cap_scale_vec=cap_total)

        # métricas
        TTT += float(np.sum(rho) * p["L_km"] * dt_h)
        throughput += float(y_out * dt_h)

        du = ramp_rate - ramp_prev
        U2 += float(np.sum((du[ramp_mask] / max(p["rmax"], 1e-12)) ** 2) * dt_h)
        Uabs += float(np.sum(np.abs(du[ramp_mask]) / max(p["rmax"], 1e-12)) * dt_h)

        ramp_prev = ramp_rate.copy()

        rho_mean[s] = float(np.mean(rho))
        rho_max[s] = float(np.max(rho))
        outflow[s] = float(y_out)
        ctrl_abs[s] = float(np.mean(ramp_rate[ramp_mask]) / max(p["rmax"], 1e-12))

    # Estimación 1: área de congestión
    A_cong = float(np.sum(C_raw) * dt_h)

    # Estimación 3: tau_relax
    start_step = int(p["tau_relax_start_h"] / dt_h)
    tau_relax = find_tau_relax(relax_gate, dt_h, sustain_steps=p["tau_relax_sustain_steps"], start_step=start_step)

    intervention = float(np.mean(open_frac)) if mode == "patronpc" else (1.0 if mode == "alinea" else 0.0)

    return dict(
        TTT=TTT,
        throughput=throughput,
        U2=U2,
        Uabs=Uabs,
        intervention=intervention,
        discharge_events=discharge_events,

        A_cong=A_cong,
        tau_relax=tau_relax,

        ts=ts,
        rho_mean=rho_mean,
        rho_max=rho_max,
        outflow=outflow,
        ctrl_abs=ctrl_abs,
        open_frac=open_frac,

        C_raw=C_raw,
        C_ema=C_ema,
        relax_gate=relax_gate,
        slope_per_h=slope_per_h,
        decong_rate=decong_rate,
        tau_clear=tau_clear,

        latent=latent_log,
        events=events_log,

        cap_min=cap_min,
        cap_min_total=cap_min_total
    )


# ============================================================
# Experimento
# ============================================================

def main():
    p = dict(
        # CTM
        n=60,
        L_km=0.5,
        dt_h=5.0 / 3600.0,
        T_h=1.0,

        vf_kmh=100.0,
        w_kmh=20.0,
        rho_j=180.0,
        qmax_vph=6000.0,

        rho_init=25.0,

        # demanda
        main_base=3000.0,
        main_amp=2200.0,
        main_period=1.0,
        demand_noise=0.02,

        # ramps
        ramp_cells=[10, 25, 40, 52],
        ramp_demand_vph=1500.0,
        ramp_init_vph=900.0,
        ramp_free_cap_vph=1800.0,

        # ALINEA
        rho_target=35.0,
        alinea_K=90.0,
        rmin=0.0,
        rmax=1800.0,
        r_cruise=1600.0,
        relax_rate=0.03,

        # incidente exógeno
        inc_t0=0.35,
        inc_t1=0.55,
        inc_cell=33,
        inc_width=3,
        inc_severity=0.55,

        # Patron-PC
        M=6,
        relax_alpha=0.05,
        hold_local_h=(40.0 / 3600.0),
        refrac_h=(60.0 / 3600.0),
        th_coh=0.08,

        grad_w=0.25,
        th_sens=6.0,

        relax_W=80,
        relax_eps_slope=0.002,

        probe_period_steps=120,
        probe_min_sens=4.0,

        slew_vph_per_h=4500.0,

        # ===========
        # Latente (Problema 1)
        # ===========
        enable_latent=True,
        latent_alpha=0.995,
        latent_closed_gain=1.0,
        latent_thresh=2.5e6,
        latent_release=0.4,
        open_is_open_threshold=0.2,

        # descarga -> debuff de capacidad temporal
        latent_cap_duration_h=(90.0 / 3600.0),  # 90 s
        latent_cap_cell=33,                     # centro del debuff (puedes poner inc_cell)
        latent_cap_width=3.0,
        latent_cap_severity=0.25,               # reducción máxima local (0.25 -> -25%)
        latent_cap_min_scale=0.5,               # clamp mínimo (no menos de 50% capacidad)

        # ===========
        # Descongestión (las 3)
        # ===========
        C_target=0.0,
        tau_relax_sustain_steps=20,
        tau_relax_start_h=0.0,
    )

    seeds = range(5)
    R_open = [run_sim(seed=s, mode="openloop", p=p) for s in seeds]
    R_aln  = [run_sim(seed=s, mode="alinea",   p=p) for s in seeds]
    R_ppc  = [run_sim(seed=s, mode="patronpc", p=p) for s in seeds]

    def agg(key, arr):
        v = np.array([a[key] for a in arr], dtype=float)
        return v.mean(), (v.std(ddof=1) if len(v) > 1 else 0.0)

    def pretty_h(x):
        return "nan" if np.isnan(x) else f"{x:.3f} h"

    print("\n=== RESULTADOS (media ± std, seeds=%d) ===" % len(list(seeds)))
    for name, arr in [("OPENLOOP", R_open), ("ALINEA", R_aln), ("PATRON_PC", R_ppc)]:
        ttt_m, ttt_s = agg("TTT", arr)
        thr_m, thr_s = agg("throughput", arr)
        u2_m, u2_s   = agg("U2", arr)
        ua_m, ua_s   = agg("Uabs", arr)
        int_m, int_s = agg("intervention", arr)
        ev_m, ev_s   = agg("discharge_events", arr)

        A_m, A_s     = agg("A_cong", arr)
        tauR_m, tauR_s = agg("tau_relax", arr)

        print(f"\n{name}")
        print(f"  TTT (↓):                {ttt_m:.2f} ± {ttt_s:.2f}")
        print(f"  Throughput (↑):         {thr_m:.2f} ± {thr_s:.2f}")
        print(f"  Control U2 (↓):         {u2_m:.4f} ± {u2_s:.4f}")
        print(f"  Control Uabs (↓):       {ua_m:.4f} ± {ua_s:.4f}")
        print(f"  Intervención (↓):       {int_m:.3f} ± {int_s:.3f}")
        print(f"  Descargas (↓):          {ev_m:.2f} ± {ev_s:.2f}")
        print(f"  A_cong=∫Cdt (↓):        {A_m:.3f} ± {A_s:.3f}")
        print(f"  tau_relax (↓):          {pretty_h(tauR_m)} ± {tauR_s:.3f} h")

    # plots (seed 0)
    ex_open = R_open[0]
    ex_aln  = R_aln[0]
    ex_ppc  = R_ppc[0]
    ts = ex_ppc["ts"]

    cum_open = np.cumsum(ex_open["outflow"]) * p["dt_h"]
    cum_aln  = np.cumsum(ex_aln["outflow"])  * p["dt_h"]
    cum_ppc  = np.cumsum(ex_ppc["outflow"])  * p["dt_h"]

    fig, ax = plt.subplots(9, 1, figsize=(12, 16), sharex=True)
    fig.suptitle(
        "CTM — Open-loop vs ALINEA vs Patron-PC\n"
        "(descarga latente como pérdida temporal de capacidad + métricas de descongestión)"
    )

    ax[0].plot(ts, ex_open["rho_max"], label="rho_max open")
    ax[0].plot(ts, ex_aln["rho_max"],  label="rho_max ALINEA")
    ax[0].plot(ts, ex_ppc["rho_max"],  label="rho_max Patron-PC")
    ax[0].set_ylabel("rho_max")
    ax[0].legend(loc="upper right")

    ax[1].plot(ts, ex_open["outflow"], label="outflow open")
    ax[1].plot(ts, ex_aln["outflow"],  label="outflow ALINEA")
    ax[1].plot(ts, ex_ppc["outflow"],  label="outflow Patron-PC")
    ax[1].set_ylabel("outflow")
    ax[1].legend(loc="upper right")

    ax[2].plot(ts, cum_open, label="cum thr open")
    ax[2].plot(ts, cum_aln,  label="cum thr ALINEA")
    ax[2].plot(ts, cum_ppc,  label="cum thr Patron-PC")
    ax[2].set_ylabel("cum thr")
    ax[2].legend(loc="upper left")

    ax[3].plot(ts, ex_aln["ctrl_abs"], label="control ALINEA (norm)")
    ax[3].plot(ts, ex_ppc["ctrl_abs"], label="control Patron-PC (norm)")
    ax[3].plot(ts, ex_ppc["open_frac"], label="ventanas Patron-PC", alpha=0.9)
    ax[3].set_ylabel("control / win")
    ax[3].set_ylim(-0.05, 1.05)
    ax[3].legend(loc="upper right")

    ax[4].plot(ts, ex_ppc["C_raw"], label="C(t) raw")
    ax[4].plot(ts, ex_ppc["C_ema"], label="C_ema(t)")
    ax[4].set_ylabel("C")
    ax[4].legend(loc="upper right")

    ax[5].plot(ts, ex_ppc["relax_gate"], label="relax_gate")
    ax[5].plot(ts, ex_ppc["slope_per_h"], label="slope(C_ema) per hour")
    ax[5].set_ylabel("gate/slope")
    ax[5].legend(loc="upper right")

    ax[6].plot(ts, ex_open["decong_rate"], label="D_hat open")
    ax[6].plot(ts, ex_aln["decong_rate"],  label="D_hat ALINEA")
    ax[6].plot(ts, ex_ppc["decong_rate"],  label="D_hat Patron-PC")
    ax[6].plot(ts, ex_ppc["tau_clear"],    label="tau_clear PPC (h)", alpha=0.8)
    ax[6].set_ylabel("D_hat / tau")
    ax[6].legend(loc="upper right")

    ax[7].plot(ts, ex_aln["latent"], label="latente ALINEA")
    ax[7].plot(ts, ex_ppc["latent"], label="latente Patron-PC")
    ax[7].plot(ts, ex_aln["events"], label="eventos ALINEA", alpha=0.8)
    ax[7].plot(ts, ex_ppc["events"], label="eventos Patron-PC", alpha=0.8)
    ax[7].set_ylabel("latente/events")
    ax[7].legend(loc="upper left")

    ax[8].plot(ts, ex_ppc["cap_min"], label="cap min (incidente exógeno)")
    ax[8].plot(ts, ex_ppc["cap_min_total"], label="cap min total (incidente * latente)")
    ax[8].set_ylabel("cap_min")
    ax[8].set_xlabel("tiempo (h)")
    ax[8].legend(loc="upper right")

    plt.tight_layout()
    plt.show()


if __name__ == "__main__":
    main()
