"""Comprehensive Test Suite for Enhanced Quantum-Classical Boundary Emergence Simulator.

This test suite validates all components of the enhanced simulator to ensure
scientific accuracy and numerical stability. Tests include:

1. Physics module tests (quantum potential, Lindblad evolution, etc.)
2. Diagnostics module tests (entropy calculations, fidelity measures, etc.)
3. Phase space analysis tests (Wigner functions, negativity, etc.)
4. Correction mechanism tests (all correction types)
5. Integration tests (full simulation workflows)
6. 2D system tests
7. Numerical accuracy and conservation law tests
"""

import numpy as np
import matplotlib.pyplot as plt
import unittest
import sys
import os
import warnings
from scipy.integrate import simpson as simps

# Import all modules
from physics import QuantumPhysics
from diagnostics import QuantumDiagnostics
from phase_space import PhaseSpaceAnalysis
from correction import CorrectionMechanism
from simulation_control import QuantumClassicalSimulator
from system_2d import QuantumSystem2D

class TestQuantumPhysics(unittest.TestCase):
    """Test the physics module components."""

    def setUp(self):
        """Set up test fixtures."""
        self.physics = QuantumPhysics(grid_size=128, x_range=(-5, 5))

    def test_double_well_potential(self):
        """Test double-well potential shape and properties."""
        V = self.physics.double_well_potential()

        # Check that minima are at approximately ±2
        x = self.physics.x
        min_indices = np.where((V[1:-1] < V[:-2]) & (V[1:-1] < V[2:]))[0] + 1

        if len(min_indices) >= 2:
            min_positions = x[min_indices]
            self.assertAlmostEqual(min_positions[0], -2.0, places=1)
            self.assertAlmostEqual(min_positions[-1], 2.0, places=1)

        # Check barrier at center
        center_idx = len(x) // 2
        self.assertGreater(V[center_idx], V[min_indices].max())

    def test_quantum_potential_calculation(self):
        """Test quantum potential calculation for known states."""
        # Create a Gaussian wavefunction
        psi_gaussian = np.exp(-self.physics.x**2)
        psi_gaussian = self.physics.normalize_wavefunction(psi_gaussian)

        Q = self.physics.quantum_potential(psi_gaussian)

        # For Gaussian, quantum potential should have specific form
        # Q = -ℏ²/(2m) * (2x²/σ² - 1/σ²) for Gaussian with width σ
        self.assertTrue(np.isfinite(Q).all())
        self.assertFalse(np.isnan(Q).any())

    def test_split_step_evolution(self):
        """Test unitary evolution and probability conservation."""
        psi_initial = np.exp(-(self.physics.x + 2)**2) + np.exp(-(self.physics.x - 2)**2)
        psi_initial = self.physics.normalize_wavefunction(psi_initial)

        initial_norm = simps(np.abs(psi_initial)**2, self.physics.x)

        # Evolve for several time steps
        psi = psi_initial.copy()
        potential = self.physics.double_well_potential()
        dt = 0.01

        for _ in range(10):
            psi = self.physics.split_step_evolution(psi, potential, dt)

        final_norm = simps(np.abs(psi)**2, self.physics.x)

        # Check probability conservation
        self.assertAlmostEqual(initial_norm, final_norm, places=6)

    def test_lindblad_evolution(self):
        """Test Lindblad master equation evolution."""
        # Create initial pure state
        psi_initial = np.exp(-self.physics.x**2)
        psi_initial = self.physics.normalize_wavefunction(psi_initial)
        rho_initial = self.physics.wavefunction_to_density_matrix(psi_initial)

        # Create simple Hamiltonian and Lindblad operators
        H = np.diag(self.physics.double_well_potential())
        lindblad_ops = self.physics.create_lindblad_operators(gamma_x=0.01, gamma_p=0.01)

        # Evolve
        rho = rho_initial.copy()
        dt = 0.001

        for _ in range(10):
            rho = self.physics.lindblad_evolution(rho, H, lindblad_ops, dt)

        # Check trace conservation
        self.assertAlmostEqual(np.trace(rho).real, 1.0, places=6)

        # Check that purity decreases (system becomes mixed)
        initial_purity = np.trace(rho_initial @ rho_initial).real
        final_purity = np.trace(rho @ rho).real
        self.assertLessEqual(final_purity, initial_purity)

    def test_measurement_collapse(self):
        """Test measurement-induced wavefunction collapse."""
        # Create superposition state
        psi = np.exp(-(self.physics.x + 2)**2) + np.exp(-(self.physics.x - 2)**2)
        psi = self.physics.normalize_wavefunction(psi)

        # Apply measurement collapse at x = 2
        psi_collapsed = self.physics.apply_measurement_collapse(psi, 2.0, collapse_width=1.0)

        # Check normalization
        norm = simps(np.abs(psi_collapsed)**2, self.physics.x)
        self.assertAlmostEqual(norm, 1.0, places=6)

        # Check that wavefunction is more localized around x = 2
        x_mean_initial = simps(self.physics.x * np.abs(psi)**2, self.physics.x)
        x_mean_collapsed = simps(self.physics.x * np.abs(psi_collapsed)**2, self.physics.x)

        self.assertGreater(x_mean_collapsed, x_mean_initial)


class TestQuantumDiagnostics(unittest.TestCase):
    """Test the diagnostics module."""

    def setUp(self):
        """Set up test fixtures."""
        self.x = np.linspace(-5, 5, 128)
        self.dx = self.x[1] - self.x[0]
        self.diagnostics = QuantumDiagnostics(self.x, self.dx)

    def test_shannon_entropy(self):
        """Test Shannon entropy calculation."""
        # Localized state should have low entropy
        psi_localized = np.exp(-10 * (self.x - 2)**2)
        psi_localized /= np.sqrt(simps(np.abs(psi_localized)**2, self.x))

        # Delocalized state should have high entropy
        psi_delocalized = np.ones_like(self.x) / np.sqrt(len(self.x))

        entropy_localized = self.diagnostics.shannon_entropy(psi_localized)
        entropy_delocalized = self.diagnostics.shannon_entropy(psi_delocalized)

        self.assertLess(entropy_localized, entropy_delocalized)
        self.assertGreater(entropy_localized, 0)

    def test_von_neumann_entropy(self):
        """Test Von Neumann entropy calculation."""
        # Pure state should have zero entropy
        psi_pure = np.exp(-self.x**2)
        psi_pure /= np.sqrt(simps(np.abs(psi_pure)**2, self.x))
        rho_pure = np.outer(psi_pure, np.conj(psi_pure))

        entropy_pure = self.diagnostics.von_neumann_entropy(rho_pure)
        self.assertAlmostEqual(entropy_pure, 0.0, places=3)

        # Mixed state should have positive entropy
        rho_mixed = 0.5 * rho_pure + 0.5 * np.eye(len(self.x)) / len(self.x)
        entropy_mixed = self.diagnostics.von_neumann_entropy(rho_mixed)
        self.assertGreater(entropy_mixed, 0)

    def test_fidelity_calculations(self):
        """Test fidelity measures."""
        psi1 = np.exp(-(self.x - 1)**2)
        psi1 /= np.sqrt(simps(np.abs(psi1)**2, self.x))

        psi2 = np.exp(-(self.x + 1)**2)
        psi2 /= np.sqrt(simps(np.abs(psi2)**2, self.x))

        # Self-fidelity should be 1
        fidelity_self = self.diagnostics.fidelity_initial_state(psi1, psi1)
        self.assertAlmostEqual(fidelity_self, 1.0, places=6)

        # Orthogonal states should have low fidelity
        fidelity_orthogonal = self.diagnostics.fidelity_initial_state(psi1, psi2)
        self.assertLess(fidelity_orthogonal, 0.1)

    def test_uncertainty_relations(self):
        """Test Heisenberg uncertainty relation."""
        # Create various quantum states and check uncertainty relation
        for width in [0.5, 1.0, 2.0]:
            psi = np.exp(-self.x**2 / (2 * width**2))
            psi /= np.sqrt(simps(np.abs(psi)**2, self.x))

            uncertainty_product = self.diagnostics.heisenberg_uncertainty_product(psi)

            # Should satisfy ΔxΔp ≥ ℏ/2 (with ℏ = 1)
            self.assertGreaterEqual(uncertainty_product, 0.45)  # Allow small numerical error


class TestPhaseSpaceAnalysis(unittest.TestCase):
    """Test the phase space analysis module."""

    def setUp(self):
        """Set up test fixtures."""
        self.x = np.linspace(-5, 5, 64)  # Smaller grid for faster tests
        self.k = np.fft.fftfreq(len(self.x), d=self.x[1] - self.x[0]) * 2 * np.pi
        self.phase_space = PhaseSpaceAnalysis(self.x, self.k)

    def test_wigner_function_properties(self):
        """Test basic properties of Wigner function."""
        # Create Gaussian wavefunction
        psi = np.exp(-self.x**2)
        psi /= np.sqrt(simps(np.abs(psi)**2, self.x))

        wigner = self.phase_space.compute_wigner_function(psi)

        # Check that it's real
        self.assertTrue(np.isreal(wigner).all())

        # Check marginal distributions
        pos_marginal, mom_marginal = self.phase_space.marginal_distributions(wigner)

        # Position marginal should match |ψ|²
        expected_pos_marginal = np.abs(psi)**2

        # Normalize both for comparison
        pos_marginal /= simps(pos_marginal, self.x)
        expected_pos_marginal /= simps(expected_pos_marginal, self.x)

        # Should be approximately equal (allowing for numerical errors)
        correlation = np.corrcoef(pos_marginal, expected_pos_marginal)[0, 1]
        self.assertGreater(correlation, 0.95)

    def test_wigner_negativity_coherent_vs_superposition(self):
        """Test that superposition states have more Wigner negativity."""
        # Coherent state (should have less negativity)
        psi_coherent = np.exp(-(self.x - 1)**2)
        psi_coherent /= np.sqrt(simps(np.abs(psi_coherent)**2, self.x))

        # Superposition state (should have more negativity)
        psi_superposition = (np.exp(-(self.x + 2)**2) + np.exp(-(self.x - 2)**2))
        psi_superposition /= np.sqrt(simps(np.abs(psi_superposition)**2, self.x))

        wigner_coherent = self.phase_space.compute_wigner_function(psi_coherent)
        wigner_superposition = self.phase_space.compute_wigner_function(psi_superposition)

        negativity_coherent = self.phase_space.classify_quantum_classical_regions(wigner_coherent)
        negativity_superposition = self.phase_space.classify_quantum_classical_regions(wigner_superposition)

        # Superposition should have less classical fraction
        self.assertLess(negativity_superposition['classical_fraction'],
                       negativity_coherent['classical_fraction'])


class TestCorrectionMechanisms(unittest.TestCase):
    """Test correction mechanism implementations."""

    def setUp(self):
        """Set up test fixtures."""
        self.x = np.linspace(-5, 5, 128)
        self.dx = self.x[1] - self.x[0]
        self.dt = 0.01
        self.correction = CorrectionMechanism(self.x, self.dt)

    def test_basic_correction(self):
        """Test basic correction mechanism."""
        psi = np.exp(-(self.x - 1)**2)  # Wavefunction centered at x=1
        psi /= np.sqrt(simps(np.abs(psi)**2, self.x))

        x_classical = 0.0  # Classical reference at origin
        lambda_strength = 0.1

        V_corr, corr_strength = self.correction.basic_correction(psi, x_classical, lambda_strength)

        # Correction should be non-zero since quantum ≠ classical position
        self.assertNotAlmostEqual(corr_strength, 0.0)

        # Correction should push toward classical position
        x_quantum = simps(self.x * np.abs(psi)**2, self.x)
        self.assertLess(corr_strength * (x_quantum - x_classical), 0)  # Restoring force

    def test_correction_memory(self):
        """Test that correction mechanisms store history properly."""
        psi = np.exp(-self.x**2)
        psi /= np.sqrt(simps(np.abs(psi)**2, self.x))

        # Apply corrections multiple times
        for i in range(5):
            x_classical = i * 0.1
            V_corr, _ = self.correction.basic_correction(psi, x_classical, 0.1)

        # Check that history is stored
        self.assertEqual(len(self.correction.expectation_history), 5)
        self.assertEqual(len(self.correction.classical_history), 5)

    def test_adaptive_correction(self):
        """Test adaptive correction with quantum potential."""
        # Create state with varying quantum potential
        psi = (np.exp(-(self.x + 2)**2) + np.exp(-(self.x - 2)**2))
        psi /= np.sqrt(simps(np.abs(psi)**2, self.x))

        # Mock quantum potential (higher in interference region)
        quantum_potential = np.exp(-self.x**2 / 2)

        V_corr, corr_strength = self.correction.adaptive_quantum_potential_correction(
            psi, 0.0, quantum_potential, 0.1)

        self.assertIsNotNone(V_corr)
        self.assertIsNotNone(corr_strength)
        self.assertTrue(hasattr(self.correction, 'adaptive_gain'))


class TestSystem2D(unittest.TestCase):
    """Test 2D system extension."""

    def setUp(self):
        """Set up 2D test fixtures."""
        self.system_2d = QuantumSystem2D(grid_size=32)  # Small grid for fast tests

    def test_2d_potential(self):
        """Test 2D double-well potential."""
        V = self.system_2d.double_well_potential_2d()

        # Check shape
        self.assertEqual(V.shape, (32, 32))

        # Check that potential has minima at corners
        self.assertTrue(np.isfinite(V).all())

    def test_2d_initial_states(self):
        """Test 2D initial state creation."""
        state_types = ['four_gaussian', 'central_gaussian', 'ring_state', 'separable_double']

        for state_type in state_types:
            psi = self.system_2d.create_initial_state_2d(state_type)

            # Check normalization
            norm = simps(simps(np.abs(psi)**2, self.system_2d.y), self.system_2d.x)
            self.assertAlmostEqual(norm, 1.0, places=4)

    def test_2d_evolution(self):
        """Test 2D time evolution."""
        psi_initial = self.system_2d.create_initial_state_2d('central_gaussian')
        psi = psi_initial.copy()

        # Evolve for a few steps
        V = self.system_2d.double_well_potential_2d()
        dt = 0.01

        initial_norm = simps(simps(np.abs(psi)**2, self.system_2d.y), self.system_2d.x)

        for _ in range(5):
            psi = self.system_2d.split_step_evolution_2d(psi, V, dt)

        final_norm = simps(simps(np.abs(psi)**2, self.system_2d.y), self.system_2d.x)

        # Check probability conservation
        self.assertAlmostEqual(initial_norm, final_norm, places=4)

    def test_2d_expectation_values(self):
        """Test 2D expectation value calculations."""
        # Create state centered away from origin
        x_offset, y_offset = 1.0, -0.5
        psi = np.exp(-((self.system_2d.X - x_offset)**2 + (self.system_2d.Y - y_offset)**2))
        psi = self.system_2d.normalize_wavefunction_2d(psi)

        expectations = self.system_2d.expectation_values_2d(psi)

        # Check that expectation values are near the offset
        self.assertAlmostEqual(expectations['x_mean'], x_offset, places=1)
        self.assertAlmostEqual(expectations['y_mean'], y_offset, places=1)


class TestIntegration(unittest.TestCase):
    """Integration tests for complete simulation workflows."""

    def test_basic_simulation_workflow(self):
        """Test basic simulation from start to finish."""
        config = {
            'grid_size': 64,
            'x_range': (-5, 5),
            'dt': 0.02,
            'n_steps': 50,
            'lambda_values': [0.0, 0.05],
            'correction_type': 'basic',
            'diagnostic_sampling': 10,
            'output_dir': 'test_output',
            'save_raw_data': False
        }

        simulator = QuantumClassicalSimulator(config)

        # Run single simulation
        result = simulator.run_single_simulation(0.05, save_evolution=False)

        # Check that results contain expected keys
        required_keys = ['lambda', 'diagnostics', 'final_state']
        for key in required_keys:
            self.assertIn(key, result)

        # Check that diagnostics were collected
        self.assertGreater(len(result['diagnostics']), 0)

        # Check that final state is normalized
        final_psi = result['final_state']
        physics = simulator.physics
        norm = simps(np.abs(final_psi)**2, physics.x)
        self.assertAlmostEqual(norm, 1.0, places=4)


class TestNumericalAccuracy(unittest.TestCase):
    """Test numerical accuracy and stability."""

    def test_energy_conservation(self):
        """Test energy conservation in quantum evolution."""
        physics = QuantumPhysics(grid_size=128, x_range=(-5, 5))

        # Create initial state
        psi = np.exp(-(physics.x + 1)**2)
        psi = physics.normalize_wavefunction(psi)

        # Calculate initial energy
        V = physics.double_well_potential()

        # Kinetic energy using finite differences
        kinetic_initial = -0.5 * physics.hbar**2 / physics.mass * simps(
            np.conj(psi) * np.gradient(np.gradient(psi, physics.dx), physics.dx), physics.x).real
        potential_initial = simps(V * np.abs(psi)**2, physics.x)
        energy_initial = kinetic_initial + potential_initial

        # Evolve without correction (should conserve energy)
        dt = 0.001
        for _ in range(100):
            psi = physics.split_step_evolution(psi, V, dt)

        # Calculate final energy
        kinetic_final = -0.5 * physics.hbar**2 / physics.mass * simps(
            np.conj(psi) * np.gradient(np.gradient(psi, physics.dx), physics.dx), physics.x).real
        potential_final = simps(V * np.abs(psi)**2, physics.x)
        energy_final = kinetic_final + potential_final

        # Energy should be approximately conserved
        energy_error = abs(energy_final - energy_initial) / abs(energy_initial)
        self.assertLess(energy_error, 0.01)  # 1% error tolerance

    def test_time_step_convergence(self):
        """Test convergence with decreasing time step."""
        physics = QuantumPhysics(grid_size=128, x_range=(-5, 5))
        psi_initial = np.exp(-physics.x**2)
        psi_initial = physics.normalize_wavefunction(psi_initial)

        V = physics.double_well_potential()
        final_time = 0.1

        results = {}

        for dt in [0.01, 0.005, 0.0025]:
            psi = psi_initial.copy()
            n_steps = int(final_time / dt)

            for _ in range(n_steps):
                psi = physics.split_step_evolution(psi, V, dt)

            results[dt] = psi.copy()

        # Compare results - smaller time steps should converge
        overlap_1 = abs(simps(np.conj(results[0.01]) * results[0.005], physics.x))**2
        overlap_2 = abs(simps(np.conj(results[0.005]) * results[0.0025], physics.x))**2

        # Overlaps should be high (convergence)
        self.assertGreater(overlap_1, 0.99)
        self.assertGreater(overlap_2, 0.995)


def run_performance_benchmarks():
    """Run performance benchmarks for the simulator."""
    import time

    print("\nRunning Performance Benchmarks...")
    print("=" * 50)

    # Benchmark 1: Basic physics operations
    physics = QuantumPhysics(grid_size=512, x_range=(-10, 10))
    psi = np.exp(-physics.x**2)
    psi = physics.normalize_wavefunction(psi)

    # Time split-step evolution
    start_time = time.time()
    V = physics.double_well_potential()
    for _ in range(100):
        psi = physics.split_step_evolution(psi, V, 0.01)
    evolution_time = time.time() - start_time

    print(f"Split-step evolution (100 steps): {evolution_time:.3f} seconds")

    # Benchmark 2: Wigner function calculation
    phase_space = PhaseSpaceAnalysis(physics.x[:128], physics.k[:128])  # Smaller for speed
    psi_small = psi[:128]

    start_time = time.time()
    wigner = phase_space.compute_wigner_function(psi_small)
    wigner_time = time.time() - start_time

    print(f"Wigner function calculation: {wigner_time:.3f} seconds")

    # Benchmark 3: Full diagnostic calculation
    diagnostics = QuantumDiagnostics(physics.x, physics.dx)
    rho = physics.wavefunction_to_density_matrix(psi)

    start_time = time.time()
    diag_results = diagnostics.comprehensive_diagnostics(psi, psi, 0.0, rho)
    diagnostics_time = time.time() - start_time

    print(f"Full diagnostics calculation: {diagnostics_time:.3f} seconds")
    print("=" * 50)


def main():
    """Run the complete test suite."""

    print("Enhanced Quantum-Classical Simulator Test Suite")
    print("=" * 60)

    # Suppress warnings for cleaner output
    warnings.filterwarnings('ignore')

    # Create test suite
    test_classes = [
        TestQuantumPhysics,
        TestQuantumDiagnostics,
        TestPhaseSpaceAnalysis,
        TestCorrectionMechanisms,
        TestSystem2D,
        TestIntegration,
        TestNumericalAccuracy
    ]

    suite = unittest.TestSuite()

    for test_class in test_classes:
        tests = unittest.TestLoader().loadTestsFromTestCase(test_class)
        suite.addTests(tests)

    # Run tests
    runner = unittest.TextTestRunner(verbosity=2)
    result = runner.run(suite)

    # Performance benchmarks
    run_performance_benchmarks()

    # Summary
    print(f"\nTest Summary:")
    print(f"Tests run: {result.testsRun}")
    print(f"Failures: {len(result.failures)}")
    print(f"Errors: {len(result.errors)}")

    if result.failures:
        print("\nFailures:")
        for test, traceback in result.failures:
            print(f"- {test}: {traceback}")

    if result.errors:
        print("\nErrors:")
        for test, traceback in result.errors:
            print(f"- {test}: {traceback}")

    success_rate = (result.testsRun - len(result.failures) - len(result.errors)) / result.testsRun * 100
    print(f"\nSuccess rate: {success_rate:.1f}%")

    if success_rate >= 95:
        print("🎉 Test suite PASSED! Simulator is ready for scientific use.")
    else:
        print("⚠️  Some tests failed. Please review and fix issues before use.")

    return result


if __name__ == "__main__":
    main()