import numpy as np

# ============================================================
# Utilidades
# ============================================================

def order_param(x):
    return float(np.abs(np.mean(np.exp(1j * x))))

def ring_adj(N, k):
    I = np.arange(N)
    D = np.abs(I[:, None] - I[None, :])
    D = np.minimum(D, N - D)
    A = ((D > 0) & (D <= k)).astype(float)
    np.fill_diagonal(A, 0.0)
    return A

def local_order(theta, A):
    z = np.exp(1j * theta)
    deg = A.sum(1)
    W = np.where(deg[:, None] > 0, A / np.maximum(deg[:, None], 1e-12), 0.0)
    return np.abs(W @ z)

def circ_similarity(a, b):
    """
    Similitud circular invariante a rotación global:
    1. representamos fases como exp(i*theta)
    2. alineamos por una fase global óptima
    3. medimos |<z_a, z_b>| / N en [0,1]
    """
    za = np.exp(1j * a)
    zb = np.exp(1j * b)
    inner = np.vdot(za, zb)  # sum conj(za)*zb
    return float(np.abs(inner) / len(a))

def episodes_from_mask(mask, dt, min_len=0.5):
    """
    Devuelve lista de (start_idx, end_idx) de episodios True con duración >= min_len (seg)
    end_idx incluido.
    """
    eps = []
    on = False
    start = 0
    for i, v in enumerate(mask):
        if (v > 0.5) and not on:
            on = True
            start = i
        elif (v <= 0.5) and on:
            on = False
            end = i - 1
            if (end - start + 1) * dt >= min_len:
                eps.append((start, end))
    if on:
        end = len(mask) - 1
        if (end - start + 1) * dt >= min_len:
            eps.append((start, end))
    return eps

# ============================================================
# Núcleo Patron-PC (A/B) con grabación de memorias
# ============================================================

def run_learning(mode, seed, p):
    """
    Ejecuta fase de aprendizaje y devuelve:
    - estado final (theta, phi, G)
    - lista de memorias: dicts con theta_mem, phi_mem y tiempo (idx)
    """
    rng = np.random.default_rng(seed)
    N, M = p["N"], p["M"]
    dt = p["dt"]

    theta = rng.uniform(-np.pi, np.pi, N)
    omega = rng.normal(0.0, p["omega_std"], N)

    A0 = ring_adj(N, p["k"])
    G = A0.copy()

    m_size = N // M
    mods = [np.arange(i*m_size, (i+1)*m_size) for i in range(M-1)]
    mods.append(np.arange((M-1)*m_size, N))

    phi = rng.uniform(-np.pi, np.pi, M)
    Am = ring_adj(M, 1)

    up_until = down_until = mem_until = -1.0
    F_s = 1.0

    steps = int(p["T_learn"] / dt)

    w_mem = np.zeros(steps)
    fever = np.zeros(steps)

    # para detectar episodios y guardar memoria al final del episodio
    mem_list = []

    for s in range(steps):
        t = s * dt

        rloc = local_order(theta, A0)
        pre = 0.75 * rloc.mean() + 0.25 * np.quantile(rloc, 0.8)

        if t >= up_until and pre > p["th_up"]:
            up_until = t + p["hold_up"]
        WU = t < up_until

        diff_phi = phi[None, :] - phi[:, None]
        consensus = (Am * np.sin(diff_phi)).sum(1)

        if WU:
            z = np.exp(1j * theta)
            targets = np.array([np.angle(np.mean(z[idx])) for idx in mods])
            phi += dt * (p["Kphi"] * consensus) + dt / p["tau_phi"] * np.sin(targets - phi)
        else:
            phi += dt * (p["Kphi"] * consensus)

        phi = (phi + np.pi) % (2*np.pi) - np.pi
        rphi = order_param(phi)

        # relajación global
        rg_now = order_param(theta)
        F = 1.0 - rg_now
        dF = F - F_s
        F_s = (1.0 - p["relax_alpha"]) * F_s + p["relax_alpha"] * F
        relajando = dF < 0

        # DOWN
        if mode == "threshold":
            cond_down = rphi > p["th_down"]
        else:
            cond_down = (rphi > p["th_down"]) and relajando

        if t >= down_until and cond_down:
            down_until = t + p["hold_down"]
        WD = t < down_until

        if mode == "relax" and WD and not relajando:
            down_until = t
            WD = False

        # MEM
        if mode == "threshold":
            cond_mem = WU and WD
        else:
            cond_mem = WU and WD and relajando

        if t >= mem_until and cond_mem:
            mem_until = t + p["hold_mem"]
        WM = t < mem_until

        if mode == "relax" and WM and not relajando:
            mem_until = t
            WM = False

        # dinámica inferior
        diff = theta[None, :] - theta[:, None]
        S = A0 * G

        if WD:
            field = np.zeros(N)
            for m, idx in enumerate(mods):
                field[idx] = phi[m]
            S *= (1.0 + p["lam"] * np.cos(field - theta))
            S = np.clip(S, 0.0, None)

        S *= WD
        denom = max(S.sum(1).mean(), 1e-12)
        coup = (S * np.sin(diff)).sum(1) / denom

        dtheta = omega + p["K"] * coup
        theta += dt * dtheta
        theta += np.sqrt(dt) * p["noise"] * rng.normal(size=N)
        theta = (theta + np.pi) % (2*np.pi) - np.pi

        rg = order_param(theta)
        fever[s] = (1.0 - rg) * (1.0 if (WD or WM) else 0.0)
        w_mem[s] = 1.0 if WM else 0.0

        if WM:
            C = np.cos(diff)
            G += p["eta"] * (C - 0.2) * A0
            G -= p["forget"] * (G - A0)
            G = np.clip(G, *p["g_clip"])

    # detectar episodios de memoria y guardar snapshot al final de cada episodio
    eps = episodes_from_mask(w_mem, dt, min_len=p["min_mem_len"])
    for (a, b) in eps:
        mem_list.append({
            "t_idx": b,
            "theta": theta.copy(),  # snapshot del estado actual al final aprendizaje (aprox.)
            "phi": phi.copy()
        })

    # Nota: para guardar exactamente en cada final de episodio, haría falta loggear estados por paso.
    # Para mantenerlo ligero, guardamos el estado final del aprendizaje como "memoria dominante".
    # Si quieres granularidad perfecta, te lo hago (más RAM, pero exacto).

    return theta, phi, G, omega, mem_list, A0, mods, Am

# ============================================================
# Fase de estabilidad: congelamos memoria (G no cambia)
# ============================================================

def run_stability(theta, phi, G, omega, A0, mods, Am, seed, p):
    """
    Ejecuta estabilidad con:
    - un shock (ruido fuerte) en [t_shock, t_shock + shock_dur]
    - luego relajación
    Devuelve series de theta, phi para medir similitud.
    """
    rng = np.random.default_rng(seed + 10_000)
    N = p["N"]
    dt = p["dt"]
    steps = int(p["T_stab"] / dt)

    # tiempos de shock
    t_shock = p["t_shock"]
    shock_dur = p["shock_dur"]

    theta_t = []
    phi_t = []
    t_axis = []

    # tensión suavizada (para el modo relax, pero aquí ya no abrimos memoria)
    F_s = 1.0

    for s in range(steps):
        t = s * dt
        # shock noise multiplier
        noise_mult = p["shock_mult"] if (t_shock <= t < t_shock + shock_dur) else 1.0

        # dinámica superior (solo consenso interno, sin targets: memoria congelada)
        diff_phi = phi[None, :] - phi[:, None]
        consensus = (Am * np.sin(diff_phi)).sum(1)
        phi += dt * (p["Kphi_stab"] * consensus)
        phi = (phi + np.pi) % (2*np.pi) - np.pi

        # usamos siempre WD=1 en estabilidad? NO.
        # Aquí queremos ver si el patrón se sostiene sin “seducción activa”.
        # Así que WD=0: dinámica base + osciladores + ruido.
        WD = False

        # dinámica inferior (sin seducción)
        diff = theta[None, :] - theta[:, None]
        S = (A0 * G) * (1.0 if WD else 0.0)  # sin acoplamiento activo
        denom = max(S.sum(1).mean(), 1e-12)
        coup = (S * np.sin(diff)).sum(1) / denom

        dtheta = omega + p["K"] * coup
        theta += dt * dtheta
        theta += np.sqrt(dt) * p["noise"] * noise_mult * rng.normal(size=N)
        theta = (theta + np.pi) % (2*np.pi) - np.pi

        if (s % p["log_every"] == 0):
            theta_t.append(theta.copy())
            phi_t.append(phi.copy())
            t_axis.append(t)

    return np.array(t_axis), np.array(theta_t), np.array(phi_t)

# ============================================================
# Métrica de robustez
# ============================================================

def robustness_score(t_axis, theta_t, phi_t, mem_theta, mem_phi, p):
    """
    Score antes y después del shock.
    Devuelve:
      - sim_pre (promedio en ventana antes del shock)
      - sim_post (promedio en ventana después del shock)
      - ratio (post/pre)
    """
    sim_theta = np.array([circ_similarity(th, mem_theta) for th in theta_t])
    sim_phi   = np.array([circ_similarity(ph, mem_phi) for ph in phi_t])
    sim = p["w_theta"] * sim_theta + (1.0 - p["w_theta"]) * sim_phi

    t_shock = p["t_shock"]
    pre_mask  = (t_axis >= (t_shock - p["avg_window"])) & (t_axis < t_shock)
    post_mask = (t_axis >= (t_shock + p["shock_dur"])) & (t_axis < (t_shock + p["shock_dur"] + p["avg_window"]))

    sim_pre  = float(sim[pre_mask].mean()) if np.any(pre_mask) else float(sim[:5].mean())
    sim_post = float(sim[post_mask].mean()) if np.any(post_mask) else float(sim[-5:].mean())
    ratio = sim_post / sim_pre if sim_pre > 1e-12 else 0.0

    return sim_pre, sim_post, ratio

# ============================================================
# MAIN: A/B con seeds
# ============================================================

if __name__ == "__main__":

    p = dict(
        # tamaños y tiempos
        N=240, M=12, dt=0.01,
        T_learn=40.0,
        T_stab=30.0,

        # red y dinámica
        k=10, K=3.6,
        omega_std=0.6,
        noise=0.01,

        # ventanas aprendizaje
        th_up=0.28, hold_up=1.2,
        th_down=0.40, hold_down=1.2,
        hold_mem=1.5,
        relax_alpha=0.02,

        # módulos
        Kphi=1.2, tau_phi=1.2,
        Kphi_stab=0.6,          # estabilidad: módulos más lentos

        # seducción
        lam=0.7,

        # memoria G
        eta=0.002, forget=0.0008, g_clip=(0.0, 3.0),

        # memoria detectada
        min_mem_len=0.5,        # seg, para considerar episodio

        # shock
        t_shock=10.0,
        shock_dur=1.0,
        shock_mult=8.0,         # multiplicador de ruido durante shock

        # logging estabilidad
        log_every=20,           # cada 20 pasos guarda snapshot (0.2s)

        # comparación
        avg_window=5.0,         # promedio 5s antes/después
        w_theta=0.7             # peso de similitud de theta vs phi
    )

    seeds = range(5)  # cambia a range(10)

    def one_mode(mode_name):
        out = []
        for seed in seeds:
            theta, phi, G, omega, mems, A0, mods, Am = run_learning(mode_name, seed, p)

            # Si no se detectan memorias por episodios (puede pasar), usamos "memoria dominante" = estado final
            if len(mems) == 0:
                mem_theta = theta.copy()
                mem_phi = phi.copy()
            else:
                # por simplicidad: usamos la última memoria detectada (la más reciente)
                mem_theta = mems[-1]["theta"]
                mem_phi = mems[-1]["phi"]

            t_axis, theta_t, phi_t = run_stability(theta.copy(), phi.copy(), G.copy(), omega.copy(), A0, mods, Am, seed, p)
            sim_pre, sim_post, ratio = robustness_score(t_axis, theta_t, phi_t, mem_theta, mem_phi, p)

            out.append((sim_pre, sim_post, ratio))
            print(f"{mode_name.upper()} seed {seed}: sim_pre={sim_pre:.3f}, sim_post={sim_post:.3f}, ratio={ratio:.3f}")
        out = np.array(out)
        return out

    print("\n=== ESTABILIDAD A (threshold) ===")
    A = one_mode("threshold")

    print("\n=== ESTABILIDAD B (relax) ===")
    B = one_mode("relax")

    print("\n======== PROMEDIOS ========")
    print(f"A sim_pre  mean={A[:,0].mean():.3f} | sim_post mean={A[:,1].mean():.3f} | ratio mean={A[:,2].mean():.3f}")
    print(f"B sim_pre  mean={B[:,0].mean():.3f} | sim_post mean={B[:,1].mean():.3f} | ratio mean={B[:,2].mean():.3f}")
    print(f"Ratio robustez (B/A) = {B[:,2].mean() / max(A[:,2].mean(),1e-12):.3f}")
