"""
Systematic Parameter Optimization for Quantum-Classical Boundary Emergence
=========================================================================

This module implements comprehensive parameter optimization to find optimal
conditions for demonstrating clear quantum-to-classical transitions.
"""

import numpy as np
import warnings
from typing import Dict, List, Tuple, Any, Callable
from dataclasses import dataclass
from itertools import product
import json
import time
from concurrent.futures import ProcessPoolExecutor, as_completed
import os

@dataclass
class OptimizationMetrics:
    """Metrics for evaluating quantum-classical transition strength."""
    transition_strength: float = 0.0
    entropy_reduction: float = 0.0
    localization_increase: float = 0.0
    position_drift: float = 0.0
    coherence_loss: float = 0.0
    classical_score: float = 0.0
    overall_score: float = 0.0

class ParameterOptimizer:
    """Comprehensive parameter optimization for quantum-classical boundary emergence."""

    def __init__(self, base_config: Dict[str, Any] = None):
        """Initialize the parameter optimizer."""
        self.base_config = base_config or self._get_default_config()
        self.optimization_results = []
        self.best_parameters = None
        self.best_score = 0.0

    def _get_default_config(self) -> Dict[str, Any]:
        """Get default configuration for optimization."""
        return {
            'grid_size': 256,
            'x_range': (-8, 8),
            'dt': 0.005,
            'n_steps': 300,
            'correction_type': 'basic',
            'use_lindblad': False,
            'use_ensemble': False,
            'diagnostic_sampling': 5,
            'output_dir': 'optimization_results'
        }

    def define_parameter_space(self) -> Dict[str, List]:
        """Define the parameter space to explore for optimization."""
        return {
            # Lambda values - systematic exploration
            'lambda_values': [
                [0.0, 0.01], [0.0, 0.02], [0.0, 0.05], [0.0, 0.08],
                [0.0, 0.1], [0.0, 0.15], [0.0, 0.2], [0.0, 0.3],
                [0.0, 0.5], [0.0, 0.8], [0.0, 1.0]
            ],

            # Time step resolution
            'dt': [0.001, 0.002, 0.005, 0.01, 0.02],

            # Simulation length
            'n_steps': [200, 400, 600, 800, 1000, 1500],

            # Grid resolution
            'grid_size': [128, 256, 512],

            # Spatial range
            'x_range': [(-6, 6), (-8, 8), (-10, 10), (-12, 12)],

            # Correction types
            'correction_type': ['basic', 'adaptive', 'nonlinear'],

            # Initial state types
            'initial_state': ['asymmetric_superposition', 'coherent_superposition', 'random_superposition']
        }

    def grid_search_optimization(self, max_evaluations: int = 100) -> List[Dict]:
        """Perform systematic grid search over parameter space."""
        print("🔍 SYSTEMATIC PARAMETER SPACE OPTIMIZATION")
        print("=" * 60)

        param_space = self.define_parameter_space()

        # Generate parameter combinations
        param_names = list(param_space.keys())
        param_values = list(param_space.values())

        # Create smart sampling of parameter space
        combinations = self._smart_parameter_sampling(param_space, max_evaluations)

        print(f"📊 Evaluating {len(combinations)} parameter combinations")
        print("🎯 Optimizing for maximum quantum→classical transition strength")

        results = []

        for i, combo in enumerate(combinations):
            print(f"\n🔬 Evaluation {i+1}/{len(combinations)}")

            # Create configuration from combination
            config = self.base_config.copy()
            for param_name, param_value in zip(param_names, combo):
                config[param_name] = param_value

            try:
                # Evaluate this parameter set
                metrics = self._evaluate_parameter_set(config)

                result = {
                    'parameters': dict(zip(param_names, combo)),
                    'metrics': metrics,
                    'config': config,
                    'evaluation_id': i
                }

                results.append(result)

                # Track best result
                if metrics.overall_score > self.best_score:
                    self.best_score = metrics.overall_score
                    self.best_parameters = config.copy()

                print(f"   Overall Score: {metrics.overall_score:.3f}")
                print(f"   Transition Strength: {metrics.transition_strength:.3f}")
                print(f"   Classical Score: {metrics.classical_score:.3f}")

            except Exception as e:
                print(f"   ❌ Evaluation failed: {e}")
                continue

        self.optimization_results = results
        return results

    def _smart_parameter_sampling(self, param_space: Dict, max_samples: int) -> List[Tuple]:
        """Intelligent sampling of parameter space."""
        # Start with key combinations that are likely to show strong effects
        priority_combinations = []

        # High lambda values with various time resolutions
        for lambda_vals in [[0.0, 0.2], [0.0, 0.5], [0.0, 1.0]]:
            for dt in [0.001, 0.005, 0.01]:
                for n_steps in [600, 1000]:
                    priority_combinations.append((
                        lambda_vals, dt, n_steps, 256, (-8, 8), 'adaptive', 'asymmetric_superposition'
                    ))

        # Add systematic sampling of remaining space
        all_combinations = list(product(*param_space.values()))

        # Remove priority combinations from full list to avoid duplicates
        remaining_combinations = [c for c in all_combinations if c not in priority_combinations]

        # Sample remaining combinations
        if len(remaining_combinations) > max_samples - len(priority_combinations):
            np.random.seed(42)  # Reproducible sampling
            indices = np.random.choice(
                len(remaining_combinations),
                max_samples - len(priority_combinations),
                replace=False
            )
            sampled_remaining = [remaining_combinations[i] for i in indices]
        else:
            sampled_remaining = remaining_combinations

        return priority_combinations + sampled_remaining[:max_samples - len(priority_combinations)]

    def _evaluate_parameter_set(self, config: Dict[str, Any]) -> OptimizationMetrics:
        """Evaluate a single parameter set for quantum-classical transition strength."""

        # Import here to avoid circular dependencies
        from physics import QuantumPhysics
        from correction import CorrectionMechanism

        try:
            # Initialize simulation components
            physics = QuantumPhysics(
                grid_size=config['grid_size'],
                x_range=config['x_range']
            )
            correction = CorrectionMechanism(physics.x, config['dt'])

            # Create initial state based on configuration
            psi = self._create_initial_state(physics, config.get('initial_state', 'asymmetric_superposition'))
            psi_initial = psi.copy()

            # Run comparison simulations (quantum vs classical)
            lambda_vals = config['lambda_values'] if isinstance(config['lambda_values'], list) else [0.0, 0.1]

            quantum_result = self._run_evaluation_simulation(
                physics, correction, psi_initial.copy(), lambda_vals[0], config
            )

            classical_result = self._run_evaluation_simulation(
                physics, correction, psi_initial.copy(), lambda_vals[1], config
            )

            # Calculate transition metrics
            metrics = self._calculate_transition_metrics(quantum_result, classical_result)

            return metrics

        except Exception as e:
            warnings.warn(f"Parameter evaluation failed: {e}")
            return OptimizationMetrics()  # Return zeros

    def _create_initial_state(self, physics, state_type: str):
        """Create various initial states for testing."""
        x = physics.x

        if state_type == 'asymmetric_superposition':
            # Strong asymmetric superposition
            psi_left = 0.8 * np.exp(-(x + 3)**2 / 2)
            psi_right = 0.4 * np.exp(-(x - 3)**2 / 2)
            psi = psi_left + 1j * psi_right

        elif state_type == 'coherent_superposition':
            # Coherent superposition with momentum
            psi = np.exp(-(x + 2)**2) + np.exp(-(x - 2)**2)
            psi *= np.exp(1j * 0.5 * x)  # Add momentum

        elif state_type == 'random_superposition':
            # Random multi-component superposition
            psi = np.zeros_like(x, dtype=complex)
            np.random.seed(42)  # Reproducible
            for i in range(4):
                center = np.random.uniform(-4, 4)
                phase = np.random.uniform(0, 2*np.pi)
                amplitude = np.random.uniform(0.3, 1.0)
                component = amplitude * np.exp(-(x - center)**2) * np.exp(1j * phase)
                psi += component

        else:
            # Default Gaussian
            psi = np.exp(-x**2)

        return physics.normalize_wavefunction(psi)

    def _run_evaluation_simulation(self, physics, correction, psi_initial, lambda_val, config):
        """Run a single evaluation simulation."""
        psi = psi_initial.copy()

        evolution_data = {
            'positions': [],
            'uncertainties': [],
            'entropies': [],
            'quantum_potentials': [],
            'times': []
        }

        for step in range(config['n_steps']):
            current_time = step * config['dt']

            # Quantum evolution
            potential = physics.double_well_potential()
            psi = physics.split_step_evolution(psi, potential, config['dt'])

            # Apply correction
            if lambda_val > 0:
                try:
                    x_classical = 0.0  # Reference point

                    if config['correction_type'] == 'basic':
                        V_corr, _ = correction.basic_correction(psi, x_classical, lambda_val)
                    elif config['correction_type'] == 'adaptive':
                        Q = physics.quantum_potential(psi)
                        V_corr, _ = correction.adaptive_quantum_potential_correction(
                            psi, x_classical, Q, lambda_val
                        )
                    elif config['correction_type'] == 'nonlinear':
                        V_corr, _ = correction.nonlinear_correction(psi, x_classical, lambda_val)
                    else:
                        V_corr, _ = correction.basic_correction(psi, x_classical, lambda_val)

                    # Apply correction safely
                    if V_corr is not None and not (np.any(np.isnan(V_corr)) or np.any(np.isinf(V_corr))):
                        correction_factor = np.exp(-1j * V_corr * config['dt'])
                        if not (np.any(np.isnan(correction_factor)) or np.any(np.isinf(correction_factor))):
                            psi *= correction_factor

                except Exception:
                    pass  # Skip correction if it fails

            psi = physics.normalize_wavefunction(psi)

            # Collect diagnostics
            if step % config['diagnostic_sampling'] == 0:
                prob_density = np.abs(psi)**2

                # Position statistics
                position = np.trapz(physics.x * prob_density, physics.x)
                position_sq = np.trapz(physics.x**2 * prob_density, physics.x)
                uncertainty = np.sqrt(position_sq - position**2)

                # Entropy
                prob_norm = prob_density / np.sum(prob_density)
                entropy = -np.sum(prob_norm * np.log(prob_norm + 1e-16))

                # Quantum potential
                Q = physics.quantum_potential(psi)
                Q_mean = np.mean(np.abs(Q))

                evolution_data['positions'].append(position)
                evolution_data['uncertainties'].append(uncertainty)
                evolution_data['entropies'].append(entropy)
                evolution_data['quantum_potentials'].append(Q_mean)
                evolution_data['times'].append(current_time)

        return evolution_data

    def _calculate_transition_metrics(self, quantum_result, classical_result) -> OptimizationMetrics:
        """Calculate metrics that quantify quantum-to-classical transition strength."""

        metrics = OptimizationMetrics()

        try:
            # Convert to numpy arrays
            q_pos = np.array(quantum_result['positions'])
            c_pos = np.array(classical_result['positions'])
            q_unc = np.array(quantum_result['uncertainties'])
            c_unc = np.array(classical_result['uncertainties'])
            q_ent = np.array(quantum_result['entropies'])
            c_ent = np.array(classical_result['entropies'])
            q_qpot = np.array(quantum_result['quantum_potentials'])
            c_qpot = np.array(classical_result['quantum_potentials'])

            # 1. Position drift difference (classical should show more directed motion)
            q_drift = abs(q_pos[-1] - q_pos[0]) if len(q_pos) > 0 else 0
            c_drift = abs(c_pos[-1] - c_pos[0]) if len(c_pos) > 0 else 0
            metrics.position_drift = max(0, c_drift - q_drift)

            # 2. Localization increase (classical should become more localized)
            q_loc_change = q_unc[0] - q_unc[-1] if len(q_unc) > 1 else 0
            c_loc_change = c_unc[0] - c_unc[-1] if len(c_unc) > 1 else 0
            metrics.localization_increase = max(0, c_loc_change - q_loc_change)

            # 3. Entropy reduction (classical should have lower final entropy)
            q_ent_change = q_ent[0] - q_ent[-1] if len(q_ent) > 1 else 0
            c_ent_change = c_ent[0] - c_ent[-1] if len(c_ent) > 1 else 0
            metrics.entropy_reduction = max(0, c_ent_change - q_ent_change)

            # 4. Coherence loss (difference in quantum potential evolution)
            q_coherence = np.mean(q_qpot[-10:]) if len(q_qpot) > 10 else np.mean(q_qpot)
            c_coherence = np.mean(c_qpot[-10:]) if len(c_qpot) > 10 else np.mean(c_qpot)
            metrics.coherence_loss = max(0, q_coherence - c_coherence)

            # 5. Overall transition strength
            metrics.transition_strength = (
                metrics.position_drift +
                metrics.localization_increase +
                metrics.entropy_reduction +
                metrics.coherence_loss
            ) / 4.0

            # 6. Classical behavior score for the classical result
            if len(c_pos) > 10:
                late_pos_std = np.std(c_pos[-10:])
                position_stability = max(0, 1.0 - late_pos_std)
            else:
                position_stability = 0.0

            final_localization = 1.0 / (1.0 + c_unc[-1]) if len(c_unc) > 0 else 0.0

            metrics.classical_score = (
                position_stability +
                final_localization +
                metrics.localization_increase +
                metrics.entropy_reduction
            ) / 4.0

            # 7. Overall optimization score
            metrics.overall_score = (
                2.0 * metrics.transition_strength +
                1.5 * metrics.classical_score +
                1.0 * metrics.position_drift +
                1.0 * metrics.localization_increase
            ) / 5.5

        except Exception as e:
            warnings.warn(f"Metrics calculation failed: {e}")

        return metrics

    def evolutionary_optimization(self, generations: int = 10, population_size: int = 20) -> Dict:
        """Perform evolutionary optimization to find optimal parameters."""
        print("\n🧬 EVOLUTIONARY PARAMETER OPTIMIZATION")
        print("=" * 60)

        # Initialize population with random parameter sets
        population = self._initialize_population(population_size)

        best_individual = None
        best_fitness = 0.0

        for generation in range(generations):
            print(f"\n🧪 Generation {generation + 1}/{generations}")

            # Evaluate population
            fitness_scores = []
            for i, individual in enumerate(population):
                try:
                    metrics = self._evaluate_parameter_set(individual)
                    fitness = metrics.overall_score
                    fitness_scores.append(fitness)

                    if fitness > best_fitness:
                        best_fitness = fitness
                        best_individual = individual.copy()

                except Exception:
                    fitness_scores.append(0.0)

            avg_fitness = np.mean(fitness_scores)
            print(f"   Average fitness: {avg_fitness:.3f}")
            print(f"   Best fitness: {best_fitness:.3f}")

            # Selection and reproduction
            if generation < generations - 1:
                population = self._evolve_population(population, fitness_scores)

        return {
            'best_parameters': best_individual,
            'best_fitness': best_fitness,
            'final_population': population
        }

    def _initialize_population(self, size: int) -> List[Dict]:
        """Initialize a random population for evolutionary optimization."""
        param_space = self.define_parameter_space()
        population = []

        np.random.seed(42)  # Reproducible

        for _ in range(size):
            individual = self.base_config.copy()

            # Randomly sample from parameter space
            individual['lambda_values'] = np.random.choice(param_space['lambda_values'])
            individual['dt'] = np.random.choice(param_space['dt'])
            individual['n_steps'] = np.random.choice(param_space['n_steps'])
            individual['grid_size'] = np.random.choice(param_space['grid_size'])
            individual['x_range'] = np.random.choice(param_space['x_range'])
            individual['correction_type'] = np.random.choice(param_space['correction_type'])
            individual['initial_state'] = np.random.choice(param_space['initial_state'])

            population.append(individual)

        return population

    def _evolve_population(self, population: List[Dict], fitness_scores: List[float]) -> List[Dict]:
        """Evolve population through selection, crossover, and mutation."""

        # Selection (keep top 50%)
        sorted_indices = np.argsort(fitness_scores)[::-1]
        elite_size = len(population) // 2
        elite = [population[i] for i in sorted_indices[:elite_size]]

        # Reproduction to fill population
        new_population = elite.copy()

        param_space = self.define_parameter_space()

        while len(new_population) < len(population):
            # Select two parents
            parent1 = np.random.choice(elite)
            parent2 = np.random.choice(elite)

            # Crossover
            child = self._crossover(parent1, parent2)

            # Mutation
            child = self._mutate(child, param_space)

            new_population.append(child)

        return new_population

    def _crossover(self, parent1: Dict, parent2: Dict) -> Dict:
        """Create offspring through parameter crossover."""
        child = parent1.copy()

        # Randomly inherit each parameter from either parent
        for key in ['lambda_values', 'dt', 'n_steps', 'grid_size', 'x_range', 'correction_type', 'initial_state']:
            if key in parent2:
                if np.random.random() < 0.5:
                    child[key] = parent2[key]

        return child

    def _mutate(self, individual: Dict, param_space: Dict) -> Dict:
        """Apply random mutations to individual."""
        mutation_rate = 0.1

        for key in ['lambda_values', 'dt', 'n_steps', 'grid_size', 'x_range', 'correction_type', 'initial_state']:
            if np.random.random() < mutation_rate and key in param_space:
                individual[key] = np.random.choice(param_space[key])

        return individual

    def save_optimization_results(self, filename: str = None):
        """Save optimization results to file."""
        if filename is None:
            timestamp = time.strftime("%Y%m%d_%H%M%S")
            filename = f"optimization_results_{timestamp}.json"

        os.makedirs(self.base_config['output_dir'], exist_ok=True)
        filepath = os.path.join(self.base_config['output_dir'], filename)

        # Prepare data for JSON serialization
        serializable_results = []
        for result in self.optimization_results:
            serializable_result = {
                'parameters': result['parameters'],
                'metrics': {
                    'transition_strength': result['metrics'].transition_strength,
                    'entropy_reduction': result['metrics'].entropy_reduction,
                    'localization_increase': result['metrics'].localization_increase,
                    'position_drift': result['metrics'].position_drift,
                    'coherence_loss': result['metrics'].coherence_loss,
                    'classical_score': result['metrics'].classical_score,
                    'overall_score': result['metrics'].overall_score
                },
                'evaluation_id': result['evaluation_id']
            }
            serializable_results.append(serializable_result)

        data = {
            'optimization_results': serializable_results,
            'best_parameters': self.best_parameters,
            'best_score': self.best_score,
            'base_config': self.base_config
        }

        with open(filepath, 'w') as f:
            json.dump(data, f, indent=2)

        print(f"✅ Optimization results saved to: {filepath}")
        return filepath

    def get_top_parameters(self, n: int = 5) -> List[Dict]:
        """Get the top N parameter sets from optimization."""
        if not self.optimization_results:
            return []

        # Sort by overall score
        sorted_results = sorted(
            self.optimization_results,
            key=lambda x: x['metrics'].overall_score,
            reverse=True
        )

        return sorted_results[:n]

    def analyze_parameter_sensitivity(self) -> Dict:
        """Analyze which parameters have the strongest effect on transition strength."""
        if not self.optimization_results:
            return {}

        # Group results by parameter values
        parameter_effects = {}

        for param_name in ['lambda_values', 'dt', 'n_steps', 'correction_type']:
            parameter_effects[param_name] = {}

            for result in self.optimization_results:
                param_value = str(result['parameters'].get(param_name, 'unknown'))
                score = result['metrics'].overall_score

                if param_value not in parameter_effects[param_name]:
                    parameter_effects[param_name][param_value] = []

                parameter_effects[param_name][param_value].append(score)

        # Calculate statistics for each parameter value
        sensitivity_analysis = {}
        for param_name, param_data in parameter_effects.items():
            sensitivity_analysis[param_name] = {}

            for param_value, scores in param_data.items():
                if len(scores) > 0:
                    sensitivity_analysis[param_name][param_value] = {
                        'mean_score': np.mean(scores),
                        'std_score': np.std(scores),
                        'count': len(scores),
                        'max_score': np.max(scores)
                    }

        return sensitivity_analysis