import numpy as np

# ============================================================
# Utilidades
# ============================================================

def order_param(x):
    return float(np.abs(np.mean(np.exp(1j * x))))

def ring_adj(N, k):
    I = np.arange(N)
    D = np.abs(I[:, None] - I[None, :])
    D = np.minimum(D, N - D)
    A = ((D > 0) & (D <= k)).astype(float)
    np.fill_diagonal(A, 0.0)
    return A

def local_order(theta, A):
    z = np.exp(1j * theta)
    deg = A.sum(1)
    W = np.where(deg[:, None] > 0, A / np.maximum(deg[:, None], 1e-12), 0.0)
    return np.abs(W @ z)

# ============================================================
# Simulación Patron-PC con registro de:
# - energía instantánea: sum(dθ/dt)^2   (sin ruido)
# - w_mem
# - fiebre
# ============================================================

def simulate_patron_pc(mode, seed, p):
    rng = np.random.default_rng(seed)

    N, M = p["N"], p["M"]
    dt = p["dt"]

    theta = rng.uniform(-np.pi, np.pi, N)
    omega = rng.normal(0.0, p["omega_std"], N)

    A0 = ring_adj(N, p["k"])
    G = A0.copy()

    m_size = N // M
    mods = [np.arange(i*m_size, (i+1)*m_size) for i in range(M-1)]
    mods.append(np.arange((M-1)*m_size, N))

    phi = rng.uniform(-np.pi, np.pi, M)
    Am = ring_adj(M, 1)

    up_until = down_until = mem_until = -1.0
    F_s = 1.0

    steps = int(p["T"] / dt)

    energy_inst = np.zeros(steps)
    w_mem = np.zeros(steps)
    fever = np.zeros(steps)

    for s in range(steps):
        t = s * dt

        # -------- pre-coherencia local
        rloc = local_order(theta, A0)
        pre = 0.75 * rloc.mean() + 0.25 * np.quantile(rloc, 0.8)

        # -------- ventana UP
        if t >= up_until and pre > p["th_up"]:
            up_until = t + p["hold_up"]
        WU = t < up_until

        # -------- módulos
        diff_phi = phi[None, :] - phi[:, None]
        consensus = (Am * np.sin(diff_phi)).sum(1)

        if WU:
            z = np.exp(1j * theta)
            targets = np.array([np.angle(np.mean(z[idx])) for idx in mods])
            phi += dt * (p["Kphi"] * consensus) + dt / p["tau_phi"] * np.sin(targets - phi)
        else:
            phi += dt * (p["Kphi"] * consensus)

        phi = (phi + np.pi) % (2*np.pi) - np.pi
        rphi = order_param(phi)

        # -------- relajación global
        rg_now = order_param(theta)
        F = 1.0 - rg_now
        dF = F - F_s
        F_s = (1.0 - p["relax_alpha"]) * F_s + p["relax_alpha"] * F
        relajando = dF < 0

        # -------- ventana DOWN
        if mode == "threshold":
            cond_down = rphi > p["th_down"]
        else:
            cond_down = (rphi > p["th_down"]) and relajando

        if t >= down_until and cond_down:
            down_until = t + p["hold_down"]
        WD = t < down_until

        if mode == "relax" and WD and not relajando:
            down_until = t
            WD = False

        # -------- ventana MEM
        if mode == "threshold":
            cond_mem = WU and WD
        else:
            cond_mem = WU and WD and relajando

        if t >= mem_until and cond_mem:
            mem_until = t + p["hold_mem"]
        WM = t < mem_until

        if mode == "relax" and WM and not relajando:
            mem_until = t
            WM = False

        # -------- dinámica inferior
        diff = theta[None, :] - theta[:, None]
        S = A0 * G

        if WD:
            field = np.zeros(N)
            for m, idx in enumerate(mods):
                field[idx] = phi[m]
            S *= (1.0 + p["lam"] * np.cos(field - theta))
            S = np.clip(S, 0.0, None)

        S *= WD
        denom = max(S.sum(1).mean(), 1e-12)
        coup = (S * np.sin(diff)).sum(1) / denom

        # energía = sum(dθ/dt)^2 con dθ/dt determinista (sin ruido)
        dtheta = omega + p["K"] * coup
        energy_inst[s] = float(np.sum(dtheta**2))

        # update estado
        theta += dt * dtheta
        theta += np.sqrt(dt) * p["noise"] * rng.normal(size=N)
        theta = (theta + np.pi) % (2*np.pi) - np.pi

        # fiebre: solo si hay intento (WD o WM)
        rg = order_param(theta)
        fever[s] = (1.0 - rg) * (1.0 if (WD or WM) else 0.0)

        # memoria (solo WM)
        if WM:
            C = np.cos(diff)
            G += p["eta"] * (C - 0.2) * A0
            G -= p["forget"] * (G - A0)
            G = np.clip(G, *p["g_clip"])

        w_mem[s] = 1.0 if WM else 0.0

    return energy_inst, w_mem, fever, dt

# ============================================================
# MAIN
# ============================================================

if __name__ == "__main__":

    params = dict(
        N=240, M=12, T=40.0, dt=0.01,
        k=10, K=3.6,
        omega_std=0.6, noise=0.01,
        th_up=0.28, hold_up=1.2,
        th_down=0.40, hold_down=1.2,
        hold_mem=1.5,
        Kphi=1.2, tau_phi=1.2,
        lam=0.7,
        eta=0.002, forget=0.0008, g_clip=(0.0, 3.0),
        relax_alpha=0.02
    )

    seeds = range(5)  # cambia a range(10) si quieres

    print("\n======== RESULTADOS POR SEED (η* ponderado por calidad) ========")

    rows = []

    for seed in seeds:
        EA, wA, fevA, dt = simulate_patron_pc("threshold", seed, params)
        EB, wB, fevB, _  = simulate_patron_pc("relax",     seed, params)

        EAt = EA.sum() * dt
        EBt = EB.sum() * dt

        # Mw = ∫ w_mem(t) * (1 - fiebre(t)) dt
        MwA = float(np.sum(wA * (1.0 - fevA)) * dt)
        MwB = float(np.sum(wB * (1.0 - fevB)) * dt)

        etaA = EAt / MwA if MwA > 1e-12 else np.inf
        etaB = EBt / MwB if MwB > 1e-12 else np.inf

        rows.append((EAt, EBt, MwA, MwB, etaA, etaB))

        print(f"\nSeed {seed}")
        print(f"  A → E={EAt:.1f}, Mw={MwA:.3f}, ηw={etaA:.1f}")
        print(f"  B → E={EBt:.1f}, Mw={MwB:.3f}, ηw={etaB:.1f}")

    rows = np.array(rows)

    print("\n======== PROMEDIOS ========")
    print(f"E total      A={rows[:,0].mean():.1f} | B={rows[:,1].mean():.1f}")
    print(f"Mw (calidad) A={rows[:,2].mean():.3f} | B={rows[:,3].mean():.3f}")
    print(f"ηw=E/Mw      A={rows[:,4].mean():.1f} | B={rows[:,5].mean():.1f}")
    print(f"Ratio ηw (B/A) = {rows[:,5].mean() / rows[:,4].mean():.3f}")


