"""Physics module for quantum-classical boundary emergence simulation.

This module contains core physics calculations including:
- Quantum potential computation (Bohmian mechanics)
- Lindblad master equation evolution
- Time evolution operators
- Classical trajectory calculations with enhanced dynamics
"""

import numpy as np
from scipy.integrate import simpson as simps
from scipy.fftpack import fft, ifft, fftfreq
import warnings

class QuantumPhysics:
    """Core physics calculations for quantum systems."""

    def __init__(self, grid_size=512, x_range=(-10, 10), mass=1.0, hbar=1.0):
        """Initialize physics parameters and grids."""
        try:
            # Validate inputs
            if grid_size <= 0 or not isinstance(grid_size, int):
                raise ValueError(f"grid_size must be a positive integer, got {grid_size}")
            if mass <= 0:
                raise ValueError(f"mass must be positive, got {mass}")
            if hbar <= 0:
                raise ValueError(f"hbar must be positive, got {hbar}")
            if len(x_range) != 2 or x_range[1] <= x_range[0]:
                raise ValueError(f"x_range must be (x_min, x_max) with x_max > x_min, got {x_range}")

            self.grid_size = grid_size
            self.mass = mass
            self.hbar = hbar

            # Spatial grid
            self.x = np.linspace(x_range[0], x_range[1], grid_size)
            self.dx = self.x[1] - self.x[0]

            # Validate grid spacing
            if self.dx <= 0:
                raise ValueError("Grid spacing dx must be positive")

            # Momentum grid for Fourier transforms
            self.k = fftfreq(grid_size, d=self.dx) * 2 * np.pi
            self.k_squared = (self.hbar * self.k)**2 / (2 * self.mass)

            # Check for NaN or infinite values
            if np.any(np.isnan(self.x)) or np.any(np.isinf(self.x)):
                raise ValueError("Invalid values in spatial grid")
            if np.any(np.isnan(self.k_squared)) or np.any(np.isinf(self.k_squared)):
                raise ValueError("Invalid values in momentum grid")

        except Exception as e:
            print(f"Error initializing QuantumPhysics: {e}")
            raise

    def double_well_potential(self, x=None):
        """Double-well potential V(x) = (1/4)x⁴ - 2x²."""
        if x is None:
            x = self.x
        return 0.25 * x**4 - 2 * x**2

    def extended_double_well_2d(self, x, y):
        """2D double-well potential V(x,y) = V(x) + V(y)."""
        return self.double_well_potential(x) + self.double_well_potential(y)

    def quantum_potential(self, psi):
        """
        Calculate quantum potential Q(x,t) = -ℏ²/(2m) * ∇²|Ψ|/|Ψ|.

        This is the key quantity in Bohmian mechanics that generates
        non-classical behavior and quantum forces.
        """
        try:
            # Validate input
            if psi is None or len(psi) == 0:
                raise ValueError("Wavefunction psi cannot be None or empty")
            if len(psi) != self.grid_size:
                raise ValueError(f"Wavefunction size {len(psi)} doesn't match grid size {self.grid_size}")

            prob_density = np.abs(psi)**2

            # Check for valid probability density
            if np.any(np.isnan(prob_density)) or np.any(np.isinf(prob_density)):
                warnings.warn("Invalid values in probability density, using fallback")
                return np.zeros_like(self.x)

            # Avoid division by zero with regularization
            prob_density_reg = np.maximum(prob_density, 1e-12)

            # Calculate second derivative of amplitude with error handling
            try:
                amplitude = np.sqrt(prob_density_reg)

                # Check for valid amplitude
                if np.any(np.isnan(amplitude)) or np.any(np.isinf(amplitude)):
                    warnings.warn("Invalid amplitude values, using fallback")
                    return np.zeros_like(self.x)

                # Compute gradients with error handling
                first_grad = np.gradient(amplitude, self.dx)
                d2_amplitude_dx2 = np.gradient(first_grad, self.dx)

                # Quantum potential with safe division
                Q = -(self.hbar**2 / (2 * self.mass)) * np.divide(
                    d2_amplitude_dx2, amplitude,
                    out=np.zeros_like(amplitude),
                    where=amplitude > 1e-12
                )

                # Set Q to zero where probability density is negligible
                Q[prob_density < 1e-10] = 0.0

                # Final safety check
                invalid_mask = np.isnan(Q) | np.isinf(Q)
                if np.any(invalid_mask):
                    Q[invalid_mask] = 0.0
                    warnings.warn(f"Set {np.sum(invalid_mask)} invalid Q values to zero")

                return Q

            except Exception as e:
                warnings.warn(f"Error computing quantum potential: {e}. Returning zeros.")
                return np.zeros_like(self.x)

        except Exception as e:
            warnings.warn(f"Critical error in quantum_potential: {e}. Returning zeros.")
            return np.zeros_like(self.x)

    def split_step_evolution(self, psi, potential, dt):
        """
        Evolve wavefunction using split-step Fourier method.

        U(dt) ≈ exp(-iV*dt/2ℏ) exp(-iT*dt/ℏ) exp(-iV*dt/2ℏ)
        """
        try:
            # Validate inputs
            if psi is None or len(psi) == 0:
                raise ValueError("Wavefunction psi cannot be None or empty")
            if potential is None or len(potential) == 0:
                raise ValueError("Potential cannot be None or empty")
            if len(psi) != len(potential):
                raise ValueError(f"Wavefunction size {len(psi)} doesn't match potential size {len(potential)}")
            if dt <= 0:
                raise ValueError(f"Time step dt must be positive, got {dt}")

            # Check for invalid values in inputs
            if np.any(np.isnan(psi)) or np.any(np.isinf(psi)):
                warnings.warn("Invalid values in input wavefunction")
                psi = np.nan_to_num(psi, nan=0.0, posinf=0.0, neginf=0.0)

            if np.any(np.isnan(potential)) or np.any(np.isinf(potential)):
                warnings.warn("Invalid values in potential")
                potential = np.nan_to_num(potential, nan=0.0, posinf=0.0, neginf=0.0)

            # Ensure psi is complex
            psi = psi.astype(np.complex128)

            # Calculate evolution factors with error handling
            try:
                potential_factor = -1j * potential * dt / (2 * self.hbar)
                kinetic_factor = -1j * self.k_squared * dt / self.hbar

                # Check for reasonable values
                if np.any(np.abs(potential_factor.imag) > 100):
                    warnings.warn("Very large potential evolution factors detected")
                if np.any(np.abs(kinetic_factor.imag) > 100):
                    warnings.warn("Very large kinetic evolution factors detected")

                # Half-step potential evolution
                psi *= np.exp(potential_factor)

                # Full-step kinetic evolution in momentum space
                try:
                    psi_k = fft(psi)
                    if np.any(np.isnan(psi_k)) or np.any(np.isinf(psi_k)):
                        warnings.warn("Invalid values after FFT")
                        psi_k = np.nan_to_num(psi_k, nan=0.0, posinf=0.0, neginf=0.0)

                    psi_k *= np.exp(kinetic_factor)
                    psi = ifft(psi_k)

                    if np.any(np.isnan(psi)) or np.any(np.isinf(psi)):
                        warnings.warn("Invalid values after IFFT")
                        psi = np.nan_to_num(psi, nan=0.0, posinf=0.0, neginf=0.0)

                except Exception as fft_error:
                    warnings.warn(f"FFT/IFFT error: {fft_error}. Skipping kinetic evolution.")

                # Second half-step potential evolution
                psi *= np.exp(potential_factor)

                # Final check and cleanup
                if np.any(np.isnan(psi)) or np.any(np.isinf(psi)):
                    warnings.warn("Invalid values in final wavefunction")
                    psi = np.nan_to_num(psi, nan=0.0, posinf=0.0, neginf=0.0)

                return psi

            except Exception as evolution_error:
                warnings.warn(f"Error during evolution: {evolution_error}. Returning input wavefunction.")
                return psi.astype(np.complex128)

        except Exception as e:
            warnings.warn(f"Critical error in split_step_evolution: {e}. Returning input wavefunction.")
            return psi.astype(np.complex128) if psi is not None else np.zeros(self.grid_size, dtype=np.complex128)

    def lindblad_evolution(self, rho, H, lindblad_ops, dt, gamma_rates=None):
        """
        Evolve density matrix using Lindblad master equation:
        dρ/dt = -i[H,ρ]/ℏ + Σ_k (L_k ρ L_k† - 1/2{L_k†L_k, ρ})

        Parameters:
        -----------
        rho : numpy.ndarray (complex)
            Density matrix
        H : numpy.ndarray (complex)
            Hamiltonian matrix
        lindblad_ops : list of numpy.ndarray
            Lindblad operators [L_x, L_p, ...]
        dt : float
            Time step
        gamma_rates : list of floats
            Decoherence rates for each operator
        """
        if gamma_rates is None:
            gamma_rates = [1.0] * len(lindblad_ops)

        # Ensure density matrix is complex
        rho = rho.astype(np.complex128)

        # Unitary evolution: -i[H,ρ]/ℏ
        commutator = -1j * (H @ rho - rho @ H) / self.hbar

        # Lindblad dissipator terms
        dissipator = np.zeros_like(rho, dtype=np.complex128)
        for L_k, gamma in zip(lindblad_ops, gamma_rates):
            L_dag = L_k.conj().T
            dissipator += gamma * (L_k @ rho @ L_dag - 0.5 * (L_dag @ L_k @ rho + rho @ L_dag @ L_k))

        # First-order evolution
        drho_dt = commutator + dissipator
        return rho + drho_dt * dt

    def create_position_operator(self):
        """Create position operator matrix x̂."""
        return np.diag(self.x)

    def create_momentum_operator(self):
        """Create momentum operator matrix p̂ = -iℏ∇."""
        # Use finite differences for momentum operator
        p_op = np.zeros((self.grid_size, self.grid_size), dtype=complex)
        for i in range(self.grid_size):
            if i > 0:
                p_op[i, i-1] = -1j * self.hbar / (2 * self.dx)
            if i < self.grid_size - 1:
                p_op[i, i+1] = 1j * self.hbar / (2 * self.dx)
        return p_op

    def create_lindblad_operators(self, gamma_x=0.01, gamma_p=0.01):
        """
        Create standard Lindblad operators for position and momentum decoherence:
        L_x = √(γ_x) * x̂
        L_p = √(γ_p) * p̂
        """
        x_op = self.create_position_operator()
        p_op = self.create_momentum_operator()

        L_x = np.sqrt(gamma_x) * x_op
        L_p = np.sqrt(gamma_p) * p_op

        return [L_x, L_p]

    def classical_trajectory_enhanced(self, x0, v0, t_final, dt, potential_func,
                                   external_force=None, damping=0.0):
        """
        Enhanced classical trajectory calculation with optional external forces and damping.

        Equation of motion: m*d²x/dt² = -dV/dx + F_ext - γ*dx/dt
        """
        x, v = x0, v0
        trajectory = [(x, v)]

        n_steps = int(t_final / dt)
        for step in range(n_steps):
            # Calculate potential force
            V_grad = np.interp(x, self.x, np.gradient(potential_func(self.x), self.dx))
            force_potential = -V_grad

            # External force (if any)
            force_external = 0.0
            if external_force is not None:
                force_external = external_force(x, v, step * dt)

            # Damping force
            force_damping = -damping * v

            # Total acceleration
            acceleration = (force_potential + force_external + force_damping) / self.mass

            # Velocity Verlet integration
            v += acceleration * dt
            x += v * dt

            trajectory.append((x, v))

        return np.array(trajectory)

    def expectation_value(self, psi, operator):
        """Calculate expectation value ⟨Ψ|Ô|Ψ⟩."""
        if operator.ndim == 1:  # Diagonal operator (like position)
            return simps(np.conj(psi) * operator * psi, self.x).real
        else:  # Full matrix operator
            return np.vdot(psi, operator @ psi).real

    def wavefunction_to_density_matrix(self, psi):
        """Convert pure state |ψ⟩ to density matrix ρ = |ψ⟩⟨ψ|."""
        return np.outer(psi, np.conj(psi))

    def density_matrix_to_wavefunction(self, rho, return_phases=False):
        """
        Extract dominant eigenstate from density matrix.
        For mixed states, returns the eigenstate with largest eigenvalue.
        """
        try:
            if rho is None:
                raise ValueError("Density matrix cannot be None")

            if rho.shape[0] != rho.shape[1]:
                raise ValueError(f"Density matrix must be square, got shape {rho.shape}")

            if rho.shape[0] != self.grid_size:
                raise ValueError(f"Density matrix size {rho.shape[0]} doesn't match grid size {self.grid_size}")

            # Check for invalid values
            if np.any(np.isnan(rho)) or np.any(np.isinf(rho)):
                warnings.warn("Invalid values in density matrix")
                rho = np.nan_to_num(rho, nan=0.0, posinf=0.0, neginf=0.0)

            try:
                eigenvals, eigenvecs = np.linalg.eigh(rho)

                # Check eigenvalue decomposition results
                if np.any(np.isnan(eigenvals)) or np.any(np.isinf(eigenvals)):
                    warnings.warn("Invalid eigenvalues, using fallback")
                    return self._fallback_wavefunction(return_phases)

                if np.any(np.isnan(eigenvecs)) or np.any(np.isinf(eigenvecs)):
                    warnings.warn("Invalid eigenvectors, using fallback")
                    return self._fallback_wavefunction(return_phases)

                max_idx = np.argmax(np.real(eigenvals))
                dominant_state = eigenvecs[:, max_idx]

                # Ensure dominant state is normalized
                try:
                    dominant_state = self.normalize_wavefunction(dominant_state)
                except Exception as norm_error:
                    warnings.warn(f"Normalization failed: {norm_error}")

                if return_phases:
                    return dominant_state, eigenvals[max_idx]
                return dominant_state

            except np.linalg.LinAlgError as linalg_error:
                warnings.warn(f"Linear algebra error: {linalg_error}. Using fallback.")
                return self._fallback_wavefunction(return_phases)

        except Exception as e:
            warnings.warn(f"Critical error in density_matrix_to_wavefunction: {e}. Using fallback.")
            return self._fallback_wavefunction(return_phases)

    def _fallback_wavefunction(self, return_phases=False):
        """Generate a fallback normalized Gaussian wavefunction."""
        try:
            gaussian = np.exp(-(self.x)**2)
            gaussian = self.normalize_wavefunction(gaussian)
            if return_phases:
                return gaussian, 1.0
            return gaussian
        except Exception:
            # Ultimate fallback
            fallback = np.zeros(self.grid_size, dtype=np.complex128)
            fallback[self.grid_size // 2] = 1.0
            if return_phases:
                return fallback, 1.0
            return fallback

    def normalize_wavefunction(self, psi):
        """Normalize wavefunction to unit probability."""
        try:
            if psi is None or len(psi) == 0:
                warnings.warn("Cannot normalize None or empty wavefunction")
                return np.zeros(self.grid_size, dtype=np.complex128)

            if len(psi) != self.grid_size:
                warnings.warn(f"Wavefunction size {len(psi)} doesn't match grid size {self.grid_size}")
                return psi

            # Check for invalid values
            if np.any(np.isnan(psi)) or np.any(np.isinf(psi)):
                warnings.warn("Invalid values in wavefunction before normalization")
                psi = np.nan_to_num(psi, nan=0.0, posinf=0.0, neginf=0.0)

            # Calculate norm using Simpson's rule with error handling
            try:
                prob_density = np.abs(psi)**2
                if np.any(np.isnan(prob_density)) or np.any(np.isinf(prob_density)):
                    warnings.warn("Invalid probability density")
                    prob_density = np.nan_to_num(prob_density, nan=0.0, posinf=0.0, neginf=0.0)

                norm = np.sqrt(simps(prob_density, self.x))

                if np.isnan(norm) or np.isinf(norm) or norm <= 1e-12:
                    warnings.warn(f"Invalid norm {norm}, cannot normalize. Returning unit Gaussian.")
                    # Return normalized Gaussian as fallback
                    gaussian = np.exp(-(self.x)**2)
                    gaussian_norm = np.sqrt(simps(np.abs(gaussian)**2, self.x))
                    return gaussian / gaussian_norm if gaussian_norm > 1e-12 else gaussian

                return psi / norm

            except Exception as integration_error:
                warnings.warn(f"Integration error during normalization: {integration_error}")
                # Fallback to simple normalization
                simple_norm = np.sqrt(np.sum(np.abs(psi)**2) * self.dx)
                if simple_norm > 1e-12:
                    return psi / simple_norm
                else:
                    warnings.warn("Cannot normalize with fallback method")
                    return psi

        except Exception as e:
            warnings.warn(f"Critical error in normalize_wavefunction: {e}. Returning input.")
            return psi if psi is not None else np.zeros(self.grid_size, dtype=np.complex128)

    def apply_measurement_collapse(self, psi, measurement_position, collapse_width=1.0):
        """
        Apply von Neumann measurement collapse around specified position.

        Projects wavefunction onto localized state around measurement_position
        with Gaussian profile of width collapse_width.
        """
        # Create measurement operator (Gaussian window)
        measurement_operator = np.exp(-(self.x - measurement_position)**2 / (2 * collapse_width**2))

        # Apply projection
        psi_collapsed = psi * measurement_operator

        # Renormalize
        return self.normalize_wavefunction(psi_collapsed)

    def tunneling_probability(self, psi, barrier_position=0.0):
        """
        Calculate tunneling probability across potential barrier.

        For double-well potential, calculates probability of finding particle
        on opposite side of the barrier.
        """
        if barrier_position == 0.0:  # Use central barrier for double-well
            left_region = self.x < 0
            right_region = self.x > 0
        else:
            left_region = self.x < barrier_position
            right_region = self.x > barrier_position

        prob_left = simps(np.abs(psi[left_region])**2, self.x[left_region])
        prob_right = simps(np.abs(psi[right_region])**2, self.x[right_region])

        # Tunneling probability as minimum of left/right probabilities
        return min(prob_left, prob_right)