"""
Extended Simulation Capabilities for Long-Time Quantum-Classical Evolution
=========================================================================

Advanced simulation framework designed for long-time quantum evolution studies
to capture quantum-classical boundary emergence over extended time scales.
"""

import numpy as np
import warnings
import time
import os
import json
import pickle
from typing import Dict, List, Tuple, Optional, Callable, Any
from dataclasses import dataclass, asdict
from collections import deque
import h5py
from datetime import datetime, timedelta

@dataclass
class SimulationCheckpoint:
    """Data structure for simulation checkpoints."""
    step: int
    time: float
    psi: np.ndarray
    diagnostics: Dict[str, Any]
    correction_state: Dict[str, Any]
    timestamp: str

@dataclass
class TransitionEvent:
    """Data structure for detected transition events."""
    event_type: str
    step: int
    time: float
    magnitude: float
    duration: int
    description: str

class ExtendedQuantumSimulation:
    """Advanced simulation framework for extended quantum-classical evolution."""

    def __init__(self, physics, enhanced_correction, config: Dict[str, Any]):
        """Initialize extended simulation framework."""
        self.physics = physics
        self.correction = enhanced_correction
        self.config = config

        # Extended simulation parameters
        self.max_steps = config.get('max_steps', 10000)
        self.checkpoint_interval = config.get('checkpoint_interval', 1000)
        self.adaptive_timestep = config.get('adaptive_timestep', False)
        self.min_dt = config.get('min_dt', 0.0001)
        self.max_dt = config.get('max_dt', 0.01)

        # Memory management
        self.max_memory_snapshots = config.get('max_memory_snapshots', 1000)
        self.compression_level = config.get('compression_level', 6)

        # Real-time monitoring
        self.monitoring_interval = config.get('monitoring_interval', 100)
        self.transition_threshold = config.get('transition_threshold', 0.1)

        # Storage
        self.storage_format = config.get('storage_format', 'hdf5')  # 'hdf5' or 'pickle'
        self.output_dir = config.get('output_dir', 'extended_simulation_results')

        # State tracking
        self.checkpoints = []
        self.transition_events = []
        self.diagnostics_history = deque(maxlen=self.max_memory_snapshots)
        self.performance_metrics = {}

        # Adaptive parameters
        self.current_dt = config.get('dt', 0.005)
        self.stability_monitor = deque(maxlen=50)
        self.convergence_monitor = deque(maxlen=20)

        os.makedirs(self.output_dir, exist_ok=True)

    def run_extended_simulation(self, initial_state: np.ndarray, lambda_schedule: Callable[[float], float],
                               stopping_criteria: Optional[Dict[str, float]] = None) -> Dict[str, Any]:
        """
        Run extended quantum simulation with advanced monitoring and checkpointing.

        Args:
            initial_state: Initial quantum state
            lambda_schedule: Function that returns lambda(t) for time-dependent corrections
            stopping_criteria: Optional early stopping criteria
        """
        print("🚀 EXTENDED QUANTUM-CLASSICAL SIMULATION")
        print("=" * 60)

        start_time = time.time()
        psi = initial_state.copy()
        psi = self.physics.normalize_wavefunction(psi)

        # Initialize tracking
        simulation_data = {
            'config': self.config,
            'start_time': datetime.now().isoformat(),
            'diagnostics_timeline': [],
            'transition_events': [],
            'performance_metrics': {},
            'checkpoints_saved': []
        }

        print(f"📊 Target steps: {self.max_steps}")
        print(f"🕐 Initial dt: {self.current_dt}")
        print(f"💾 Checkpoint interval: {self.checkpoint_interval}")
        print(f"📈 Adaptive timestep: {self.adaptive_timestep}")

        # Main simulation loop
        for step in range(self.max_steps):
            current_time = step * self.current_dt

            try:
                # Get current lambda value
                lambda_val = lambda_schedule(current_time)

                # Perform evolution step with monitoring
                psi, step_diagnostics = self._evolution_step_with_monitoring(
                    psi, current_time, lambda_val, step
                )

                # Adaptive time stepping
                if self.adaptive_timestep:
                    self.current_dt = self._adjust_timestep(step_diagnostics)

                # Store diagnostics
                if step % self.monitoring_interval == 0:
                    extended_diagnostics = self._comprehensive_diagnostics(
                        psi, current_time, lambda_val, step, step_diagnostics
                    )
                    self.diagnostics_history.append(extended_diagnostics)
                    simulation_data['diagnostics_timeline'].append(extended_diagnostics)

                    # Print progress
                    if step % (self.monitoring_interval * 10) == 0:
                        self._print_progress(step, current_time, extended_diagnostics)

                # Check for transition events
                if step % self.monitoring_interval == 0:
                    events = self._detect_transition_events(step, current_time, extended_diagnostics)
                    for event in events:
                        self.transition_events.append(event)
                        simulation_data['transition_events'].append(asdict(event))
                        print(f"🎯 Transition Event: {event.description} at t={event.time:.3f}")

                # Checkpointing
                if step % self.checkpoint_interval == 0 and step > 0:
                    checkpoint_file = self._save_checkpoint(psi, step, current_time, extended_diagnostics)
                    simulation_data['checkpoints_saved'].append(checkpoint_file)
                    print(f"💾 Checkpoint saved: {checkpoint_file}")

                # Check stopping criteria
                if stopping_criteria and self._check_stopping_criteria(extended_diagnostics, stopping_criteria):
                    print(f"🏁 Early stopping criteria met at step {step}")
                    break

                # Memory management
                if step % (self.checkpoint_interval * 2) == 0:
                    self._manage_memory()

            except Exception as e:
                warnings.warn(f"Error at step {step}: {e}")
                # Try to recover by renormalizing
                psi = self.physics.normalize_wavefunction(psi)
                continue

        # Final analysis
        total_time = time.time() - start_time
        simulation_data['end_time'] = datetime.now().isoformat()
        simulation_data['total_runtime'] = total_time
        simulation_data['final_step'] = step
        simulation_data['final_time'] = current_time

        # Performance metrics
        self.performance_metrics = {
            'total_runtime': total_time,
            'steps_per_second': step / total_time,
            'average_dt': self.current_dt,
            'checkpoints_created': len(simulation_data['checkpoints_saved']),
            'transition_events_detected': len(self.transition_events),
            'memory_snapshots': len(self.diagnostics_history)
        }

        simulation_data['performance_metrics'] = self.performance_metrics

        # Save final results
        final_results_file = self._save_final_results(simulation_data, psi)
        print(f"✅ Extended simulation completed")
        print(f"📁 Results saved to: {final_results_file}")

        return simulation_data

    def _evolution_step_with_monitoring(self, psi: np.ndarray, current_time: float,
                                      lambda_val: float, step: int) -> Tuple[np.ndarray, Dict]:
        """Perform single evolution step with stability monitoring."""

        # Store initial state for stability monitoring
        psi_initial = psi.copy()
        initial_norm = np.sqrt(np.trapz(np.abs(psi)**2, self.physics.x))

        # Quantum evolution
        potential = self.physics.double_well_potential()
        psi = self.physics.split_step_evolution(psi, potential, self.current_dt)

        # Apply enhanced correction
        if lambda_val > 0:
            x_classical = self._calculate_classical_reference(current_time)

            # Choose correction method based on configuration
            correction_type = self.config.get('enhanced_correction_type', 'coherence_sensitive')

            try:
                if correction_type == 'coherence_sensitive':
                    V_corr, corr_strength = self.correction.coherence_sensitive_correction(
                        psi, x_classical, lambda_val
                    )
                elif correction_type == 'energy_based':
                    V_corr, corr_strength = self.correction.energy_based_correction(
                        psi, x_classical, lambda_val
                    )
                elif correction_type == 'feedback_driven':
                    V_corr, corr_strength = self.correction.feedback_driven_correction(
                        psi, x_classical, lambda_val
                    )
                elif correction_type == 'spectral':
                    V_corr, corr_strength = self.correction.spectral_correction(
                        psi, x_classical, lambda_val
                    )
                elif correction_type == 'decoherence_enhanced':
                    V_corr, corr_strength = self.correction.decoherence_enhanced_correction(
                        psi, x_classical, lambda_val
                    )
                elif correction_type == 'classical_limit':
                    V_corr, corr_strength = self.correction.classical_limit_correction(
                        psi, x_classical, lambda_val
                    )
                else:
                    # Fallback to basic correction
                    from correction import CorrectionMechanism
                    basic_correction = CorrectionMechanism(self.physics.x, self.current_dt)
                    V_corr, corr_strength = basic_correction.basic_correction(psi, x_classical, lambda_val)

                # Apply correction safely
                if V_corr is not None and not (np.any(np.isnan(V_corr)) or np.any(np.isinf(V_corr))):
                    correction_factor = np.exp(-1j * V_corr * self.current_dt / self.physics.hbar)
                    if not (np.any(np.isnan(correction_factor)) or np.any(np.isinf(correction_factor))):
                        psi *= correction_factor

            except Exception as e:
                warnings.warn(f"Enhanced correction failed at step {step}: {e}")
                corr_strength = 0.0

        else:
            corr_strength = 0.0

        # Renormalize
        psi = self.physics.normalize_wavefunction(psi)

        # Calculate step diagnostics
        final_norm = np.sqrt(np.trapz(np.abs(psi)**2, self.physics.x))
        norm_change = abs(final_norm - initial_norm)

        # Stability metrics
        overlap = abs(np.vdot(psi, psi_initial))**2
        stability_measure = overlap * (1.0 - norm_change)

        step_diagnostics = {
            'norm_change': norm_change,
            'stability_measure': stability_measure,
            'correction_strength': corr_strength,
            'lambda_val': lambda_val
        }

        # Update stability monitor
        self.stability_monitor.append(stability_measure)

        return psi, step_diagnostics

    def _comprehensive_diagnostics(self, psi: np.ndarray, current_time: float,
                                 lambda_val: float, step: int, step_diagnostics: Dict) -> Dict:
        """Calculate comprehensive diagnostics for long-time analysis."""

        diagnostics = {
            'step': step,
            'time': current_time,
            'lambda': lambda_val,
            'dt': self.current_dt
        }

        try:
            # Basic quantum state properties
            prob_density = np.abs(psi)**2
            position = np.trapz(self.physics.x * prob_density, self.physics.x)
            position_sq = np.trapz(self.physics.x**2 * prob_density, self.physics.x)
            uncertainty = np.sqrt(position_sq - position**2)

            # Entropy measures
            prob_norm = prob_density / np.sum(prob_density)
            shannon_entropy = -np.sum(prob_norm * np.log(prob_norm + 1e-16))

            # Quantum potential
            Q = self.physics.quantum_potential(psi)
            Q_mean = np.mean(np.abs(Q))
            Q_max = np.max(np.abs(Q))

            # Advanced measures for long-time analysis
            # Participation ratio (measure of localization)
            participation_ratio = 1.0 / np.sum(prob_density**2)

            # Coherence length
            coherence_length = self._calculate_coherence_length(psi)

            # Energy
            energy = self._calculate_total_energy(psi)

            # Classical comparison
            x_classical = self._calculate_classical_reference(current_time)
            classical_deviation = abs(position - x_classical)

            # Long-term trends (if enough history)
            trends = self._calculate_trends()

            diagnostics.update({
                'position': position,
                'uncertainty': uncertainty,
                'shannon_entropy': shannon_entropy,
                'quantum_potential_mean': Q_mean,
                'quantum_potential_max': Q_max,
                'participation_ratio': participation_ratio,
                'coherence_length': coherence_length,
                'total_energy': energy,
                'classical_reference': x_classical,
                'classical_deviation': classical_deviation,
                'stability_measure': step_diagnostics['stability_measure'],
                'norm_change': step_diagnostics['norm_change'],
                'correction_strength': step_diagnostics['correction_strength'],
                'trends': trends
            })

            # Enhanced correction diagnostics
            if hasattr(self.correction, 'get_enhanced_diagnostics'):
                enhanced_diag = self.correction.get_enhanced_diagnostics()
                diagnostics.update({f'enhanced_{k}': v for k, v in enhanced_diag.items()})

        except Exception as e:
            warnings.warn(f"Comprehensive diagnostics calculation failed: {e}")

        return diagnostics

    def _calculate_coherence_length(self, psi: np.ndarray) -> float:
        """Calculate coherence length of the wavefunction."""
        try:
            # Spatial correlation function
            phase = np.angle(psi)
            phase_unwrapped = np.unwrap(phase)

            # Calculate correlation length from phase fluctuations
            phase_var = np.var(np.gradient(phase_unwrapped, self.physics.dx))
            coherence_length = 1.0 / np.sqrt(phase_var + 1e-12)

            return min(coherence_length, np.max(self.physics.x) - np.min(self.physics.x))

        except Exception:
            return 0.0

    def _calculate_total_energy(self, psi: np.ndarray) -> float:
        """Calculate total energy of the quantum state."""
        try:
            # Kinetic energy
            psi_k = np.fft.fft(psi)
            k = np.fft.fftfreq(len(psi), d=self.physics.dx) * 2 * np.pi
            kinetic_energy = np.sum(np.abs(psi_k)**2 * k**2) * self.physics.hbar**2 / (2 * self.physics.mass)

            # Potential energy
            V = self.physics.double_well_potential()
            potential_energy = np.trapz(V * np.abs(psi)**2, self.physics.x)

            return kinetic_energy + potential_energy

        except Exception:
            return 0.0

    def _calculate_classical_reference(self, current_time: float) -> float:
        """Calculate classical reference position at given time."""
        # Simple harmonic motion approximation for classical reference
        # This can be made more sophisticated based on the specific system
        omega = 1.0  # Characteristic frequency
        amplitude = 2.0  # Characteristic amplitude
        return amplitude * np.cos(omega * current_time) * np.exp(-0.1 * current_time)

    def _calculate_trends(self) -> Dict[str, float]:
        """Calculate long-term trends from diagnostic history."""
        trends = {}

        try:
            if len(self.diagnostics_history) >= 10:
                recent_data = list(self.diagnostics_history)[-10:]

                # Position trend
                positions = [d.get('position', 0) for d in recent_data]
                position_trend = np.polyfit(range(len(positions)), positions, 1)[0] if len(positions) > 1 else 0.0

                # Uncertainty trend
                uncertainties = [d.get('uncertainty', 0) for d in recent_data]
                uncertainty_trend = np.polyfit(range(len(uncertainties)), uncertainties, 1)[0] if len(uncertainties) > 1 else 0.0

                # Energy trend
                energies = [d.get('total_energy', 0) for d in recent_data]
                energy_trend = np.polyfit(range(len(energies)), energies, 1)[0] if len(energies) > 1 else 0.0

                trends = {
                    'position_trend': position_trend,
                    'uncertainty_trend': uncertainty_trend,
                    'energy_trend': energy_trend
                }

        except Exception:
            pass

        return trends

    def _detect_transition_events(self, step: int, current_time: float, diagnostics: Dict) -> List[TransitionEvent]:
        """Detect significant transition events in the simulation."""
        events = []

        try:
            # Check for significant localization event
            uncertainty = diagnostics.get('uncertainty', 0)
            if len(self.diagnostics_history) >= 5:
                recent_uncertainties = [d.get('uncertainty', 0) for d in list(self.diagnostics_history)[-5:]]
                uncertainty_drop = max(recent_uncertainties) - uncertainty

                if uncertainty_drop > self.transition_threshold:
                    events.append(TransitionEvent(
                        event_type='localization',
                        step=step,
                        time=current_time,
                        magnitude=uncertainty_drop,
                        duration=5,
                        description=f'Localization event: Δσ = {uncertainty_drop:.3f}'
                    ))

            # Check for classical behavior emergence
            classical_deviation = diagnostics.get('classical_deviation', float('inf'))
            if classical_deviation < 0.5:  # Close to classical trajectory
                events.append(TransitionEvent(
                    event_type='classical_approach',
                    step=step,
                    time=current_time,
                    magnitude=classical_deviation,
                    duration=1,
                    description=f'Classical approach: deviation = {classical_deviation:.3f}'
                ))

            # Check for coherence loss
            coherence_length = diagnostics.get('coherence_length', float('inf'))
            if coherence_length < 2.0:  # Short coherence length
                events.append(TransitionEvent(
                    event_type='decoherence',
                    step=step,
                    time=current_time,
                    magnitude=1.0/coherence_length,
                    duration=1,
                    description=f'Decoherence: ξ = {coherence_length:.3f}'
                ))

        except Exception as e:
            warnings.warn(f"Event detection failed: {e}")

        return events

    def _adjust_timestep(self, step_diagnostics: Dict) -> float:
        """Adaptive timestep adjustment based on stability."""
        try:
            stability = step_diagnostics.get('stability_measure', 1.0)
            norm_change = step_diagnostics.get('norm_change', 0.0)

            # Adjust timestep based on stability
            if stability > 0.99 and norm_change < 1e-6:
                # Very stable, can increase timestep
                new_dt = min(self.current_dt * 1.05, self.max_dt)
            elif stability < 0.95 or norm_change > 1e-4:
                # Unstable, decrease timestep
                new_dt = max(self.current_dt * 0.9, self.min_dt)
            else:
                # Stable, keep current timestep
                new_dt = self.current_dt

            return new_dt

        except Exception:
            return self.current_dt

    def _check_stopping_criteria(self, diagnostics: Dict, criteria: Dict[str, float]) -> bool:
        """Check if early stopping criteria are met."""
        try:
            for criterion, threshold in criteria.items():
                if criterion in diagnostics:
                    if diagnostics[criterion] < threshold:
                        return True
            return False
        except Exception:
            return False

    def _print_progress(self, step: int, current_time: float, diagnostics: Dict):
        """Print simulation progress."""
        progress = step / self.max_steps * 100
        uncertainty = diagnostics.get('uncertainty', 0)
        classical_dev = diagnostics.get('classical_deviation', 0)
        energy = diagnostics.get('total_energy', 0)

        print(f"⏱️  Step {step:6d} ({progress:5.1f}%) | t={current_time:7.3f} | "
              f"σ={uncertainty:5.3f} | δ_cl={classical_dev:5.3f} | E={energy:7.3f}")

    def _save_checkpoint(self, psi: np.ndarray, step: int, current_time: float, diagnostics: Dict) -> str:
        """Save simulation checkpoint."""
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        checkpoint_file = os.path.join(self.output_dir, f"checkpoint_step_{step}_{timestamp}.pkl")

        checkpoint = SimulationCheckpoint(
            step=step,
            time=current_time,
            psi=psi.copy(),
            diagnostics=diagnostics,
            correction_state=self.correction.get_enhanced_diagnostics() if hasattr(self.correction, 'get_enhanced_diagnostics') else {},
            timestamp=timestamp
        )

        try:
            with open(checkpoint_file, 'wb') as f:
                pickle.dump(checkpoint, f, protocol=pickle.HIGHEST_PROTOCOL)
        except Exception as e:
            warnings.warn(f"Checkpoint save failed: {e}")

        return checkpoint_file

    def _save_final_results(self, simulation_data: Dict, final_psi: np.ndarray) -> str:
        """Save final simulation results."""
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

        if self.storage_format == 'hdf5':
            results_file = os.path.join(self.output_dir, f"extended_simulation_{timestamp}.h5")
            self._save_hdf5_results(results_file, simulation_data, final_psi)
        else:
            results_file = os.path.join(self.output_dir, f"extended_simulation_{timestamp}.pkl")
            self._save_pickle_results(results_file, simulation_data, final_psi)

        return results_file

    def _save_hdf5_results(self, filename: str, simulation_data: Dict, final_psi: np.ndarray):
        """Save results in HDF5 format for efficient storage."""
        try:
            with h5py.File(filename, 'w') as f:
                # Save configuration
                config_group = f.create_group('config')
                for key, value in simulation_data['config'].items():
                    config_group.attrs[key] = value

                # Save final state
                f.create_dataset('final_wavefunction', data=final_psi, compression='gzip', compression_opts=self.compression_level)

                # Save diagnostics timeline
                if simulation_data['diagnostics_timeline']:
                    diag_group = f.create_group('diagnostics')
                    first_diag = simulation_data['diagnostics_timeline'][0]

                    for key in first_diag.keys():
                        if key != 'trends':  # Skip nested dictionaries
                            values = [d.get(key, 0) for d in simulation_data['diagnostics_timeline']]
                            diag_group.create_dataset(key, data=values, compression='gzip', compression_opts=self.compression_level)

                # Save performance metrics
                perf_group = f.create_group('performance')
                for key, value in simulation_data['performance_metrics'].items():
                    perf_group.attrs[key] = value

        except Exception as e:
            warnings.warn(f"HDF5 save failed: {e}")

    def _save_pickle_results(self, filename: str, simulation_data: Dict, final_psi: np.ndarray):
        """Save results in pickle format."""
        try:
            simulation_data['final_wavefunction'] = final_psi
            with open(filename, 'wb') as f:
                pickle.dump(simulation_data, f, protocol=pickle.HIGHEST_PROTOCOL)
        except Exception as e:
            warnings.warn(f"Pickle save failed: {e}")

    def _manage_memory(self):
        """Manage memory usage during long simulations."""
        # Compress older diagnostics if memory is getting full
        if len(self.diagnostics_history) >= self.max_memory_snapshots * 0.9:
            # Keep only essential data for older entries
            compressed_history = deque(maxlen=self.max_memory_snapshots)

            for i, diag in enumerate(self.diagnostics_history):
                if i < len(self.diagnostics_history) // 2:
                    # Compress older entries
                    compressed_diag = {
                        'step': diag.get('step', 0),
                        'time': diag.get('time', 0),
                        'position': diag.get('position', 0),
                        'uncertainty': diag.get('uncertainty', 0),
                        'shannon_entropy': diag.get('shannon_entropy', 0)
                    }
                    compressed_history.append(compressed_diag)
                else:
                    # Keep full data for recent entries
                    compressed_history.append(diag)

            self.diagnostics_history = compressed_history