#!/usr/bin/env python3
"""
Safe Black Hole Simulation Runner
=================================

This script provides a safe wrapper around the original black hole simulation,
adding comprehensive error handling, resource monitoring, and graceful fallbacks.

Features:
- Automatic grid size adjustment based on available memory
- CPU usage monitoring and yielding
- Memory leak prevention with garbage collection
- Timeout mechanisms for long-running operations
- Graceful degradation when system resources are stressed
- Comprehensive error logging and recovery

Usage:
    python run_safe_blackhole_simulation.py [options]

Options:
    --grid-size SIZE        Initial grid size (default: auto-detect)
    --max-memory GB         Maximum memory usage in GB (default: 2.0)
    --max-time SECONDS      Maximum execution time (default: 300)
    --time-steps STEPS      Number of time evolution steps (default: 5)
    --safe-mode            Use ultra-conservative resource limits
    --debug                Enable detailed logging
"""

import sys
import argparse
import logging
import traceback
import time
import json
from pathlib import Path

# Import our safety monitoring system
try:
    from system_monitor import ResourceSafetyWrapper, SafetyConfig
except ImportError:
    print("Error: system_monitor.py not found. Please ensure it's in the same directory.")
    sys.exit(1)

# Try to import the original simulation components
try:
    import numpy as np
    import matplotlib
    matplotlib.use('Agg')  # Use non-interactive backend
    import matplotlib.pyplot as plt
except ImportError as e:
    print(f"Error importing required packages: {e}")
    print("Please install required packages: numpy matplotlib")
    sys.exit(1)

class SafeBlackHoleSimulation:
    """
    Safe wrapper around the black hole simulation with comprehensive error handling
    """

    def __init__(self, safety_config: SafetyConfig):
        self.safety = ResourceSafetyWrapper(safety_config)
        self.config = safety_config
        self.results = {
            'initialization': {'success': False, 'error': None},
            'simulation': {'success': False, 'error': None, 'steps_completed': 0},
            'analysis': {'success': False, 'error': None},
            'visualization': {'success': False, 'error': None}
        }

        # Configure logging
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s',
            handlers=[
                logging.FileHandler('safe_simulation.log'),
                logging.StreamHandler()
            ]
        )
        self.logger = logging.getLogger(__name__)

    def safe_field_initialization(self, desired_grid_size: int) -> tuple:
        """Initialize quantum field with automatic size adjustment"""
        with self.safety.safe_operation("field_initialization"):
            # Get safe grid size based on available memory
            safe_grid_size = self.safety.get_safe_grid_size(
                desired_grid_size, dtype=np.complex128, dimensions=3
            )

            self.logger.info(f"Initializing field with grid size: {safe_grid_size}")

            # Allocate main quantum field
            field = self.safety.safe_array_allocation(
                (safe_grid_size, safe_grid_size, safe_grid_size),
                dtype=np.complex128
            )

            if field is None:
                raise RuntimeError("Could not allocate quantum field array")

            # Initialize with simple pattern to avoid crashes
            center = safe_grid_size // 2
            sigma = safe_grid_size / 8.0

            def init_operation(i, j, k):
                if not self.safety.should_continue():
                    return 0.0

                r_sq = (i - center)**2 + (j - center)**2 + (k - center)**2
                value = np.exp(-r_sq / (2 * sigma**2)) * np.exp(1j * np.pi * r_sq / (4 * sigma**2))
                return value

            # Use safe nested loop for initialization
            self.logger.info("Initializing field values...")
            init_results = self.safety.safe_nested_loop_execution(
                [safe_grid_size, safe_grid_size, safe_grid_size],
                lambda i, j, k: (i, j, k, init_operation(i, j, k)),
                max_total_iterations=safe_grid_size**3
            )

            # Populate field from results
            for i, j, k, value in init_results:
                field[i, j, k] = value

            # Initialize tracking arrays
            entangled_pairs = {}
            hawking_emissions = []

            self.logger.info(f"Field initialized successfully: {field.shape}")
            return field, entangled_pairs, hawking_emissions, safe_grid_size

    def safe_time_evolution(self, field, grid_size: int, num_steps: int) -> list:
        """Perform time evolution with safety checks"""
        evolution_data = []

        with self.safety.safe_operation("time_evolution"):
            self.logger.info(f"Starting time evolution: {num_steps} steps")

            for step in range(num_steps):
                if not self.safety.should_continue():
                    self.logger.warning(f"Time evolution stopped early at step {step}")
                    break

                step_start_time = time.time()

                try:
                    # Simple, safe field evolution
                    evolution_factor = 0.005 * (1.0 + 0.05j * step)

                    # Process field in chunks to avoid memory issues
                    chunk_size = min(16, grid_size // 4)

                    for i in range(0, grid_size, chunk_size):
                        for j in range(0, grid_size, chunk_size):
                            for k in range(0, grid_size, chunk_size):
                                if not self.safety.should_continue():
                                    break

                                i_end = min(i + chunk_size, grid_size)
                                j_end = min(j + chunk_size, grid_size)
                                k_end = min(k + chunk_size, grid_size)

                                # Apply evolution
                                field[i:i_end, j:j_end, k:k_end] *= evolution_factor

                    # Calculate step metrics safely
                    total_energy = self.safety.safe_mathematical_operation(
                        lambda: np.sum(np.abs(field)**2).real
                    )

                    max_amplitude = self.safety.safe_mathematical_operation(
                        lambda: np.max(np.abs(field))
                    )

                    mean_phase = self.safety.safe_mathematical_operation(
                        lambda: np.mean(np.angle(field))
                    )

                    step_data = {
                        'step': step,
                        'time': step * 0.1,
                        'total_energy': float(total_energy) if total_energy else 0.0,
                        'max_amplitude': float(max_amplitude) if max_amplitude else 0.0,
                        'mean_phase': float(mean_phase) if mean_phase else 0.0,
                        'grid_size': grid_size,
                        'step_duration': time.time() - step_start_time
                    }

                    evolution_data.append(step_data)
                    self.results['simulation']['steps_completed'] = step + 1

                    if step % max(1, num_steps // 10) == 0:
                        self.logger.info(f"Completed step {step+1}/{num_steps} "
                                       f"(energy={step_data['total_energy']:.3e})")

                except Exception as e:
                    self.logger.error(f"Error in time evolution step {step}: {e}")
                    # Continue with next step instead of crashing
                    continue

            self.logger.info(f"Time evolution completed: {len(evolution_data)} steps")
            return evolution_data

    def safe_analysis(self, evolution_data: list) -> dict:
        """Perform safe analysis of simulation results"""
        with self.safety.safe_operation("analysis"):
            analysis_results = {}

            try:
                if not evolution_data:
                    self.logger.warning("No evolution data to analyze")
                    return analysis_results

                # Extract time series data
                times = [d['time'] for d in evolution_data]
                energies = [d['total_energy'] for d in evolution_data]
                amplitudes = [d['max_amplitude'] for d in evolution_data]

                # Safe statistical analysis
                analysis_results['time_range'] = [min(times), max(times)] if times else [0, 0]
                analysis_results['energy_range'] = [min(energies), max(energies)] if energies else [0, 0]
                analysis_results['amplitude_range'] = [min(amplitudes), max(amplitudes)] if amplitudes else [0, 0]

                # Calculate trends safely
                if len(energies) > 1:
                    energy_trend = self.safety.safe_mathematical_operation(
                        lambda: (energies[-1] - energies[0]) / (times[-1] - times[0]) if times[-1] != times[0] else 0
                    )
                    analysis_results['energy_trend'] = energy_trend

                analysis_results['total_steps'] = len(evolution_data)
                analysis_results['mean_energy'] = np.mean(energies) if energies else 0.0
                analysis_results['energy_stability'] = np.std(energies) / np.mean(energies) if energies and np.mean(energies) > 0 else 0.0

                self.logger.info("Analysis completed successfully")

            except Exception as e:
                self.logger.error(f"Error in analysis: {e}")
                analysis_results['error'] = str(e)

            return analysis_results

    def safe_visualization(self, evolution_data: list, output_dir: str = ".") -> dict:
        """Create safe visualizations of results"""
        viz_results = {}

        with self.safety.safe_operation("visualization"):
            try:
                if not evolution_data:
                    self.logger.warning("No data to visualize")
                    return viz_results

                # Extract data for plotting
                times = [d['time'] for d in evolution_data]
                energies = [d['total_energy'] for d in evolution_data]
                amplitudes = [d['max_amplitude'] for d in evolution_data]

                # Create simple plots
                fig, axes = plt.subplots(2, 1, figsize=(10, 8))

                # Energy evolution
                axes[0].plot(times, energies, 'b-', linewidth=2, label='Total Energy')
                axes[0].set_xlabel('Time')
                axes[0].set_ylabel('Energy')
                axes[0].set_title('Energy Evolution')
                axes[0].grid(True, alpha=0.3)
                axes[0].legend()

                # Amplitude evolution
                axes[1].plot(times, amplitudes, 'r-', linewidth=2, label='Max Amplitude')
                axes[1].set_xlabel('Time')
                axes[1].set_ylabel('Amplitude')
                axes[1].set_title('Field Amplitude Evolution')
                axes[1].grid(True, alpha=0.3)
                axes[1].legend()

                plt.tight_layout()

                # Save plot
                output_file = Path(output_dir) / "safe_simulation_results.png"
                plt.savefig(output_file, dpi=150, bbox_inches='tight')
                plt.close()

                viz_results['plot_saved'] = str(output_file)
                self.logger.info(f"Visualization saved to: {output_file}")

            except Exception as e:
                self.logger.error(f"Error in visualization: {e}")
                viz_results['error'] = str(e)

            return viz_results

    def run_complete_simulation(self, grid_size: int = 32, num_steps: int = 5) -> dict:
        """Run the complete simulation with all safety measures"""
        self.logger.info("=== Starting Safe Black Hole Simulation ===")
        start_time = time.time()

        final_results = {
            'success': False,
            'stages_completed': [],
            'total_duration': 0,
            'resource_usage': {},
            'data': {}
        }

        try:
            # Stage 1: Initialization
            self.logger.info("Stage 1: Field Initialization")
            field, entangled_pairs, hawking_emissions, actual_grid_size = self.safe_field_initialization(grid_size)
            self.results['initialization']['success'] = True
            final_results['stages_completed'].append('initialization')
            final_results['data']['grid_size_used'] = actual_grid_size

            # Stage 2: Time Evolution
            self.logger.info("Stage 2: Time Evolution")
            evolution_data = self.safe_time_evolution(field, actual_grid_size, num_steps)
            self.results['simulation']['success'] = True
            final_results['stages_completed'].append('simulation')
            final_results['data']['evolution_data'] = evolution_data

            # Stage 3: Analysis
            self.logger.info("Stage 3: Analysis")
            analysis_results = self.safe_analysis(evolution_data)
            self.results['analysis']['success'] = True
            final_results['stages_completed'].append('analysis')
            final_results['data']['analysis'] = analysis_results

            # Stage 4: Visualization
            self.logger.info("Stage 4: Visualization")
            viz_results = self.safe_visualization(evolution_data)
            self.results['visualization']['success'] = True
            final_results['stages_completed'].append('visualization')
            final_results['data']['visualization'] = viz_results

            final_results['success'] = True
            self.logger.info("=== Simulation Completed Successfully ===")

        except KeyboardInterrupt:
            self.logger.warning("Simulation interrupted by user")
            final_results['error'] = "User interrupted"

        except Exception as e:
            self.logger.error(f"Critical error in simulation: {e}")
            self.logger.error(traceback.format_exc())
            final_results['error'] = str(e)

        finally:
            # Final cleanup and reporting
            final_results['total_duration'] = time.time() - start_time
            final_results['resource_usage'] = self.safety.get_current_resource_usage()
            final_results['detailed_results'] = self.results

            # Force garbage collection
            import gc
            gc.collect()

        return final_results

def main():
    """Main function with command line interface"""
    parser = argparse.ArgumentParser(description="Safe Black Hole Information Paradox Simulation")
    parser.add_argument('--grid-size', type=int, default=0, help='Initial grid size (0=auto)')
    parser.add_argument('--max-memory', type=float, default=2.0, help='Max memory in GB')
    parser.add_argument('--max-time', type=int, default=300, help='Max execution time in seconds')
    parser.add_argument('--time-steps', type=int, default=5, help='Number of time steps')
    parser.add_argument('--safe-mode', action='store_true', help='Ultra-conservative limits')
    parser.add_argument('--debug', action='store_true', help='Enable debug logging')

    args = parser.parse_args()

    # Configure safety limits
    if args.safe_mode:
        config = SafetyConfig(
            max_memory_gb=min(1.0, args.max_memory),
            max_cpu_percent=60.0,
            max_execution_time=min(120, args.max_time),
            memory_check_frequency=50
        )
    else:
        config = SafetyConfig(
            max_memory_gb=args.max_memory,
            max_cpu_percent=80.0,
            max_execution_time=args.max_time
        )

    if args.debug:
        logging.getLogger().setLevel(logging.DEBUG)

    # Auto-detect grid size if not specified
    if args.grid_size == 0:
        safety_temp = ResourceSafetyWrapper(config)
        args.grid_size = safety_temp.get_safe_grid_size(64)

    print("=== Safe Black Hole Information Paradox Simulation ===")
    print(f"Grid size: {args.grid_size}")
    print(f"Time steps: {args.time_steps}")
    print(f"Memory limit: {config.max_memory_gb}GB")
    print(f"Time limit: {config.max_execution_time}s")
    print(f"Safe mode: {args.safe_mode}")
    print()

    # Run simulation
    simulation = SafeBlackHoleSimulation(config)
    results = simulation.run_complete_simulation(args.grid_size, args.time_steps)

    # Print results
    print("\n=== SIMULATION RESULTS ===")
    print(f"Success: {results['success']}")
    print(f"Stages completed: {', '.join(results['stages_completed'])}")
    print(f"Total duration: {results['total_duration']:.2f}s")

    if 'error' in results:
        print(f"Error: {results['error']}")

    if results.get('resource_usage'):
        usage = results['resource_usage']
        print(f"Peak memory: {usage.get('memory_gb', 0):.2f}GB")
        print(f"CPU usage: {usage.get('cpu_percent', 0):.1f}%")

    # Save detailed results
    output_file = "safe_simulation_complete_results.json"
    with open(output_file, 'w') as f:
        json.dump(results, f, indent=2, default=str)
    print(f"Detailed results saved to: {output_file}")

    return 0 if results['success'] else 1

if __name__ == "__main__":
    sys.exit(main())