#!/usr/bin/env python3
"""
Gate-2C (phase-current / torque proxy) — reference re-run.

Produces:
  - gate2c_angles.csv (phi_deg, x, y, Ax, Ay, tau_z)
  - gate2c_run_summary.txt

Defaults mirror locked tracker settings: d_ref=10, n_angles=36.

This is provided for reproducibility. It re-implements the Gate-2C class
of calculation (ellipse-tangential current -> Helmholtz solve -> sampling).

Exact bit-for-bit matching with earlier bundles isn't guaranteed
(grid/FFT conventions), but definitions and parameter semantics match.
"""
import argparse
import math
from pathlib import Path

import numpy as np


def fft_helmholtz_solve(src: np.ndarray, m: float, L: float) -> np.ndarray:
    """Solve (∇^2 - m^2) u = -src on periodic square domain using FFT."""
    N = src.shape[0]
    k = 2.0 * math.pi * np.fft.fftfreq(N, d=L / N)
    kx, ky = np.meshgrid(k, k, indexing="ij")
    denom = (kx * kx + ky * ky + m * m)
    src_hat = np.fft.fft2(src)
    u_hat = src_hat / denom
    u_hat[0, 0] = 0.0
    return np.real(np.fft.ifft2(u_hat))


def build_current(N: int, L: float, a: float, b: float, J0: float, width: float):
    """Tangential current concentrated near ellipse x^2/a^2 + y^2/b^2 = 1."""
    x = np.linspace(-L / 2, L / 2, N, endpoint=False)
    X, Y = np.meshgrid(x, x, indexing="ij")
    s = (X / a) ** 2 + (Y / b) ** 2
    W = np.exp(-0.5 * ((s - 1.0) / width) ** 2)

    # tangent to level set s=const: grad(s)=(2x/a^2,2y/b^2) so t=(-gy,gx)
    gx = 2.0 * X / (a * a)
    gy = 2.0 * Y / (b * b)
    tx, ty = -gy, gx
    nrm = np.sqrt(tx * tx + ty * ty) + 1e-30
    tx /= nrm
    ty /= nrm

    return (J0 * W * tx, J0 * W * ty)


def sample_bilinear(field: np.ndarray, x: float, y: float, L: float) -> float:
    """Bilinear sampling with periodic wrap."""
    N = field.shape[0]
    dx = L / N
    fx = (x + L / 2) / dx
    fy = (y + L / 2) / dx
    i0 = int(math.floor(fx)) % N
    j0 = int(math.floor(fy)) % N
    i1 = (i0 + 1) % N
    j1 = (j0 + 1) % N
    tx = fx - math.floor(fx)
    ty = fy - math.floor(fy)
    v00 = field[i0, j0]
    v10 = field[i1, j0]
    v01 = field[i0, j1]
    v11 = field[i1, j1]
    return (1 - tx) * (1 - ty) * v00 + tx * (1 - ty) * v10 + (1 - tx) * ty * v01 + tx * ty * v11


def run(N: int, L: float, a: float, b: float, m: float, J0: float, width: float, n_angles: int, d_ref: float, outdir: Path) -> None:
    outdir.mkdir(parents=True, exist_ok=True)

    Jx, Jy = build_current(N=N, L=L, a=a, b=b, J0=J0, width=width)
    Ax = fft_helmholtz_solve(Jx, m=m, L=L)
    Ay = fft_helmholtz_solve(Jy, m=m, L=L)

    rows = []
    for k in range(n_angles):
        phi = 2.0 * math.pi * k / n_angles
        x = d_ref * a * math.cos(phi)
        y = d_ref * b * math.sin(phi)
        ax = sample_bilinear(Ax, x, y, L)
        ay = sample_bilinear(Ay, x, y, L)
        tau_z = x * ay - y * ax
        rows.append((phi * 180 / math.pi, x, y, ax, ay, tau_z))

    csv_path = outdir / "gate2c_angles.csv"
    np.savetxt(csv_path, np.array(rows), delimiter=",", header="phi_deg,x,y,Ax,Ay,tau_z", comments="")

    tau = np.array([r[-1] for r in rows])
    summary = (
        "Gate-2C re-run summary
"
        f"N={N}, L={L}, a={a}, b={b}, m={m}, J0={J0}, width={width}, n_angles={n_angles}, d_ref={d_ref}
"
        f"tau_z: min={tau.min():.6e}, median={np.median(tau):.6e}, max={tau.max():.6e}, RMS={math.sqrt(np.mean(tau*tau)):.6e}
"
        f"Wrote: {csv_path.name}
"
    )
    (outdir / "gate2c_run_summary.txt").write_text(summary, encoding="utf-8")


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--N", type=int, default=512)
    ap.add_argument("--L", type=float, default=60.0)
    ap.add_argument("--a", type=float, default=6.0)
    ap.add_argument("--b", type=float, default=4.0)
    ap.add_argument("--m", type=float, default=0.35)
    ap.add_argument("--J0", type=float, default=1.0)
    ap.add_argument("--width", type=float, default=0.08)
    ap.add_argument("--n_angles", type=int, default=36)
    ap.add_argument("--d_ref", type=float, default=10.0)
    ap.add_argument("--outdir", type=str, default="out_gate2c")
    args = ap.parse_args()

    run(
        N=args.N,
        L=args.L,
        a=args.a,
        b=args.b,
        m=args.m,
        J0=args.J0,
        width=args.width,
        n_angles=args.n_angles,
        d_ref=args.d_ref,
        outdir=Path(args.outdir),
    )


if __name__ == "__main__":
    main()
