"""Diagnostics module for quantum-classical boundary emergence simulation.

This module provides comprehensive diagnostic tools including:
- Shannon and Von Neumann entropy calculations
- Fidelity tracking (initial state and classical proximity)
- Mutual information and entanglement measures
- Correction efficiency metrics
- Quantum coherence measures
"""

import numpy as np
from scipy.integrate import simpson as simps
from scipy.linalg import logm, eigvals
import warnings

class QuantumDiagnostics:
    """Comprehensive diagnostic calculations for quantum systems."""

    def __init__(self, x_grid, dx):
        """Initialize diagnostics with spatial grid."""
        self.x = x_grid
        self.dx = dx
        self.grid_size = len(x_grid)

    def shannon_entropy(self, psi):
        """
        Calculate Shannon entropy H = -Σ pᵢ log₂(pᵢ).

        Measures information content and localization of probability distribution.
        """
        prob_density = np.abs(psi)**2
        prob_density /= np.sum(prob_density)  # Normalize

        # Avoid log(0) with regularization
        entropy = -np.sum(prob_density * np.log2(prob_density + 1e-12))
        return entropy

    def von_neumann_entropy(self, rho):
        """
        Calculate Von Neumann entropy S_VN = -Tr(ρ log ρ).

        Measures genuine quantum entanglement and mixedness.
        For pure states, S_VN = 0. For maximally mixed states, S_VN is maximal.
        """
        # Get eigenvalues of density matrix
        eigenvals = eigvals(rho)
        eigenvals = eigenvals.real  # Remove numerical imaginary parts

        # Filter out near-zero eigenvalues
        eigenvals = eigenvals[eigenvals > 1e-12]
        eigenvals /= np.sum(eigenvals)  # Ensure normalization

        # Calculate entropy
        if len(eigenvals) == 0:
            return 0.0

        entropy = -np.sum(eigenvals * np.log2(eigenvals))
        return entropy

    def mutual_information(self, rho_full):
        """
        Calculate mutual information between left and right halves of system.

        I(A:B) = S(ρ_A) + S(ρ_B) - S(ρ_AB)

        Measures quantum correlations across spatial regions.
        """
        mid_point = self.grid_size // 2

        # Extract reduced density matrices
        rho_left = rho_full[:mid_point, :mid_point]
        rho_right = rho_full[mid_point:, mid_point:]

        # Calculate entropies
        S_left = self.von_neumann_entropy(rho_left)
        S_right = self.von_neumann_entropy(rho_right)
        S_total = self.von_neumann_entropy(rho_full)

        # Mutual information
        mutual_info = S_left + S_right - S_total
        return max(0.0, mutual_info)  # Ensure non-negative

    def fidelity_initial_state(self, psi_current, psi_initial):
        """
        Calculate fidelity with initial state: F_init = |⟨Ψ(0)|Ψ(t)⟩|².

        Measures how much the system remembers its initial quantum state.
        F = 1: perfect memory, F = 0: complete decoherence from initial state.
        """
        overlap = simps(np.conj(psi_initial) * psi_current, self.x)
        fidelity = np.abs(overlap)**2
        return fidelity

    def fidelity_classical_proximity(self, psi, x_classical, sigma_classical=1.0):
        """
        Calculate fidelity with classical state at position x_classical:
        F_class = exp(-|⟨x⟩ - x_class|²/(2σ²))

        Measures how close the quantum state is to classical behavior.
        """
        x_quantum = self.expectation_position(psi)
        deviation = (x_quantum - x_classical)**2
        fidelity = np.exp(-deviation / (2 * sigma_classical**2))
        return fidelity

    def expectation_position(self, psi):
        """Calculate position expectation value ⟨x⟩."""
        return simps(self.x * np.abs(psi)**2, self.x)

    def expectation_position_squared(self, psi):
        """Calculate ⟨x²⟩."""
        return simps(self.x**2 * np.abs(psi)**2, self.x)

    def position_uncertainty(self, psi):
        """Calculate position uncertainty Δx = √(⟨x²⟩ - ⟨x⟩²)."""
        x_mean = self.expectation_position(psi)
        x2_mean = self.expectation_position_squared(psi)
        variance = x2_mean - x_mean**2
        return np.sqrt(max(0.0, variance))

    def momentum_uncertainty(self, psi):
        """
        Calculate momentum uncertainty using discrete derivatives.
        Δp = ℏ√(⟨(∇Ψ)²⟩/⟨Ψ²⟩ - ⟨∇Ψ/Ψ⟩²)
        """
        # Calculate derivatives
        dpsi_dx = np.gradient(psi, self.dx)

        # Expectation values
        norm_sq = simps(np.abs(psi)**2, self.x)
        grad_norm_sq = simps(np.abs(dpsi_dx)**2, self.x)

        # Momentum uncertainty (in units where ℏ = 1)
        if norm_sq > 1e-12:
            momentum_var = grad_norm_sq / norm_sq
            return np.sqrt(max(0.0, momentum_var))
        else:
            return 0.0

    def heisenberg_uncertainty_product(self, psi):
        """
        Calculate Heisenberg uncertainty product Δx·Δp.

        For quantum states, this should satisfy Δx·Δp ≥ ℏ/2.
        Classical states approach the minimum uncertainty limit.
        """
        delta_x = self.position_uncertainty(psi)
        delta_p = self.momentum_uncertainty(psi)
        return delta_x * delta_p

    def correction_efficiency(self, correction_applied, deviation_before, deviation_after):
        """
        Calculate correction efficiency: η = (deviation_before - deviation_after) / |correction_applied|.

        Measures how effectively the correction mechanism reduces quantum deviations.
        High efficiency means small corrections produce large stabilization.
        """
        if abs(correction_applied) < 1e-12:
            return 0.0

        deviation_reduction = deviation_before - deviation_after
        efficiency = deviation_reduction / abs(correction_applied)
        return efficiency

    def quantum_coherence_l1_norm(self, rho):
        """
        Calculate l1-norm quantum coherence: C_l1 = Σ_{i≠j} |ρ_{ij}|.

        Measures the amount of quantum coherence in the density matrix.
        Pure diagonal states have zero coherence, superposition states have high coherence.
        """
        # Extract off-diagonal elements
        coherence = 0.0
        for i in range(rho.shape[0]):
            for j in range(rho.shape[1]):
                if i != j:
                    coherence += abs(rho[i, j])

        return coherence

    def quantum_coherence_relative_entropy(self, rho):
        """
        Calculate relative entropy coherence: C_r = S(ρ_diag) - S(ρ).

        where ρ_diag is the diagonal part of ρ.
        """
        # Extract diagonal part
        rho_diag = np.diag(np.diag(rho))

        # Calculate entropies
        S_rho = self.von_neumann_entropy(rho)
        S_diag = self.von_neumann_entropy(rho_diag)

        return S_diag - S_rho

    def wigner_negativity(self, wigner_function):
        """
        Calculate Wigner function negativity: N = Σ_{W<0} |W|.

        Negative values in the Wigner function indicate quantum behavior.
        Classical states have non-negative Wigner functions.
        """
        negative_regions = wigner_function < 0
        if np.any(negative_regions):
            negativity = np.sum(np.abs(wigner_function[negative_regions]))
            return negativity
        return 0.0

    def participation_ratio(self, psi):
        """
        Calculate participation ratio: PR = 1/Σᵢ|ψᵢ|⁴.

        Measures the effective number of basis states participating in the wavefunction.
        Highly localized states have low PR, delocalized states have high PR.
        """
        prob_density = np.abs(psi)**2
        if np.sum(prob_density**2) > 1e-12:
            pr = 1.0 / np.sum(prob_density**2)
            return pr
        return 0.0

    def purity(self, rho):
        """
        Calculate purity: P = Tr(ρ²).

        P = 1 for pure states, P < 1 for mixed states.
        Measures the degree of quantum mixedness.
        """
        purity_val = np.trace(rho @ rho).real
        return purity_val

    def linear_entropy(self, rho):
        """
        Calculate linear entropy: S_L = 1 - Tr(ρ²) = 1 - Purity.

        Linear approximation to Von Neumann entropy, computationally efficient.
        S_L = 0 for pure states, S_L → 1 for maximally mixed states.
        """
        return 1.0 - self.purity(rho)

    def quantum_fisher_information(self, psi, parameter_gradient):
        """
        Calculate quantum Fisher information for parameter estimation.

        F_Q = 4(⟨∂ψ/∂θ|∂ψ/∂θ⟩ - |⟨ψ|∂ψ/∂θ⟩|²)

        Measures sensitivity of quantum state to parameter changes.
        """
        # Normalize states
        psi_norm = np.sqrt(simps(np.abs(psi)**2, self.x))
        psi_normalized = psi / psi_norm if psi_norm > 1e-12 else psi

        grad_norm = np.sqrt(simps(np.abs(parameter_gradient)**2, self.x))
        grad_normalized = parameter_gradient / grad_norm if grad_norm > 1e-12 else parameter_gradient

        # Calculate overlaps
        overlap_grad_grad = simps(np.conj(grad_normalized) * grad_normalized, self.x).real
        overlap_psi_grad = simps(np.conj(psi_normalized) * grad_normalized, self.x)

        # Fisher information
        fisher_info = 4 * (overlap_grad_grad - np.abs(overlap_psi_grad)**2)
        return max(0.0, fisher_info)

    def entanglement_entropy_half_system(self, rho_full):
        """
        Calculate entanglement entropy of half the system.

        For bipartite systems, this measures the entanglement between
        left and right halves of the spatial domain.
        """
        mid_point = self.grid_size // 2
        rho_left = rho_full[:mid_point, :mid_point]

        return self.von_neumann_entropy(rho_left)

    def classical_correlation_function(self, psi, separation_distance=1):
        """
        Calculate spatial correlation function: C(Δx) = ⟨ρ(x)ρ(x+Δx)⟩ / ⟨ρ(x)⟩².

        Measures spatial correlations in the probability density.
        Classical systems typically have short-range correlations.
        """
        prob_density = np.abs(psi)**2
        mean_density = np.mean(prob_density)

        if mean_density < 1e-12:
            return 0.0

        # Calculate correlation for given separation
        sep_steps = int(separation_distance / self.dx)
        if sep_steps >= len(prob_density):
            return 0.0

        correlation = 0.0
        count = 0
        for i in range(len(prob_density) - sep_steps):
            correlation += prob_density[i] * prob_density[i + sep_steps]
            count += 1

        if count > 0:
            correlation /= count
            correlation /= mean_density**2

        return correlation

    def comprehensive_diagnostics(self, psi, psi_initial=None, x_classical=None,
                                rho=None, wigner=None, correction_applied=None,
                                deviation_before=None, deviation_after=None):
        """
        Calculate all available diagnostics for a quantum state.

        Returns a dictionary with all computed diagnostic measures.
        """
        diagnostics = {}

        # Basic entropy measures
        diagnostics['shannon_entropy'] = self.shannon_entropy(psi)

        # Position and momentum properties
        diagnostics['position_expectation'] = self.expectation_position(psi)
        diagnostics['position_uncertainty'] = self.position_uncertainty(psi)
        diagnostics['momentum_uncertainty'] = self.momentum_uncertainty(psi)
        diagnostics['uncertainty_product'] = self.heisenberg_uncertainty_product(psi)

        # State complexity measures
        diagnostics['participation_ratio'] = self.participation_ratio(psi)

        # Fidelity measures (if references provided)
        if psi_initial is not None:
            diagnostics['fidelity_initial'] = self.fidelity_initial_state(psi, psi_initial)

        if x_classical is not None:
            diagnostics['fidelity_classical'] = self.fidelity_classical_proximity(psi, x_classical)

        # Density matrix diagnostics (if provided)
        if rho is not None:
            diagnostics['von_neumann_entropy'] = self.von_neumann_entropy(rho)
            diagnostics['mutual_information'] = self.mutual_information(rho)
            diagnostics['purity'] = self.purity(rho)
            diagnostics['linear_entropy'] = self.linear_entropy(rho)
            diagnostics['coherence_l1'] = self.quantum_coherence_l1_norm(rho)
            diagnostics['coherence_relative_entropy'] = self.quantum_coherence_relative_entropy(rho)
            diagnostics['entanglement_entropy'] = self.entanglement_entropy_half_system(rho)

        # Wigner function diagnostics (if provided)
        if wigner is not None:
            diagnostics['wigner_negativity'] = self.wigner_negativity(wigner)

        # Correction efficiency (if correction data provided)
        if all(x is not None for x in [correction_applied, deviation_before, deviation_after]):
            diagnostics['correction_efficiency'] = self.correction_efficiency(
                correction_applied, deviation_before, deviation_after)

        return diagnostics