#!/usr/bin/env python3
"""
Robust test of the enhanced quantum-classical simulator.

This test verifies that the system can handle errors gracefully and
continues operation even when components fail.
"""

import numpy as np
import warnings
import sys
import traceback
from typing import Dict, Any

def test_robust_simulation():
    """Test the enhanced error-handling capabilities."""
    print("🧪 Testing Robust Quantum-Classical Simulator")
    print("=" * 60)

    success_count = 0
    total_tests = 0

    # Test 1: Basic physics operations
    total_tests += 1
    try:
        print("\n🔬 Test 1: Basic Physics Operations")
        from physics import QuantumPhysics

        physics = QuantumPhysics(grid_size=128)
        psi = np.exp(-physics.x**2)
        psi = physics.normalize_wavefunction(psi)

        # Test quantum potential with error handling
        Q = physics.quantum_potential(psi)

        # Test evolution
        potential = physics.double_well_potential()
        psi_evolved = physics.split_step_evolution(psi, potential, 0.01)

        print("  ✅ Basic physics operations successful")
        success_count += 1

    except Exception as e:
        print(f"  ❌ Basic physics test failed: {e}")

    # Test 2: Error handling with invalid inputs
    total_tests += 1
    try:
        print("\n🛡️ Test 2: Error Handling with Invalid Inputs")

        # Test with NaN wavefunction
        psi_invalid = np.array([np.nan, np.inf, 0.0, 1.0] * 32, dtype=complex)
        psi_recovered = physics.normalize_wavefunction(psi_invalid)

        # Test quantum potential with invalid wavefunction
        Q_invalid = physics.quantum_potential(psi_invalid)

        print("  ✅ Error handling successful - system didn't crash")
        success_count += 1

    except Exception as e:
        print(f"  ❌ Error handling test failed: {e}")

    # Test 3: Correction mechanisms
    total_tests += 1
    try:
        print("\n⚙️ Test 3: Correction Mechanisms")
        from correction import CorrectionMechanism

        correction = CorrectionMechanism(physics.x, 0.01)

        # Test basic correction
        V_corr, strength = correction.basic_correction(psi, 0.0, 0.1)

        # Test with invalid inputs
        V_corr_bad, strength_bad = correction.basic_correction(
            np.array([np.nan, np.inf]), np.nan, np.inf
        )

        print("  ✅ Correction mechanisms robust")
        success_count += 1

    except Exception as e:
        print(f"  ❌ Correction test failed: {e}")

    # Test 4: Input validation
    total_tests += 1
    try:
        print("\n📝 Test 4: Input Validation")
        from input_validator import ConfigValidator, create_safe_config

        # Test with invalid config
        bad_config = {
            'grid_size': -100,
            'dt': -1.0,
            'x_range': [5, -5],  # Invalid range
            'lambda_values': 'not_a_list',
            'mass': 0.0  # Invalid mass
        }

        validator = ConfigValidator()
        safe_config, warnings_list = validator.validate_config(bad_config)

        print(f"  ✅ Validation caught {len(warnings_list)} issues")
        print(f"     Sanitized config: grid_size={safe_config['grid_size']}")
        success_count += 1

    except Exception as e:
        print(f"  ❌ Validation test failed: {e}")

    # Test 5: Safe configuration creation
    total_tests += 1
    try:
        print("\n🔧 Test 5: Safe Configuration Creation")

        safe_config = create_safe_config(
            grid_size=64,
            time_steps=50,
            lambda_values=[0.0, 0.05]
        )

        print(f"  ✅ Safe config created: {safe_config['grid_size']}x{safe_config['n_steps']} simulation")
        success_count += 1

    except Exception as e:
        print(f"  ❌ Safe config test failed: {e}")

    # Test 6: Minimal simulation run
    total_tests += 1
    try:
        print("\n🚀 Test 6: Minimal Simulation Run")

        # Create minimal physics simulation
        psi = np.exp(-physics.x**2)
        psi = physics.normalize_wavefunction(psi)

        results = []
        for step in range(10):  # Very short simulation
            try:
                # Basic evolution
                potential = physics.double_well_potential()
                psi = physics.split_step_evolution(psi, potential, 0.02)

                # Basic diagnostics
                norm = np.sqrt(np.trapz(np.abs(psi)**2, physics.x))
                position = np.real(np.trapz(physics.x * np.abs(psi)**2, physics.x))

                results.append({
                    'step': step,
                    'norm': norm,
                    'position': position
                })

                # Check for instability
                if norm < 0.5 or norm > 2.0:
                    print(f"    Warning: norm became {norm:.3f} at step {step}")
                    # Renormalize
                    psi = physics.normalize_wavefunction(psi)

            except Exception as step_error:
                print(f"    Warning: Step {step} failed: {step_error}")
                continue

        print(f"  ✅ Completed {len(results)}/10 simulation steps")
        print(f"     Final position: {results[-1]['position']:.3f}")
        success_count += 1

    except Exception as e:
        print(f"  ❌ Minimal simulation failed: {e}")

    # Test 7: Stress test with extreme parameters
    total_tests += 1
    try:
        print("\n💥 Test 7: Stress Test with Extreme Parameters")

        # Try to break the system with extreme parameters
        extreme_physics = QuantumPhysics(
            grid_size=16,  # Very small
            x_range=(-1000, 1000),  # Very large range
            mass=1e-10,  # Very small mass
            hbar=1e10   # Very large hbar
        )

        # Create extreme wavefunction
        psi_extreme = np.ones(16, dtype=complex) * 1e6  # Very large amplitude
        psi_extreme = extreme_physics.normalize_wavefunction(psi_extreme)

        # Try extreme quantum potential
        Q_extreme = extreme_physics.quantum_potential(psi_extreme)

        print("  ✅ System survived extreme parameters")
        success_count += 1

    except Exception as e:
        print(f"  ❌ Stress test failed: {e}")

    # Summary
    print("\n" + "=" * 60)
    print(f"🎯 Test Summary: {success_count}/{total_tests} tests passed")
    print(f"   Success rate: {100*success_count/total_tests:.1f}%")

    if success_count == total_tests:
        print("🎉 All tests passed! System is robust and crash-resistant.")
        return True
    elif success_count >= total_tests * 0.8:
        print("✅ Most tests passed. System is reasonably robust.")
        return True
    else:
        print("⚠️ Several tests failed. System needs more work.")
        return False

def run_safe_demo():
    """Run a safe, minimal demonstration."""
    print("\n🔄 Running Safe Minimal Demonstration")
    print("-" * 40)

    try:
        # Import and test modules
        from physics import QuantumPhysics
        from input_validator import create_safe_config

        # Create safe configuration
        config = create_safe_config(grid_size=64, time_steps=20)
        print(f"✅ Safe config: {config['grid_size']} grid, {config['n_steps']} steps")

        # Initialize physics
        physics = QuantumPhysics(
            grid_size=config['grid_size'],
            x_range=config['x_range']
        )
        print("✅ Physics initialized")

        # Create initial state
        psi = np.exp(-physics.x**2)
        psi = physics.normalize_wavefunction(psi)
        print("✅ Initial wavefunction created")

        # Simple evolution
        results = []
        for step in range(config['n_steps']):
            try:
                # Evolution step
                potential = physics.double_well_potential()
                psi = physics.split_step_evolution(psi, potential, config['dt'])

                # Check stability
                norm = np.sqrt(np.trapz(np.abs(psi)**2, physics.x))
                if norm < 0.1 or norm > 10:
                    warnings.warn(f"Renormalizing at step {step} (norm={norm:.3f})")
                    psi = physics.normalize_wavefunction(psi)

                # Store results
                position = np.real(np.trapz(physics.x * np.abs(psi)**2, physics.x))
                results.append({'step': step, 'position': position, 'norm': norm})

                if step % 5 == 0:
                    print(f"  Step {step:2d}: position = {position:6.3f}")

            except Exception as step_error:
                warnings.warn(f"Step {step} failed: {step_error}")
                continue

        print(f"✅ Completed {len(results)} evolution steps")
        print(f"Final position: {results[-1]['position']:.3f}")

        return results

    except Exception as e:
        print(f"❌ Safe demo failed: {e}")
        traceback.print_exc()
        return None

if __name__ == "__main__":
    # Suppress excessive warnings for cleaner output
    warnings.filterwarnings("ignore", category=UserWarning)

    # Run tests
    robust = test_robust_simulation()

    # Run safe demo
    demo_results = run_safe_demo()

    if robust and demo_results:
        print("\n🎉 SUCCESS: Enhanced quantum-classical simulator is robust and functional!")
        print("The system can now handle errors gracefully without crashing.")
        sys.exit(0)
    else:
        print("\n⚠️ PARTIAL SUCCESS: System is more robust but may need additional work.")
        sys.exit(1)