"""Simulation control module for quantum-classical boundary emergence.

This module provides comprehensive simulation control including:
- Toggle modes: open system, Lindblad evolution, collapse injection
- Ensemble averaging for density matrix mode
- Parameter sweeps and boundary mapping
- Integration of all physics modules
- Advanced visualization and analysis
"""

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
import os
from datetime import datetime
import json
from multiprocessing import Pool
import warnings

from physics import QuantumPhysics
from diagnostics import QuantumDiagnostics
from phase_space import PhaseSpaceAnalysis
from correction import CorrectionMechanism

class QuantumClassicalSimulator:
    """Main simulation controller with all enhanced capabilities."""

    def __init__(self, config=None):
        """Initialize the enhanced quantum-classical boundary simulator."""
        try:
            # Default configuration
            default_config = {
                # Grid parameters
                'grid_size': 512,
                'x_range': (-10, 10),
                'dt': 0.005,
                'n_steps': 1000,

                # Physical parameters
                'mass': 1.0,
                'hbar': 1.0,

                # Simulation modes
                'use_lindblad': False,
                'use_ensemble': False,
                'use_measurement_collapse': False,
                'use_2d_extension': False,

                # Correction parameters
                'correction_type': 'basic',  # 'basic', 'time_delayed', 'adaptive', 'predictive', 'nonlinear'
                'lambda_values': [0.0, 0.01, 0.02, 0.05, 0.1],
                'correction_delay_steps': 5,

                # Lindblad parameters
                'gamma_x': 0.01,
                'gamma_p': 0.01,

                # Measurement parameters
                'measurement_times': [0.5],
                'measurement_positions': [0.0],
                'collapse_width': 1.0,

                # Ensemble parameters
                'ensemble_size': 10,
                'ensemble_phase_randomization': True,

                # Analysis parameters
                'diagnostic_sampling': 10,
                'save_wigner_evolution': True,
                'create_boundary_map': True,

                # Output parameters
                'output_dir': 'enhanced_quantum_classical_results',
                'save_raw_data': True
            }

            self.config = default_config.copy()
            if config:
                self.config.update(config)

            # Validate and sanitize configuration
            self._validate_config()

            # Initialize modules with error handling
            self._setup_modules()

        except Exception as e:
            print(f"Error initializing QuantumClassicalSimulator: {e}")
            print("Falling back to minimal configuration...")
            self._setup_minimal_config()
            self._setup_modules()

    def _setup_modules(self):
        """Initialize all physics and analysis modules."""
        try:
            # Core physics with error handling
            try:
                self.physics = QuantumPhysics(
                    grid_size=self.config['grid_size'],
                    x_range=self.config['x_range'],
                    mass=self.config['mass'],
                    hbar=self.config['hbar']
                )
            except Exception as e:
                print(f"Error initializing QuantumPhysics: {e}")
                # Try minimal physics setup
                self.physics = QuantumPhysics(
                    grid_size=128,
                    x_range=(-5, 5),
                    mass=1.0,
                    hbar=1.0
                )
                print("Using minimal physics configuration")

            # Diagnostics with error handling
            try:
                self.diagnostics = QuantumDiagnostics(
                    self.physics.x,
                    self.physics.dx
                )
            except Exception as e:
                print(f"Error initializing QuantumDiagnostics: {e}")
                self.diagnostics = None

            # Phase space analysis with error handling
            try:
                self.phase_space = PhaseSpaceAnalysis(
                    self.physics.x,
                    self.physics.k,
                    self.config['hbar']
                )
            except Exception as e:
                print(f"Error initializing PhaseSpaceAnalysis: {e}")
                self.phase_space = None

            # Correction mechanism with error handling
            try:
                self.correction = CorrectionMechanism(
                    self.physics.x,
                    self.config['dt']
                )
            except Exception as e:
                print(f"Error initializing CorrectionMechanism: {e}")
                self.correction = None

        except Exception as e:
            print(f"Critical error in _setup_modules: {e}")
            raise RuntimeError("Failed to initialize simulation modules")

    def _validate_config(self):
        """Validate and sanitize configuration parameters."""
        try:
            # Validate grid parameters
            if self.config['grid_size'] <= 0 or not isinstance(self.config['grid_size'], int):
                warnings.warn(f"Invalid grid_size {self.config['grid_size']}, using 256")
                self.config['grid_size'] = 256

            if self.config['dt'] <= 0:
                warnings.warn(f"Invalid dt {self.config['dt']}, using 0.01")
                self.config['dt'] = 0.01

            if self.config['n_steps'] <= 0 or not isinstance(self.config['n_steps'], int):
                warnings.warn(f"Invalid n_steps {self.config['n_steps']}, using 100")
                self.config['n_steps'] = 100

            # Validate physical parameters
            if self.config['mass'] <= 0:
                warnings.warn(f"Invalid mass {self.config['mass']}, using 1.0")
                self.config['mass'] = 1.0

            if self.config['hbar'] <= 0:
                warnings.warn(f"Invalid hbar {self.config['hbar']}, using 1.0")
                self.config['hbar'] = 1.0

            # Validate x_range
            x_range = self.config['x_range']
            if not isinstance(x_range, (list, tuple)) or len(x_range) != 2 or x_range[1] <= x_range[0]:
                warnings.warn(f"Invalid x_range {x_range}, using (-5, 5)")
                self.config['x_range'] = (-5, 5)

            # Validate lambda_values
            if not isinstance(self.config['lambda_values'], (list, tuple)):
                warnings.warn("Invalid lambda_values, using [0.0, 0.05]")
                self.config['lambda_values'] = [0.0, 0.05]

            # Validate ensemble parameters
            if self.config['ensemble_size'] <= 0 or not isinstance(self.config['ensemble_size'], int):
                warnings.warn(f"Invalid ensemble_size {self.config['ensemble_size']}, using 5")
                self.config['ensemble_size'] = 5

            # Validate diagnostic sampling
            if self.config['diagnostic_sampling'] <= 0 or not isinstance(self.config['diagnostic_sampling'], int):
                warnings.warn(f"Invalid diagnostic_sampling {self.config['diagnostic_sampling']}, using 10")
                self.config['diagnostic_sampling'] = 10

        except Exception as e:
            warnings.warn(f"Error validating config: {e}. Using defaults.")

    def _setup_minimal_config(self):
        """Setup minimal configuration for fallback operation."""
        self.config = {
            'grid_size': 128,
            'x_range': (-5, 5),
            'dt': 0.02,
            'n_steps': 50,
            'mass': 1.0,
            'hbar': 1.0,
            'use_lindblad': False,
            'use_ensemble': False,
            'use_measurement_collapse': False,
            'use_2d_extension': False,
            'correction_type': 'basic',
            'lambda_values': [0.0],
            'diagnostic_sampling': 10,
            'save_wigner_evolution': False,
            'create_boundary_map': False,
            'output_dir': 'minimal_results',
            'save_raw_data': False
        }

    def create_initial_state(self, state_type='double_gaussian', **kwargs):
        """Create various initial quantum states."""
        if state_type == 'double_gaussian':
            # Standard double Gaussian superposition
            psi_left = np.exp(-(self.physics.x + 2)**2)
            psi_right = np.exp(-(self.physics.x - 2)**2)
            psi = psi_left + psi_right

        elif state_type == 'coherent_state':
            x0 = kwargs.get('x0', 0.0)
            p0 = kwargs.get('p0', 0.0)
            sigma = kwargs.get('sigma', 1.0)
            psi = np.exp(-(self.physics.x - x0)**2 / (2 * sigma**2))
            psi *= np.exp(1j * p0 * self.physics.x / self.config['hbar'])

        elif state_type == 'random_superposition':
            n_components = kwargs.get('n_components', 5)
            psi = np.zeros_like(self.physics.x, dtype=complex)
            for i in range(n_components):
                x_center = np.random.uniform(-5, 5)
                phase = np.random.uniform(0, 2*np.pi)
                amplitude = np.random.uniform(0.5, 1.5)
                component = amplitude * np.exp(-(self.physics.x - x_center)**2)
                component *= np.exp(1j * phase)
                psi += component

        else:
            raise ValueError(f"Unknown state type: {state_type}")

        return self.physics.normalize_wavefunction(psi)

    def run_single_simulation(self, lambda_val, save_evolution=True):
        """Run a single simulation with specified correction strength."""
        try:
            # Validate lambda_val
            if np.isnan(lambda_val) or np.isinf(lambda_val):
                warnings.warn(f"Invalid lambda_val {lambda_val}, using 0.0")
                lambda_val = 0.0

            results = {
                'lambda': lambda_val,
                'time_points': [],
                'diagnostics': [],
                'wavefunction_evolution': [],
                'wigner_evolution': [],
                'correction_diagnostics': []
            }

        except Exception as e:
            print(f"Error initializing simulation results: {e}")
            return {
                'lambda': 0.0,
                'time_points': [],
                'diagnostics': [],
                'error': str(e)
            }

            # Initialize state with error handling
            try:
                psi = self.create_initial_state()
                if psi is None:
                    raise ValueError("Failed to create initial state")
                psi_initial = psi.copy()
            except Exception as e:
                print(f"Error creating initial state: {e}")
                # Fallback to simple Gaussian
                psi = np.exp(-self.physics.x**2)
                psi = self.physics.normalize_wavefunction(psi)
                psi_initial = psi.copy()

        # Create density matrix if using Lindblad evolution
        if self.config['use_lindblad']:
            rho = self.physics.wavefunction_to_density_matrix(psi)
            lindblad_ops = self.physics.create_lindblad_operators(
                self.config['gamma_x'], self.config['gamma_p']
            )

            # Storage for time evolution
            wigner_snapshots = []
            diagnostic_history = []

            # Main time evolution loop with error handling
            successful_steps = 0
            max_failures = min(10, self.config['n_steps'] // 10)  # Allow some failures
            failure_count = 0

            for step in range(self.config['n_steps']):
                try:
                    current_time = step * self.config['dt']

                    # Classical reference trajectory
                    if step == 0:
                        x_classical_initial = 0.0  # Starting at center
                        v_classical_initial = 0.0
                    else:
                        # Use classical trajectory calculation
                        trajectory = self.physics.classical_trajectory_enhanced(
                            x_classical_initial, v_classical_initial, current_time, self.config['dt'],
                            self.physics.double_well_potential
                        )
                        x_classical = trajectory[-1, 0] if len(trajectory) > 0 else 0.0

                    if step == 0:
                        x_classical = 0.0

                    # Quantum evolution with numerical stability checks
                    try:
                        if self.config['use_lindblad']:
                            # Lindblad master equation evolution
                            try:
                                H = np.diag(self.physics.double_well_potential() + self.physics.k_squared)
                                rho = self.physics.lindblad_evolution(rho, H, lindblad_ops, self.config['dt'])
                                psi = self.physics.density_matrix_to_wavefunction(rho)
                            except Exception as lindblad_error:
                                warnings.warn(f"Lindblad evolution failed: {lindblad_error}. Falling back to Schrödinger.")
                                potential = self.physics.double_well_potential()
                                psi = self.physics.split_step_evolution(psi, potential, self.config['dt'])
                        else:
                            # Standard Schrödinger evolution
                            potential = self.physics.double_well_potential()
                            psi = self.physics.split_step_evolution(psi, potential, self.config['dt'])

                        # Numerical stability checks
                        if not self._check_wavefunction_stability(psi):
                            warnings.warn(f"Wavefunction became unstable at step {step}, attempting recovery")
                            psi = self._recover_wavefunction(psi, psi_initial)

                    except Exception as evolution_error:
                        warnings.warn(f"Evolution failed at step {step}: {evolution_error}")
                        # Try to recover or continue with previous state
                        continue

                    # Apply quantum potential-based corrections with error handling
                    try:
                        quantum_potential = self.physics.quantum_potential(psi)
                        if np.any(np.isnan(quantum_potential)) or np.any(np.isinf(quantum_potential)):
                            warnings.warn("Invalid quantum potential values detected")
                            quantum_potential = np.zeros_like(self.physics.x)
                    except Exception as qp_error:
                        warnings.warn(f"Error computing quantum potential: {qp_error}")
                        quantum_potential = np.zeros_like(self.physics.x)

                    # Apply correction mechanism with error handling
                    if lambda_val > 0:
                        try:
                            correction_type = self.config['correction_type']

                            if correction_type == 'basic':
                                V_corr, corr_strength = self.correction.basic_correction(
                                    psi, x_classical, lambda_val)
                            elif correction_type == 'time_delayed':
                                V_corr, corr_strength = self.correction.time_delayed_correction(
                                    psi, x_classical, lambda_val, self.config['correction_delay_steps'])
                            elif correction_type == 'adaptive':
                                V_corr, corr_strength = self.correction.adaptive_quantum_potential_correction(
                                    psi, x_classical, quantum_potential, lambda_val)
                            elif correction_type == 'predictive':
                                V_corr, corr_strength = self.correction.predictive_correction(
                                    psi, x_classical, lambda_val)
                            elif correction_type == 'nonlinear':
                                V_corr, corr_strength = self.correction.nonlinear_correction(
                                    psi, x_classical, lambda_val)
                            else:
                                V_corr, corr_strength = self.correction.basic_correction(
                                    psi, x_classical, lambda_val)

                            # Apply correction with safety checks
                            if V_corr is not None and not np.any(np.isnan(V_corr)) and not np.any(np.isinf(V_corr)):
                                correction_factor = np.exp(-1j * V_corr * self.config['dt'] / self.config['hbar'])
                                if not np.any(np.isnan(correction_factor)) and not np.any(np.isinf(correction_factor)):
                                    psi *= correction_factor
                                    psi = self.physics.normalize_wavefunction(psi)
                                else:
                                    warnings.warn("Invalid correction factor, skipping correction")
                            else:
                                warnings.warn("Invalid correction potential, skipping correction")

                        except Exception as correction_error:
                            warnings.warn(f"Error applying correction: {correction_error}")

                    # Apply measurement collapses if configured
                    if (self.config['use_measurement_collapse'] and
                        current_time in self.config['measurement_times']):
                        try:
                            measurement_pos = self.config['measurement_positions'][0]
                            psi = self.physics.apply_measurement_collapse(
                                psi, measurement_pos, self.config['collapse_width'])
                        except Exception as measurement_error:
                            warnings.warn(f"Error applying measurement collapse: {measurement_error}")

                    # Collect diagnostics with error handling
                    if step % self.config['diagnostic_sampling'] == 0:
                        try:
                            # Convert to density matrix for comprehensive diagnostics
                            rho_current = self.physics.wavefunction_to_density_matrix(psi)

                            # Compute Wigner function
                            wigner = self.phase_space.compute_wigner_function(psi)
                            if save_evolution:
                                wigner_snapshots.append(wigner)

                            # Calculate comprehensive diagnostics
                            diag = self.diagnostics.comprehensive_diagnostics(
                                psi, psi_initial, x_classical, rho_current, wigner
                            )

                            # Add quantum potential information
                            diag['quantum_potential_mean'] = np.mean(np.abs(quantum_potential))
                            diag['quantum_potential_max'] = np.max(np.abs(quantum_potential))

                            # Add correction diagnostics
                            if lambda_val > 0:
                                corr_diag = self.correction.get_correction_diagnostics()
                                diag.update({f'correction_{k}': v for k, v in corr_diag.items()})

                            diag['time'] = current_time
                            diag['step'] = step

                            diagnostic_history.append(diag)
                            results['time_points'].append(current_time)

                        except Exception as diag_error:
                            warnings.warn(f"Error collecting diagnostics at step {step}: {diag_error}")
                            # Add minimal diagnostic entry
                            diagnostic_history.append({
                                'time': current_time,
                                'step': step,
                                'error': str(diag_error)
                            })
                            results['time_points'].append(current_time)

                    # Store wavefunction snapshots if requested
                    if save_evolution and step % (self.config['diagnostic_sampling'] * 2) == 0:
                        try:
                            results['wavefunction_evolution'].append(psi.copy())
                        except Exception as e:
                            warnings.warn(f"Error storing wavefunction at step {step}: {e}")

                    successful_steps += 1

                except Exception as step_error:
                    failure_count += 1
                    warnings.warn(f"Error in evolution step {step}: {step_error}")

                    if failure_count > max_failures:
                        print(f"Too many failures ({failure_count}), stopping simulation early")
                        break

                    # Try to continue with previous wavefunction
                    continue

            # Final analysis with error handling
            try:
                results['diagnostics'] = diagnostic_history
                results['wigner_evolution'] = wigner_snapshots
                results['final_state'] = psi.copy()
                results['successful_steps'] = successful_steps
                results['total_steps'] = self.config['n_steps']
                results['failure_count'] = failure_count

                if self.phase_space is not None:
                    try:
                        results['final_wigner'] = self.phase_space.compute_wigner_function(psi)
                    except Exception as e:
                        warnings.warn(f"Error computing final Wigner function: {e}")
                        results['final_wigner'] = None
                else:
                    results['final_wigner'] = None

                return results

            except Exception as e:
                print(f"Error in final analysis: {e}")
                results['error'] = str(e)
                return results

        except Exception as e:
            print(f"Critical error in run_single_simulation: {e}")
            return {
                'lambda': lambda_val,
                'error': str(e),
                'time_points': [],
                'diagnostics': []
            }

    def run_ensemble_simulation(self, lambda_val, ensemble_size=None):
        """Run ensemble simulation with multiple realizations."""
        if ensemble_size is None:
            ensemble_size = self.config['ensemble_size']

        ensemble_results = []
        ensemble_diagnostics = {}

        # Run individual realizations
        for realization in range(ensemble_size):
            # Randomize initial phase if configured
            if self.config['ensemble_phase_randomization']:
                phase_offset = np.random.uniform(0, 2*np.pi)
                # Set random seed for this realization
                np.random.seed(realization * 42)

            # Run single simulation
            result = self.run_single_simulation(lambda_val, save_evolution=False)
            ensemble_results.append(result)

        # Average ensemble diagnostics
        all_diagnostics = [result['diagnostics'] for result in ensemble_results]
        time_points = ensemble_results[0]['time_points']

        averaged_diagnostics = []
        for time_idx in range(len(time_points)):
            time_slice = {}
            for key in all_diagnostics[0][time_idx].keys():
                if isinstance(all_diagnostics[0][time_idx][key], (int, float)):
                    values = [diag[time_idx][key] for diag in all_diagnostics]
                    time_slice[key] = np.mean(values)
                    time_slice[key + '_std'] = np.std(values)
                else:
                    time_slice[key] = all_diagnostics[0][time_idx][key]  # Keep non-numeric values
            averaged_diagnostics.append(time_slice)

        return {
            'lambda': lambda_val,
            'ensemble_size': ensemble_size,
            'individual_results': ensemble_results,
            'averaged_diagnostics': averaged_diagnostics,
            'time_points': time_points
        }

    def create_boundary_map(self, lambda_range=None, metrics=['shannon_entropy', 'von_neumann_entropy']):
        """Create quantum-classical boundary phase diagram."""
        if lambda_range is None:
            lambda_range = self.config['lambda_values']

        boundary_data = {
            'lambda_values': lambda_range,
            'metrics': {},
            'final_metrics': {}
        }

        for metric in metrics:
            boundary_data['metrics'][metric] = []
            boundary_data['final_metrics'][metric] = []

        # Run simulations across lambda values
        for lambda_val in lambda_range:
            print(f"Computing boundary map for λ = {lambda_val:.3f}")

            if self.config['use_ensemble']:
                result = self.run_ensemble_simulation(lambda_val)
                diagnostics = result['averaged_diagnostics']
            else:
                result = self.run_single_simulation(lambda_val, save_evolution=False)
                diagnostics = result['diagnostics']

            # Extract time series for each metric
            for metric in metrics:
                if metric in diagnostics[0]:
                    time_series = [diag[metric] for diag in diagnostics if metric in diag]
                    boundary_data['metrics'][metric].append(time_series)
                    boundary_data['final_metrics'][metric].append(time_series[-1] if time_series else 0.0)

        return boundary_data

    def visualize_boundary_map(self, boundary_data, save_path=None):
        """Create boundary map visualizations."""
        lambda_vals = boundary_data['lambda_values']
        metrics = list(boundary_data['final_metrics'].keys())

        fig, axes = plt.subplots(2, 2, figsize=(15, 12))
        axes = axes.flatten()

        for i, metric in enumerate(metrics[:4]):  # Plot up to 4 metrics
            if i >= len(axes):
                break

            ax = axes[i]

            # Time evolution plot
            for j, lambda_val in enumerate(lambda_vals):
                time_series = boundary_data['metrics'][metric][j]
                time_points = np.linspace(0, self.config['n_steps'] * self.config['dt'], len(time_series))
                ax.plot(time_points, time_series, label=f'λ = {lambda_val:.3f}')

            ax.set_xlabel('Time')
            ax.set_ylabel(metric.replace('_', ' ').title())
            ax.legend()
            ax.grid(True, alpha=0.3)

        plt.tight_layout()
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')

        # Create 2D boundary map
        fig, ax = plt.subplots(figsize=(12, 8))

        # Prepare data for heatmap
        metric_matrix = np.array([boundary_data['metrics'][metrics[0]][i]
                                for i in range(len(lambda_vals))])

        im = ax.imshow(metric_matrix, aspect='auto', origin='lower',
                      extent=[0, self.config['n_steps'] * self.config['dt'],
                             lambda_vals[0], lambda_vals[-1]],
                      cmap='viridis')

        ax.set_xlabel('Time')
        ax.set_ylabel('Correction Strength λ')
        ax.set_title(f'{metrics[0].replace("_", " ").title()} Evolution')
        plt.colorbar(im, ax=ax)

        if save_path:
            boundary_map_path = save_path.replace('.png', '_boundary_map.png')
            plt.savefig(boundary_map_path, dpi=300, bbox_inches='tight')

        return fig, ax

    def save_results(self, results, output_dir=None):
        """Save simulation results to files."""
        if output_dir is None:
            output_dir = self.config['output_dir']

        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        session_dir = os.path.join(output_dir, f"session_{timestamp}")
        os.makedirs(session_dir, exist_ok=True)

        # Save configuration
        config_path = os.path.join(session_dir, "config.json")
        with open(config_path, 'w') as f:
            json.dump(self.config, f, indent=2)

        # Save results based on type
        if isinstance(results, list):  # Multiple simulations
            for i, result in enumerate(results):
                self._save_single_result(result, session_dir, f"lambda_{result['lambda']:.3f}")
        else:  # Single result
            self._save_single_result(results, session_dir, "single_run")

        return session_dir

    def _save_single_result(self, result, base_dir, run_name):
        """Save a single simulation result."""
        run_dir = os.path.join(base_dir, run_name)
        os.makedirs(run_dir, exist_ok=True)

        # Save diagnostics as CSV-like format
        if 'diagnostics' in result:
            diagnostics_file = os.path.join(run_dir, "diagnostics.csv")
            self._save_diagnostics_csv(result['diagnostics'], diagnostics_file)

        # Save raw data if configured
        if self.config['save_raw_data']:
            if 'final_state' in result:
                np.save(os.path.join(run_dir, "final_wavefunction.npy"), result['final_state'])

            if 'final_wigner' in result:
                np.save(os.path.join(run_dir, "final_wigner.npy"), result['final_wigner'])

            if 'wigner_evolution' in result and result['wigner_evolution']:
                wigner_stack = np.array(result['wigner_evolution'])
                np.save(os.path.join(run_dir, "wigner_evolution.npy"), wigner_stack)

    def _save_diagnostics_csv(self, diagnostics, filename):
        """Save diagnostics to CSV format."""
        if not diagnostics:
            return

        # Get all possible keys
        all_keys = set()
        for diag in diagnostics:
            all_keys.update(diag.keys())

        # Create header
        header = sorted(list(all_keys))

        # Write data
        with open(filename, 'w') as f:
            # Write header
            f.write(','.join(header) + '\n')

            # Write data rows
            for diag in diagnostics:
                row = []
                for key in header:
                    value = diag.get(key, '')
                    if isinstance(value, (int, float)):
                        row.append(str(value))
                    else:
                        row.append(str(value))
                f.write(','.join(row) + '\n')

    def run_full_parameter_sweep(self):
        """Run complete parameter sweep with all configured options."""
        print("Starting Enhanced Quantum-Classical Boundary Emergence Simulation")
        print("=" * 70)
        print(f"Grid size: {self.config['grid_size']}")
        print(f"Time steps: {self.config['n_steps']}")
        print(f"Correction type: {self.config['correction_type']}")
        print(f"Lambda values: {self.config['lambda_values']}")
        print(f"Use Lindblad: {self.config['use_lindblad']}")
        print(f"Use ensemble: {self.config['use_ensemble']}")
        print("=" * 70)

        all_results = []

        # Run simulations for each lambda value
        for i, lambda_val in enumerate(self.config['lambda_values']):
            print(f"\nRunning simulation {i+1}/{len(self.config['lambda_values'])}: λ = {lambda_val:.3f}")

            if self.config['use_ensemble']:
                result = self.run_ensemble_simulation(lambda_val)
            else:
                result = self.run_single_simulation(lambda_val)

            all_results.append(result)

        # Create boundary map if configured
        if self.config['create_boundary_map']:
            print("\nCreating quantum-classical boundary map...")
            boundary_data = self.create_boundary_map()
            all_results.append({'boundary_map': boundary_data})

        # Save all results
        output_dir = self.save_results(all_results)
        print(f"\nAll results saved to: {output_dir}")

        return all_results, output_dir

    def _check_wavefunction_stability(self, psi: np.ndarray) -> bool:
        """Check if wavefunction is numerically stable."""
        try:
            # Check for NaN or infinite values
            if np.any(np.isnan(psi)) or np.any(np.isinf(psi)):
                return False

            # Check norm
            norm = np.sqrt(np.trapz(np.abs(psi)**2, self.physics.x))
            if np.isnan(norm) or np.isinf(norm) or norm < 1e-12 or norm > 100:
                return False

            # Check for unreasonably large values
            max_amplitude = np.max(np.abs(psi))
            if max_amplitude > 1e6:
                return False

            return True

        except Exception:
            return False

    def _recover_wavefunction(self, psi: np.ndarray, psi_initial: np.ndarray) -> np.ndarray:
        """Attempt to recover a stable wavefunction."""
        try:
            # First, try to clean the current wavefunction
            psi_clean = np.nan_to_num(psi, nan=0.0, posinf=0.0, neginf=0.0)

            # Try to renormalize
            norm = np.sqrt(np.trapz(np.abs(psi_clean)**2, self.physics.x))
            if norm > 1e-12:
                psi_clean = psi_clean / norm
                if self._check_wavefunction_stability(psi_clean):
                    return psi_clean

            # If that doesn't work, return to initial state
            warnings.warn("Reverting to initial wavefunction")
            return psi_initial.copy()

        except Exception:
            # Ultimate fallback: create a simple Gaussian
            warnings.warn("Creating fallback Gaussian wavefunction")
            gaussian = np.exp(-self.physics.x**2)
            return self.physics.normalize_wavefunction(gaussian)

    def _safe_diagnostics_calculation(self, psi, psi_initial, x_classical, current_time, step):
        """Calculate diagnostics with maximum error tolerance."""
        diag = {
            'time': current_time,
            'step': step
        }

        try:
            if self.diagnostics is not None:
                # Basic diagnostics
                try:
                    diag['shannon_entropy'] = self.diagnostics.shannon_entropy(psi)
                except Exception:
                    diag['shannon_entropy'] = 0.0

                try:
                    diag['position_expectation'] = np.real(np.trapz(self.physics.x * np.abs(psi)**2, self.physics.x))
                except Exception:
                    diag['position_expectation'] = 0.0

                try:
                    diag['position_uncertainty'] = self.diagnostics.position_uncertainty(psi)
                except Exception:
                    diag['position_uncertainty'] = 0.0

                # Try more complex diagnostics
                try:
                    rho = self.physics.wavefunction_to_density_matrix(psi)
                    diag['von_neumann_entropy'] = self.diagnostics.von_neumann_entropy(rho)
                except Exception:
                    diag['von_neumann_entropy'] = 0.0

        except Exception as e:
            warnings.warn(f"Error in diagnostics calculation: {e}")
            diag.update({
                'shannon_entropy': 0.0,
                'position_expectation': 0.0,
                'position_uncertainty': 0.0,
                'von_neumann_entropy': 0.0,
                'diagnostic_error': str(e)
            })

        return diag