import numpy as np
import matplotlib.pyplot as plt
import healpy as hp
from mpl_toolkits.mplot3d import Axes3D
from typing import Dict, List, Optional, Tuple
import matplotlib.gridspec as gridspec

class CMBVisualizer:
    def __init__(self, figsize: Tuple[int, int] = (15, 10)):
        self.figsize = figsize
        
    def plot_cmb_map(self, map_data: np.ndarray, title: str = "CMB Map", 
                     cmap: str = 'RdBu_r', save_path: Optional[str] = None):
        fig = plt.figure(figsize=(12, 6))
        hp.mollview(map_data, title=title, cmap=cmap, fig=fig)
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()
    
    def plot_alignment_vectors(self, vectors: Dict[str, np.ndarray], 
                              save_path: Optional[str] = None):
        fig = plt.figure(figsize=(10, 10))
        ax = fig.add_subplot(111, projection='3d')
        
        ax.quiver(0, 0, 0, 1, 0, 0, color='gray', alpha=0.3, arrow_length_ratio=0.1)
        ax.quiver(0, 0, 0, 0, 1, 0, color='gray', alpha=0.3, arrow_length_ratio=0.1)
        ax.quiver(0, 0, 0, 0, 0, 1, color='gray', alpha=0.3, arrow_length_ratio=0.1)
        
        colors = ['red', 'blue', 'green', 'orange']
        for idx, (label, vector) in enumerate(vectors.items()):
            if vector is not None and len(vector) == 3:
                ax.quiver(0, 0, 0, vector[0], vector[1], vector[2], 
                         color=colors[idx % len(colors)], 
                         arrow_length_ratio=0.1, 
                         linewidth=3,
                         label=label)
        
        ax.set_xlim([-1.2, 1.2])
        ax.set_ylim([-1.2, 1.2])
        ax.set_zlim([-1.2, 1.2])
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        ax.set_title('CMB Multipole Alignment Vectors')
        ax.legend()
        
        ax.plot([-1, 1], [0, 0], [0, 0], 'k--', alpha=0.3)
        ax.plot([0, 0], [-1, 1], [0, 0], 'k--', alpha=0.3)
        ax.plot([0, 0], [0, 0], [-1, 1], 'k--', alpha=0.3)
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()
    
    def plot_power_spectrum(self, cl_dict: Dict[str, np.ndarray], 
                           save_path: Optional[str] = None):
        fig, ax = plt.subplots(figsize=(10, 6))
        
        for label, cl in cl_dict.items():
            ell = np.arange(len(cl))
            ax.plot(ell[2:], cl[2:] * ell[2:] * (ell[2:] + 1) / (2 * np.pi), 
                   'o-', label=label, markersize=8)
        
        ax.set_xlabel(r'$\ell$')
        ax.set_ylabel(r'$\ell(\ell+1)C_\ell/2\pi$')
        ax.set_title('CMB Power Spectrum')
        ax.legend()
        ax.grid(True, alpha=0.3)
        ax.set_xlim(1.5, max(4, len(next(iter(cl_dict.values()))) - 0.5))
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()
    
    def plot_comparison(self, uncorrected_data: Dict, corrected_data: Dict,
                       save_path: Optional[str] = None):
        fig = plt.figure(figsize=(20, 15))
        gs = gridspec.GridSpec(3, 2, height_ratios=[2, 1, 1])
        
        ax1 = plt.subplot(gs[0, 0])
        hp.mollview(uncorrected_data['map'], title="Uncorrected CMB", 
                   cmap='RdBu_r', hold=True, sub=(3, 2, 1))
        
        ax2 = plt.subplot(gs[0, 1])
        hp.mollview(corrected_data['map'], title="Corrected CMB", 
                   cmap='RdBu_r', hold=True, sub=(3, 2, 2))
        
        ax3 = plt.subplot(gs[1, :])
        ax3.bar(['Uncorrected', 'Corrected'], 
               [uncorrected_data['alignment']['alignment_angle'],
                corrected_data['alignment']['alignment_angle']],
               color=['blue', 'red'])
        ax3.set_ylabel('Alignment Angle (degrees)')
        ax3.set_title('Quadrupole-Octopole Alignment')
        ax3.set_ylim(0, 90)
        
        for i, angle in enumerate([uncorrected_data['alignment']['alignment_angle'],
                                  corrected_data['alignment']['alignment_angle']]):
            ax3.text(i, angle + 5, f'{angle:.1f}°', ha='center')
        
        ax4 = plt.subplot(gs[2, :])
        metrics = corrected_data.get('attractor_metrics', {})
        if metrics:
            labels = list(metrics.keys())
            values = []
            for key, val in metrics.items():
                if isinstance(val, (int, float)):
                    values.append(val)
                elif isinstance(val, np.ndarray):
                    values.append(np.mean(val))
                else:
                    values.append(0)
            
            x = np.arange(len(labels))
            ax4.bar(x, values, color='green')
            ax4.set_xticks(x)
            ax4.set_xticklabels(labels, rotation=45, ha='right')
            ax4.set_ylabel('Value')
            ax4.set_title('Attractor Correction Metrics')
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()
    
    def plot_phase_evolution(self, phase_history: List[Dict[str, float]], 
                           save_path: Optional[str] = None):
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8))
        
        iterations = list(range(len(phase_history)))
        alignment_angles = [p.get('alignment_angle', 0) for p in phase_history]
        phase_coherences = [p.get('phase_coherence', 0) for p in phase_history]
        
        ax1.plot(iterations, alignment_angles, 'b-o', linewidth=2, markersize=6)
        ax1.set_xlabel('Iteration')
        ax1.set_ylabel('Alignment Angle (degrees)')
        ax1.set_title('Evolution of Quadrupole-Octopole Alignment')
        ax1.grid(True, alpha=0.3)
        ax1.set_ylim(0, 90)
        
        ax2.plot(iterations, phase_coherences, 'r-s', linewidth=2, markersize=6)
        ax2.set_xlabel('Iteration')
        ax2.set_ylabel('Phase Coherence')
        ax2.set_title('Evolution of Phase Coherence')
        ax2.grid(True, alpha=0.3)
        ax2.set_ylim(0, 1)
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()