"""
Environmental Coupling and Decoherence Effects
==============================================

Realistic environmental coupling mechanisms that drive quantum-to-classical
transitions through decoherence, dissipation, and noise effects.
"""

import numpy as np
import warnings
from typing import Dict, List, Tuple, Optional, Callable, Any
from collections import deque
from scipy.integrate import simpson as simps
from scipy.special import erf
from scipy.fft import fft, ifft, fftfreq

class EnvironmentalCoupling:
    """Comprehensive environmental coupling for quantum-classical transitions."""

    def __init__(self, x_grid: np.ndarray, dt: float, hbar: float = 1.0, mass: float = 1.0):
        """Initialize environmental coupling system."""
        self.x = x_grid
        self.dx = x_grid[1] - x_grid[0]
        self.dt = dt
        self.hbar = hbar
        self.mass = mass

        # Environment parameters
        self.temperature = 1.0  # Environmental temperature
        self.kB = 1.0  # Boltzmann constant (in natural units)

        # Noise generators
        self.noise_history = deque(maxlen=1000)
        self.correlation_time = 10.0  # Environmental correlation time

        # Decoherence parameters
        self.position_decoherence_rate = 0.01
        self.momentum_decoherence_rate = 0.01
        self.energy_decoherence_rate = 0.005

        # Dissipation parameters
        self.damping_coefficient = 0.1
        self.friction_coefficient = 0.05

        # Random seed for reproducible noise
        self.noise_seed = 42
        np.random.seed(self.noise_seed)

    def thermal_decoherence(self, psi: np.ndarray, temperature: float = None) -> np.ndarray:
        """
        Apply thermal decoherence effects that drive quantum-classical transition.
        """
        try:
            if temperature is None:
                temperature = self.temperature

            # Thermal de Broglie wavelength
            lambda_thermal = np.sqrt(2 * np.pi * self.hbar**2 / (self.mass * self.kB * temperature))

            # Position-dependent decoherence
            prob_density = np.abs(psi)**2

            # Thermal decoherence reduces off-diagonal coherences
            # Apply phase damping based on thermal length scale
            decoherence_rate = 1.0 / lambda_thermal

            # Create decoherence operator
            x_center = simps(self.x * prob_density, self.x)
            thermal_damping = np.exp(-decoherence_rate * (self.x - x_center)**2 * self.dt)

            # Apply thermal damping to wavefunction
            psi_decohered = psi * thermal_damping

            # Add thermal fluctuations
            thermal_noise_amplitude = np.sqrt(self.kB * temperature * self.dt / self.hbar)
            thermal_noise = np.random.normal(0, thermal_noise_amplitude, len(self.x))

            # Apply thermal noise as phase fluctuations
            thermal_phase = thermal_noise * self.dt / self.hbar
            psi_decohered *= np.exp(1j * thermal_phase)

            return psi_decohered

        except Exception as e:
            warnings.warn(f"Thermal decoherence failed: {e}")
            return psi

    def position_momentum_decoherence(self, psi: np.ndarray) -> np.ndarray:
        """
        Apply Lindblad-type decoherence in position and momentum.
        """
        try:
            # Convert to density matrix
            rho = np.outer(psi, np.conj(psi))

            # Position decoherence: L_x = sqrt(γ_x) * x
            gamma_x = self.position_decoherence_rate
            L_x = np.sqrt(gamma_x) * np.diag(self.x)

            # Momentum decoherence: L_p = sqrt(γ_p) * p
            gamma_p = self.momentum_decoherence_rate
            # Momentum operator in position representation
            L_p = np.zeros((len(self.x), len(self.x)), dtype=complex)
            for i in range(len(self.x)):
                if i > 0:
                    L_p[i, i-1] = -1j * self.hbar * np.sqrt(gamma_p) / (2 * self.dx)
                if i < len(self.x) - 1:
                    L_p[i, i+1] = 1j * self.hbar * np.sqrt(gamma_p) / (2 * self.dx)

            # Apply Lindblad evolution
            # dρ/dt = Σ_k (L_k ρ L_k† - 1/2{L_k†L_k, ρ})

            # Position decoherence term
            L_x_dag = L_x.conj().T
            rho_new = rho + self.dt * (
                L_x @ rho @ L_x_dag - 0.5 * (L_x_dag @ L_x @ rho + rho @ L_x_dag @ L_x)
            )

            # Momentum decoherence term
            L_p_dag = L_p.conj().T
            rho_new = rho_new + self.dt * (
                L_p @ rho_new @ L_p_dag - 0.5 * (L_p_dag @ L_p @ rho_new + rho_new @ L_p_dag @ L_p)
            )

            # Extract dominant eigenstate
            eigenvals, eigenvecs = np.linalg.eigh(rho_new)
            max_idx = np.argmax(np.real(eigenvals))
            psi_decohered = eigenvecs[:, max_idx]

            # Ensure proper normalization
            norm = np.sqrt(simps(np.abs(psi_decohered)**2, self.x))
            if norm > 1e-12:
                psi_decohered = psi_decohered / norm

            return psi_decohered

        except Exception as e:
            warnings.warn(f"Position-momentum decoherence failed: {e}")
            return psi

    def environmental_noise_coupling(self, psi: np.ndarray, noise_type: str = 'white') -> np.ndarray:
        """
        Apply various types of environmental noise coupling.
        """
        try:
            if noise_type == 'white':
                return self._white_noise_coupling(psi)
            elif noise_type == 'colored':
                return self._colored_noise_coupling(psi)
            elif noise_type == 'telegraph':
                return self._telegraph_noise_coupling(psi)
            elif noise_type == 'thermal':
                return self._thermal_noise_coupling(psi)
            else:
                warnings.warn(f"Unknown noise type: {noise_type}")
                return psi

        except Exception as e:
            warnings.warn(f"Environmental noise coupling failed: {e}")
            return psi

    def _white_noise_coupling(self, psi: np.ndarray) -> np.ndarray:
        """White noise environmental coupling."""
        # White noise in position
        noise_strength = 0.1
        white_noise = np.random.normal(0, noise_strength, len(self.x))

        # Apply as random potential fluctuations
        noise_potential = white_noise * np.sqrt(self.dt)
        phase_fluctuation = -1j * noise_potential * self.dt / self.hbar

        return psi * np.exp(phase_fluctuation)

    def _colored_noise_coupling(self, psi: np.ndarray) -> np.ndarray:
        """Colored (1/f-type) noise environmental coupling."""
        # Generate colored noise with memory
        if len(self.noise_history) == 0:
            # Initialize noise history
            for _ in range(50):
                self.noise_history.append(np.random.normal(0, 0.1, len(self.x)))

        # Generate new noise correlated with history
        correlation_factor = np.exp(-self.dt / self.correlation_time)
        new_noise = (correlation_factor * self.noise_history[-1] +
                    np.sqrt(1 - correlation_factor**2) * np.random.normal(0, 0.1, len(self.x)))

        self.noise_history.append(new_noise)

        # Apply colored noise
        phase_fluctuation = -1j * new_noise * self.dt / self.hbar
        return psi * np.exp(phase_fluctuation)

    def _telegraph_noise_coupling(self, psi: np.ndarray) -> np.ndarray:
        """Telegraph (random switching) noise coupling."""
        # Random telegraph noise switches between two values
        switch_probability = 0.1 * self.dt
        current_noise = getattr(self, '_telegraph_state', np.ones(len(self.x)))

        # Random switches
        switches = np.random.random(len(self.x)) < switch_probability
        current_noise[switches] *= -1

        self._telegraph_state = current_noise

        # Apply telegraph noise
        noise_amplitude = 0.2
        noise_potential = noise_amplitude * current_noise
        phase_fluctuation = -1j * noise_potential * self.dt / self.hbar

        return psi * np.exp(phase_fluctuation)

    def _thermal_noise_coupling(self, psi: np.ndarray) -> np.ndarray:
        """Thermal noise with proper temperature scaling."""
        # Thermal noise amplitude scales with sqrt(temperature)
        thermal_amplitude = np.sqrt(self.kB * self.temperature)
        thermal_noise = np.random.normal(0, thermal_amplitude, len(self.x))

        # Apply with proper thermal scaling
        phase_fluctuation = -1j * thermal_noise * np.sqrt(self.dt) / self.hbar
        return psi * np.exp(phase_fluctuation)

    def dissipative_coupling(self, psi: np.ndarray) -> np.ndarray:
        """
        Apply dissipative coupling that drives the system toward classical states.
        """
        try:
            # Calculate current momentum distribution
            psi_k = fft(psi)
            k = fftfreq(len(psi), d=self.dx) * 2 * np.pi

            # Apply momentum damping (friction-like effect)
            damping_factor = np.exp(-self.damping_coefficient * np.abs(k) * self.dt)
            psi_k_damped = psi_k * damping_factor

            # Transform back to position
            psi_damped = ifft(psi_k_damped)

            # Apply position-dependent friction
            prob_density = np.abs(psi_damped)**2
            x_mean = simps(self.x * prob_density, self.x)

            # Friction force toward equilibrium position
            friction_potential = 0.5 * self.friction_coefficient * (self.x - x_mean)**2
            friction_phase = -1j * friction_potential * self.dt / self.hbar

            psi_final = psi_damped * np.exp(friction_phase)

            return psi_final

        except Exception as e:
            warnings.warn(f"Dissipative coupling failed: {e}")
            return psi

    def measurement_induced_decoherence(self, psi: np.ndarray, measurement_strength: float = 0.1) -> np.ndarray:
        """
        Simulate continuous weak measurement that induces decoherence.
        """
        try:
            # Continuous weak position measurement
            prob_density = np.abs(psi)**2
            x_measured = simps(self.x * prob_density, self.x)

            # Add measurement noise
            measurement_noise = np.random.normal(0, 1.0 / np.sqrt(measurement_strength))
            x_measured_noisy = x_measured + measurement_noise

            # Measurement back-action: localize wavefunction around measured value
            localization_width = 1.0 / np.sqrt(measurement_strength)
            measurement_operator = np.exp(-(self.x - x_measured_noisy)**2 / (2 * localization_width**2))

            # Apply measurement back-action
            psi_measured = psi * np.sqrt(measurement_operator)

            # Renormalize
            norm = np.sqrt(simps(np.abs(psi_measured)**2, self.x))
            if norm > 1e-12:
                psi_measured = psi_measured / norm

            return psi_measured

        except Exception as e:
            warnings.warn(f"Measurement-induced decoherence failed: {e}")
            return psi

    def spatial_correlation_decoherence(self, psi: np.ndarray, correlation_length: float = 2.0) -> np.ndarray:
        """
        Apply spatially correlated decoherence that preserves local coherence.
        """
        try:
            # Generate spatially correlated noise
            # Use Gaussian process for spatial correlations
            spatial_noise = self._generate_correlated_noise(correlation_length)

            # Apply spatially varying decoherence
            local_decoherence_rate = 0.05 * (1.0 + spatial_noise)

            # Phase damping with spatial variation
            prob_density = np.abs(psi)**2
            x_center = simps(self.x * prob_density, self.x)

            spatial_damping = np.exp(-local_decoherence_rate * np.abs(self.x - x_center) * self.dt)
            psi_decohered = psi * spatial_damping

            return psi_decohered

        except Exception as e:
            warnings.warn(f"Spatial correlation decoherence failed: {e}")
            return psi

    def _generate_correlated_noise(self, correlation_length: float) -> np.ndarray:
        """Generate spatially correlated Gaussian noise."""
        try:
            # Generate white noise
            white_noise = np.random.normal(0, 1, len(self.x))

            # Apply spatial filtering for correlations
            # Gaussian kernel for spatial correlations
            x_indices = np.arange(len(self.x))
            kernel = np.exp(-0.5 * (x_indices - len(self.x)//2)**2 / (correlation_length/self.dx)**2)
            kernel = kernel / np.sum(kernel)

            # Convolve with kernel to create spatial correlations
            correlated_noise = np.convolve(white_noise, kernel, mode='same')

            return correlated_noise

        except Exception:
            return np.random.normal(0, 1, len(self.x))

    def environment_driven_classical_force(self, psi: np.ndarray, target_classical_state: float = 0.0) -> np.ndarray:
        """
        Environmental force that drives the system toward a classical state.
        """
        try:
            prob_density = np.abs(psi)**2
            x_quantum = simps(self.x * prob_density, self.x)

            # Environmental force toward classical target
            environmental_force_strength = 0.2
            deviation = x_quantum - target_classical_state

            # Create attractive potential toward classical state
            attractive_potential = 0.5 * environmental_force_strength * (self.x - target_classical_state)**2

            # Additional localization force
            x_spread = np.sqrt(simps((self.x - x_quantum)**2 * prob_density, self.x))
            localization_strength = 0.1 * x_spread  # Stronger for more delocalized states

            localization_potential = localization_strength * (self.x - x_quantum)**2

            # Total environmental potential
            total_env_potential = attractive_potential + localization_potential

            # Apply environmental potential
            env_phase = -1j * total_env_potential * self.dt / self.hbar
            psi_env = psi * np.exp(env_phase)

            return psi_env

        except Exception as e:
            warnings.warn(f"Environment-driven classical force failed: {e}")
            return psi

    def comprehensive_environmental_evolution(self, psi: np.ndarray, env_config: Dict[str, Any]) -> np.ndarray:
        """
        Apply comprehensive environmental evolution with multiple effects.
        """
        try:
            psi_evolved = psi.copy()

            # Apply environmental effects in sequence
            if env_config.get('thermal_decoherence', False):
                temperature = env_config.get('temperature', self.temperature)
                psi_evolved = self.thermal_decoherence(psi_evolved, temperature)

            if env_config.get('position_momentum_decoherence', False):
                psi_evolved = self.position_momentum_decoherence(psi_evolved)

            if env_config.get('environmental_noise', False):
                noise_type = env_config.get('noise_type', 'white')
                psi_evolved = self.environmental_noise_coupling(psi_evolved, noise_type)

            if env_config.get('dissipative_coupling', False):
                psi_evolved = self.dissipative_coupling(psi_evolved)

            if env_config.get('measurement_decoherence', False):
                measurement_strength = env_config.get('measurement_strength', 0.1)
                psi_evolved = self.measurement_induced_decoherence(psi_evolved, measurement_strength)

            if env_config.get('spatial_correlation_decoherence', False):
                correlation_length = env_config.get('correlation_length', 2.0)
                psi_evolved = self.spatial_correlation_decoherence(psi_evolved, correlation_length)

            if env_config.get('classical_force', False):
                target_state = env_config.get('target_classical_state', 0.0)
                psi_evolved = self.environment_driven_classical_force(psi_evolved, target_state)

            # Final normalization
            norm = np.sqrt(simps(np.abs(psi_evolved)**2, self.x))
            if norm > 1e-12:
                psi_evolved = psi_evolved / norm

            return psi_evolved

        except Exception as e:
            warnings.warn(f"Comprehensive environmental evolution failed: {e}")
            return psi

    def get_environmental_diagnostics(self) -> Dict[str, Any]:
        """Get diagnostic information about environmental effects."""
        diagnostics = {}

        try:
            diagnostics['temperature'] = self.temperature
            diagnostics['position_decoherence_rate'] = self.position_decoherence_rate
            diagnostics['momentum_decoherence_rate'] = self.momentum_decoherence_rate
            diagnostics['damping_coefficient'] = self.damping_coefficient
            diagnostics['correlation_time'] = self.correlation_time
            diagnostics['noise_history_length'] = len(self.noise_history)

            if len(self.noise_history) > 0:
                recent_noise = self.noise_history[-1]
                diagnostics['current_noise_amplitude'] = np.std(recent_noise)
                diagnostics['current_noise_mean'] = np.mean(recent_noise)

        except Exception as e:
            warnings.warn(f"Environmental diagnostics calculation failed: {e}")

        return diagnostics

    def set_environmental_parameters(self, **kwargs):
        """Set environmental parameters."""
        if 'temperature' in kwargs:
            self.temperature = kwargs['temperature']

        if 'position_decoherence_rate' in kwargs:
            self.position_decoherence_rate = kwargs['position_decoherence_rate']

        if 'momentum_decoherence_rate' in kwargs:
            self.momentum_decoherence_rate = kwargs['momentum_decoherence_rate']

        if 'damping_coefficient' in kwargs:
            self.damping_coefficient = kwargs['damping_coefficient']

        if 'correlation_time' in kwargs:
            self.correlation_time = kwargs['correlation_time']

    def create_strong_decoherence_profile(self) -> Dict[str, Any]:
        """Create a configuration profile for strong decoherence effects."""
        return {
            'thermal_decoherence': True,
            'temperature': 2.0,  # High temperature
            'position_momentum_decoherence': True,
            'environmental_noise': True,
            'noise_type': 'colored',
            'dissipative_coupling': True,
            'measurement_decoherence': True,
            'measurement_strength': 0.2,  # Strong measurement
            'spatial_correlation_decoherence': True,
            'correlation_length': 1.0,  # Short correlation length
            'classical_force': True,
            'target_classical_state': 0.0
        }

    def create_moderate_decoherence_profile(self) -> Dict[str, Any]:
        """Create a configuration profile for moderate decoherence effects."""
        return {
            'thermal_decoherence': True,
            'temperature': 1.0,
            'position_momentum_decoherence': True,
            'environmental_noise': True,
            'noise_type': 'white',
            'dissipative_coupling': False,
            'measurement_decoherence': True,
            'measurement_strength': 0.05,
            'spatial_correlation_decoherence': False,
            'classical_force': True,
            'target_classical_state': 0.0
        }

    def create_minimal_decoherence_profile(self) -> Dict[str, Any]:
        """Create a configuration profile for minimal decoherence effects."""
        return {
            'thermal_decoherence': True,
            'temperature': 0.5,
            'position_momentum_decoherence': False,
            'environmental_noise': False,
            'dissipative_coupling': False,
            'measurement_decoherence': False,
            'spatial_correlation_decoherence': False,
            'classical_force': False
        }