"""
LFM Mercury Perihelion Precession - PROPER LATTICE SIMULATION
==============================================================

METHODOLOGY (following SELF_IMPROVEMENT.md rules):
- ONLY GOV-01 and GOV-02 are used to evolve the lattice
- NO derived forces, NO Newtonian gravity, NO shortcuts
- The outcome is MEASURED from wave packet trajectory

THE ONLY EQUATIONS USED:
  GOV-01: ∂²E/∂t² = c²∇²E − χ²E
  GOV-02: ∂²χ/∂t² = c²∇²χ − κ(E² − E₀²)

Discrete leapfrog form:
  E^{t+1} = 2E^t − E^{t−1} + (Δt)²[c²∇²E^t − (χ^t)²E^t]
  χ^{t+1} = 2χ^t − χ^{t−1} + (Δt)²[c²∇²χ^t − κ((E^t)² − E₀²)]

WHAT WE MEASURE:
- Center of mass of E² distribution over time
- Perihelion angles from trajectory
- Precession rate per orbit

Author: LFM Research
Date: 2026-02-05
Methodology: PURE LATTICE EVOLUTION - no derived physics
"""

import numpy as np
import matplotlib.pyplot as plt
from dataclasses import dataclass
from typing import Tuple, List
import json
from pathlib import Path

# Natural units
c = 1.0


@dataclass
class LatticeParams:
    """Simulation parameters for 2D lattice"""
    # Grid parameters
    N: int = 512                    # Grid points per dimension
    L: float = 300.0                # Physical size of domain
    
    # LFM parameters
    chi_0: float = 1.0              # Background χ value
    kappa: float = 1.0              # Coupling constant in GOV-02
    E0_squared: float = 0.0         # Background E² (vacuum = 0)
    
    # Central mass (Sun) - represented as fixed χ source
    # We model the Sun as a region where E² is permanently high
    sun_radius: float = 5.0         # Radius of Sun region
    sun_E_amplitude: float = 10.0   # E amplitude in Sun (creates χ well)
    
    # Wave packet (Mercury) parameters
    packet_radius: float = 3.0      # Size of wave packet
    packet_amplitude: float = 1.0   # E amplitude of packet
    
    # Orbit initial conditions
    orbit_radius: float = 100.0     # Initial distance from center
    orbit_velocity: float = 0.1     # Initial tangential velocity
    
    # Time evolution
    dt: float = 0.1                 # Timestep (must satisfy CFL: dt < dx/c)
    n_steps: int = 100000           # Total steps
    measure_interval: int = 100     # Steps between COM measurements
    
    @property
    def dx(self) -> float:
        return self.L / self.N
    
    @property
    def CFL_limit(self) -> float:
        return self.dx / c
    
    def validate(self):
        """Check CFL condition"""
        if self.dt >= self.CFL_limit:
            raise ValueError(f"CFL violation: dt={self.dt} >= dx/c={self.CFL_limit}")
        print(f"CFL check passed: dt={self.dt} < dx/c={self.CFL_limit:.4f}")


class LFMLattice:
    """
    2D lattice evolving ONLY via GOV-01 and GOV-02.
    
    Now using COMPLEX E field to properly represent angular momentum.
    The Klein-Gordon equation for complex field:
      ∂²E/∂t² = c²∇²E − χ²E
    with E = E_r + i*E_i (real and imaginary parts)
    
    Angular momentum is carried by phase vortices: E ~ |E|exp(i*m*φ)
    """
    
    def __init__(self, params: LatticeParams):
        self.p = params
        params.validate()
        
        # Create coordinate grids
        x = np.linspace(-params.L/2, params.L/2, params.N)
        y = np.linspace(-params.L/2, params.L/2, params.N)
        self.X, self.Y = np.meshgrid(x, y)
        self.R = np.sqrt(self.X**2 + self.Y**2)
        self.PHI = np.arctan2(self.Y, self.X)  # Angle from origin
        
        # E field: COMPLEX, current and previous
        self.E = np.zeros((params.N, params.N), dtype=complex)
        self.E_prev = np.zeros((params.N, params.N), dtype=complex)
        
        # χ field: REAL, current and previous
        self.chi = np.ones((params.N, params.N)) * params.chi_0
        self.chi_prev = np.ones((params.N, params.N)) * params.chi_0
        
        # Fixed Sun source (permanently high |E|² at center)
        self.sun_mask = self.R < params.sun_radius
        self.E_sun = np.zeros((params.N, params.N), dtype=complex)
        self.E_sun[self.sun_mask] = params.sun_E_amplitude
        
        # Initialize Mercury wave packet with angular momentum
        self._init_wave_packet()
        
        # Tracking
        self.t = 0.0
        self.step_count = 0
        self.trajectory: List[Tuple[float, float, float]] = []
        
        # Let χ field settle
        print("Letting χ field equilibrate with Sun source...")
        self._equilibrate_chi(n_steps=1000)
    
    def _init_wave_packet(self):
        """
        Initialize Mercury as a STABLE orbital mode.
        
        We need a wave configuration that is an eigenmode of the system.
        For Klein-Gordon in a potential well, eigenmodes are stationary
        (standing waves), but we can superpose modes to get rotation.
        
        Let's try a different approach: initialize with known orbital
        velocity using the wave equation's group velocity.
        """
        # Create a tight ring at orbit radius
        r_ring = self.p.orbit_radius
        
        # Radial envelope (thin ring)
        radial_envelope = np.exp(-(self.R - r_ring)**2 / (2 * self.p.packet_radius**2))
        
        # For the wave to move tangentially, we need it to have momentum
        # In a wave equation, momentum comes from spatial phase gradient
        # For tangential motion: phase gradient in φ direction
        # 
        # The wave number k relates to momentum: p = ℏk
        # For circular orbit at velocity v: k_tangential = v/c (in natural units)
        # Phase change per unit arc length: k = v/c
        # Phase change per unit angle: k*r = v*r/c = (v/c)*r
        
        # Angular momentum number m gives phase = m*φ
        # This corresponds to tangential k = m/r
        # For v = 0.05 at r = 50: k_tan = 0.05, so m = k_tan * r = 2.5
        m = int(self.p.orbit_velocity * r_ring)  # Angular momentum number
        m = max(m, 1)  # At least 1
        
        print(f"Wave packet initialized with m = {m} (angular momentum mode)")
        
        # E = A(r) * exp(i*m*φ)
        azimuthal_phase = np.exp(1j * m * self.PHI)
        self.E = self.p.packet_amplitude * radial_envelope * azimuthal_phase
        
        # For wave to have angular velocity ω = v/r:
        # E(t) ~ exp(i*(m*φ - ω*t))
        # At t=-dt: E_prev ~ exp(i*(m*φ + ω*dt))
        omega = self.p.orbit_velocity / r_ring
        self.E_prev = self.p.packet_amplitude * radial_envelope * np.exp(1j * (m * self.PHI + omega * self.p.dt))
    
    def _init_wave_packet(self):
        """
        Initialize Mercury as a Gaussian wave packet with angular momentum.
        
        For a wave to have orbital angular momentum, we need a phase gradient
        in the azimuthal direction: E ~ exp(i*m*φ) where m is the angular
        momentum quantum number.
        
        This creates a vortex-like structure that carries angular momentum.
        """
        # Position: at orbit_radius on +x axis
        x0 = self.p.orbit_radius
        y0 = 0.0
        
        # Distance from packet center
        dx = self.X - x0
        dy = self.Y - y0
        dist_sq = dx**2 + dy**2
        
        # Gaussian envelope centered at (x0, y0)
        envelope = np.exp(-dist_sq / (2 * self.p.packet_radius**2))
        
        # Angular momentum via azimuthal phase gradient
        # Phase = m * φ where φ is angle around packet center
        # This gives the wave packet tangential momentum
        m = 5  # Angular momentum quantum number (try different values)
        phi_local = np.arctan2(dy, dx)
        phase = m * phi_local
        
        # Also add a radial wave number for radial momentum toward orbit center
        # k_radial points from packet toward origin (inward)
        dist_from_origin = np.sqrt(self.X**2 + self.Y**2)
        # We want tangential velocity, not radial, so use perpendicular direction
        # For circular orbit at (x0,0), tangential is +y direction
        k_y = self.p.orbit_velocity * 10  # Wave number for tangential momentum
        
        # E field: envelope * cos(phase + k_y * Y)
        # The k_y*Y term gives the packet momentum in +y direction
        self.E = self.p.packet_amplitude * envelope * np.cos(phase + k_y * (self.Y - y0))
        
        # For the time derivative (velocity), we need E_prev offset
        # If E ~ cos(k*y - ω*t), then at t=-dt: E_prev ~ cos(k*y + ω*dt)
        # The offset is in the phase: shift by ω*dt
        omega = self.p.orbit_velocity * 10  # Angular frequency
        self.E_prev = self.p.packet_amplitude * envelope * np.cos(phase + k_y * (self.Y - y0) + omega * self.p.dt)
    
    def _equilibrate_chi(self, n_steps: int):
        """Let χ evolve to equilibrium with fixed Sun source before adding Mercury"""
        # Temporarily store Mercury E field
        E_mercury = self.E.copy()
        E_mercury_prev = self.E_prev.copy()
        
        # Only Sun source
        self.E = self.E_sun.copy()
        self.E_prev = self.E_sun.copy()
        
        # Evolve χ only (E is static Sun)
        for _ in range(n_steps):
            self._step_chi_only()
        
        # Restore Mercury
        self.E = E_mercury + self.E_sun
        self.E_prev = E_mercury_prev + self.E_sun
        
        print(f"χ equilibrated. χ at center: {self.chi[self.p.N//2, self.p.N//2]:.4f}")
        orb_idx = self.p.N//2 + int(self.p.orbit_radius/self.p.dx)
        if orb_idx < self.p.N:
            print(f"χ at orbit radius: {self.chi[self.p.N//2, orb_idx]:.4f}")
    
    def _step_chi_only(self):
        """Evolve only χ field (for equilibration)"""
        # GOV-02: ∂²χ/∂t² = c²∇²χ − κ(|E|² − E₀²)
        # Note: |E|² for complex E
        laplacian_chi = self._laplacian(self.chi)
        E_squared = np.abs(self.E)**2
        
        chi_next = (2 * self.chi - self.chi_prev + 
                   self.p.dt**2 * (c**2 * laplacian_chi - 
                                   self.p.kappa * (E_squared - self.p.E0_squared)))
        
        chi_next = np.clip(chi_next, 0.01, 10.0)
        
        self.chi_prev = self.chi
        self.chi = chi_next
    
    def _laplacian(self, field: np.ndarray) -> np.ndarray:
        """Compute discrete Laplacian using 5-point stencil"""
        # ∇²f = (f[i+1,j] + f[i-1,j] + f[i,j+1] + f[i,j-1] - 4f[i,j]) / dx²
        dx2 = self.p.dx**2
        
        lap = (np.roll(field, 1, axis=0) + np.roll(field, -1, axis=0) +
               np.roll(field, 1, axis=1) + np.roll(field, -1, axis=1) -
               4 * field) / dx2
        
        return lap
    
    def step(self):
        """
        Evolve BOTH E and χ fields by one timestep.
        
        ONLY these equations are used:
          GOV-01: E^{t+1} = 2E^t − E^{t−1} + (Δt)²[c²∇²E^t − (χ^t)²E^t]
          GOV-02: χ^{t+1} = 2χ^t − χ^{t−1} + (Δt)²[c²∇²χ^t − κ(|E^t|² − E₀²)]
        
        E is COMPLEX. χ is REAL.
        NO OTHER PHYSICS IS USED.
        """
        dt2 = self.p.dt**2
        
        # Compute Laplacians (complex for E, real for χ)
        laplacian_E = self._laplacian(self.E)
        laplacian_chi = self._laplacian(self.chi)
        
        # GOV-01: E evolution (complex Klein-Gordon)
        # ∂²E/∂t² = c²∇²E − χ²E
        E_next = (2 * self.E - self.E_prev + 
                 dt2 * (c**2 * laplacian_E - self.chi**2 * self.E))
        
        # GOV-02: χ evolution
        # ∂²χ/∂t² = c²∇²χ − κ(|E|² − E₀²)
        E_squared = np.abs(self.E)**2
        chi_next = (2 * self.chi - self.chi_prev + 
                   dt2 * (c**2 * laplacian_chi - 
                          self.p.kappa * (E_squared - self.p.E0_squared)))
        
        # Keep Sun region fixed (boundary condition)
        E_next[self.sun_mask] = self.p.sun_E_amplitude
        
        # Prevent χ from going negative or too large (stability)
        chi_next = np.clip(chi_next, 0.01, 10.0)
        
        # Check for numerical instability
        if np.any(np.isnan(E_next)) or np.any(np.isnan(chi_next)):
            print(f"WARNING: NaN detected at step {self.step_count}")
            print(f"  max|E|={np.nanmax(np.abs(self.E)):.2e}")
            print(f"  max|χ|={np.nanmax(np.abs(self.chi)):.2e}")
            return False
        
        # Update fields
        self.E_prev = self.E
        self.E = E_next
        self.chi_prev = self.chi
        self.chi = chi_next
        
        self.t += self.p.dt
        self.step_count += 1
        return True
    
    def measure_center_of_mass(self) -> Tuple[float, float]:
        """
        Measure center of mass of |E|² distribution (excluding Sun).
        This is where the "Mercury" wave packet is located.
        """
        # Exclude Sun region
        E_mercury = self.E.copy()
        E_mercury[self.sun_mask] = 0
        
        # Weight by |E|² (complex magnitude squared)
        E2 = np.abs(E_mercury)**2
        total = np.sum(E2)
        
        if total < 1e-10:
            return 0.0, 0.0
        
        x_com = np.sum(self.X * E2) / total
        y_com = np.sum(self.Y * E2) / total
        
        return x_com, y_com
    
    def run(self, verbose: bool = True) -> dict:
        """Run the simulation and track trajectory"""
        if verbose:
            print("\n" + "="*70)
            print("LFM LATTICE SIMULATION - PURE GOV-01/GOV-02 EVOLUTION")
            print("="*70)
            print(f"\nGrid: {self.p.N}x{self.p.N}, dx={self.p.dx:.4f}")
            print(f"Time: {self.p.n_steps} steps, dt={self.p.dt}")
            print(f"CFL ratio: {self.p.dt / self.p.CFL_limit:.2f}")
            print(f"\nEQUATIONS USED (and ONLY these):")
            print(f"  GOV-01: ∂²E/∂t² = c²∇²E − χ²E")
            print(f"  GOV-02: ∂²χ/∂t² = c²∇²χ − κ(E² − E₀²)")
            print(f"\nNO Newtonian gravity. NO GR. NO derived forces.")
            print(f"Whatever trajectory emerges IS the LFM prediction.")
            print("="*70 + "\n")
        
        # Initial measurement
        x0, y0 = self.measure_center_of_mass()
        self.trajectory.append((0.0, x0, y0))
        
        for step in range(self.p.n_steps):
            stable = self.step()
            if not stable:
                print("Simulation terminated due to instability")
                break
            
            # Measure periodically
            if step % self.p.measure_interval == 0:
                x, y = self.measure_center_of_mass()
                self.trajectory.append((self.t, x, y))
                
                if verbose and step % (self.p.n_steps // 10) == 0:
                    r = np.sqrt(x**2 + y**2)
                    theta = np.arctan2(y, x)
                    E_max = np.max(np.abs(self.E))
                    chi_min = np.min(self.chi)
                    print(f"Step {step:6d}: t={self.t:.1f}, r={r:.2f}, θ={np.degrees(theta):.1f}°, |E|_max={E_max:.3f}, χ_min={chi_min:.4f}")
        
        return self._analyze_trajectory()
    
    def _analyze_trajectory(self) -> dict:
        """Analyze trajectory for orbital properties and precession"""
        times = np.array([t[0] for t in self.trajectory])
        x = np.array([t[1] for t in self.trajectory])
        y = np.array([t[2] for t in self.trajectory])
        
        r = np.sqrt(x**2 + y**2)
        theta = np.unwrap(np.arctan2(y, x))
        
        # Find perihelion passages (local minima in r)
        perihelia_idx = []
        for i in range(1, len(r) - 1):
            if r[i] < r[i-1] and r[i] < r[i+1]:
                perihelia_idx.append(i)
        
        # Calculate precession from perihelion angles
        perihelion_angles = theta[perihelia_idx]
        perihelion_times = times[perihelia_idx]
        
        if len(perihelion_angles) >= 2:
            # Precession per orbit = angle change - 2π
            angle_changes = np.diff(perihelion_angles)
            precession_per_orbit = angle_changes - 2 * np.pi
            mean_precession = np.mean(precession_per_orbit)
        else:
            mean_precession = 0.0
            precession_per_orbit = np.array([])
        
        results = {
            "n_perihelia": len(perihelia_idx),
            "perihelion_angles_rad": perihelion_angles.tolist(),
            "precession_per_orbit_rad": precession_per_orbit.tolist(),
            "mean_precession_rad": float(mean_precession),
            "mean_precession_arcsec": float(mean_precession * 180/np.pi * 3600),
            "trajectory_x": x.tolist(),
            "trajectory_y": y.tolist(),
            "trajectory_t": times.tolist()
        }
        
        print("\n" + "="*70)
        print("RESULTS (from pure GOV-01/GOV-02 lattice evolution)")
        print("="*70)
        print(f"Perihelia detected: {len(perihelia_idx)}")
        if len(precession_per_orbit) > 0:
            print(f"Mean precession: {mean_precession:.6f} rad/orbit")
            print(f"                 {mean_precession * 180/np.pi * 3600:.2f} arcsec/orbit")
        else:
            print("Not enough orbits to measure precession")
        print("="*70)
        
        return results
    
    def save_results(self, results: dict, filename: str = "lattice_results.json"):
        """Save results to JSON"""
        output_path = Path(__file__).parent / filename
        with open(output_path, 'w') as f:
            json.dump(results, f, indent=2)
        print(f"\nResults saved to {output_path}")
    
    def plot_snapshot(self, filename: str = "lattice_snapshot.png"):
        """Save snapshot of E and χ fields"""
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        # |E|² field (magnitude squared of complex E)
        E2 = np.abs(self.E)**2
        im0 = axes[0].imshow(E2, extent=[-self.p.L/2, self.p.L/2, 
                                          -self.p.L/2, self.p.L/2],
                             cmap='hot', origin='lower')
        axes[0].set_title('|E|² (energy density)')
        axes[0].set_xlabel('x')
        axes[0].set_ylabel('y')
        plt.colorbar(im0, ax=axes[0])
        
        # χ field
        im1 = axes[1].imshow(self.chi, extent=[-self.p.L/2, self.p.L/2,
                                                -self.p.L/2, self.p.L/2],
                             cmap='viridis', origin='lower')
        axes[1].set_title('χ field')
        axes[1].set_xlabel('x')
        axes[1].set_ylabel('y')
        plt.colorbar(im1, ax=axes[1])
        
        # Trajectory
        if self.trajectory:
            x = [t[1] for t in self.trajectory]
            y = [t[2] for t in self.trajectory]
            axes[2].plot(x, y, 'b-', linewidth=0.5)
            axes[2].plot(x[0], y[0], 'go', markersize=10, label='Start')
            axes[2].plot(x[-1], y[-1], 'ro', markersize=10, label='End')
            axes[2].set_xlim(-self.p.L/2, self.p.L/2)
            axes[2].set_ylim(-self.p.L/2, self.p.L/2)
            axes[2].set_aspect('equal')
            axes[2].set_title('Trajectory (COM of E²)')
            axes[2].set_xlabel('x')
            axes[2].set_ylabel('y')
            axes[2].legend()
        
        plt.tight_layout()
        output_path = Path(__file__).parent / filename
        plt.savefig(output_path, dpi=150)
        plt.close()
        print(f"Snapshot saved to {output_path}")


def main():
    """Run the pure LFM lattice simulation"""
    
    # Parameters tuned for reasonable computation time
    # while still capturing orbital dynamics
    params = LatticeParams(
        N=128,              # Grid size (smaller for speed)
        L=150.0,            # Domain size (dx = 150/128 = 1.17)
        chi_0=1.0,          # Background χ
        kappa=0.01,         # GOV-02 coupling (very gentle)
        sun_radius=5.0,     # Sun size
        sun_E_amplitude=1.0,   # Sun E field strength (smaller)
        packet_radius=4.0,  # Mercury packet size
        packet_amplitude=0.2,  # Mercury E amplitude (smaller)
        orbit_radius=50.0,  # Starting distance
        orbit_velocity=0.05, # Initial tangential velocity (slower)
        dt=0.3,             # Timestep (CFL: dt < dx/c = 1.17)
        n_steps=5000,       # Fewer steps for quick test
        measure_interval=10 # Measurement frequency
    )
    
    # Create and run simulation
    lattice = LFMLattice(params)
    results = lattice.run(verbose=True)
    
    # Save outputs
    lattice.save_results(results)
    lattice.plot_snapshot()
    
    return results


if __name__ == "__main__":
    main()
