import numpy as np
import matplotlib.pyplot as plt

# ============================================================
# Utilidades
# ============================================================

def order_param(x):
    return float(np.abs(np.mean(np.exp(1j * x))))

def ring_adj(N, k):
    I = np.arange(N)
    D = np.abs(I[:, None] - I[None, :])
    D = np.minimum(D, N - D)
    A = ((D > 0) & (D <= k)).astype(float)
    np.fill_diagonal(A, 0.0)
    return A

def local_order(theta, A):
    z = np.exp(1j * theta)
    deg = A.sum(1)
    W = np.where(deg[:, None] > 0, A / np.maximum(deg[:, None], 1e-12), 0.0)
    return np.abs(W @ z)

# ============================================================
# Simulación Patron-PC (dos modos de apertura de w_down)
# ============================================================

def simulate_patron_pc(
    mode="relax",           # "relax" o "threshold"
    N=240, M=12, T=40.0, dt=0.01,
    k=10, K=3.6,
    omega_std=0.6, noise=0.01,
    # ventanas
    th_up=0.28, hold_up=1.2,
    th_down=0.40, hold_down=1.2,
    hold_mem=1.5,
    # dinámica superior
    Kphi=1.2, tau_phi=1.2,
    # seducción
    lam=0.7,
    # memoria
    eta=0.0020, forget=0.0008, g_clip=(0.0, 3.0),
    # relajación
    relax_alpha=0.02,
    seed=2
):
    rng = np.random.default_rng(seed)

    theta = rng.uniform(-np.pi, np.pi, N)
    omega = rng.normal(0.0, omega_std, N)

    A0 = ring_adj(N, k)
    G = A0.copy().astype(float)

    # módulos
    m_size = N // M
    mods = [np.arange(i*m_size, (i+1)*m_size) for i in range(M-1)]
    mods.append(np.arange((M-1)*m_size, N))
    phi = rng.uniform(-np.pi, np.pi, M)
    Am = ring_adj(M, 1)

    # ventanas
    up_until = down_until = mem_until = -1.0

    # tensión suavizada
    F_s = 1.0

    # logs
    steps = int(T / dt)
    ts = np.zeros(steps)
    rG = np.zeros(steps)
    rL = np.zeros(steps)
    rPhi = np.zeros(steps)
    fever = np.zeros(steps)
    w_up = np.zeros(steps)
    w_down = np.zeros(steps)
    w_mem = np.zeros(steps)
    energy_inst = np.zeros(steps)   # potencia instantánea = sum(dtheta/dt)^2

    for s in range(steps):
        t = s * dt

        # ---------- lectura local
        rloc = local_order(theta, A0)
        pre = float(0.75 * rloc.mean() + 0.25 * np.quantile(rloc, 0.80))

        # ---------- ventana UP
        if t >= up_until and pre > th_up:
            up_until = t + hold_up
        WU = 1.0 if t < up_until else 0.0

        # ---------- dinámica superior
        phi_diff = phi[None, :] - phi[:, None]
        consensus = (Am * np.sin(phi_diff)).sum(1)

        if WU:
            z = np.exp(1j * theta)
            targets = np.array([np.angle(np.mean(z[idx])) for idx in mods])
            phi += dt * (Kphi * consensus) + (dt / tau_phi) * np.sin(targets - phi)
        else:
            phi += dt * (Kphi * consensus)

        phi = (phi + np.pi) % (2*np.pi) - np.pi
        rphi = order_param(phi)

        # ---------- relajación global (para modo "relax")
        rg_now = order_param(theta)
        F = 1.0 - rg_now
        F_prev = F_s
        F_s = (1.0 - relax_alpha) * F_s + relax_alpha * F
        dF = F_s - F_prev

        relajando = (dF < 0)  # barato y robusto

        # ---------- ventana DOWN (dos modos)
        if mode == "threshold":
            cond_open = (rphi > th_down)
        elif mode == "relax":
            cond_open = (rphi > th_down) and relajando
        else:
            raise ValueError("mode debe ser 'threshold' o 'relax'")

        if t >= down_until and cond_open:
            down_until = t + hold_down
        WD = 1.0 if t < down_until else 0.0

        # cierre temprano SOLO en modo relax (si deja de relajar)
        if mode == "relax" and WD and not relajando:
            down_until = t
            WD = 0.0

        # ---------- ventana MEM (solo si hay bidirección; y en relax además exige relajación)
        if mode == "threshold":
            cond_mem = (WU > 0 and WD > 0)
        else:
            cond_mem = (WU > 0 and WD > 0 and relajando)

        if t >= mem_until and cond_mem:
            mem_until = t + hold_mem
        WM = 1.0 if t < mem_until else 0.0

        if mode == "relax" and WM and not relajando:
            mem_until = t
            WM = 0.0

        # ---------- acoplamiento inferior
        diff = theta[None, :] - theta[:, None]
        S = A0 * G

        # seducción top-down: modula paisaje (no impone) solo si WD
        if WD:
            field = np.zeros(N)
            for m, idx in enumerate(mods):
                field[idx] = phi[m]
            S = S * (1.0 + lam * np.cos(field - theta))
            S = np.clip(S, 0.0, None)

        S *= WD
        denom = np.maximum(S.sum(1).mean(), 1e-12)
        coup = (S * np.sin(diff)).sum(1) / denom

        # ---------- dinámica (guardamos dtheta/dt para energía)
        # dtheta/dt determinista (sin ruido) + ruido
        dtheta_det = (omega + K * coup)
        theta += dt * dtheta_det
        theta += np.sqrt(dt) * noise * rng.normal(size=N)
        theta = (theta + np.pi) % (2*np.pi) - np.pi

        rg = order_param(theta)

        # fiebre solo si hay intento (down o mem)
        fever_t = (1.0 - rg) * (1.0 if (WD > 0 or WM > 0) else 0.0)

        # memoria solo si WM
        if WM:
            C = np.cos(diff)
            G += eta * (C - 0.2) * A0
            G -= forget * (G - A0)
            G = np.clip(G, g_clip[0], g_clip[1])

        # energía instantánea (potencia): sum(dtheta/dt)^2
        # usamos dtheta_det (la parte "computacional") para medir coste interno
        energy_inst[s] = float(np.sum(dtheta_det**2))

        # logs
        ts[s] = t
        rG[s] = rg
        rL[s] = pre
        rPhi[s] = rphi
        fever[s] = fever_t
        w_up[s] = WU
        w_down[s] = WD
        w_mem[s] = WM

    return {
        "ts": ts, "rG": rG, "rL": rL, "rPhi": rPhi,
        "fever": fever, "w_up": w_up, "w_down": w_down, "w_mem": w_mem,
        "energy_inst": energy_inst, "dt": dt
    }

# ============================================================
# Métricas A/B
# ============================================================

def summarize(run):
    dt = run["dt"]
    rG = run["rG"]
    fever = run["fever"]
    w_down = run["w_down"]
    w_mem = run["w_mem"]
    E_inst = run["energy_inst"]

    E_total = float(np.sum(E_inst) * dt)          # ∫ sum(dtheta/dt)^2 dt
    fever_area = float(np.sum(fever) * dt)        # ∫ fiebre dt
    t_down = float(np.sum(w_down) * dt)
    t_mem = float(np.sum(w_mem) * dt)

    d_rG = float(rG[-1] - rG[0])
    rG_mean = float(np.mean(rG))
    rG_max = float(np.max(rG))

    # eficiencia: energía por ΔrG (si ΔrG <= 0, lo dejamos como inf)
    eff = (E_total / d_rG) if d_rG > 1e-9 else float("inf")

    return {
        "E_total": E_total,
        "fever_area": fever_area,
        "t_down": t_down,
        "t_mem": t_mem,
        "d_rG": d_rG,
        "rG_mean": rG_mean,
        "rG_max": rG_max,
        "eff_E_per_dRG": eff
    }

# ============================================================
# Plot
# ============================================================

def plot_ab(A, B, title="Energía A/B (threshold vs relax)"):
    ts = A["ts"]

    fig, ax = plt.subplots(4, 1, figsize=(13, 9), sharex=True)
    fig.suptitle(title)

    ax[0].plot(ts, A["rG"], label="A rG")
    ax[0].plot(ts, A["rL"], label="A rL*", alpha=0.8)
    ax[0].plot(ts, A["rPhi"], label="A rPhi", alpha=0.8)
    ax[0].plot(ts, B["rG"], label="B rG")
    ax[0].plot(ts, B["rL"], label="B rL*", alpha=0.8)
    ax[0].plot(ts, B["rPhi"], label="B rPhi", alpha=0.8)
    ax[0].set_ylim(0, 1.02)
    ax[0].set_ylabel("coherencia")
    ax[0].legend(ncol=3, loc="upper right")

    ax[1].plot(ts, A["fever"], label="A fiebre")
    ax[1].plot(ts, B["fever"], label="B fiebre")
    ax[1].set_ylim(0, 1.02)
    ax[1].set_ylabel("fiebre")
    ax[1].legend(loc="upper right")

    ax[2].plot(ts, A["w_down"], label="A w_down")
    ax[2].plot(ts, B["w_down"], label="B w_down")
    ax[2].plot(ts, A["w_mem"], label="A w_mem", alpha=0.7)
    ax[2].plot(ts, B["w_mem"], label="B w_mem", alpha=0.7)
    ax[2].set_ylim(-0.05, 1.05)
    ax[2].set_ylabel("ventanas")
    ax[2].legend(ncol=2, loc="upper right")

    ax[3].plot(ts, np.cumsum(A["energy_inst"]) * A["dt"], label="A energía acumulada")
    ax[3].plot(ts, np.cumsum(B["energy_inst"]) * B["dt"], label="B energía acumulada")
    ax[3].set_ylabel("∫ Σ(dθ/dt)^2 dt")
    ax[3].set_xlabel("tiempo")
    ax[3].legend(loc="upper left")

    plt.tight_layout()
    plt.show()

# ============================================================
# MAIN
# ============================================================

if __name__ == "__main__":
    # Parámetros comunes para ambos
    common = dict(
        N=240, M=12, T=40.0, dt=0.01,
        k=10, K=3.6,
        omega_std=0.6, noise=0.01,
        th_up=0.28, hold_up=1.2,
        th_down=0.40, hold_down=1.2,
        hold_mem=1.5,
        Kphi=1.2, tau_phi=1.2,
        lam=0.7,
        eta=0.0020, forget=0.0008, g_clip=(0.0, 3.0),
        relax_alpha=0.02,
        seed=2
    )

    # A: apertura por umbral
    A = simulate_patron_pc(mode="threshold", **common)

    # B: apertura por relajación (tu principio)
    B = simulate_patron_pc(mode="relax", **common)

    SA = summarize(A)
    SB = summarize(B)

    print("\n==================== RESUMEN A/B ====================")
    print("A = umbral clásico (sin relajación)")
    for k, v in SA.items():
        print(f"  {k:>14}: {v:.6f}" if np.isfinite(v) else f"  {k:>14}: inf")

    print("\nB = apertura por relajación (dF < 0)")
    for k, v in SB.items():
        print(f"  {k:>14}: {v:.6f}" if np.isfinite(v) else f"  {k:>14}: inf")

    # Comparativas simples
    def ratio(a, b):
        return (b / a) if abs(a) > 1e-12 else float("inf")

    print("\n-------------------- COMPARATIVAS --------------------")
    print(f"  Energía total (B/A):     {ratio(SA['E_total'], SB['E_total']):.6f}")
    print(f"  Área fiebre (B/A):       {ratio(SA['fever_area'], SB['fever_area']):.6f}")
    print(f"  Tiempo w_down (B/A):     {ratio(SA['t_down'], SB['t_down']):.6f}")
    print(f"  ΔrG (B/A):               {ratio(SA['d_rG'], SB['d_rG']):.6f}")
    print(f"  E por ΔrG (B/A):         {ratio(SA['eff_E_per_dRG'], SB['eff_E_per_dRG']):.6f}")

    plot_ab(A, B, title="Patron-PC A/B: umbral vs relajación — energía y fiebre")
