import numpy as np
import healpy as hp
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from matplotlib.gridspec import GridSpec
import seaborn as sns
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass
import pickle
from datetime import datetime

@dataclass
class VisualizationConfig:
    figure_size: Tuple[int, int] = (16, 10)
    dpi: int = 100
    cmap: str = 'RdBu_r'
    save_animations: bool = False
    animation_fps: int = 10
    real_time_update: bool = True
    entropy_tracking: bool = True
    stochastic_noise_visualization: bool = True

class EnhancedCMBVisualizer:
    def __init__(self, config: Optional[VisualizationConfig] = None):
        self.config = config if config is not None else VisualizationConfig()
        self.iteration_data = []
        self.power_spectrum_history = []
        self.entropy_history = []
        self.alignment_history = []
        self.stochastic_noise_seeds = []

        # Setup style
        plt.style.use('seaborn-v0_8-darkgrid')
        sns.set_palette("husl")

    def real_time_power_spectrum_tracker(self, alm_history: List[np.ndarray],
                                        lmax: int = 30) -> None:
        """Real-time visualization of power spectrum evolution"""
        fig = plt.figure(figsize=self.config.figure_size)
        gs = GridSpec(2, 2, figure=fig)

        ax_spectrum = fig.add_subplot(gs[0, :])
        ax_ratio = fig.add_subplot(gs[1, 0])
        ax_phase = fig.add_subplot(gs[1, 1])

        # Initialize plots
        ell = np.arange(lmax + 1)
        line_current, = ax_spectrum.semilogy([], [], 'b-', label='Current', linewidth=2)
        line_initial, = ax_spectrum.semilogy([], [], 'r--', label='Initial', alpha=0.5)
        line_theory, = ax_spectrum.semilogy([], [], 'g:', label='Theory', alpha=0.5)

        ax_spectrum.set_xlim(0, lmax)
        ax_spectrum.set_ylim(1e-12, 1e-6)
        ax_spectrum.set_xlabel('Multipole ℓ')
        ax_spectrum.set_ylabel('C_ℓ [μK²]')
        ax_spectrum.set_title('Real-Time Power Spectrum Evolution')
        ax_spectrum.legend()
        ax_spectrum.grid(True, alpha=0.3)

        # Ratio plot
        line_ratio, = ax_ratio.plot([], [], 'o-', markersize=4)
        ax_ratio.set_xlim(0, lmax)
        ax_ratio.set_ylim(0.5, 2.0)
        ax_ratio.axhline(y=1, color='k', linestyle='--', alpha=0.3)
        ax_ratio.set_xlabel('Multipole ℓ')
        ax_ratio.set_ylabel('C_ℓ(current) / C_ℓ(initial)')
        ax_ratio.set_title('Power Spectrum Ratio')
        ax_ratio.grid(True, alpha=0.3)

        # Phase coherence plot
        phase_scatter = ax_phase.scatter([], [], c=[], cmap='twilight', s=50)
        ax_phase.set_xlim(2, min(10, lmax))
        ax_phase.set_ylim(0, 1)
        ax_phase.set_xlabel('Multipole ℓ')
        ax_phase.set_ylabel('Phase Coherence')
        ax_phase.set_title('Multipole Phase Coherence')
        ax_phase.grid(True, alpha=0.3)

        # Store initial spectrum
        initial_cl = hp.alm2cl(alm_history[0]) if alm_history else None

        def update(frame):
            if frame >= len(alm_history):
                return line_current, line_ratio, phase_scatter

            alm = alm_history[frame]
            cl = hp.alm2cl(alm)

            # Update power spectrum
            line_current.set_data(ell[:len(cl)], cl)

            if initial_cl is not None:
                line_initial.set_data(ell[:len(initial_cl)], initial_cl)

                # Update ratio
                ratio = cl / (initial_cl + 1e-15)
                line_ratio.set_data(ell[:len(ratio)], ratio)

            # Compute phase coherence
            coherences = []
            ell_values = []
            for l in range(2, min(11, lmax + 1)):
                phases = []
                for m in range(1, l + 1):
                    idx = hp.Alm.getidx(lmax, l, m)
                    if idx < len(alm):
                        phases.append(np.angle(alm[idx]))

                if phases:
                    phase_array = np.array(phases)
                    coherence = np.abs(np.mean(np.exp(1j * phase_array)))
                    coherences.append(coherence)
                    ell_values.append(l)

            # Update phase scatter
            if coherences:
                phase_scatter.set_offsets(np.c_[ell_values, coherences])
                phase_scatter.set_array(np.array(coherences))

            # Add iteration info
            fig.suptitle(f'Iteration {frame + 1}/{len(alm_history)}', fontsize=14)

            # Store for history
            self.power_spectrum_history.append(cl)

            return line_current, line_ratio, phase_scatter

        if self.config.real_time_update:
            anim = FuncAnimation(fig, update, frames=len(alm_history),
                               interval=1000/self.config.animation_fps, blit=False)

            if self.config.save_animations:
                anim.save('power_spectrum_evolution.gif', fps=self.config.animation_fps)
        else:
            # Just show final state
            update(len(alm_history) - 1)

        plt.tight_layout()
        plt.show()

    def entropy_alignment_correlation(self, entropy_data: List[float],
                                     alignment_data: List[float]) -> None:
        """Visualize correlation between entropy minimization and alignment"""
        fig, axes = plt.subplots(2, 2, figsize=self.config.figure_size)

        iterations = np.arange(len(entropy_data))

        # Entropy evolution
        ax1 = axes[0, 0]
        ax1.plot(iterations, entropy_data, 'b-', linewidth=2, label='Entropy')
        ax1.fill_between(iterations, entropy_data, alpha=0.3)
        ax1.set_xlabel('Iteration')
        ax1.set_ylabel('Entropy')
        ax1.set_title('Entropy Minimization Progress')
        ax1.grid(True, alpha=0.3)

        # Alignment evolution
        ax2 = axes[0, 1]
        ax2.plot(iterations, alignment_data, 'r-', linewidth=2, label='Alignment Angle')
        ax2.fill_between(iterations, alignment_data, 90, alpha=0.3, color='red')
        ax2.axhline(y=30, color='g', linestyle='--', label='Target (<30°)')
        ax2.set_xlabel('Iteration')
        ax2.set_ylabel('Alignment Angle (degrees)')
        ax2.set_title('Quadrupole-Octopole Alignment')
        ax2.legend()
        ax2.grid(True, alpha=0.3)

        # Correlation scatter
        ax3 = axes[1, 0]
        scatter = ax3.scatter(entropy_data, alignment_data, c=iterations,
                            cmap='viridis', s=50, alpha=0.7)
        ax3.set_xlabel('Entropy')
        ax3.set_ylabel('Alignment Angle (degrees)')
        ax3.set_title('Entropy vs Alignment Correlation')
        plt.colorbar(scatter, ax=ax3, label='Iteration')

        # Add trendline
        z = np.polyfit(entropy_data, alignment_data, 1)
        p = np.poly1d(z)
        ax3.plot(entropy_data, p(entropy_data), "r--", alpha=0.5,
                label=f'Trend: y={z[0]:.2f}x+{z[1]:.2f}')
        ax3.legend()
        ax3.grid(True, alpha=0.3)

        # Phase space trajectory
        ax4 = axes[1, 1]
        # Normalize for phase space
        entropy_norm = (entropy_data - np.min(entropy_data)) / (np.max(entropy_data) - np.min(entropy_data))
        alignment_norm = (alignment_data - np.min(alignment_data)) / (np.max(alignment_data) - np.min(alignment_data))

        ax4.plot(entropy_norm, alignment_norm, 'b-', alpha=0.5)
        ax4.scatter(entropy_norm[0], alignment_norm[0], color='green', s=100, marker='o', label='Start')
        ax4.scatter(entropy_norm[-1], alignment_norm[-1], color='red', s=100, marker='s', label='End')

        # Add vector field
        for i in range(0, len(entropy_norm)-1, max(1, len(entropy_norm)//20)):
            dx = entropy_norm[i+1] - entropy_norm[i]
            dy = alignment_norm[i+1] - alignment_norm[i]
            ax4.arrow(entropy_norm[i], alignment_norm[i], dx, dy,
                     head_width=0.02, head_length=0.02, fc='gray', ec='gray', alpha=0.3)

        ax4.set_xlabel('Normalized Entropy')
        ax4.set_ylabel('Normalized Alignment')
        ax4.set_title('Phase Space Trajectory')
        ax4.legend()
        ax4.grid(True, alpha=0.3)

        plt.suptitle('Entropy-Alignment Correlation Analysis', fontsize=16)
        plt.tight_layout()
        plt.show()

        # Compute correlation coefficient
        correlation = np.corrcoef(entropy_data, alignment_data)[0, 1]
        print(f"Entropy-Alignment Correlation Coefficient: {correlation:.4f}")

    def stochastic_noise_replay_visualizer(self, noise_seeds: List[Dict]) -> None:
        """Visualize and replay stochastic noise patterns."""
        fig = plt.figure(figsize=(14, 8))
        gs = GridSpec(2, 3, figure=fig)

        ax_noise = fig.add_subplot(gs[0, :2])
        ax_spectrum = fig.add_subplot(gs[0, 2])
        ax_phase = fig.add_subplot(gs[1, :2])
        ax_stats = fig.add_subplot(gs[1, 2])

        def update(frame):
            if frame >= len(noise_seeds):
                return

            seed_data = noise_seeds[frame]

            # Clear axes
            ax_noise.clear()
            ax_spectrum.clear()
            ax_phase.clear()
            ax_stats.clear()

            # Extract noise field
            if 'field' in seed_data:
                noise_field = seed_data['field']

                # Noise map
                hp.mollview(noise_field, ax=ax_noise, title=f'Stochastic Noise Field (Seed {frame})',
                          cmap=self.config.cmap, hold=True)

                # Power spectrum of noise
                cl_noise = hp.anafast(noise_field)
                ell = np.arange(len(cl_noise))
                ax_spectrum.semilogy(ell[1:], cl_noise[1:], 'b-')
                ax_spectrum.set_xlabel('ℓ')
                ax_spectrum.set_ylabel('Power')
                ax_spectrum.set_title('Noise Power Spectrum')
                ax_spectrum.grid(True, alpha=0.3)

            # Phase distribution
            if 'phases' in seed_data:
                phases = seed_data['phases']
                ax_phase.hist(phases, bins=50, alpha=0.7, color='purple', edgecolor='black')
                ax_phase.axvline(x=np.mean(phases), color='red', linestyle='--',
                               label=f'Mean: {np.mean(phases):.2f}')
                ax_phase.set_xlabel('Phase (radians)')
                ax_phase.set_ylabel('Frequency')
                ax_phase.set_title('Stochastic Phase Distribution')
                ax_phase.legend()

            # Statistics
            if 'statistics' in seed_data:
                stats = seed_data['statistics']
                stats_text = '\n'.join([f'{k}: {v:.4f}' for k, v in stats.items()])
                ax_stats.text(0.1, 0.5, stats_text, transform=ax_stats.transAxes,
                            fontsize=10, verticalalignment='center')
                ax_stats.set_title('Noise Statistics')
                ax_stats.axis('off')

            fig.suptitle(f'Stochastic Noise Pattern {frame + 1}/{len(noise_seeds)}',
                        fontsize=14)

        if self.config.real_time_update and len(noise_seeds) > 1:
            anim = FuncAnimation(fig, update, frames=len(noise_seeds),
                               interval=1000/self.config.animation_fps)

            if self.config.save_animations:
                anim.save('stochastic_noise_replay.gif', fps=self.config.animation_fps)
        else:
            update(0)

        plt.tight_layout()
        plt.show()

    def multipole_decomposition_viewer(self, alm: np.ndarray, lmax: int,
                                      l_range: Tuple[int, int] = (2, 10)) -> None:
        """Visualize multipole decomposition and alignment.

        Alignment vectors are computed using the Angular Momentum Dispersion
        (AMD) method (de Oliveira-Costa et al. 2004), not dipole projection.
        """
        from .cmb_generator import CMBGenerator

        l_min, l_max = l_range
        n_multipoles = l_max - l_min + 1

        fig, axes = plt.subplots(2, n_multipoles//2 + 1,
                                figsize=(16, 8),
                                subplot_kw={'projection': 'mollweide'})
        axes = axes.flatten()

        alignment_vectors = []
        nside = 64  # For visualization

        for idx, l in enumerate(range(l_min, min(l_max + 1, lmax + 1))):
            # Extract multipole
            multipole_alm = np.zeros_like(alm)
            for m in range(0, l + 1):
                alm_idx = hp.Alm.getidx(lmax, l, m)
                if alm_idx < len(alm):
                    multipole_alm[alm_idx] = alm[alm_idx]

            # Convert to map
            multipole_map = hp.alm2map(multipole_alm, nside)

            # Plot
            hp.mollview(multipole_map, ax=axes[idx],
                       title=f'ℓ = {l}',
                       cmap=self.config.cmap, hold=True)

            # Compute alignment vector using AMD method
            vec = self._compute_amd_alignment(alm, l, lmax)
            alignment_vectors.append(vec)

        # Remove extra axes
        for idx in range(n_multipoles, len(axes)):
            fig.delaxes(axes[idx])

        plt.suptitle(f'Multipole Decomposition (ℓ = {l_min} to {l_max})', fontsize=14)
        plt.tight_layout()
        plt.show()

        # Plot alignment vectors in 3D
        self._plot_alignment_vectors_3d(alignment_vectors, l_range)

    @staticmethod
    def _compute_amd_alignment(alm: np.ndarray, l: int,
                                lmax: int) -> np.ndarray:
        """Compute AMD alignment direction for multipole l.

        Uses the Angular Momentum Dispersion method: eigenvector of the
        3x3 tensor M_ij = a^dag {L_i,L_j}/2 a with largest eigenvalue.
        """
        from .cmb_generator import CMBGenerator

        # Extract full a_{l,m} for m = -l to +l
        full = np.zeros(2 * l + 1, dtype=complex)
        for m in range(0, l + 1):
            idx = hp.Alm.getidx(lmax, l, m)
            if idx < len(alm):
                full[l + m] = alm[idx]
                if m > 0:
                    full[l - m] = (-1)**m * np.conj(alm[idx])

        total_power = np.sum(np.abs(full)**2)
        if total_power < 1e-30:
            return np.array([0.0, 0.0, 1.0])

        Lx, Ly, Lz = CMBGenerator._build_angular_momentum_matrices(l)
        L_ops = [Lx, Ly, Lz]

        M = np.zeros((3, 3))
        for i in range(3):
            for j in range(i, 3):
                sym = (L_ops[i] @ L_ops[j] + L_ops[j] @ L_ops[i]) / 2.0
                val = np.real(np.conj(full) @ sym @ full)
                M[i, j] = val
                M[j, i] = val

        _, eigvecs = np.linalg.eigh(M)
        axis = eigvecs[:, -1].real
        if axis[2] < 0:
            axis = -axis
        return axis / np.linalg.norm(axis)

    def _plot_alignment_vectors_3d(self, vectors: List[np.ndarray],
                                  l_range: Tuple[int, int]) -> None:
        """3D visualization of multipole alignment vectors"""
        from mpl_toolkits.mplot3d import Axes3D

        fig = plt.figure(figsize=(10, 8))
        ax = fig.add_subplot(111, projection='3d')

        l_min, l_max = l_range
        colors = plt.cm.viridis(np.linspace(0, 1, len(vectors)))

        for idx, (vec, l) in enumerate(zip(vectors, range(l_min, l_min + len(vectors)))):
            # Plot vector
            ax.quiver(0, 0, 0, vec[0], vec[1], vec[2],
                     color=colors[idx], arrow_length_ratio=0.1,
                     linewidth=2, label=f'ℓ = {l}')

            # Add text label
            ax.text(vec[0]*1.1, vec[1]*1.1, vec[2]*1.1, f'ℓ={l}', fontsize=9)

        # Add unit sphere
        u = np.linspace(0, 2 * np.pi, 50)
        v = np.linspace(0, np.pi, 50)
        x = np.outer(np.cos(u), np.sin(v))
        y = np.outer(np.sin(u), np.sin(v))
        z = np.outer(np.ones(np.size(u)), np.cos(v))

        ax.plot_surface(x, y, z, alpha=0.1, color='gray')

        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        ax.set_title('Multipole Alignment Vectors')
        ax.legend(loc='upper right', fontsize=8)

        # Set equal aspect ratio
        ax.set_box_aspect([1,1,1])

        plt.show()

    def save_replay_data(self, data: Dict, filename: str) -> None:
        """Save simulation data for replay"""
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        full_filename = f"{filename}_{timestamp}.pkl"

        replay_data = {
            'iteration_data': self.iteration_data,
            'power_spectrum_history': self.power_spectrum_history,
            'entropy_history': self.entropy_history,
            'alignment_history': self.alignment_history,
            'stochastic_noise_seeds': self.quantum_noise_seeds,
            'simulation_data': data,
            'timestamp': timestamp
        }

        with open(full_filename, 'wb') as f:
            pickle.dump(replay_data, f)

        print(f"Replay data saved to {full_filename}")

    def load_and_replay(self, filename: str) -> None:
        """Load and replay saved simulation"""
        with open(filename, 'rb') as f:
            replay_data = pickle.load(f)

        print(f"Loaded simulation from {replay_data['timestamp']}")

        # Replay visualizations
        if replay_data['power_spectrum_history']:
            print("Replaying power spectrum evolution...")
            # Convert back to alm format if needed

        if replay_data['entropy_history'] and replay_data['alignment_history']:
            print("Replaying entropy-alignment correlation...")
            self.entropy_alignment_correlation(
                replay_data['entropy_history'],
                replay_data['alignment_history']
            )

        if replay_data['stochastic_noise_seeds']:
            print("Replaying stochastic noise patterns...")
            self.stochastic_noise_replay_visualizer(replay_data['stochastic_noise_seeds'])

    def plot_convergence_metrics(self, convergence_data: List[float]) -> None:
        """Plot convergence metrics with analysis"""
        fig, axes = plt.subplots(2, 2, figsize=(12, 10))

        iterations = np.arange(len(convergence_data))

        # Convergence curve
        ax1 = axes[0, 0]
        ax1.semilogy(iterations, convergence_data, 'b-', linewidth=2)
        ax1.set_xlabel('Iteration')
        ax1.set_ylabel('Convergence Metric (log scale)')
        ax1.set_title('Convergence History')
        ax1.grid(True, alpha=0.3)

        # Convergence rate
        ax2 = axes[0, 1]
        if len(convergence_data) > 1:
            rates = np.diff(np.log(convergence_data + 1e-15))
            ax2.plot(iterations[1:], rates, 'r-', linewidth=2)
            ax2.axhline(y=0, color='k', linestyle='--', alpha=0.3)
            ax2.set_xlabel('Iteration')
            ax2.set_ylabel('Log Convergence Rate')
            ax2.set_title('Convergence Rate Evolution')
            ax2.grid(True, alpha=0.3)

        # Running average
        ax3 = axes[1, 0]
        window = min(10, len(convergence_data) // 4)
        if window > 1:
            running_avg = np.convolve(convergence_data,
                                     np.ones(window)/window, mode='valid')
            ax3.plot(iterations[:len(running_avg)], running_avg, 'g-', linewidth=2)
            ax3.set_xlabel('Iteration')
            ax3.set_ylabel('Running Average')
            ax3.set_title(f'Running Average (window={window})')
            ax3.grid(True, alpha=0.3)

        # Convergence prediction
        ax4 = axes[1, 1]
        if len(convergence_data) > 10:
            # Fit exponential decay
            log_conv = np.log(convergence_data + 1e-15)
            z = np.polyfit(iterations, log_conv, 1)
            p = np.poly1d(z)

            # Predict future convergence
            future_iterations = np.arange(len(convergence_data),
                                        len(convergence_data) + 20)
            predicted = np.exp(p(future_iterations))

            ax4.semilogy(iterations, convergence_data, 'b-', label='Actual')
            ax4.semilogy(future_iterations, predicted, 'r--', label='Predicted')
            ax4.set_xlabel('Iteration')
            ax4.set_ylabel('Convergence Metric')
            ax4.set_title('Convergence Prediction')
            ax4.legend()
            ax4.grid(True, alpha=0.3)

        plt.suptitle('Convergence Analysis', fontsize=16)
        plt.tight_layout()
        plt.show()

    def plot_recursion_dynamics(self, recursion_data: List[Dict]) -> None:
        """Visualize dynamics across iterations.

        Expects recursion_data dicts with keys:
          'iteration', 'convergence', 'alignment_angle', 'entropy',
          'fractal_dimension'
        """
        if not recursion_data:
            return

        fig = plt.figure(figsize=(15, 10))
        gs = GridSpec(2, 2, figure=fig)

        iterations = [d.get('iteration', i) for i, d in enumerate(recursion_data)]

        # Convergence evolution
        ax1 = fig.add_subplot(gs[0, 0])
        conv = [d.get('convergence', 0.0) for d in recursion_data]
        ax1.semilogy(iterations, conv, 'b-o', linewidth=2, markersize=3)
        ax1.set_xlabel('Iteration')
        ax1.set_ylabel('Convergence Metric')
        ax1.set_title('Convergence Evolution')
        ax1.grid(True, alpha=0.3)

        # Alignment angle evolution
        ax2 = fig.add_subplot(gs[0, 1])
        angles = [d.get('alignment_angle', 0.0) for d in recursion_data]
        ax2.plot(iterations, angles, 'r-s', linewidth=2, markersize=3)
        ax2.set_xlabel('Iteration')
        ax2.set_ylabel('Alignment Angle (deg)')
        ax2.set_title('Q-O Alignment Evolution')
        ax2.set_ylim(0, 90)
        ax2.grid(True, alpha=0.3)

        # Entropy evolution
        ax3 = fig.add_subplot(gs[1, 0])
        entropy = [d.get('entropy', 0.0) for d in recursion_data]
        ax3.plot(iterations, entropy, 'g-^', linewidth=2, markersize=3)
        ax3.set_xlabel('Iteration')
        ax3.set_ylabel('Shannon Entropy')
        ax3.set_title('Power Spectrum Entropy')
        ax3.grid(True, alpha=0.3)

        # Fractal dimension tracking
        ax4 = fig.add_subplot(gs[1, 1])
        if 'fractal_dimension' in recursion_data[0]:
            fractal_dims = [d['fractal_dimension'] for d in recursion_data]
            ax4.plot(iterations, fractal_dims, 'm-o', linewidth=2, markersize=3)
            ax4.axhline(y=2.0, color='k', linestyle='--', alpha=0.3,
                        label='D=2 (Gaussian)')
            ax4.set_xlabel('Iteration')
            ax4.set_ylabel('Fractal Dimension')
            ax4.set_title('Fractal Dimension Evolution')
            ax4.legend()
            ax4.grid(True, alpha=0.3)

        plt.suptitle('Correction Dynamics Analysis', fontsize=16)
        plt.tight_layout()
        plt.show()