import numpy as np
import matplotlib.pyplot as plt

# ============================================================
# CTM: Cell Transmission Model (triangular fundamental diagram)
# ============================================================

def ctm_step(rho, r_main_in, r_ramp, params, cap_scale):
    """
    rho: densidad por celda (veh/km)
    r_main_in: flujo entrada upstream (veh/h)
    r_ramp: flujo de ramp metering por celda (veh/h) (solo donde haya ramp)
    cap_scale: multiplicador de capacidad por celda (incidentes)
    """
    n = params["n"]
    L = params["L_km"]
    dt_h = params["dt_h"]

    vf = params["vf_kmh"]
    w = params["w_kmh"]
    rho_j = params["rho_j"]
    qmax = params["qmax_vph"] * cap_scale  # capacidad efectiva

    # densidad crítica para triángulo: qmax = vf * rho_c = w*(rho_j-rho_c)
    rho_c = qmax / np.maximum(vf, 1e-12)

    # sending (demanda) y receiving (oferta)
    S = np.minimum(vf * rho, qmax)  # veh/h
    R = np.minimum(qmax, w * (rho_j - rho))  # veh/h

    # flujos entre celdas
    f = np.zeros(n + 1)  # f[i] = flujo de i->i+1, con f[0]=entrada, f[n]=salida
    f[0] = min(r_main_in, R[0])  # entrada a celda 0

    for i in range(n - 1):
        f[i + 1] = min(S[i], R[i + 1])

    f[n] = S[n - 1]  # salida libre (podrías limitarla, pero así es claro)

    # actualización densidad
    rho_next = rho.copy()
    for i in range(n):
        inflow = f[i] + r_ramp[i]
        outflow = f[i + 1]
        rho_next[i] = rho[i] + (dt_h / L) * (inflow - outflow)

    # recorte
    rho_next = np.clip(rho_next, 0.0, rho_j)
    throughput_out = f[n]  # veh/h
    return rho_next, f, throughput_out

# ============================================================
# Demandas y eventos (incidente)
# ============================================================

def demand_profile(t, base=3200.0, amp=1400.0, period=1.0):
    """
    t en horas.
    perfil suave: pico de demanda.
    """
    # onda suavizada (0..1)
    x = 0.5 * (1.0 + np.sin(2*np.pi*(t/period - 0.25)))
    return base + amp * x

def incident_capacity_scale(n, t, t0, t1, cell_center, width, severity):
    """
    devuelve cap_scale por celda: 1 salvo en incidente (reduce capacidad).
    severity=0.5 => reduce 50%
    """
    scale = np.ones(n)
    if t0 <= t <= t1:
        idx = np.arange(n)
        dist = np.abs(idx - cell_center)
        bump = np.exp(-(dist**2) / (2*(width**2)))
        scale = 1.0 - severity * bump
        scale = np.clip(scale, 0.2, 1.0)
    return scale

# ============================================================
# Control clásico: ALINEA (ramp metering)
# ============================================================

def alinea_control(ramp_rate, rho_meas, rho_target, K=80.0, rmin=0.0, rmax=1800.0):
    """
    ramp_rate: veh/h
    rho_meas: densidad medida en celda (veh/km)
    ALINEA: r <- r + K*(rho_target - rho_meas)
    """
    r_new = ramp_rate + K * (rho_target - rho_meas)
    return float(np.clip(r_new, rmin, rmax))

# ============================================================
# Patron-PC windows: decisión por módulos (sensores)
# ============================================================

def ema_update(x_s, x, alpha):
    return (1.0 - alpha) * x_s + alpha * x

def run_sim(seed=0, mode="baseline", params=None):
    """
    mode:
      - "baseline": ALINEA siempre activo (control clásico continuo)
      - "patronpc": ventanas por relajación + sensores por módulos (control episódico)
      - "openloop": sin control (referencia)
    """
    rng = np.random.default_rng(seed)
    p = params

    n = p["n"]
    dt_h = p["dt_h"]
    T_h = p["T_h"]
    steps = int(T_h / dt_h)

    # estado: densidad
    rho = np.zeros(n) + 25.0  # densidad inicial (veh/km)

    # ramps: ubicaciones (boolean)
    ramp_mask = np.zeros(n, dtype=bool)
    ramp_cells = p["ramp_cells"]
    ramp_mask[ramp_cells] = True

    # ramp demand "exógeno" (cuánto quiere entrar)
    ramp_demand = np.zeros(n)
    ramp_demand[ramp_mask] = p["ramp_demand_vph"]

    # ramp metering rate (salida del controlador)
    ramp_rate = np.zeros(n)
    ramp_rate[ramp_mask] = p["ramp_init_vph"]

    # Patron-PC: módulos
    M = p["M"]
    m_size = n // M
    mods = [np.arange(i*m_size, (i+1)*m_size) for i in range(M-1)]
    mods.append(np.arange((M-1)*m_size, n))

    until = np.full(M, -1.0)
    TTT = 0.0  # total time spent (veh*h)
    U2 = 0.0   # energía control ∫u^2 dt (u=ramp_rate variación)
    throughput = 0.0

    # relajación global: usamos "tensión" = mean(max(0, rho - rho_c_ref))
    tension_s = 0.0

    # logs
    ts = np.zeros(steps)
    rho_mean = np.zeros(steps)
    rho_max = np.zeros(steps)
    outflow = np.zeros(steps)
    ctrl_abs = np.zeros(steps)
    open_frac = np.zeros(steps)
    tension = np.zeros(steps)

    # para medir intervención: cambios de control
    ramp_prev = ramp_rate.copy()

    for s in range(steps):
        t = s * dt_h
        ts[s] = t

        # demanda upstream (mundo)
        main_in = demand_profile(t, base=p["main_base"], amp=p["main_amp"], period=p["main_period"])
        # ruido leve
        main_in = max(0.0, main_in * (1.0 + 0.02 * rng.normal()))

        # incidente: reduce capacidad temporalmente
        cap_scale = incident_capacity_scale(
            n, t,
            p["inc_t0"], p["inc_t1"],
            p["inc_cell"], p["inc_width"], p["inc_severity"]
        )

        # referencia de densidad crítica (aprox) para "tensión"
        rho_c_ref = (p["qmax_vph"] / p["vf_kmh"])

        tens_now = float(np.mean(np.maximum(0.0, rho - rho_c_ref)))
        tension[s] = tens_now
        prev = tension_s
        tension_s = ema_update(tension_s, tens_now, p["relax_alpha"])
        relajando = (tension_s - prev) < 0.0

        # sensores por módulo: congestión local y tendencia
        # sensor = mean(max(0,rho-rho_c_ref)) + cuantil alto
        m_sensor = np.zeros(M)
        m_coh = np.zeros(M)  # "coherencia" aquí = inversa de variabilidad
        for m, idx in enumerate(mods):
            rr = np.maximum(0.0, rho[idx] - rho_c_ref)
            m_sensor[m] = 0.7 * float(np.mean(rr)) + 0.3 * float(np.quantile(rr, 0.85))
            m_coh[m] = 1.0 / (1.0 + float(np.std(rho[idx])))

        # decidir ventanas (Patron-PC): solo abre si hay congestión y el sistema relaja
        if mode == "patronpc":
            for m in range(M):
                if (t >= until[m]) and relajando and (m_sensor[m] > p["th_cong"]) and (m_coh[m] > p["th_coh"]):
                    until[m] = t + p["hold_local_h"]
            w_m = (t < until).astype(float)
            open_frac[s] = float(np.mean(w_m))
        else:
            w_m = np.zeros(M)
            open_frac[s] = 0.0

        # ============================================================
        # SALIDA (control): ramp metering
        # ============================================================

        # ramp inflow real = min(demand, metering)
        r_ramp = np.zeros(n)

        if mode == "openloop":
            # sin control: ramp entra libre hasta su demanda (capado)
            r_ramp[ramp_mask] = np.minimum(ramp_demand[ramp_mask], p["ramp_free_cap_vph"])
        else:
            # baseline: ALINEA siempre
            # PatronPC: ALINEA solo en módulos con ventana abierta (y el resto casi libre o mínimo)
            for i in np.where(ramp_mask)[0]:
                # celda de medición para ALINEA (típico: downstream de la rampa)
                meas_cell = min(i + 1, n - 1)
                rho_meas = rho[meas_cell]

                # decide si actúa
                act = True
                if mode == "patronpc":
                    # módulo de esta rampa
                    m = min(i // m_size, M - 1)
                    act = (w_m[m] > 0.5)

                if act:
                    ramp_rate[i] = alinea_control(
                        ramp_rate[i], rho_meas, p["rho_target"],
                        K=p["alinea_K"], rmin=p["rmin"], rmax=p["rmax"]
                    )
                else:
                    # seducción: no fuerces; deja casi libre pero con suavizado y mínimo
                    # (o podrías dejarlo en rmax; aquí lo dejamos relajarse hacia un "cruise" alto)
                    ramp_rate[i] = float(ramp_rate[i] + p["relax_rate"] * (p["r_cruise"] - ramp_rate[i]))

                ramp_rate[i] = float(np.clip(ramp_rate[i], p["rmin"], p["rmax"]))
                r_ramp[i] = min(ramp_demand[i], ramp_rate[i])

        # ============================================================
        # Planta CTM
        # ============================================================

        rho, flows, y_out = ctm_step(rho, main_in, r_ramp, p, cap_scale)

        # métricas acumuladas
        # TTT ~ sum rho_i * L * dt  (veh/km * km = veh)
        TTT += float(np.sum(rho) * p["L_km"] * dt_h)
        throughput += float(y_out * dt_h)

        # energía de control: cambios y magnitud (u^2)
        du = ramp_rate - ramp_prev
        U2 += float(np.sum((du[ramp_mask] / max(p["rmax"], 1e-12))**2) * dt_h)
        ramp_prev = ramp_rate.copy()

        # logs
        rho_mean[s] = float(np.mean(rho))
        rho_max[s] = float(np.max(rho))
        outflow[s] = float(y_out)
        ctrl_abs[s] = float(np.mean(ramp_rate[ramp_mask]) / max(p["rmax"], 1e-12))

    # intervención: % tiempo actuando (para PatronPC: % ventanas activas * ramp count aprox)
    if mode == "patronpc":
        intervention = float(np.mean(open_frac))
    elif mode == "baseline":
        intervention = 1.0
    else:
        intervention = 0.0

    result = {
        "TTT": TTT,
        "throughput": throughput,
        "U2": U2,
        "intervention": intervention,
        "ts": ts,
        "rho_mean": rho_mean,
        "rho_max": rho_max,
        "outflow": outflow,
        "ctrl_abs": ctrl_abs,
        "open_frac": open_frac,
        "tension": tension
    }
    return result

# ============================================================
# Experimento A/B con seeds + plots
# ============================================================

def main():
    params = dict(
        # CTM
        n=60,
        L_km=0.5,
        dt_h=5.0/3600.0,  # 5 segundos
        T_h=1.0,          # 1 hora

        vf_kmh=100.0,
        w_kmh=20.0,
        rho_j=180.0,
        qmax_vph=6000.0,

        # demanda principal
        main_base=3000.0,
        main_amp=2200.0,
        main_period=1.0,

        # ramps
        ramp_cells=[10, 25, 40, 52],
        ramp_demand_vph=1500.0,
        ramp_init_vph=900.0,
        ramp_free_cap_vph=1800.0,

        # ALINEA
        rho_target=35.0,
        alinea_K=90.0,
        rmin=0.0,
        rmax=1800.0,
        r_cruise=1600.0,
        relax_rate=0.03,   # relajación cuando no actúa (PatronPC)

        # incidente
        inc_t0=0.35,
        inc_t1=0.55,
        inc_cell=33,
        inc_width=3,
        inc_severity=0.55,

        # Patron-PC módulos / ventanas
        M=6,
        relax_alpha=0.05,
        th_cong=6.0,        # umbral congestión local (veh/km por encima de rho_c)
        th_coh=0.08,        # coherencia mínima (1/(1+std))
        hold_local_h=(40.0/3600.0)  # 40 s
    )

    seeds = range(5)

    # ejecutar
    R_open = [run_sim(seed=s, mode="openloop", params=params) for s in seeds]
    R_base = [run_sim(seed=s, mode="baseline", params=params) for s in seeds]
    R_ppc  = [run_sim(seed=s, mode="patronpc", params=params) for s in seeds]

    def agg(key, arr):
        v = np.array([a[key] for a in arr], dtype=float)
        return v.mean(), v.std(ddof=1) if len(v) > 1 else 0.0

    print("\n=== RESULTADOS (media ± std, seeds=%d) ===" % len(list(seeds)))
    for name, arr in [("OPENLOOP", R_open), ("BASELINE_ALINEA", R_base), ("PATRON_PC", R_ppc)]:
        ttt_m, ttt_s = agg("TTT", arr)
        thr_m, thr_s = agg("throughput", arr)
        u2_m, u2_s = agg("U2", arr)
        int_m, int_s = agg("intervention", arr)
        print(f"\n{name}")
        print(f"  TTT (↓ mejor):        {ttt_m:.2f} ± {ttt_s:.2f}")
        print(f"  Throughput (↑ mejor): {thr_m:.2f} ± {thr_s:.2f}")
        print(f"  Control U2 (↓):       {u2_m:.4f} ± {u2_s:.4f}")
        print(f"  Intervención (↓):     {int_m:.3f} ± {int_s:.3f}")

    # plots: usamos seed 0 como ejemplo visual
    ex_open = R_open[0]
    ex_base = R_base[0]
    ex_ppc  = R_ppc[0]
    ts = ex_ppc["ts"]

    fig, ax = plt.subplots(4, 1, figsize=(12, 10), sharex=True)
    fig.suptitle("Tráfico (CTM) — comparación: Open-loop vs ALINEA vs Patron-PC (ventanas)")

    ax[0].plot(ts, ex_open["rho_max"], label="rho_max open")
    ax[0].plot(ts, ex_base["rho_max"], label="rho_max ALINEA")
    ax[0].plot(ts, ex_ppc["rho_max"],  label="rho_max Patron-PC")
    ax[0].set_ylabel("rho_max (veh/km)")
    ax[0].legend(loc="upper right")

    ax[1].plot(ts, ex_open["outflow"], label="outflow open")
    ax[1].plot(ts, ex_base["outflow"], label="outflow ALINEA")
    ax[1].plot(ts, ex_ppc["outflow"],  label="outflow Patron-PC")
    ax[1].set_ylabel("outflow (veh/h)")
    ax[1].legend(loc="upper right")

    ax[2].plot(ts, ex_base["ctrl_abs"], label="control ALINEA (norm)")
    ax[2].plot(ts, ex_ppc["ctrl_abs"],  label="control Patron-PC (norm)")
    ax[2].plot(ts, ex_ppc["open_frac"], label="ventanas Patron-PC (open_frac)", alpha=0.85)
    ax[2].set_ylabel("control / ventanas")
    ax[2].set_ylim(-0.05, 1.05)
    ax[2].legend(loc="upper right")

    ax[3].plot(ts, ex_ppc["tension"], label="tensión (Patron-PC sensor global)")
    ax[3].set_ylabel("tensión")
    ax[3].set_xlabel("tiempo (h)")
    ax[3].legend(loc="upper right")

    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    main()
