import numpy as np
import matplotlib.pyplot as plt

# ============================================================
# Utilidades
# ============================================================

def order_param(theta):
    return float(np.abs(np.mean(np.exp(1j * theta))))

def ring_adj(N, k):
    I = np.arange(N)
    D = np.abs(I[:, None] - I[None, :])
    D = np.minimum(D, N - D)
    A = ((D > 0) & (D <= k)).astype(float)
    np.fill_diagonal(A, 0.0)
    return A

def local_order(theta, A):
    z = np.exp(1j * theta)
    deg = A.sum(1)
    W = np.where(deg[:, None] > 0, A / np.maximum(deg[:, None], 1e-12), 0.0)
    return np.abs(W @ z)

def wrap_pi(x):
    return (x + np.pi) % (2*np.pi) - np.pi

# ============================================================
# Input eléctrico base: P_raw(t)
# ============================================================

def power_injection_with_base(t, base_P, rng,
                              wave_scale=0.6, noise_scale=0.15,
                              event_prob=0.002, event_scale=2.0):
    N = len(base_P)

    phases = rng.uniform(0, 2*np.pi, N)
    P_wave = wave_scale * np.sin(2*np.pi * 0.03 * t + phases)

    P_noise = noise_scale * rng.normal(size=N)

    P_event = np.zeros(N)
    if rng.random() < event_prob:
        center = rng.integers(0, N)
        width = rng.integers(max(2, N//30), max(3, N//12))
        amp = event_scale * (1.0 if rng.random() < 0.5 else -1.0)
        idx = (np.arange(N) - center)
        idx = np.minimum(np.abs(idx), N - np.abs(idx))
        bump = np.exp(-(idx**2) / (2*(width**2)))
        P_event += amp * bump

    P = base_P + P_wave + P_noise + P_event
    P = P - np.mean(P)  # balance global
    return P

# ============================================================
# Patron-PC (lazo cerrado): las SALIDAS retroalimentan P_eff abajo
# ============================================================

def patronpc_grid_closedloop(
    N=240, M=12, T=80.0, dt=0.02,
    k=8,

    # red / dinámica
    K_line=1.6,
    noise_theta=0.01,

    # input
    base_scale=0.8, wave_scale=0.7, noise_scale=0.12,
    event_prob=0.003, event_scale=2.2,

    # campo global débil
    psi=0.0, psi_bias=0.02,

    # sensores / ventanas
    th_P=0.20, th_local=0.28,
    hold_local=1.0, relax_alpha=0.03,

    # amplificación + memoria
    K_boost=2.2,
    mem_eta=0.0012, mem_forget=0.0006,
    g_clip=(0.0, 3.0),

    # ===== Actuador (salida que va abajo) =====
    droop_gain=0.90,      # cuánto responde por módulo al desbalance (más alto = más corrección)
    ctrl_limit=1.2,       # límite de actuación (capacidad)
    ctrl_tau=1.0,         # inercia del actuador (seg) -> evita cambios bruscos
    ctrl_leak=0.02,       # coste/fuga (evita que el actuador se quede “sosteniendo” para siempre)

    # “elegancia”: actuar más si módulo está incoherente
    act_on_incoh=True,

    seed=2
):
    rng = np.random.default_rng(seed)

    A = ring_adj(N, k)
    G = A.copy().astype(float)

    m_size = N // M
    mods = [np.arange(i*m_size, (i+1)*m_size) for i in range(M-1)]
    mods.append(np.arange((M-1)*m_size, N))

    theta = rng.uniform(-np.pi, np.pi, N)

    base_P = base_scale * rng.normal(size=N)
    base_P -= np.mean(base_P)

    until = np.full(M, -1.0, dtype=float)

    F_s = 1.0
    steps = int(T / dt)

    # estado del actuador abajo (salida aplicada): u_i(t)
    u = np.zeros(N)

    # logs
    ts = np.zeros(steps)
    rG = np.zeros(steps)
    rL = np.zeros(steps)
    open_frac = np.zeros(steps)
    w_emerg = np.zeros(steps)
    fever = np.zeros(steps)

    P_abs_raw = np.zeros(steps)
    P_abs_eff = np.zeros(steps)
    u_abs = np.zeros(steps)
    freq_dev = np.zeros(steps)

    # ventanas por módulo
    wM = np.zeros((steps, M), dtype=float)

    for s in range(steps):
        t = s * dt
        ts[s] = t

        # ---------- input bruto (mundo)
        P_raw = power_injection_with_base(
            t, base_P, rng,
            wave_scale=wave_scale, noise_scale=noise_scale,
            event_prob=event_prob, event_scale=event_scale
        )
        P_abs_raw[s] = float(np.mean(np.abs(P_raw)))

        # ---------- coherencias
        rg_now = order_param(theta)
        rG[s] = rg_now

        rloc = local_order(theta, A)
        pre_local = float(0.75 * rloc.mean() + 0.25 * np.quantile(rloc, 0.80))
        rL[s] = pre_local

        # ---------- relajación global
        F = 1.0 - rg_now
        F_prev = F_s
        F_s = (1.0 - relax_alpha) * F_s + relax_alpha * F
        dF = F_s - F_prev
        relajando = (dF < 0.0)

        # ---------- sensores por módulo (sobre P_eff, no sobre P_raw)
        # P_eff es el input de abajo después de la acción u
        P_eff = P_raw + u
        # mantenemos balance global
        P_eff = P_eff - np.mean(P_eff)

        P_abs_eff[s] = float(np.mean(np.abs(P_eff)))
        u_abs[s] = float(np.mean(np.abs(u)))

        z = np.exp(1j * theta)
        Pm = np.zeros(M)
        Pm_abs = np.zeros(M)
        m_coh = np.zeros(M)

        for m, idx in enumerate(mods):
            Pm[m] = float(np.mean(P_eff[idx]))
            Pm_abs[m] = float(np.mean(np.abs(P_eff[idx])))
            m_coh[m] = float(np.abs(np.mean(z[idx])))

        # ---------- abrir ventanas locales
        for m in range(M):
            if (t >= until[m]) and relajando and (m_coh[m] > th_local) and (Pm_abs[m] > th_P):
                until[m] = t + hold_local

        w_m = (t < until).astype(float)
        wM[s, :] = w_m
        open_frac[s] = float(np.mean(w_m))
        w_emerg[s] = 1.0 if open_frac[s] > 0.35 else 0.0

        fever[s] = (1.0 - rg_now) * (1.0 if open_frac[s] > 0 else 0.0)

        # ============================================================
        # SALIDA -> ABAJO (actuador con inercia)
        # u intenta cancelar P_eff en módulos abiertos (redispatch/almacenamiento/DR)
        # ============================================================

        u_target = np.zeros(N)
        for m, idx in enumerate(mods):
            if w_m[m] > 0.5:
                gain = droop_gain
                if act_on_incoh:
                    # si el módulo está incoherente, actuamos más (seducción eficiente)
                    gain = droop_gain * (0.5 + 0.5*(1.0 - m_coh[m]))

                # cancelar el desbalance signed del módulo
                u_target[idx] = -gain * Pm[m]

        # saturación (capacidad del actuador)
        u_target = np.clip(u_target, -ctrl_limit, ctrl_limit)

        # dinámica del actuador (primer orden): du/dt = (u_target - u)/tau - leak*u
        u += dt * ((u_target - u) / max(ctrl_tau, 1e-6) - ctrl_leak * u)

        # ============================================================
        # dinámica de la red (estado)
        # ============================================================

        diff = theta[None, :] - theta[:, None]
        S = A * G

        Boost = np.zeros_like(S)
        for m, idx in enumerate(mods):
            if w_m[m] > 0.5:
                Boost[np.ix_(idx, idx)] = 1.0

        K_eff = K_line + K_boost * Boost
        S_eff = S * K_eff

        denom = max(S_eff.sum(1).mean(), 1e-12)
        coupling = (S_eff * np.sin(-(diff))).sum(1) / denom

        bias = psi_bias * np.sin(psi - theta)

        # usamos P_eff (ya retroalimentado) en la dinámica
        dtheta = (P_eff + coupling + bias)
        freq_dev[s] = float(np.mean(dtheta))

        theta += dt * dtheta
        theta += np.sqrt(dt) * noise_theta * rng.normal(size=N)
        theta = wrap_pi(theta)

        # memoria estructural
        if (w_emerg[s] > 0.5) and relajando:
            C = np.cos(diff)
            G += mem_eta * (C - 0.2) * A
            G -= mem_forget * (G - A)
            G = np.clip(G, g_clip[0], g_clip[1])

    return {
        "ts": ts, "rG": rG, "rL": rL,
        "open_frac": open_frac, "w_emerg": w_emerg,
        "fever": fever, "freq_dev": freq_dev,
        "P_abs_raw": P_abs_raw, "P_abs_eff": P_abs_eff,
        "u_abs": u_abs, "wM": wM
    }

# ============================================================
# MAIN + plots (una sola ventana)
# ============================================================

if __name__ == "__main__":
    out = patronpc_grid_closedloop(seed=2)

    ts = out["ts"]

    fig, ax = plt.subplots(5, 1, figsize=(12, 10), sharex=True)
    fig.suptitle("Patron-PC — Red eléctrica (lazo cerrado): SALIDAS retroalimentan el patrón de abajo")

    ax[0].plot(ts, out["rG"], label="rG (global)")
    ax[0].plot(ts, out["rL"], label="rL* (local)", alpha=0.85)
    ax[0].set_ylim(0, 1.02)
    ax[0].set_ylabel("coherencia")
    ax[0].legend(loc="upper right")

    ax[1].plot(ts, out["P_abs_raw"], label="|P_raw| medio (mundo)")
    ax[1].plot(ts, out["P_abs_eff"], label="|P_eff| medio (tras control)", alpha=0.9)
    ax[1].set_ylabel("desbalance")
    ax[1].legend(loc="upper right")

    ax[2].plot(ts, out["open_frac"], label="fracción módulos con ventana abierta")
    ax[2].plot(ts, out["w_emerg"], label="ventana emergente", alpha=0.8)
    ax[2].set_ylim(-0.05, 1.05)
    ax[2].set_ylabel("ventanas")
    ax[2].legend(loc="upper right")

    ax[3].plot(ts, out["fever"], label="fiebre (intento)")
    ax[3].plot(ts, out["u_abs"], label="|u| medio (salida aplicada)", alpha=0.9)
    ax[3].set_ylim(-0.05, 1.05)
    ax[3].set_ylabel("intento / salida")
    ax[3].legend(loc="upper right")

    ax[4].plot(ts, out["freq_dev"], label="mean(dθ/dt) (desviación)")
    ax[4].set_ylabel("desviación")
    ax[4].set_xlabel("tiempo")
    ax[4].legend(loc="upper right")

    plt.tight_layout()
    plt.show()
