"""Enhanced Quantum-Classical Boundary Emergence Simulator.

This is the main entry point for the scientifically rigorous and upgraded
quantum-classical boundary emergence simulation. It integrates all enhanced
modules and provides comprehensive analysis capabilities.

FEATURES:
=========
✅ Lindblad Master Equation Mode (Optional Noise Injection)
✅ Time-Delayed Feedback Corrections
✅ Quantum Potential Decomposition (Bohmian Overlay)
✅ Full Entropy Suite (Von Neumann + Mutual Information)
✅ Fidelity Tracking (Initial and Classical States)
✅ Correction Efficiency Metrics
✅ External Measurement Collapse Simulations
✅ Multi-Dimensional Extension (2D Double Well)
✅ Quantum-Classical Boundary Mapping
✅ Ensemble Averaging for Density Matrix Mode

Run with: python Enhanced_Quantum_Classical_Simulator.py
Or customize configurations and run programmatically.
"""

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib import cm
import os
import json
from datetime import datetime
import argparse
import sys

from simulation_control import QuantumClassicalSimulator

# === PREDEFINED SIMULATION CONFIGURATIONS ===

def get_basic_config():
    """Basic configuration for quick testing."""
    return {
        'grid_size': 256,
        'x_range': (-8, 8),
        'dt': 0.01,
        'n_steps': 500,
        'lambda_values': [0.0, 0.05, 0.1],
        'correction_type': 'basic',
        'use_lindblad': False,
        'use_ensemble': False,
        'diagnostic_sampling': 10,
        'output_dir': 'results_basic'
    }

def get_advanced_config():
    """Advanced configuration with all features enabled."""
    return {
        'grid_size': 512,
        'x_range': (-10, 10),
        'dt': 0.005,
        'n_steps': 1000,
        'lambda_values': [0.0, 0.01, 0.02, 0.05, 0.1, 0.2],
        'correction_type': 'adaptive',
        'correction_delay_steps': 5,

        # Enhanced features
        'use_lindblad': True,
        'gamma_x': 0.01,
        'gamma_p': 0.01,

        'use_ensemble': True,
        'ensemble_size': 20,
        'ensemble_phase_randomization': True,

        'use_measurement_collapse': True,
        'measurement_times': [2.5],
        'measurement_positions': [1.5],
        'collapse_width': 0.8,

        'diagnostic_sampling': 5,
        'save_wigner_evolution': True,
        'create_boundary_map': True,
        'save_raw_data': True,
        'output_dir': 'results_advanced'
    }

def get_research_config():
    """Research-grade configuration for publication-quality results."""
    return {
        'grid_size': 1024,
        'x_range': (-12, 12),
        'dt': 0.002,
        'n_steps': 2000,
        'lambda_values': np.logspace(-3, 0, 20),  # Logarithmic sampling
        'correction_type': 'predictive',
        'correction_delay_steps': 10,

        # Maximum feature set
        'use_lindblad': True,
        'gamma_x': 0.005,
        'gamma_p': 0.005,

        'use_ensemble': True,
        'ensemble_size': 50,
        'ensemble_phase_randomization': True,

        'use_measurement_collapse': False,  # Separate study

        'diagnostic_sampling': 2,
        'save_wigner_evolution': True,
        'create_boundary_map': True,
        'save_raw_data': True,
        'output_dir': 'results_research'
    }

def get_lindblad_study_config():
    """Configuration for studying Lindblad vs endogenous decoherence."""
    return {
        'grid_size': 512,
        'x_range': (-10, 10),
        'dt': 0.005,
        'n_steps': 1500,
        'lambda_values': [0.0, 0.02, 0.05, 0.1],
        'correction_type': 'adaptive',

        # Focus on Lindblad comparison
        'use_lindblad': True,
        'gamma_x': 0.02,
        'gamma_p': 0.01,

        'use_ensemble': True,
        'ensemble_size': 30,

        'diagnostic_sampling': 5,
        'create_boundary_map': True,
        'output_dir': 'results_lindblad_study'
    }

def get_2d_extension_config():
    """Configuration for 2D double-well extension."""
    return {
        'grid_size': 128,  # Smaller for 2D (total will be 128x128)
        'x_range': (-8, 8),
        'dt': 0.01,
        'n_steps': 800,
        'lambda_values': [0.0, 0.05, 0.1],
        'correction_type': 'basic',
        'use_2d_extension': True,

        'diagnostic_sampling': 10,
        'save_wigner_evolution': False,  # Too much data for 2D
        'create_boundary_map': True,
        'output_dir': 'results_2d'
    }

# === SPECIALIZED ANALYSIS FUNCTIONS ===

def compare_correction_mechanisms():
    """Compare different correction mechanisms side-by-side."""
    correction_types = ['basic', 'time_delayed', 'adaptive', 'predictive', 'nonlinear']

    comparison_results = {}

    for correction_type in correction_types:
        print(f"\nTesting {correction_type} correction...")

        config = get_basic_config()
        config['correction_type'] = correction_type
        config['lambda_values'] = [0.05]  # Single lambda for comparison
        config['output_dir'] = f'correction_comparison_{correction_type}'

        simulator = QuantumClassicalSimulator(config)
        results, _ = simulator.run_full_parameter_sweep()

        comparison_results[correction_type] = results[0]  # Single lambda result

    # Create comparison plots
    _plot_correction_comparison(comparison_results)

    return comparison_results

def study_measurement_induced_transitions():
    """Specialized study of measurement-induced quantum-to-classical transitions."""
    print("\nStudying measurement-induced transitions...")

    # Multiple measurement scenarios
    measurement_scenarios = [
        {'times': [], 'positions': []},  # No measurement
        {'times': [1.0], 'positions': [0.0]},  # Central measurement
        {'times': [1.0], 'positions': [2.0]},   # Right well measurement
        {'times': [0.5, 1.5], 'positions': [0.0, 2.0]},  # Multiple measurements
    ]

    results = {}

    for i, scenario in enumerate(measurement_scenarios):
        config = get_advanced_config()
        config['use_measurement_collapse'] = len(scenario['times']) > 0
        config['measurement_times'] = scenario['times']
        config['measurement_positions'] = scenario['positions']
        config['lambda_values'] = [0.0, 0.05]
        config['output_dir'] = f'measurement_study_scenario_{i}'

        simulator = QuantumClassicalSimulator(config)
        result, _ = simulator.run_full_parameter_sweep()
        results[f'scenario_{i}'] = result

    return results

def generate_publication_figures():
    """Generate high-quality figures suitable for publication."""
    print("\nGenerating publication-quality figures...")

    config = get_research_config()
    config['lambda_values'] = [0.0, 0.01, 0.05, 0.1]  # Reduced for clarity

    simulator = QuantumClassicalSimulator(config)
    results, output_dir = simulator.run_full_parameter_sweep()

    # Create enhanced visualizations
    _create_publication_plots(results, output_dir)

    return results, output_dir

def _plot_correction_comparison(comparison_results):
    """Create comparison plots for different correction mechanisms."""
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    axes = axes.flatten()

    metrics = ['shannon_entropy', 'position_uncertainty', 'fidelity_classical',
               'correction_efficiency_mean', 'quantum_potential_mean']

    for i, metric in enumerate(metrics[:5]):
        ax = axes[i]

        for correction_type, result in comparison_results.items():
            if 'diagnostics' in result:
                diagnostics = result['diagnostics']
                time_points = [d['time'] for d in diagnostics if 'time' in d]
                values = [d.get(metric, 0) for d in diagnostics]

                ax.plot(time_points, values, label=correction_type, linewidth=2)

        ax.set_xlabel('Time')
        ax.set_ylabel(metric.replace('_', ' ').title())
        ax.legend()
        ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig('correction_mechanisms_comparison.png', dpi=300, bbox_inches='tight')
    print("Correction mechanism comparison saved to: correction_mechanisms_comparison.png")

def _create_publication_plots(results, output_dir):
    """Create publication-quality plots."""
    pub_dir = os.path.join(output_dir, 'publication_figures')
    os.makedirs(pub_dir, exist_ok=True)

    # Plot 1: Entropy evolution for different lambda values
    fig, ax = plt.subplots(figsize=(10, 8))

    for result in results:
        if 'boundary_map' not in result:  # Skip boundary map entry
            lambda_val = result['lambda']
            diagnostics = result.get('averaged_diagnostics', result.get('diagnostics', []))

            if diagnostics:
                times = [d['time'] for d in diagnostics if 'time' in d]
                entropies = [d.get('von_neumann_entropy', d.get('shannon_entropy', 0))
                           for d in diagnostics]

                ax.plot(times, entropies, label=f'λ = {lambda_val:.3f}', linewidth=2)

    ax.set_xlabel('Time', fontsize=14)
    ax.set_ylabel('Von Neumann Entropy', fontsize=14)
    ax.set_title('Quantum-Classical Transition via Entropy Evolution', fontsize=16)
    ax.legend(fontsize=12)
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(pub_dir, 'entropy_evolution.pdf'), dpi=300, bbox_inches='tight')

    print(f"Publication figures saved to: {pub_dir}")

# === MAIN EXECUTION FUNCTIONS ===

def run_interactive_session():
    """Run interactive session with user choices."""
    print("Enhanced Quantum-Classical Boundary Emergence Simulator")
    print("=" * 60)
    print("Available simulation modes:")
    print("1. Basic (quick test)")
    print("2. Advanced (all features)")
    print("3. Research (publication quality)")
    print("4. Lindblad study")
    print("5. 2D extension")
    print("6. Compare correction mechanisms")
    print("7. Measurement transition study")
    print("8. Generate publication figures")

    choice = input("\nSelect simulation mode (1-8): ").strip()

    if choice == '1':
        config = get_basic_config()
    elif choice == '2':
        config = get_advanced_config()
    elif choice == '3':
        config = get_research_config()
    elif choice == '4':
        config = get_lindblad_study_config()
    elif choice == '5':
        config = get_2d_extension_config()
    elif choice == '6':
        return compare_correction_mechanisms()
    elif choice == '7':
        return study_measurement_induced_transitions()
    elif choice == '8':
        return generate_publication_figures()
    else:
        print("Invalid choice. Using basic configuration.")
        config = get_basic_config()

    # Run simulation
    simulator = QuantumClassicalSimulator(config)
    results, output_dir = simulator.run_full_parameter_sweep()

    print(f"\nSimulation completed! Results saved to: {output_dir}")
    return results, output_dir

def run_batch_mode(config_file=None):
    """Run in batch mode with configuration file."""
    if config_file and os.path.exists(config_file):
        with open(config_file, 'r') as f:
            config = json.load(f)
        print(f"Loaded configuration from: {config_file}")
    else:
        config = get_advanced_config()
        print("Using default advanced configuration")

    simulator = QuantumClassicalSimulator(config)
    results, output_dir = simulator.run_full_parameter_sweep()

    return results, output_dir

def create_example_configs():
    """Create example configuration files for users."""
    configs = {
        'basic_config.json': get_basic_config(),
        'advanced_config.json': get_advanced_config(),
        'research_config.json': get_research_config(),
        'lindblad_study_config.json': get_lindblad_study_config()
    }

    config_dir = 'example_configs'
    os.makedirs(config_dir, exist_ok=True)

    for filename, config in configs.items():
        # Convert numpy arrays to lists for JSON serialization
        json_config = {}
        for key, value in config.items():
            if isinstance(value, np.ndarray):
                json_config[key] = value.tolist()
            else:
                json_config[key] = value

        filepath = os.path.join(config_dir, filename)
        with open(filepath, 'w') as f:
            json.dump(json_config, f, indent=2)

    print(f"Example configuration files created in: {config_dir}")

if __name__ == "__main__":
    # Command line argument parsing
    parser = argparse.ArgumentParser(description='Enhanced Quantum-Classical Boundary Emergence Simulator')
    parser.add_argument('--mode', choices=['interactive', 'batch', 'examples'],
                       default='interactive', help='Execution mode')
    parser.add_argument('--config', type=str, help='Configuration file for batch mode')
    parser.add_argument('--preset', choices=['basic', 'advanced', 'research', 'lindblad', '2d'],
                       help='Use preset configuration')

    args = parser.parse_args()

    try:
        if args.mode == 'examples':
            create_example_configs()
        elif args.mode == 'batch':
            results, output_dir = run_batch_mode(args.config)
        elif args.preset:
            # Run with preset configuration
            preset_configs = {
                'basic': get_basic_config(),
                'advanced': get_advanced_config(),
                'research': get_research_config(),
                'lindblad': get_lindblad_study_config(),
                '2d': get_2d_extension_config()
            }

            config = preset_configs[args.preset]
            simulator = QuantumClassicalSimulator(config)
            results, output_dir = simulator.run_full_parameter_sweep()
            print(f"\nSimulation completed! Results saved to: {output_dir}")
        else:
            # Interactive mode
            results, output_dir = run_interactive_session()

    except KeyboardInterrupt:
        print("\n\nSimulation interrupted by user.")
        sys.exit(1)
    except Exception as e:
        print(f"\nError during simulation: {str(e)}")
        import traceback
        traceback.print_exc()
        sys.exit(1)

    print("\n" + "="*70)
    print("Enhanced Quantum-Classical Boundary Emergence Simulation Complete!")
    print("="*70)