"""Input validation and configuration checking module.

This module provides comprehensive validation for all simulation parameters
to prevent crashes and ensure safe operation of the quantum-classical
boundary emergence simulator.
"""

import numpy as np
import warnings
import os
from typing import Dict, Any, List, Union, Tuple


class ConfigValidator:
    """Comprehensive configuration validator for the quantum-classical simulator."""

    def __init__(self):
        """Initialize the validator with default limits and constraints."""
        # Physical constraints
        self.min_grid_size = 16
        self.max_grid_size = 4096
        self.min_dt = 1e-6
        self.max_dt = 1.0
        self.min_steps = 1
        self.max_steps = 100000
        self.min_physical_constant = 1e-12
        self.max_physical_constant = 1e12

        # Numerical stability limits
        self.max_lambda = 10.0
        self.max_ensemble_size = 1000
        self.max_memory_length = 10000

        # Valid correction types
        self.valid_correction_types = [
            'basic', 'time_delayed', 'adaptive', 'predictive', 'nonlinear'
        ]

    def validate_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], List[str]]:
        """
        Validate and sanitize a complete configuration dictionary.

        Returns:
            Tuple of (sanitized_config, list_of_warnings)
        """
        warnings_list = []
        sanitized_config = config.copy()

        try:
            # Validate grid parameters
            sanitized_config, grid_warnings = self._validate_grid_parameters(sanitized_config)
            warnings_list.extend(grid_warnings)

            # Validate physical parameters
            sanitized_config, phys_warnings = self._validate_physical_parameters(sanitized_config)
            warnings_list.extend(phys_warnings)

            # Validate simulation parameters
            sanitized_config, sim_warnings = self._validate_simulation_parameters(sanitized_config)
            warnings_list.extend(sim_warnings)

            # Validate correction parameters
            sanitized_config, corr_warnings = self._validate_correction_parameters(sanitized_config)
            warnings_list.extend(corr_warnings)

            # Validate ensemble parameters
            sanitized_config, ens_warnings = self._validate_ensemble_parameters(sanitized_config)
            warnings_list.extend(ens_warnings)

            # Validate I/O parameters
            sanitized_config, io_warnings = self._validate_io_parameters(sanitized_config)
            warnings_list.extend(io_warnings)

            # Cross-validate parameters for consistency
            sanitized_config, cross_warnings = self._cross_validate_parameters(sanitized_config)
            warnings_list.extend(cross_warnings)

        except Exception as e:
            warnings_list.append(f"Critical error during validation: {e}")
            # Return minimal safe config
            sanitized_config = self._get_minimal_safe_config()

        return sanitized_config, warnings_list

    def _validate_grid_parameters(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], List[str]]:
        """Validate grid-related parameters."""
        warnings_list = []

        # Grid size
        if 'grid_size' in config:
            grid_size = config['grid_size']
            if not isinstance(grid_size, int) or grid_size < self.min_grid_size:
                warnings_list.append(f"Invalid grid_size {grid_size}, using {self.min_grid_size}")
                config['grid_size'] = self.min_grid_size
            elif grid_size > self.max_grid_size:
                warnings_list.append(f"Grid size {grid_size} too large, using {self.max_grid_size}")
                config['grid_size'] = self.max_grid_size
        else:
            config['grid_size'] = 256
            warnings_list.append("Missing grid_size, using default 256")

        # X range
        if 'x_range' in config:
            x_range = config['x_range']
            if not isinstance(x_range, (list, tuple)) or len(x_range) != 2:
                warnings_list.append(f"Invalid x_range format {x_range}, using (-5, 5)")
                config['x_range'] = (-5, 5)
            else:
                x_min, x_max = x_range
                if not all(isinstance(x, (int, float)) for x in [x_min, x_max]):
                    warnings_list.append("Non-numeric values in x_range, using (-5, 5)")
                    config['x_range'] = (-5, 5)
                elif x_max <= x_min:
                    warnings_list.append(f"Invalid x_range {x_range} (x_max <= x_min), using (-5, 5)")
                    config['x_range'] = (-5, 5)
                elif abs(x_max - x_min) > 100:
                    warnings_list.append(f"Very large x_range {x_range}, may cause memory issues")
        else:
            config['x_range'] = (-5, 5)
            warnings_list.append("Missing x_range, using default (-5, 5)")

        return config, warnings_list

    def _validate_physical_parameters(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], List[str]]:
        """Validate physical constants and parameters."""
        warnings_list = []

        # Mass
        if 'mass' in config:
            mass = config['mass']
            if not isinstance(mass, (int, float)) or mass <= 0:
                warnings_list.append(f"Invalid mass {mass}, using 1.0")
                config['mass'] = 1.0
            elif mass < self.min_physical_constant or mass > self.max_physical_constant:
                warnings_list.append(f"Extreme mass value {mass}, may cause numerical issues")
        else:
            config['mass'] = 1.0

        # Reduced Planck constant
        if 'hbar' in config:
            hbar = config['hbar']
            if not isinstance(hbar, (int, float)) or hbar <= 0:
                warnings_list.append(f"Invalid hbar {hbar}, using 1.0")
                config['hbar'] = 1.0
            elif hbar < self.min_physical_constant or hbar > self.max_physical_constant:
                warnings_list.append(f"Extreme hbar value {hbar}, may cause numerical issues")
        else:
            config['hbar'] = 1.0

        return config, warnings_list

    def _validate_simulation_parameters(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], List[str]]:
        """Validate simulation time parameters."""
        warnings_list = []

        # Time step
        if 'dt' in config:
            dt = config['dt']
            if not isinstance(dt, (int, float)) or dt <= 0:
                warnings_list.append(f"Invalid dt {dt}, using 0.01")
                config['dt'] = 0.01
            elif dt < self.min_dt:
                warnings_list.append(f"Very small dt {dt}, may be slow")
            elif dt > self.max_dt:
                warnings_list.append(f"Large dt {dt}, may be unstable")
                config['dt'] = self.max_dt
        else:
            config['dt'] = 0.01

        # Number of steps
        if 'n_steps' in config:
            n_steps = config['n_steps']
            if not isinstance(n_steps, int) or n_steps < self.min_steps:
                warnings_list.append(f"Invalid n_steps {n_steps}, using 100")
                config['n_steps'] = 100
            elif n_steps > self.max_steps:
                warnings_list.append(f"Very large n_steps {n_steps}, may take very long")
        else:
            config['n_steps'] = 100

        # Diagnostic sampling
        if 'diagnostic_sampling' in config:
            diag_sampling = config['diagnostic_sampling']
            if not isinstance(diag_sampling, int) or diag_sampling < 1:
                warnings_list.append(f"Invalid diagnostic_sampling {diag_sampling}, using 10")
                config['diagnostic_sampling'] = 10
            elif diag_sampling > config.get('n_steps', 100):
                warnings_list.append("diagnostic_sampling larger than n_steps, adjusting")
                config['diagnostic_sampling'] = max(1, config.get('n_steps', 100) // 10)
        else:
            config['diagnostic_sampling'] = 10

        return config, warnings_list

    def _validate_correction_parameters(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], List[str]]:
        """Validate correction mechanism parameters."""
        warnings_list = []

        # Correction type
        if 'correction_type' in config:
            corr_type = config['correction_type']
            if corr_type not in self.valid_correction_types:
                warnings_list.append(f"Invalid correction_type {corr_type}, using 'basic'")
                config['correction_type'] = 'basic'
        else:
            config['correction_type'] = 'basic'

        # Lambda values
        if 'lambda_values' in config:
            lambda_values = config['lambda_values']
            if not isinstance(lambda_values, (list, tuple)):
                warnings_list.append(f"Invalid lambda_values format, using [0.0, 0.05]")
                config['lambda_values'] = [0.0, 0.05]
            else:
                sanitized_lambdas = []
                for lam in lambda_values:
                    if isinstance(lam, (int, float)) and not (np.isnan(lam) or np.isinf(lam)):
                        if abs(lam) > self.max_lambda:
                            warnings_list.append(f"Large lambda value {lam} clamped to {self.max_lambda}")
                            sanitized_lambdas.append(np.sign(lam) * self.max_lambda)
                        else:
                            sanitized_lambdas.append(float(lam))
                    else:
                        warnings_list.append(f"Invalid lambda value {lam} removed")

                if not sanitized_lambdas:
                    warnings_list.append("No valid lambda values, using [0.0]")
                    sanitized_lambdas = [0.0]

                config['lambda_values'] = sanitized_lambdas
        else:
            config['lambda_values'] = [0.0, 0.05]

        # Correction delay steps
        if 'correction_delay_steps' in config:
            delay_steps = config['correction_delay_steps']
            if not isinstance(delay_steps, int) or delay_steps < 0:
                warnings_list.append(f"Invalid correction_delay_steps {delay_steps}, using 5")
                config['correction_delay_steps'] = 5
            elif delay_steps > config.get('n_steps', 100) // 10:
                warnings_list.append("correction_delay_steps very large compared to n_steps")
        else:
            config['correction_delay_steps'] = 5

        return config, warnings_list

    def _validate_ensemble_parameters(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], List[str]]:
        """Validate ensemble simulation parameters."""
        warnings_list = []

        # Ensemble size
        if 'ensemble_size' in config:
            ens_size = config['ensemble_size']
            if not isinstance(ens_size, int) or ens_size < 1:
                warnings_list.append(f"Invalid ensemble_size {ens_size}, using 5")
                config['ensemble_size'] = 5
            elif ens_size > self.max_ensemble_size:
                warnings_list.append(f"Large ensemble_size {ens_size}, may use lots of memory")
        else:
            config['ensemble_size'] = 5

        # Boolean flags
        boolean_params = ['use_ensemble', 'use_lindblad', 'use_measurement_collapse',
                         'use_2d_extension', 'ensemble_phase_randomization']

        for param in boolean_params:
            if param in config and not isinstance(config[param], bool):
                warnings_list.append(f"Invalid {param} {config[param]}, using False")
                config[param] = False

        return config, warnings_list

    def _validate_io_parameters(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], List[str]]:
        """Validate input/output parameters."""
        warnings_list = []

        # Output directory
        if 'output_dir' in config:
            output_dir = config['output_dir']
            if not isinstance(output_dir, str) or not output_dir.strip():
                warnings_list.append("Invalid output_dir, using 'safe_results'")
                config['output_dir'] = 'safe_results'
            else:
                # Sanitize directory name
                import re
                sanitized_dir = re.sub(r'[<>:"/\\|?*]', '_', output_dir.strip())
                if sanitized_dir != output_dir:
                    warnings_list.append(f"Sanitized output_dir from '{output_dir}' to '{sanitized_dir}'")
                    config['output_dir'] = sanitized_dir
        else:
            config['output_dir'] = 'safe_results'

        # Boolean I/O flags
        io_flags = ['save_raw_data', 'save_wigner_evolution', 'create_boundary_map']
        for flag in io_flags:
            if flag in config and not isinstance(config[flag], bool):
                warnings_list.append(f"Invalid {flag} {config[flag]}, using False")
                config[flag] = False

        return config, warnings_list

    def _cross_validate_parameters(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], List[str]]:
        """Cross-validate parameters for consistency."""
        warnings_list = []

        # Check time step stability
        try:
            x_range = config['x_range']
            grid_size = config['grid_size']
            dt = config['dt']
            hbar = config['hbar']
            mass = config['mass']

            dx = (x_range[1] - x_range[0]) / grid_size
            # Rough stability criterion for split-step method
            critical_dt = 0.1 * mass * dx**2 / hbar

            if dt > critical_dt:
                warnings_list.append(f"Time step {dt} may be too large for stability (critical: {critical_dt:.6f})")
        except Exception as e:
            warnings_list.append(f"Could not check time step stability: {e}")

        # Check memory usage estimation
        try:
            total_grid_points = grid_size**2 if config.get('use_2d_extension', False) else grid_size
            n_steps = config['n_steps']
            ensemble_size = config['ensemble_size'] if config.get('use_ensemble', False) else 1

            # Rough memory estimation (in MB)
            estimated_memory = (total_grid_points * n_steps * ensemble_size * 16) / (1024**2)  # 16 bytes per complex number

            if estimated_memory > 1000:  # More than 1GB
                warnings_list.append(f"Estimated memory usage: {estimated_memory:.1f} MB - may be excessive")
        except Exception as e:
            warnings_list.append(f"Could not estimate memory usage: {e}")

        return config, warnings_list

    def _get_minimal_safe_config(self) -> Dict[str, Any]:
        """Return a minimal safe configuration as fallback."""
        return {
            'grid_size': 64,
            'x_range': (-3, 3),
            'dt': 0.05,
            'n_steps': 50,
            'mass': 1.0,
            'hbar': 1.0,
            'use_lindblad': False,
            'use_ensemble': False,
            'use_measurement_collapse': False,
            'use_2d_extension': False,
            'correction_type': 'basic',
            'lambda_values': [0.0],
            'correction_delay_steps': 1,
            'ensemble_size': 1,
            'ensemble_phase_randomization': False,
            'diagnostic_sampling': 10,
            'save_wigner_evolution': False,
            'create_boundary_map': False,
            'output_dir': 'minimal_safe_results',
            'save_raw_data': False
        }

    def validate_wavefunction(self, psi: np.ndarray, grid_size: int) -> Tuple[np.ndarray, List[str]]:
        """Validate and sanitize a wavefunction array."""
        warnings_list = []

        try:
            if psi is None:
                warnings_list.append("Wavefunction is None, creating unit Gaussian")
                x = np.linspace(-5, 5, grid_size)
                psi = np.exp(-x**2)
                psi = psi / np.sqrt(np.trapz(np.abs(psi)**2, x))
                return psi.astype(np.complex128), warnings_list

            # Convert to numpy array if needed
            psi = np.array(psi, dtype=np.complex128)

            # Check size
            if len(psi) != grid_size:
                warnings_list.append(f"Wavefunction size {len(psi)} doesn't match grid size {grid_size}")
                # Resize or create new
                if len(psi) > grid_size:
                    psi = psi[:grid_size]
                else:
                    x = np.linspace(-5, 5, grid_size)
                    psi = np.exp(-x**2)
                    psi = psi / np.sqrt(np.trapz(np.abs(psi)**2, x))
                    warnings_list.append("Created new Gaussian wavefunction")

            # Check for invalid values
            if np.any(np.isnan(psi)) or np.any(np.isinf(psi)):
                warnings_list.append("Invalid values in wavefunction, cleaning")
                psi = np.nan_to_num(psi, nan=0.0, posinf=0.0, neginf=0.0)

            # Check normalization
            norm = np.sqrt(np.trapz(np.abs(psi)**2, np.linspace(-5, 5, len(psi))))
            if norm < 1e-12:
                warnings_list.append("Wavefunction has zero norm, creating unit Gaussian")
                x = np.linspace(-5, 5, grid_size)
                psi = np.exp(-x**2)
                psi = psi / np.sqrt(np.trapz(np.abs(psi)**2, x))
            elif abs(norm - 1.0) > 0.1:
                warnings_list.append(f"Wavefunction poorly normalized (norm={norm:.6f}), renormalizing")
                psi = psi / norm

            return psi, warnings_list

        except Exception as e:
            warnings_list.append(f"Critical error validating wavefunction: {e}, creating fallback")
            x = np.linspace(-5, 5, grid_size)
            psi = np.exp(-x**2)
            psi = psi / np.sqrt(np.trapz(np.abs(psi)**2, x))
            return psi.astype(np.complex128), warnings_list


def validate_and_sanitize_config(config: Dict[str, Any]) -> Dict[str, Any]:
    """
    Convenience function to validate and sanitize a configuration dictionary.

    Args:
        config: Configuration dictionary to validate

    Returns:
        Sanitized configuration dictionary
    """
    validator = ConfigValidator()
    sanitized_config, warnings_list = validator.validate_config(config)

    # Print all warnings
    for warning in warnings_list:
        warnings.warn(warning)

    return sanitized_config


def create_safe_config(
    grid_size: int = 128,
    time_steps: int = 100,
    lambda_values: List[float] = None
) -> Dict[str, Any]:
    """
    Create a safe configuration with validated parameters.

    Args:
        grid_size: Spatial grid size
        time_steps: Number of time evolution steps
        lambda_values: List of correction strengths to test

    Returns:
        Safe configuration dictionary
    """
    if lambda_values is None:
        lambda_values = [0.0, 0.05]

    config = {
        'grid_size': grid_size,
        'x_range': (-5, 5),
        'dt': 0.02,
        'n_steps': time_steps,
        'lambda_values': lambda_values,
        'correction_type': 'basic',
        'use_ensemble': False,
        'diagnostic_sampling': max(1, time_steps // 10),
        'output_dir': 'safe_simulation_results'
    }

    return validate_and_sanitize_config(config)