#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import mpmath as mp
import time

# ----------------------------------------------------------------------
# High‑precision configuration (4000 digits)
# ----------------------------------------------------------------------
mp.mp.dps = 4000
PI = mp.pi

# ----------------------------------------------------------------------
# Analytic Fourier transform of k_H(t) = (6/H^2) sech^4(t/H)
# ----------------------------------------------------------------------
def k_hat_H(xi, H):
    """Exact \hat k_H(ξ) = 8π²ξ(π²H²ξ²+1) / sinh(π²Hξ), with k_hat_H(0) = 8/H."""
    if xi == 0:
        return mp.mpf(8) / H
    a = PI**2 * H * xi
    num = mp.mpf(8) * PI**2 * xi * (PI**2 * H**2 * xi**2 + 1)
    return num / mp.sinh(a)

def k_hat_H_at_zero(H):
    return mp.mpf(8) / H

# ----------------------------------------------------------------------
# Exact Q_H via Fourier series – no numerical integration in t
# ----------------------------------------------------------------------
def Q_H(N, H, T0):
    """
    Q_H(N,H,T0) = ∫ k_H(t) |Z_N(1/2 + i(T0+2πt))|² dt
                = Σ_{n,m} (nm)^(-1/2) e^{-iT0(log n - log m)} * \hat k_H((log n - log m)/(2π)).
    We use the RHS exact Fourier representation, truncating only in N.
    """
    logs = [mp.log(mp.mpf(n)) for n in range(1, N + 1)]
    inv_sqrt = [1 / mp.sqrt(n) for n in range(1, N + 1)]
    total = mp.mpf(0)
    for i in range(N):
        log_i, sqrt_i = logs[i], inv_sqrt[i]
        for j in range(N):
            log_j, sqrt_j = logs[j], inv_sqrt[j]
            xi = (log_i - log_j) / (2 * PI)
            phase = T0 * (log_i - log_j)
            # (nm)^(-1/2) e^{-i phase}
            coef = sqrt_i * sqrt_j
            a = coef * (mp.cos(phase) - 1j * mp.sin(phase))
            total += a * k_hat_H(xi, H)
    return mp.re(total)

# ----------------------------------------------------------------------
# Matrix K_H (N×N) using analytic \hat k_H
# ----------------------------------------------------------------------
def K_matrix(N, H):
    """K_{ij} = \hat k_H(log(i) - log(j)) / sqrt(i*j)."""
    logs = [mp.log(mp.mpf(n)) for n in range(1, N + 1)]
    inv_sqrt = [1 / mp.sqrt(n) for n in range(1, N + 1)]
    mat = [[mp.mpf(0) for _ in range(N)] for _ in range(N)]
    for i in range(N):
        inv_sqrt_i = inv_sqrt[i]
        log_i = logs[i]
        row = mat[i]
        for j in range(N):
            xi = log_i - logs[j]
            row[j] = k_hat_H(xi, H) * inv_sqrt_i * inv_sqrt[j]
    return mp.matrix(mat)

# ----------------------------------------------------------------------
# Largest eigenvalue – eigsy with power iteration fallback
# ----------------------------------------------------------------------
def lambda_max(K):
    """
    Largest eigenvalue of symmetric K.

    Try mp.eigsy first; if unavailable/unstable, fall back to a power method.
    """
    try:
        evals = mp.eigsy(K, eigvals_only=True)
        return max(evals)
    except Exception as e:
        print(f"eigsy failed ({e}); falling back to power iteration")
        return _lambda_max_power(K)

def _lambda_max_power(K):
    """
    Power method with Rayleigh quotient.

    Uses a relative tolerance based on mp.mp.dps. Still expensive at 4000 dps,
    but converges for reasonable N.
    """
    N = K.rows
    # Ask for ~ (dps - 20) digits, not the full dps (more realistic)
    tol = mp.mpf(10) ** (-(mp.mp.dps - 20))
    max_iter = 10000

    v = [mp.mpf(1) for _ in range(N)]
    norm = mp.sqrt(sum(x * x for x in v))
    v = [x / norm for x in v]

    lam_old = mp.mpf(0)
    for it in range(max_iter):
        # w = K * v
        w = [mp.mpf(0) for _ in range(N)]
        for i in range(N):
            s = mp.mpf(0)
            row_i = K[i, :]
            for j in range(N):
                s += row_i[j] * v[j]
            w[i] = s

        # Rayleigh quotient
        lam = sum(vi * wi for vi, wi in zip(v, w))

        if it > 0:
            if abs(lam - lam_old) <= tol * max(mp.mpf(1), abs(lam)):
                return lam
        lam_old = lam

        norm_w = mp.sqrt(sum(x * x for x in w))
        if norm_w == 0:
            break
        v = [x / norm_w for x in w]

    print("Warning: power iteration did not converge within max_iter")
    return lam_old

# ----------------------------------------------------------------------
# LOCK functional and stationarity probe
# ----------------------------------------------------------------------
def lock_functional(N, H, T0):
    """F(H,N,T0) = λ_max(K_H) · Q_H."""
    t0 = time.time()
    K = K_matrix(N, H)
    t1 = time.time()
    lam = lambda_max(K)
    t2 = time.time()
    Q = Q_H(N, H, T0)
    t3 = time.time()
    print(f"  Timings: K {t1-t0:.2f}s, eig {t2-t1:.2f}s, Q {t3-t2:.2f}s")
    return lam * Q, lam, Q

def finite_difference_stationarity(N, H, T0, dT):
    """Central finite differences of F w.r.t. T0."""
    print("Evaluating F(T0-dT)...")
    Fm, lam_m, Qm = lock_functional(N, H, T0 - dT)
    print("Evaluating F(T0)...")
    F0, lam0, Q0 = lock_functional(N, H, T0)
    print("Evaluating F(T0+dT)...")
    Fp, lam_p, Qp = lock_functional(N, H, T0 + dT)

    Fprime = (Fp - Fm) / (2 * dT)
    Fsecond = (Fp - 2 * F0 + Fm) / (dT * dT)

    return {
        "F(T0-dT)": Fm,
        "F(T0)": F0,
        "F(T0+dT)": Fp,
        "F_prime_approx": Fprime,
        "F_second_approx": Fsecond,
        "lambda_max": lam0,
        "Q_H": Q0,
    }

# ----------------------------------------------------------------------
# Main driver
# ----------------------------------------------------------------------
def main():
    H = mp.mpf("1.0")
    N = 20
    T0 = mp.mpf("0.0")
    dT = mp.mpf("1e-3")

    print("=== Settings ===")
    print(f"mp.dps      = {mp.mp.dps}")
    print(f"H           = {H}")
    print(f"N           = {N}")
    print(f"T0          = {T0}")
    print(f"dT          = {dT}")
    print()
    print(f"k_hat_H(0)  = {k_hat_H_at_zero(H)}")
    print()

    print("=== HPH Lock functional and stationarity in T0 ===")
    t_start = time.time()
    res = finite_difference_stationarity(N, H, T0, dT)
    t_end = time.time()
    print(f"\nTotal time: {t_end - t_start:.2f} s\n")

    # Pretty-print the main quantities (first 60 significant digits)
    for k, v in res.items():
        print(f"{k:>18} = {mp.nstr(v, 60)}")

    # Explicit LOCK assessment
    F0 = res["F(T0)"]
    Fprime = res["F_prime_approx"]
    Fsecond = res["F_second_approx"]

    rel_grad = abs(Fprime) / max(mp.mpf(1), abs(F0))
    print("\nLOCK diagnostics:")
    print(f"  Relative gradient |F'(T0)| / max(1, |F(T0)|) = {mp.nstr(rel_grad, 30)}")
    print(f"  Second derivative F''(T0)                    = {mp.nstr(Fsecond, 30)}")

    # Threshold for declaring a numerical LOCK; you can tune this
    thresh = mp.mpf("1e-50")
    if rel_grad < thresh:
        if Fsecond < 0:
            lock_type = "local maximum"
        elif Fsecond > 0:
            lock_type = "local minimum"
        else:
            lock_type = "flat / saddle"
        print(f"\n>>> HPH LOCK CONFIRMED (numerically):")
        print(f"    T0 = {T0} is a stationary point ({lock_type}) of F(H,N,T0).")
    else:
        print("\n>>> HPH LOCK NOT CONFIRMED at this tolerance.")
        print("    Consider adjusting dT or mp.dps, or scanning other T0.")

if __name__ == "__main__":
    main()