"""Correction mechanism module for quantum-classical boundary emergence.

This module implements sophisticated correction mechanisms including:
- Time-delayed feedback corrections
- Adaptive correction strength based on quantum potential
- Predictive correction algorithms
- Correction efficiency optimization
- Multi-modal correction strategies
"""

import numpy as np
from scipy.integrate import simpson as simps
from collections import deque
import warnings

class CorrectionMechanism:
    """Advanced correction mechanisms for quantum-classical stabilization."""

    def __init__(self, x_grid, dt, memory_length=50):
        """
        Initialize correction mechanism.

        Parameters:
        -----------
        x_grid : numpy.ndarray
            Spatial grid points
        dt : float
            Time step size
        memory_length : int
            Number of time steps to store in memory for delayed feedback
        """
        try:
            # Validate inputs
            if x_grid is None or len(x_grid) == 0:
                raise ValueError("x_grid cannot be None or empty")
            if dt <= 0:
                raise ValueError(f"dt must be positive, got {dt}")
            if memory_length <= 0 or not isinstance(memory_length, int):
                raise ValueError(f"memory_length must be a positive integer, got {memory_length}")

            self.x = np.array(x_grid)
            if len(self.x) < 2:
                raise ValueError("x_grid must have at least 2 points")

            self.dx = self.x[1] - self.x[0]
            if self.dx <= 0:
                raise ValueError("x_grid must have positive spacing")

            self.dt = dt
            self.memory_length = memory_length

            # Memory buffers for time-delayed feedback
            self.position_history = deque(maxlen=memory_length)
            self.expectation_history = deque(maxlen=memory_length)
            self.classical_history = deque(maxlen=memory_length)
            self.correction_history = deque(maxlen=memory_length)

            # Adaptive parameters
            self.base_lambda = 0.01
            self.adaptive_gain = 1.0
            self.correction_efficiency_window = deque(maxlen=10)

        except Exception as e:
            print(f"Error initializing CorrectionMechanism: {e}")
            raise

    def basic_correction(self, psi, x_classical, lambda_strength):
        """
        Basic recursive correction: V_corr = -λ(x - ⟨x⟩)(⟨x⟩ - x_classical).

        This is the original correction mechanism from the base simulator.
        """
        try:
            # Validate inputs
            if psi is None or len(psi) == 0:
                warnings.warn("Invalid wavefunction in basic_correction")
                return np.zeros_like(self.x), 0.0

            if len(psi) != len(self.x):
                warnings.warn(f"Wavefunction size {len(psi)} doesn't match grid size {len(self.x)}")
                return np.zeros_like(self.x), 0.0

            if np.any(np.isnan(psi)) or np.any(np.isinf(psi)):
                warnings.warn("Invalid values in wavefunction for basic_correction")
                psi = np.nan_to_num(psi, nan=0.0, posinf=0.0, neginf=0.0)

            if np.isnan(x_classical) or np.isinf(x_classical):
                warnings.warn(f"Invalid classical position {x_classical}")
                x_classical = 0.0

            if np.isnan(lambda_strength) or np.isinf(lambda_strength):
                warnings.warn(f"Invalid lambda_strength {lambda_strength}")
                lambda_strength = 0.0

            # Calculate quantum expectation value with error handling
            try:
                prob_density = np.abs(psi)**2
                if np.any(np.isnan(prob_density)) or np.any(np.isinf(prob_density)):
                    warnings.warn("Invalid probability density in basic_correction")
                    prob_density = np.nan_to_num(prob_density, nan=0.0, posinf=0.0, neginf=0.0)

                x_quantum = simps(self.x * prob_density, self.x)

                if np.isnan(x_quantum) or np.isinf(x_quantum):
                    warnings.warn("Invalid quantum expectation value")
                    x_quantum = 0.0

            except Exception as integration_error:
                warnings.warn(f"Integration error in basic_correction: {integration_error}")
                # Fallback to simple integration
                prob_density = np.abs(psi)**2
                prob_density = np.nan_to_num(prob_density, nan=0.0, posinf=0.0, neginf=0.0)
                x_quantum = np.sum(self.x * prob_density) * self.dx
                if np.isnan(x_quantum) or np.isinf(x_quantum):
                    x_quantum = 0.0

            # Correction strength proportional to deviation
            deviation = x_quantum - x_classical
            if np.isnan(deviation) or np.isinf(deviation):
                warnings.warn("Invalid deviation in basic_correction")
                deviation = 0.0

            correction_strength = -lambda_strength * deviation

            # Correction potential
            position_offset = self.x - x_quantum
            V_corr = correction_strength * position_offset

            # Check for invalid values in correction
            if np.any(np.isnan(V_corr)) or np.any(np.isinf(V_corr)):
                warnings.warn("Invalid values in correction potential")
                V_corr = np.nan_to_num(V_corr, nan=0.0, posinf=0.0, neginf=0.0)

            # Store in memory with validation
            safe_x_quantum = x_quantum if not (np.isnan(x_quantum) or np.isinf(x_quantum)) else 0.0
            safe_x_classical = x_classical if not (np.isnan(x_classical) or np.isinf(x_classical)) else 0.0
            safe_correction_strength = correction_strength if not (np.isnan(correction_strength) or np.isinf(correction_strength)) else 0.0

            self.expectation_history.append(safe_x_quantum)
            self.classical_history.append(safe_x_classical)
            self.correction_history.append(safe_correction_strength)

            return V_corr, correction_strength

        except Exception as e:
            warnings.warn(f"Critical error in basic_correction: {e}. Returning zeros.")
            return np.zeros_like(self.x), 0.0

    def time_delayed_correction(self, psi, x_classical, lambda_strength, delay_steps=5):
        """
        Time-delayed feedback correction with memory.

        V_corr(x,t) = -λ(x - ⟨x(t)⟩)(⟨x(t)⟩ - x_classical(t-τ))

        Models realistic observer memory and measurement delay effects.
        """
        # Current quantum expectation
        x_quantum_current = simps(self.x * np.abs(psi)**2, self.x)

        # Delayed classical reference (if available in memory)
        if len(self.classical_history) >= delay_steps:
            x_classical_delayed = self.classical_history[-delay_steps]
        else:
            x_classical_delayed = x_classical  # Use current if not enough history

        # Delayed feedback correction
        correction_strength = -lambda_strength * (x_quantum_current - x_classical_delayed)

        # Also incorporate momentum from historical trend
        if len(self.expectation_history) >= 2:
            momentum = (self.expectation_history[-1] - self.expectation_history[-2]) / self.dt
            correction_strength += -0.1 * lambda_strength * momentum  # Damping term

        # Correction potential
        V_corr = correction_strength * (self.x - x_quantum_current)

        # Update memory
        self.expectation_history.append(x_quantum_current)
        self.classical_history.append(x_classical)
        self.correction_history.append(correction_strength)

        return V_corr, correction_strength

    def adaptive_quantum_potential_correction(self, psi, x_classical, quantum_potential,
                                            base_lambda=None):
        """
        Adaptive correction with intensity modulated by quantum potential.

        In regions of high quantum potential Q(x), apply stronger corrections
        to suppress non-classical behavior more aggressively.
        """
        try:
            # Input validation
            if psi is None or quantum_potential is None:
                warnings.warn("Invalid inputs to adaptive_quantum_potential_correction")
                return np.zeros_like(self.x), 0.0

            if len(psi) != len(self.x) or len(quantum_potential) != len(self.x):
                warnings.warn("Size mismatch in adaptive_quantum_potential_correction")
                return np.zeros_like(self.x), 0.0

            # Clean inputs
            psi = np.nan_to_num(psi, nan=0.0, posinf=0.0, neginf=0.0)
            quantum_potential = np.nan_to_num(quantum_potential, nan=0.0, posinf=0.0, neginf=0.0)

            if base_lambda is None:
                base_lambda = self.base_lambda

            if np.isnan(base_lambda) or np.isinf(base_lambda):
                base_lambda = self.base_lambda

            try:
                # Calculate quantum expectation with error handling
                prob_density = np.abs(psi)**2
                prob_density = np.nan_to_num(prob_density, nan=0.0, posinf=0.0, neginf=0.0)
                x_quantum = simps(self.x * prob_density, self.x)

                if np.isnan(x_quantum) or np.isinf(x_quantum):
                    x_quantum = 0.0

            except Exception:
                x_quantum = 0.0

            # Base correction strength
            base_correction = -base_lambda * (x_quantum - x_classical)
            if np.isnan(base_correction) or np.isinf(base_correction):
                base_correction = 0.0

            try:
                # Adaptive gain based on local quantum potential
                Q_abs = np.abs(quantum_potential)
                Q_abs = np.nan_to_num(Q_abs, nan=0.0, posinf=0.0, neginf=0.0)
                Q_mean = simps(Q_abs * prob_density, self.x)

                if np.isnan(Q_mean) or np.isinf(Q_mean):
                    Q_mean = 0.0

                # Saturating function with safety checks
                if Q_mean > 0:
                    adaptive_factor = 1.0 + 0.5 * Q_mean / (1.0 + Q_mean)
                else:
                    adaptive_factor = 1.0

                if np.isnan(adaptive_factor) or np.isinf(adaptive_factor):
                    adaptive_factor = 1.0

            except Exception:
                adaptive_factor = 1.0

            try:
                # Position-dependent correction strength
                Q_weighted = Q_abs * prob_density
                Q_sum = np.sum(Q_weighted)
                if Q_sum > 1e-12:
                    Q_normalized = Q_weighted / Q_sum
                else:
                    Q_normalized = np.ones_like(Q_weighted) / len(Q_weighted)

                Q_normalized = np.nan_to_num(Q_normalized, nan=0.0, posinf=0.0, neginf=0.0)

                # Apply stronger correction in high-Q regions
                position_dependent_lambda = base_lambda * (1.0 + Q_normalized)
                position_offset = self.x - x_quantum
                correction_field = -base_correction * position_dependent_lambda * position_offset

                # Safety check on correction field
                correction_field = np.nan_to_num(correction_field, nan=0.0, posinf=0.0, neginf=0.0)

            except Exception as field_error:
                warnings.warn(f"Error computing correction field: {field_error}")
                correction_field = np.zeros_like(self.x)

            # Store adaptive gain for monitoring
            self.adaptive_gain = adaptive_factor

            # Store in memory with validation
            safe_x_quantum = x_quantum if not (np.isnan(x_quantum) or np.isinf(x_quantum)) else 0.0
            safe_x_classical = x_classical if not (np.isnan(x_classical) or np.isinf(x_classical)) else 0.0
            safe_base_correction = base_correction if not (np.isnan(base_correction) or np.isinf(base_correction)) else 0.0

            self.expectation_history.append(safe_x_quantum)
            self.classical_history.append(safe_x_classical)
            self.correction_history.append(safe_base_correction)

            return correction_field, base_correction * adaptive_factor

        except Exception as e:
            warnings.warn(f"Critical error in adaptive_quantum_potential_correction: {e}")
            return np.zeros_like(self.x), 0.0

    def predictive_correction(self, psi, x_classical, lambda_strength, prediction_steps=3):
        """
        Predictive correction based on extrapolating quantum dynamics.

        Predicts future quantum state evolution and applies correction
        to prevent deviation before it becomes large.
        """
        x_quantum = simps(self.x * np.abs(psi)**2, self.x)

        # Predict future quantum position based on recent history
        if len(self.expectation_history) >= prediction_steps:
            recent_positions = list(self.expectation_history)[-prediction_steps:]
            # Linear extrapolation
            if len(recent_positions) >= 2:
                velocity = (recent_positions[-1] - recent_positions[-2]) / self.dt
                acceleration = 0.0
                if len(recent_positions) >= 3:
                    v1 = (recent_positions[-1] - recent_positions[-2]) / self.dt
                    v2 = (recent_positions[-2] - recent_positions[-3]) / self.dt
                    acceleration = (v1 - v2) / self.dt

                # Predict future position
                x_predicted = x_quantum + velocity * self.dt + 0.5 * acceleration * self.dt**2
            else:
                x_predicted = x_quantum
        else:
            x_predicted = x_quantum

        # Predict future classical position (assuming continuation of current dynamics)
        if len(self.classical_history) >= 2:
            classical_velocity = (self.classical_history[-1] - self.classical_history[-2]) / self.dt
            x_classical_predicted = x_classical + classical_velocity * self.dt
        else:
            x_classical_predicted = x_classical

        # Predictive correction
        prediction_error = x_predicted - x_classical_predicted
        correction_strength = -lambda_strength * prediction_error

        # Combine with current-time correction
        current_error = x_quantum - x_classical
        total_correction = -lambda_strength * (0.7 * current_error + 0.3 * prediction_error)

        V_corr = total_correction * (self.x - x_quantum)

        self.expectation_history.append(x_quantum)
        self.classical_history.append(x_classical)
        self.correction_history.append(total_correction)

        return V_corr, total_correction

    def multi_scale_correction(self, psi, x_classical, lambda_strength, scales=[1.0, 0.5, 2.0]):
        """
        Multi-scale correction operating on different length scales simultaneously.

        Applies corrections at multiple spatial scales to address both
        local and global quantum deviations.
        """
        x_quantum = simps(self.x * np.abs(psi)**2, self.x)
        total_correction = np.zeros_like(self.x)

        base_correction_strength = -lambda_strength * (x_quantum - x_classical)

        for scale in scales:
            # Create scaled correction profile
            scaled_width = scale
            correction_profile = np.exp(-(self.x - x_quantum)**2 / (2 * scaled_width**2))
            correction_profile /= np.max(correction_profile)  # Normalize

            # Apply scale-dependent correction
            scale_strength = base_correction_strength / len(scales)
            total_correction += scale_strength * correction_profile * (self.x - x_quantum)

        self.expectation_history.append(x_quantum)
        self.classical_history.append(x_classical)
        self.correction_history.append(base_correction_strength)

        return total_correction, base_correction_strength

    def efficiency_optimized_correction(self, psi, x_classical, base_lambda, target_efficiency=0.8):
        """
        Dynamically adjust correction strength to optimize efficiency.

        Monitors correction efficiency and adjusts λ to maintain
        optimal performance while minimizing energy input.
        """
        x_quantum = simps(self.x * np.abs(psi)**2, self.x)
        deviation_before = abs(x_quantum - x_classical)

        # Calculate recent efficiency if we have enough history
        current_lambda = base_lambda
        if len(self.correction_efficiency_window) > 0:
            avg_efficiency = np.mean(list(self.correction_efficiency_window))

            # Adjust lambda based on efficiency
            if avg_efficiency < target_efficiency * 0.8:  # Too low efficiency
                current_lambda *= 1.05  # Increase correction strength
            elif avg_efficiency > target_efficiency * 1.2:  # Too high efficiency (overcompensating)
                current_lambda *= 0.95  # Decrease correction strength

            # Clamp lambda to reasonable bounds
            current_lambda = np.clip(current_lambda, base_lambda * 0.1, base_lambda * 5.0)

        # Apply correction
        correction_strength = -current_lambda * (x_quantum - x_classical)
        V_corr = correction_strength * (self.x - x_quantum)

        # Calculate efficiency for this step (will be used next iteration)
        deviation_after = abs(x_quantum - x_classical)  # This will be updated after applying correction
        if abs(correction_strength) > 1e-12:
            efficiency = max(0, deviation_before - deviation_after) / abs(correction_strength)
            self.correction_efficiency_window.append(efficiency)

        self.expectation_history.append(x_quantum)
        self.classical_history.append(x_classical)
        self.correction_history.append(correction_strength)

        return V_corr, correction_strength

    def nonlinear_correction(self, psi, x_classical, lambda_strength, nonlinearity=2.0):
        """
        Nonlinear correction mechanism with power-law dependence on deviation.

        V_corr ∝ sign(deviation) * |deviation|^α

        Provides stronger response for large deviations while being
        gentler for small fluctuations.
        """
        x_quantum = simps(self.x * np.abs(psi)**2, self.x)
        deviation = x_quantum - x_classical

        # Nonlinear correction strength
        if abs(deviation) > 1e-12:
            nonlinear_factor = np.sign(deviation) * (abs(deviation)**nonlinearity)
        else:
            nonlinear_factor = 0.0

        correction_strength = -lambda_strength * nonlinear_factor

        V_corr = correction_strength * (self.x - x_quantum)

        self.expectation_history.append(x_quantum)
        self.classical_history.append(x_classical)
        self.correction_history.append(correction_strength)

        return V_corr, correction_strength

    def momentum_aware_correction(self, psi, x_classical, v_classical, lambda_strength, mass=1.0):
        """
        Correction that accounts for both position and momentum deviations.

        Incorporates classical velocity to anticipate future motion
        and apply more sophisticated corrections.
        """
        x_quantum = simps(self.x * np.abs(psi)**2, self.x)

        # Estimate quantum velocity from position history
        if len(self.expectation_history) >= 2:
            v_quantum = (x_quantum - self.expectation_history[-1]) / self.dt
        else:
            v_quantum = 0.0

        # Position and momentum corrections
        position_error = x_quantum - x_classical
        momentum_error = mass * (v_quantum - v_classical)

        # Combined correction
        alpha_pos = 0.7  # Weight for position correction
        alpha_mom = 0.3  # Weight for momentum correction

        correction_strength = -lambda_strength * (alpha_pos * position_error +
                                                alpha_mom * momentum_error / mass)

        V_corr = correction_strength * (self.x - x_quantum)

        self.expectation_history.append(x_quantum)
        self.classical_history.append(x_classical)
        self.correction_history.append(correction_strength)

        return V_corr, correction_strength

    def get_correction_diagnostics(self):
        """Return diagnostic information about correction mechanism performance."""
        diagnostics = {}

        if len(self.correction_history) > 0:
            diagnostics['correction_strength_mean'] = np.mean(list(self.correction_history))
            diagnostics['correction_strength_std'] = np.std(list(self.correction_history))
            diagnostics['correction_strength_max'] = np.max(np.abs(list(self.correction_history)))

        if len(self.expectation_history) > 1:
            positions = list(self.expectation_history)
            velocities = np.diff(positions) / self.dt
            diagnostics['quantum_velocity_mean'] = np.mean(velocities)
            diagnostics['quantum_velocity_std'] = np.std(velocities)

        if len(self.correction_efficiency_window) > 0:
            diagnostics['correction_efficiency_mean'] = np.mean(list(self.correction_efficiency_window))
            diagnostics['correction_efficiency_std'] = np.std(list(self.correction_efficiency_window))

        diagnostics['adaptive_gain'] = self.adaptive_gain
        diagnostics['memory_usage'] = len(self.expectation_history)

        return diagnostics

    def reset_memory(self):
        """Clear all memory buffers and reset to initial state."""
        self.position_history.clear()
        self.expectation_history.clear()
        self.classical_history.clear()
        self.correction_history.clear()
        self.correction_efficiency_window.clear()
        self.adaptive_gain = 1.0

    def set_base_parameters(self, base_lambda=None, memory_length=None):
        """Update base correction parameters."""
        if base_lambda is not None:
            self.base_lambda = base_lambda

        if memory_length is not None:
            self.memory_length = memory_length
            # Resize memory buffers
            new_buffers = [
                deque(maxlen=memory_length),
                deque(maxlen=memory_length),
                deque(maxlen=memory_length),
                deque(maxlen=memory_length)
            ]

            # Preserve recent history if possible
            for old_buffer, new_buffer in zip([self.position_history, self.expectation_history,
                                             self.classical_history, self.correction_history],
                                            new_buffers):
                for item in old_buffer:
                    new_buffer.append(item)

            self.position_history, self.expectation_history, self.classical_history, self.correction_history = new_buffers