"""2D Quantum-Classical System Extension.

This module extends the quantum-classical boundary emergence simulation
to two spatial dimensions, enabling studies of:
- 2D double-well dynamics V(x,y) = V(x) + V(y)
- Anisotropic decoherence and correction
- Spatial structure in quantum-classical transitions
- 2D tunneling and interference patterns
"""

import numpy as np
from scipy.integrate import simpson as simps
from scipy.fftpack import fft2, ifft2, fftfreq
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D

class QuantumSystem2D:
    """2D quantum system for studying spatial quantum-classical transitions."""

    def __init__(self, grid_size=128, x_range=(-8, 8), y_range=(-8, 8),
                 mass=1.0, hbar=1.0):
        """Initialize 2D quantum system."""
        self.grid_size = grid_size
        self.mass = mass
        self.hbar = hbar

        # Spatial grids
        self.x = np.linspace(x_range[0], x_range[1], grid_size)
        self.y = np.linspace(y_range[0], y_range[1], grid_size)
        self.dx = self.x[1] - self.x[0]
        self.dy = self.y[1] - self.y[0]

        # Create 2D meshgrids
        self.X, self.Y = np.meshgrid(self.x, self.y)

        # Momentum grids for FFT
        self.kx = fftfreq(grid_size, d=self.dx) * 2 * np.pi
        self.ky = fftfreq(grid_size, d=self.dy) * 2 * np.pi
        self.KX, self.KY = np.meshgrid(self.kx, self.ky)

        # Kinetic energy operator in momentum space
        self.kinetic_energy = (self.hbar**2 / (2 * self.mass)) * (self.KX**2 + self.KY**2)

    def double_well_potential_2d(self, coupling=0.0):
        """
        2D double-well potential: V(x,y) = V(x) + V(y) + coupling*x*y

        coupling parameter allows for interaction between x and y degrees of freedom.
        """
        V_x = 0.25 * self.X**4 - 2 * self.X**2
        V_y = 0.25 * self.Y**4 - 2 * self.Y**2
        V_coupling = coupling * self.X * self.Y

        return V_x + V_y + V_coupling

    def create_initial_state_2d(self, state_type='four_gaussian'):
        """Create various 2D initial states."""

        if state_type == 'four_gaussian':
            # Four Gaussian wavepackets at the corners of potential wells
            psi = (np.exp(-((self.X + 2)**2 + (self.Y + 2)**2)) +
                   np.exp(-((self.X + 2)**2 + (self.Y - 2)**2)) +
                   np.exp(-((self.X - 2)**2 + (self.Y + 2)**2)) +
                   np.exp(-((self.X - 2)**2 + (self.Y - 2)**2)))

        elif state_type == 'central_gaussian':
            # Single Gaussian at center
            psi = np.exp(-(self.X**2 + self.Y**2))

        elif state_type == 'ring_state':
            # Ring-like initial state
            r = np.sqrt(self.X**2 + self.Y**2)
            psi = np.exp(-0.5 * (r - 3)**2) * np.exp(1j * np.arctan2(self.Y, self.X))

        elif state_type == 'separable_double':
            # Separable state: ψ(x,y) = ψ_x(x) * ψ_y(y)
            psi_x = np.exp(-(self.x + 2)**2) + np.exp(-(self.x - 2)**2)
            psi_y = np.exp(-(self.y + 1)**2) + np.exp(-(self.y - 1)**2)
            psi_x_2d = psi_x[:, np.newaxis] * np.ones((1, self.grid_size))
            psi_y_2d = np.ones((self.grid_size, 1)) * psi_y[np.newaxis, :]
            psi = psi_x_2d * psi_y_2d

        else:
            raise ValueError(f"Unknown state type: {state_type}")

        # Normalize
        norm = np.sqrt(simps(simps(np.abs(psi)**2, self.y), self.x))
        return psi / norm if norm > 1e-12 else psi

    def split_step_evolution_2d(self, psi, potential, dt):
        """2D split-step Fourier method evolution."""
        # Ensure psi is complex
        psi = psi.astype(np.complex128)

        # Half-step potential evolution
        psi *= np.exp(-1j * potential * dt / (2 * self.hbar))

        # Full-step kinetic evolution in momentum space
        psi_k = fft2(psi)
        psi_k *= np.exp(-1j * self.kinetic_energy * dt / self.hbar)
        psi = ifft2(psi_k)

        # Second half-step potential evolution
        psi *= np.exp(-1j * potential * dt / (2 * self.hbar))

        return psi

    def expectation_values_2d(self, psi):
        """Calculate expectation values in 2D."""
        prob_density = np.abs(psi)**2
        norm = simps(simps(prob_density, self.y), self.x)

        if norm < 1e-12:
            return {'x_mean': 0, 'y_mean': 0, 'x_var': 0, 'y_var': 0, 'xy_corr': 0}

        prob_density /= norm

        # First moments
        x_mean = simps(simps(self.X * prob_density, self.y), self.x)
        y_mean = simps(simps(self.Y * prob_density, self.y), self.x)

        # Second moments
        x2_mean = simps(simps(self.X**2 * prob_density, self.y), self.x)
        y2_mean = simps(simps(self.Y**2 * prob_density, self.y), self.x)
        xy_mean = simps(simps(self.X * self.Y * prob_density, self.y), self.x)

        # Variances and correlation
        x_var = x2_mean - x_mean**2
        y_var = y2_mean - y_mean**2
        xy_corr = xy_mean - x_mean * y_mean

        return {
            'x_mean': x_mean,
            'y_mean': y_mean,
            'x_var': max(0, x_var),
            'y_var': max(0, y_var),
            'xy_corr': xy_corr
        }

    def quantum_potential_2d(self, psi):
        """Calculate 2D quantum potential Q(x,y) = -ℏ²/(2m) * ∇²|ψ|/|ψ|."""
        prob_density = np.abs(psi)**2
        prob_density = np.maximum(prob_density, 1e-12)  # Regularization

        amplitude = np.sqrt(prob_density)

        # Calculate second derivatives
        d2_amplitude_dx2 = np.gradient(np.gradient(amplitude, self.dx, axis=1), self.dx, axis=1)
        d2_amplitude_dy2 = np.gradient(np.gradient(amplitude, self.dy, axis=0), self.dy, axis=0)

        # Quantum potential
        laplacian_amplitude = d2_amplitude_dx2 + d2_amplitude_dy2
        Q = -(self.hbar**2 / (2 * self.mass)) * (laplacian_amplitude / amplitude)

        # Set to zero where probability density is negligible
        Q[prob_density < 1e-10] = 0.0

        return Q

    def compute_2d_entropy(self, psi):
        """Compute Shannon entropy for 2D probability distribution."""
        prob_density = np.abs(psi)**2
        prob_density /= np.sum(prob_density)  # Normalize

        # Calculate entropy with regularization
        entropy = -np.sum(prob_density * np.log2(prob_density + 1e-12))
        return entropy

    def apply_2d_correction(self, psi, x_classical, y_classical, lambda_x, lambda_y=None):
        """
        Apply 2D correction mechanism with potentially anisotropic strength.

        V_corr(x,y) = -λ_x(x - ⟨x⟩)(⟨x⟩ - x_class) - λ_y(y - ⟨y⟩)(⟨y⟩ - y_class)
        """
        if lambda_y is None:
            lambda_y = lambda_x

        expectations = self.expectation_values_2d(psi)
        x_quantum = expectations['x_mean']
        y_quantum = expectations['y_mean']

        # Correction strengths
        correction_strength_x = -lambda_x * (x_quantum - x_classical)
        correction_strength_y = -lambda_y * (y_quantum - y_classical)

        # Correction potential
        V_corr = (correction_strength_x * (self.X - x_quantum) +
                  correction_strength_y * (self.Y - y_quantum))

        return V_corr, (correction_strength_x, correction_strength_y)

    def classical_trajectory_2d(self, x0, y0, vx0, vy0, t_final, dt, potential_func):
        """Calculate 2D classical trajectory."""
        x, y = x0, y0
        vx, vy = vx0, vy0

        trajectory = [(x, y, vx, vy)]
        n_steps = int(t_final / dt)

        for _ in range(n_steps):
            # Calculate forces from potential gradient
            V = potential_func()
            grad_V_x = np.gradient(V, self.dx, axis=1)
            grad_V_y = np.gradient(V, self.dy, axis=0)

            # Interpolate forces at current position
            fx = -np.interp(x, self.x, np.interp(y, self.y, grad_V_x))
            fy = -np.interp(y, self.y, np.interp(x, self.x, grad_V_y.T))

            # Update velocities and positions
            vx += (fx / self.mass) * dt
            vy += (fy / self.mass) * dt
            x += vx * dt
            y += vy * dt

            trajectory.append((x, y, vx, vy))

        return np.array(trajectory)

    def tunneling_analysis_2d(self, psi):
        """Analyze 2D tunneling between potential wells."""
        prob_density = np.abs(psi)**2

        # Define quadrants around potential wells
        left_mask = self.X < 0
        right_mask = self.X > 0
        bottom_mask = self.Y < 0
        top_mask = self.Y > 0

        # Calculate probabilities in each quadrant
        prob_ll = simps(simps(prob_density * (left_mask & bottom_mask), self.y), self.x)   # Left-bottom
        prob_lr = simps(simps(prob_density * (left_mask & top_mask), self.y), self.x)      # Left-top
        prob_rl = simps(simps(prob_density * (right_mask & bottom_mask), self.y), self.x)  # Right-bottom
        prob_rr = simps(simps(prob_density * (right_mask & top_mask), self.y), self.x)     # Right-top

        # Tunneling measures
        x_tunneling = min(np.sum([prob_ll, prob_lr]), np.sum([prob_rl, prob_rr]))
        y_tunneling = min(np.sum([prob_ll, prob_rl]), np.sum([prob_lr, prob_rr]))

        return {
            'quadrant_probabilities': {
                'left_bottom': prob_ll,
                'left_top': prob_lr,
                'right_bottom': prob_rl,
                'right_top': prob_rr
            },
            'x_tunneling_probability': x_tunneling,
            'y_tunneling_probability': y_tunneling
        }

    def visualize_2d_state(self, psi, title="2D Quantum State", quantum_potential=None):
        """Create comprehensive 2D visualization."""
        fig = plt.figure(figsize=(15, 12))

        # Probability density
        ax1 = fig.add_subplot(2, 3, 1)
        prob_density = np.abs(psi)**2
        im1 = ax1.imshow(prob_density, extent=[self.x[0], self.x[-1], self.y[0], self.y[-1]],
                        origin='lower', cmap='viridis', aspect='equal')
        ax1.set_title('Probability Density |ψ|²')
        ax1.set_xlabel('x')
        ax1.set_ylabel('y')
        plt.colorbar(im1, ax=ax1)

        # Phase
        ax2 = fig.add_subplot(2, 3, 2)
        phase = np.angle(psi)
        im2 = ax2.imshow(phase, extent=[self.x[0], self.x[-1], self.y[0], self.y[-1]],
                        origin='lower', cmap='hsv', aspect='equal')
        ax2.set_title('Phase arg(ψ)')
        ax2.set_xlabel('x')
        ax2.set_ylabel('y')
        plt.colorbar(im2, ax=ax2)

        # 3D probability surface
        ax3 = fig.add_subplot(2, 3, 3, projection='3d')
        surf = ax3.plot_surface(self.X, self.Y, prob_density, cmap='viridis', alpha=0.8)
        ax3.set_title('3D Probability Surface')
        ax3.set_xlabel('x')
        ax3.set_ylabel('y')
        ax3.set_zlabel('|ψ|²')

        # Potential
        ax4 = fig.add_subplot(2, 3, 4)
        V = self.double_well_potential_2d()
        im4 = ax4.contour(self.X, self.Y, V, levels=20, colors='black', alpha=0.6)
        ax4.contourf(self.X, self.Y, V, levels=20, cmap='RdYlBu_r', alpha=0.3)
        ax4.set_title('Double-Well Potential')
        ax4.set_xlabel('x')
        ax4.set_ylabel('y')

        # Marginal distributions
        ax5 = fig.add_subplot(2, 3, 5)
        marginal_x = simps(prob_density, self.y, axis=0)
        marginal_y = simps(prob_density, self.x, axis=1)
        ax5.plot(self.x, marginal_x, 'b-', label='P(x)', linewidth=2)
        ax5.plot(self.y, marginal_y, 'r-', label='P(y)', linewidth=2)
        ax5.set_title('Marginal Distributions')
        ax5.set_xlabel('Position')
        ax5.set_ylabel('Probability')
        ax5.legend()
        ax5.grid(True, alpha=0.3)

        # Quantum potential (if provided)
        if quantum_potential is not None:
            ax6 = fig.add_subplot(2, 3, 6)
            im6 = ax6.imshow(quantum_potential, extent=[self.x[0], self.x[-1], self.y[0], self.y[-1]],
                           origin='lower', cmap='seismic', aspect='equal')
            ax6.set_title('Quantum Potential Q(x,y)')
            ax6.set_xlabel('x')
            ax6.set_ylabel('y')
            plt.colorbar(im6, ax=ax6)

        plt.suptitle(title, fontsize=16)
        plt.tight_layout()
        return fig

    def normalize_wavefunction_2d(self, psi):
        """Normalize 2D wavefunction."""
        norm = np.sqrt(simps(simps(np.abs(psi)**2, self.y), self.x))
        return psi / norm if norm > 1e-12 else psi

    def run_2d_simulation(self, lambda_vals=[0.0, 0.05], dt=0.01, n_steps=500,
                         initial_state_type='four_gaussian', anisotropic=False):
        """Run complete 2D simulation with specified parameters."""
        results = {}

        for lambda_val in lambda_vals:
            print(f"Running 2D simulation with λ = {lambda_val}")

            # Initialize state
            psi = self.create_initial_state_2d(initial_state_type)
            psi_initial = psi.copy()

            # Storage arrays
            diagnostics = []
            expectation_history = []

            # Time evolution
            for step in range(n_steps):
                current_time = step * dt

                # Potential
                V = self.double_well_potential_2d()

                # Quantum evolution
                psi = self.split_step_evolution_2d(psi, V, dt)

                # Apply corrections
                if lambda_val > 0:
                    # Simple classical reference at origin for now
                    x_classical, y_classical = 0.0, 0.0

                    # Anisotropic correction if enabled
                    if anisotropic:
                        lambda_x, lambda_y = lambda_val, lambda_val * 0.5
                    else:
                        lambda_x, lambda_y = lambda_val, lambda_val

                    V_corr, _ = self.apply_2d_correction(psi, x_classical, y_classical, lambda_x, lambda_y)
                    psi *= np.exp(-1j * V_corr * dt / self.hbar)

                # Renormalize
                psi = self.normalize_wavefunction_2d(psi)

                # Collect diagnostics
                if step % 10 == 0:  # Sample every 10 steps
                    expectations = self.expectation_values_2d(psi)
                    tunneling_data = self.tunneling_analysis_2d(psi)
                    entropy = self.compute_2d_entropy(psi)

                    diag = {
                        'time': current_time,
                        'entropy_2d': entropy,
                        'x_mean': expectations['x_mean'],
                        'y_mean': expectations['y_mean'],
                        'x_uncertainty': np.sqrt(expectations['x_var']),
                        'y_uncertainty': np.sqrt(expectations['y_var']),
                        'xy_correlation': expectations['xy_corr'],
                        'x_tunneling': tunneling_data['x_tunneling_probability'],
                        'y_tunneling': tunneling_data['y_tunneling_probability']
                    }

                    diagnostics.append(diag)
                    expectation_history.append(expectations)

            # Store results
            results[lambda_val] = {
                'final_state': psi,
                'diagnostics': diagnostics,
                'expectation_history': expectation_history
            }

        return results