"""Phase space analysis module for quantum-classical boundary emergence.

This module provides comprehensive phase space tools including:
- Wigner function computation and analysis
- Quantum potential overlays (Bohmian mechanics)
- Phase space flow visualization
- Classical-quantum correspondence analysis
- Negativity and coherence measures in phase space
"""

import numpy as np
from scipy.integrate import simpson as simps
from scipy.fftpack import fft, ifft, fftshift
import matplotlib.pyplot as plt
from matplotlib import cm

class PhaseSpaceAnalysis:
    """Comprehensive phase space analysis tools."""

    def __init__(self, x_grid, k_grid=None, hbar=1.0):
        """Initialize phase space analysis."""
        self.x = x_grid
        self.dx = x_grid[1] - x_grid[0]
        self.hbar = hbar
        self.grid_size = len(x_grid)

        if k_grid is None:
            self.k = np.fft.fftfreq(self.grid_size, d=self.dx) * 2 * np.pi
        else:
            self.k = k_grid

    def compute_wigner_function(self, psi):
        """
        Calculate Wigner quasi-probability distribution W(x,p).

        W(x,p) = (1/πℏ) ∫ ψ*(x+s/2) ψ(x-s/2) exp(-ips/ℏ) ds

        Uses efficient FFT-based algorithm for computational speed.
        """
        N = len(psi)
        wigner = np.zeros((N, N))

        for i in range(N):
            # Create symmetric and antisymmetric combinations
            psi_shifted = np.roll(psi, i - N//2)
            psi_conj_shifted = np.roll(np.conj(psi), -i + N//2)

            # Wigner kernel in position
            kernel = psi_conj_shifted * psi_shifted

            # Fourier transform to get momentum dependence
            wigner_slice = fftshift(fft(kernel))
            wigner[:, i] = wigner_slice.real

        # Normalize and apply prefactor
        wigner = wigner / (np.pi * self.hbar)

        return wigner

    def compute_wigner_function_optimized(self, psi):
        """
        Optimized Wigner function calculation using symmetry properties.

        More efficient for large grid sizes.
        """
        N = len(psi)
        x_mesh, k_mesh = np.meshgrid(self.x, self.k)

        wigner = np.zeros((N, N))

        # Vectorized computation over momentum grid
        for i, x_val in enumerate(self.x):
            # Find positions x ± s/2
            s_values = 2 * (self.x - x_val)

            # Interpolate wavefunction values
            psi_plus = np.interp(x_val + s_values/2, self.x, psi)
            psi_minus = np.interp(x_val - s_values/2, self.x, psi)

            # Wigner integrand
            for j, k_val in enumerate(self.k):
                integrand = (np.conj(psi_plus) * psi_minus *
                           np.exp(-1j * k_val * s_values / self.hbar))
                wigner[j, i] = simps(integrand, s_values).real

        return wigner / (np.pi * self.hbar)

    def wigner_with_quantum_potential_overlay(self, psi, quantum_potential):
        """
        Compute Wigner function with quantum potential Q(x) overlaid.

        Provides visualization of how quantum forces (from Q) affect
        the phase space distribution.
        """
        wigner = self.compute_wigner_function(psi)

        # Create quantum potential contour overlay data
        x_mesh, k_mesh = np.meshgrid(self.x, self.k)

        # Extend Q(x) to phase space by making it momentum-independent
        Q_extended = np.tile(quantum_potential, (len(self.k), 1))

        return wigner, Q_extended, x_mesh, k_mesh

    def husimi_q_function(self, psi, alpha=1.0):
        """
        Calculate Husimi Q-function (regularized Wigner function).

        Q(x,p) = (1/π)|⟨x,p|ψ⟩|² where |x,p⟩ are coherent states.

        The Q-function is always non-negative, unlike the Wigner function.
        """
        N = len(psi)
        husimi = np.zeros((N, N))

        # Coherent state width parameter
        sigma = np.sqrt(self.hbar / (2 * alpha))

        for i, x_val in enumerate(self.x):
            for j, p_val in enumerate(self.k):
                # Create coherent state centered at (x_val, p_val)
                coherent_state = np.exp(-(self.x - x_val)**2 / (4 * sigma**2))
                coherent_state *= np.exp(1j * p_val * self.x / self.hbar)
                coherent_state /= np.sqrt(np.sqrt(2 * np.pi) * sigma)

                # Overlap with wavefunction
                overlap = simps(np.conj(coherent_state) * psi, self.x)
                husimi[j, i] = np.abs(overlap)**2 / np.pi

        return husimi

    def phase_space_flow_classical(self, x_vals, p_vals, potential_func, mass=1.0, dt=0.01, n_steps=100):
        """
        Calculate classical phase space flow using Hamilton's equations.

        dx/dt = p/m
        dp/dt = -dV/dx
        """
        trajectories = []

        for x0, p0 in zip(x_vals, p_vals):
            trajectory = [(x0, p0)]
            x, p = x0, p0

            for _ in range(n_steps):
                # Hamilton's equations
                dx_dt = p / mass
                dp_dt = -np.interp(x, self.x, np.gradient(potential_func, self.dx))

                # Symplectic Euler integration
                p += dp_dt * dt
                x += dx_dt * dt

                trajectory.append((x, p))

            trajectories.append(np.array(trajectory))

        return trajectories

    def quantum_phase_space_flow(self, psi_sequence):
        """
        Analyze quantum evolution in phase space through Wigner function sequence.

        Computes how the Wigner function changes over time, showing
        quantum-to-classical transition dynamics.
        """
        wigner_sequence = []
        negativity_evolution = []
        entropy_evolution = []

        for psi in psi_sequence:
            wigner = self.compute_wigner_function(psi)
            wigner_sequence.append(wigner)

            # Calculate Wigner negativity
            negative_regions = wigner < 0
            negativity = np.sum(np.abs(wigner[negative_regions])) if np.any(negative_regions) else 0.0
            negativity_evolution.append(negativity)

            # Phase space entropy (discretized)
            wigner_positive = np.abs(wigner) + 1e-12  # Regularize
            wigner_normalized = wigner_positive / np.sum(wigner_positive)
            phase_entropy = -np.sum(wigner_normalized * np.log2(wigner_normalized))
            entropy_evolution.append(phase_entropy)

        return {
            'wigner_sequence': wigner_sequence,
            'negativity_evolution': negativity_evolution,
            'entropy_evolution': entropy_evolution
        }

    def classify_quantum_classical_regions(self, wigner):
        """
        Classify phase space regions as quantum (W < 0) or classical (W ≥ 0).

        Returns masks for quantum and classical regions.
        """
        quantum_regions = wigner < 0
        classical_regions = wigner >= 0

        # Calculate areas
        quantum_area = np.sum(quantum_regions) * self.dx * (self.k[1] - self.k[0])
        classical_area = np.sum(classical_regions) * self.dx * (self.k[1] - self.k[0])

        return {
            'quantum_mask': quantum_regions,
            'classical_mask': classical_regions,
            'quantum_area': quantum_area,
            'classical_area': classical_area,
            'classical_fraction': classical_area / (quantum_area + classical_area)
        }

    def marginal_distributions(self, wigner):
        """
        Calculate marginal position and momentum distributions from Wigner function.

        P(x) = ∫ W(x,p) dp
        P(p) = ∫ W(x,p) dx
        """
        # Position marginal
        position_marginal = simps(wigner, self.k, axis=0)

        # Momentum marginal
        momentum_marginal = simps(wigner, self.x, axis=1)

        return position_marginal, momentum_marginal

    def wigner_cross_correlations(self, wigner):
        """
        Calculate position-momentum cross-correlations from Wigner function.

        ⟨xp⟩ = ∫∫ x·p·W(x,p) dx dp
        """
        x_mesh, p_mesh = np.meshgrid(self.x, self.k)

        # Expectation values
        mean_x = simps(simps(x_mesh * wigner, self.k, axis=0), self.x)
        mean_p = simps(simps(p_mesh * wigner, self.k, axis=0), self.x)

        # Second moments
        mean_x2 = simps(simps(x_mesh**2 * wigner, self.k, axis=0), self.x)
        mean_p2 = simps(simps(p_mesh**2 * wigner, self.k, axis=0), self.x)
        mean_xp = simps(simps(x_mesh * p_mesh * wigner, self.k, axis=0), self.x)

        # Variances and covariance
        var_x = mean_x2 - mean_x**2
        var_p = mean_p2 - mean_p**2
        cov_xp = mean_xp - mean_x * mean_p

        # Correlation coefficient
        correlation = cov_xp / np.sqrt(var_x * var_p) if (var_x > 0 and var_p > 0) else 0.0

        return {
            'mean_x': mean_x,
            'mean_p': mean_p,
            'var_x': var_x,
            'var_p': var_p,
            'cov_xp': cov_xp,
            'correlation': correlation
        }

    def wigner_moments(self, wigner, max_order=4):
        """
        Calculate moments of Wigner function up to specified order.

        Returns dictionary of moments: ⟨x^m p^n⟩ for m+n ≤ max_order
        """
        x_mesh, p_mesh = np.meshgrid(self.x, self.k)
        moments = {}

        for total_order in range(max_order + 1):
            for m in range(total_order + 1):
                n = total_order - m
                if m >= 0 and n >= 0:
                    moment_integrand = (x_mesh**m) * (p_mesh**n) * wigner
                    moment_value = simps(simps(moment_integrand, self.k, axis=0), self.x)
                    moments[f'x{m}p{n}'] = moment_value

        return moments

    def visualize_wigner_with_overlays(self, wigner, quantum_potential=None,
                                     classical_trajectories=None, title="Wigner Function"):
        """
        Create comprehensive Wigner function visualization with overlays.

        Includes quantum potential contours and classical phase space trajectories.
        """
        fig, ax = plt.subplots(figsize=(12, 8))

        # Main Wigner function heatmap
        im = ax.imshow(wigner, extent=[self.x[0], self.x[-1], self.k[0], self.k[-1]],
                      aspect='auto', origin='lower', cmap='seismic',
                      vmin=-np.max(np.abs(wigner)), vmax=np.max(np.abs(wigner)))

        # Quantum potential contours (if provided)
        if quantum_potential is not None:
            x_mesh, k_mesh = np.meshgrid(self.x, self.k)
            Q_extended = np.tile(quantum_potential, (len(self.k), 1))
            ax.contour(x_mesh, k_mesh, Q_extended, levels=10, colors='white', alpha=0.5, linewidths=1)

        # Classical trajectories (if provided)
        if classical_trajectories is not None:
            for trajectory in classical_trajectories:
                positions, momenta = trajectory[:, 0], trajectory[:, 1]
                ax.plot(positions, momenta, 'yellow', linewidth=1, alpha=0.8)

        ax.set_xlabel('Position x')
        ax.set_ylabel('Momentum p')
        ax.set_title(title)

        # Colorbar
        cbar = plt.colorbar(im, ax=ax)
        cbar.set_label('W(x,p)')

        return fig, ax

    def wigner_evolution_analysis(self, wigner_sequence, times):
        """
        Analyze evolution of Wigner function properties over time.

        Returns time series of key phase space indicators.
        """
        analysis_results = {
            'times': times,
            'negativity': [],
            'classical_fraction': [],
            'position_width': [],
            'momentum_width': [],
            'correlation': []
        }

        for wigner in wigner_sequence:
            # Negativity
            negative_regions = wigner < 0
            negativity = np.sum(np.abs(wigner[negative_regions])) if np.any(negative_regions) else 0.0
            analysis_results['negativity'].append(negativity)

            # Classical fraction
            classification = self.classify_quantum_classical_regions(wigner)
            analysis_results['classical_fraction'].append(classification['classical_fraction'])

            # Position and momentum widths from marginals
            pos_marginal, mom_marginal = self.marginal_distributions(wigner)

            # Position width
            pos_mean = simps(self.x * pos_marginal, self.x)
            pos_var = simps(self.x**2 * pos_marginal, self.x) - pos_mean**2
            analysis_results['position_width'].append(np.sqrt(max(0, pos_var)))

            # Momentum width
            mom_mean = simps(self.k * mom_marginal, self.k)
            mom_var = simps(self.k**2 * mom_marginal, self.k) - mom_mean**2
            analysis_results['momentum_width'].append(np.sqrt(max(0, mom_var)))

            # Cross-correlations
            correlations = self.wigner_cross_correlations(wigner)
            analysis_results['correlation'].append(correlations['correlation'])

        return analysis_results