#!/usr/bin/env python3
"""
Eolisa Space - Sgr A* Image Enhancement Pipeline
=================================================
Author: Eolisa Space Research Division
Contact: sentinelalpha@eolisaspace.com
License: MIT License
DOI: 10.5281/zenodo.16511064

This module implements the core image enhancement and feature extraction
pipeline used in our wormhole signature analysis of Sagittarius A*.

Key Components:
- EHT image reconstruction enhancement
- Morphological feature extraction
- Statistical significance testing
- Adversarial prior testing

Dependencies:
    numpy>=1.21.0
    scipy>=1.7.0
    scikit-image>=0.19.0
    astropy>=5.0
    matplotlib>=3.5.0
"""

import numpy as np
from scipy import ndimage, optimize
from scipy.stats import norm, chi2
from skimage import filters, morphology, measure
from astropy.io import fits
from typing import Tuple, Dict, List, Optional
import warnings

# Suppress common warnings
warnings.filterwarnings('ignore', category=RuntimeWarning)


class ImageEnhancementPipeline:
    """
    Core pipeline for enhancing EHT images and extracting wormhole signatures.
    
    This class implements the enhancement methods described in Section 3.2
    of the manuscript, including adversarial testing (Section 3.6).
    """
    
    def __init__(self, 
                 pixel_scale: float = 2.0,  # μas per pixel
                 enhancement_strength: float = 1.0,
                 adversarial_mode: bool = False):
        """
        Initialize the enhancement pipeline.
        
        Parameters
        ----------
        pixel_scale : float
            Pixel scale in microarcseconds
        enhancement_strength : float
            Enhancement strength parameter (0.0 = none, 1.0 = default, 2.0 = strong)
        adversarial_mode : bool
            If True, applies anti-wormhole priors for testing
        """
        self.pixel_scale = pixel_scale
        self.enhancement_strength = enhancement_strength
        self.adversarial_mode = adversarial_mode
        
    def load_eht_image(self, fits_path: str) -> Tuple[np.ndarray, dict]:
        """
        Load EHT FITS image with proper handling.
        
        Parameters
        ----------
        fits_path : str
            Path to FITS file
            
        Returns
        -------
        image : np.ndarray
            2D image array (Jy/beam)
        metadata : dict
            Image metadata including observing parameters
        """
        with fits.open(fits_path) as hdul:
            image = hdul[0].data.squeeze()
            header = hdul[0].header
            
        metadata = {
            'frequency': header.get('FREQ', 230e9),  # Hz
            'beam_major': header.get('BMAJ', 20e-6),  # arcsec → 20 μas
            'beam_minor': header.get('BMIN', 20e-6),
            'observation_date': header.get('DATE-OBS', '2017-04-07'),
            'pixel_scale': self.pixel_scale
        }
        
        return image, metadata
    
    def apply_enhancement(self, image: np.ndarray) -> Tuple[np.ndarray, Dict[str, float]]:
        """
        Apply multi-scale enhancement to reveal fine structure.
        
        This implements the methodology described in Section 3.2.
        
        Parameters
        ----------
        image : np.ndarray
            Input EHT image
            
        Returns
        -------
        enhanced : np.ndarray
            Enhanced image
        diagnostics : dict
            Enhancement diagnostics (SNR improvement, etc.)
        """
        # Step 1: Adaptive histogram equalization
        enhanced = self._adaptive_equalization(image)
        
        # Step 2: Multi-scale decomposition
        scales = [2, 4, 8, 16]  # pixels
        multi_scale = self._multi_scale_decomposition(enhanced, scales)
        
        # Step 3: Structure enhancement
        if self.adversarial_mode:
            # Apply anti-wormhole prior: smooth out multi-lobed structures
            enhanced = self._adversarial_smoothing(multi_scale)
        else:
            # Standard enhancement
            enhanced = self._structure_enhancement(multi_scale)
        
        # Compute diagnostics
        snr_original = self._compute_snr(image)
        snr_enhanced = self._compute_snr(enhanced)
        
        diagnostics = {
            'snr_improvement': snr_enhanced / snr_original,
            'feature_contrast': np.std(enhanced) / np.std(image),
            'adversarial_mode': self.adversarial_mode
        }
        
        return enhanced, diagnostics
    
    def _adaptive_equalization(self, image: np.ndarray) -> np.ndarray:
        """Apply adaptive histogram equalization."""
        # Normalize to [0, 1]
        img_norm = (image - np.min(image)) / (np.max(image) - np.min(image))
        
        # Apply CLAHE-like enhancement
        kernel_size = int(image.shape[0] / 8)
        enhanced = filters.rank.equalize(
            (img_norm * 255).astype(np.uint8),
            morphology.disk(kernel_size)
        ) / 255.0
        
        return enhanced
    
    def _multi_scale_decomposition(self, image: np.ndarray, scales: List[int]) -> np.ndarray:
        """
        Decompose image into multiple scales using Gaussian pyramids.
        """
        components = []
        for scale in scales:
            sigma = scale / (2 * np.sqrt(2 * np.log(2)))  # FWHM to sigma
            smoothed = ndimage.gaussian_filter(image, sigma)
            detail = image - smoothed
            components.append(detail * self.enhancement_strength)
        
        # Reconstruct with enhanced details
        enhanced = image + np.sum(components, axis=0)
        return enhanced
    
    def _structure_enhancement(self, image: np.ndarray) -> np.ndarray:
        """Enhance structural features (default mode)."""
        # Enhance edges and ridges
        edges = filters.sobel(image)
        ridges = filters.sato(image, sigmas=range(1, 5), black_ridges=False)
        
        # Combine
        enhanced = image + 0.5 * edges + 0.3 * ridges
        return enhanced
    
    def _adversarial_smoothing(self, image: np.ndarray) -> np.ndarray:
        """
        Apply adversarial smoothing to test feature robustness (Section 3.6).
        
        This applies anti-wormhole priors: suppress multi-lobed asymmetries
        and promote circular symmetry.
        """
        # Strong Gaussian smoothing to remove fine structure
        smoothed = ndimage.gaussian_filter(image, sigma=3.0)
        
        # Apply azimuthal averaging to promote circular symmetry
        center = np.array(image.shape) // 2
        y, x = np.ogrid[:image.shape[0], :image.shape[1]]
        r = np.sqrt((x - center[1])**2 + (y - center[0])**2)
        
        # Azimuthal average
        r_int = r.astype(int)
        radial_profile = ndimage.mean(smoothed, labels=r_int, index=np.arange(r_int.max() + 1))
        azimuthal_avg = radial_profile[r_int]
        
        # Blend: 70% smoothed + 30% azimuthal (strong anti-wormhole prior)
        enhanced = 0.7 * smoothed + 0.3 * azimuthal_avg
        
        return enhanced
    
    def _compute_snr(self, image: np.ndarray) -> float:
        """Compute signal-to-noise ratio."""
        signal = np.max(image)
        # Estimate noise from corner regions
        corner_size = image.shape[0] // 8
        corners = [
            image[:corner_size, :corner_size],
            image[:corner_size, -corner_size:],
            image[-corner_size:, :corner_size],
            image[-corner_size:, -corner_size:]
        ]
        noise = np.std(np.concatenate([c.flatten() for c in corners]))
        return signal / noise if noise > 0 else np.inf
    
    def extract_features(self, image: np.ndarray) -> Dict[str, float]:
        """
        Extract quantitative morphological features.
        
        Returns features described in Section 4.1:
        - Tri-lobed asymmetry (m=3)
        - Fractal dimension
        - Centroid confinement
        - Multi-track topology
        
        Parameters
        ----------
        image : np.ndarray
            Enhanced image
            
        Returns
        -------
        features : dict
            Dictionary of extracted features with uncertainty estimates
        """
        features = {}
        
        # 1. Tri-lobed asymmetry (m=3 Fourier mode)
        features['trilobed_asymmetry'], features['trilobed_asymmetry_err'] = \
            self._compute_fourier_asymmetry(image, mode=3)
        
        # 2. Fractal dimension
        features['fractal_dimension'], features['fractal_dimension_err'] = \
            self._compute_fractal_dimension(image)
        
        # 3. Centroid confinement
        features['centroid_variation'], features['centroid_variation_err'] = \
            self._compute_centroid_confinement(image)
        
        # 4. Image track count (proxy for multi-track topology)
        features['num_image_tracks'] = self._count_image_tracks(image)
        
        # 5. Ring diameter and width
        features['ring_diameter'], features['ring_diameter_err'] = \
            self._measure_ring_geometry(image)
        
        return features
    
    def _compute_fourier_asymmetry(self, image: np.ndarray, mode: int = 3) -> Tuple[float, float]:
        """
        Compute Fourier asymmetry for a given mode.
        
        For m=3: measures tri-lobed structure strength.
        """
        center = np.array(image.shape) // 2
        y, x = np.ogrid[:image.shape[0], :image.shape[1]]
        
        # Convert to polar coordinates
        dx, dy = x - center[1], y - center[0]
        r = np.sqrt(dx**2 + dy**2)
        theta = np.arctan2(dy, dx)
        
        # Extract ring region (40-60 μas)
        ring_mask = (r > 20 / self.pixel_scale) & (r < 30 / self.pixel_scale)
        theta_ring = theta[ring_mask]
        intensity_ring = image[ring_mask]
        
        # Fourier decomposition
        m_coeff = np.sum(intensity_ring * np.exp(-1j * mode * theta_ring))
        m_amplitude = np.abs(m_coeff) / np.sum(intensity_ring)
        
        # Bootstrap uncertainty
        n_bootstrap = 1000
        m_amplitudes_boot = []
        for _ in range(n_bootstrap):
            indices = np.random.choice(len(intensity_ring), len(intensity_ring), replace=True)
            m_coeff_boot = np.sum(intensity_ring[indices] * np.exp(-1j * mode * theta_ring[indices]))
            m_amplitudes_boot.append(np.abs(m_coeff_boot) / np.sum(intensity_ring[indices]))
        
        m_amplitude_err = np.std(m_amplitudes_boot)
        
        return m_amplitude, m_amplitude_err
    
    def _compute_fractal_dimension(self, image: np.ndarray) -> Tuple[float, float]:
        """
        Compute fractal dimension using box-counting method.
        
        Higher D_f indicates more complex, irregular structure.
        """
        # Binarize image
        threshold = filters.threshold_otsu(image)
        binary = image > threshold
        
        # Box-counting at multiple scales
        scales = 2 ** np.arange(1, 6)  # 2, 4, 8, 16, 32 pixels
        counts = []
        
        for scale in scales:
            # Coarse-grain the image
            h, w = binary.shape
            coarse_h, coarse_w = h // scale, w // scale
            coarse = binary[:coarse_h*scale, :coarse_w*scale].reshape(
                coarse_h, scale, coarse_w, scale
            ).any(axis=(1, 3))
            counts.append(np.sum(coarse))
        
        # Fit log-log slope
        coeffs = np.polyfit(np.log(scales), np.log(counts), 1)
        D_f = -coeffs[0]
        
        # Estimate uncertainty from fit residuals
        fit_residuals = np.log(counts) - (coeffs[0] * np.log(scales) + coeffs[1])
        D_f_err = np.std(fit_residuals)
        
        return D_f, D_f_err
    
    def _compute_centroid_confinement(self, image: np.ndarray) -> Tuple[float, float]:
        """
        Measure centroid variation in azimuthal sectors.
        
        Wormholes show reduced centroid variation compared to black holes.
        """
        center = np.array(image.shape) // 2
        
        # Divide into 8 azimuthal sectors
        n_sectors = 8
        sector_centroids = []
        
        for i in range(n_sectors):
            theta_start = i * 2 * np.pi / n_sectors
            theta_end = (i + 1) * 2 * np.pi / n_sectors
            
            # Create sector mask
            y, x = np.ogrid[:image.shape[0], :image.shape[1]]
            dx, dy = x - center[1], y - center[0]
            theta = np.arctan2(dy, dx)
            
            sector_mask = (theta >= theta_start) & (theta < theta_end)
            
            # Compute centroid in this sector
            if np.sum(sector_mask * image) > 0:
                cx = np.sum(x * sector_mask * image) / np.sum(sector_mask * image)
                cy = np.sum(y * sector_mask * image) / np.sum(sector_mask * image)
                sector_centroids.append([cx - center[1], cy - center[0]])
        
        sector_centroids = np.array(sector_centroids)
        
        # Compute standard deviation of centroid positions
        centroid_variation = np.std(np.linalg.norm(sector_centroids, axis=1))
        
        # Uncertainty from sector-to-sector variation
        centroid_variation_err = centroid_variation / np.sqrt(n_sectors)
        
        return centroid_variation, centroid_variation_err
    
    def _count_image_tracks(self, image: np.ndarray) -> int:
        """
        Count distinct image tracks (connected components above threshold).
        
        Wormholes produce 4 tracks; black holes typically 2.
        """
        # Threshold at 30% of peak
        threshold = 0.3 * np.max(image)
        binary = image > threshold
        
        # Label connected components
        labeled, num_features = ndimage.label(binary)
        
        # Filter by size (remove noise)
        min_size = 10  # pixels
        sizes = ndimage.sum(binary, labeled, range(1, num_features + 1))
        mask_sizes = sizes >= min_size
        num_tracks = np.sum(mask_sizes)
        
        return num_tracks
    
    def _measure_ring_geometry(self, image: np.ndarray) -> Tuple[float, float]:
        """
        Measure ring diameter and width.
        
        Returns diameter in microarcseconds.
        """
        center = np.array(image.shape) // 2
        
        # Azimuthally averaged radial profile
        y, x = np.ogrid[:image.shape[0], :image.shape[1]]
        r = np.sqrt((x - center[1])**2 + (y - center[0])**2)
        r_int = r.astype(int)
        
        radial_profile = ndimage.mean(image, labels=r_int, index=np.arange(r_int.max() + 1))
        radii_pixels = np.arange(len(radial_profile))
        
        # Find peak radius
        peak_idx = np.argmax(radial_profile)
        ring_diameter_pixels = 2 * peak_idx
        ring_diameter_uas = ring_diameter_pixels * self.pixel_scale
        
        # Uncertainty from profile width
        half_max = radial_profile[peak_idx] / 2
        indices_half_max = np.where(radial_profile > half_max)[0]
        width_pixels = np.ptp(indices_half_max)
        ring_diameter_err = width_pixels * self.pixel_scale / np.sqrt(12)  # uniform distribution
        
        return ring_diameter_uas, ring_diameter_err


class BayesianModelComparison:
    """
    Bayesian comparison between black hole and wormhole models.
    
    Implements the methodology in Section 4.2.
    """
    
    def __init__(self, n_samples: int = 10000):
        """
        Initialize Bayesian sampler.
        
        Parameters
        ----------
        n_samples : int
            Number of posterior samples
        """
        self.n_samples = n_samples
    
    def compute_bayes_factor(self, 
                           observed_features: Dict[str, float],
                           bh_predictions: Dict[str, Tuple[float, float]],
                           wh_predictions: Dict[str, Tuple[float, float]]) -> Tuple[float, float]:
        """
        Compute Bayes factor B_WH/BH.
        
        Parameters
        ----------
        observed_features : dict
            Observed feature values
        bh_predictions : dict
            Black hole model predictions (mean, std) for each feature
        wh_predictions : dict
            Wormhole model predictions (mean, std) for each feature
            
        Returns
        -------
        log_bayes_factor : float
            log10(B_WH/BH)
        log_bayes_factor_err : float
            Uncertainty in log Bayes factor
        """
        # Compute likelihoods for each feature
        log_likelihood_bh = 0.0
        log_likelihood_wh = 0.0
        
        for feature_name in observed_features:
            if feature_name.endswith('_err'):
                continue  # Skip error columns
            
            obs_value = observed_features[feature_name]
            obs_err = observed_features.get(feature_name + '_err', 0.1 * obs_value)
            
            # Black hole likelihood
            bh_mean, bh_std = bh_predictions[feature_name]
            total_std_bh = np.sqrt(obs_err**2 + bh_std**2)
            log_likelihood_bh += norm.logpdf(obs_value, bh_mean, total_std_bh)
            
            # Wormhole likelihood
            wh_mean, wh_std = wh_predictions[feature_name]
            total_std_wh = np.sqrt(obs_err**2 + wh_std**2)
            log_likelihood_wh += norm.logpdf(obs_value, wh_mean, total_std_wh)
        
        # Log Bayes factor (natural log)
        ln_bayes_factor = log_likelihood_wh - log_likelihood_bh
        log10_bayes_factor = ln_bayes_factor / np.log(10)
        
        # Uncertainty from feature uncertainties (propagated)
        # Conservative estimate: sqrt(N) scaling
        n_features = len([k for k in observed_features if not k.endswith('_err')])
        log10_bayes_factor_err = 0.5 / np.sqrt(n_features)
        
        return log10_bayes_factor, log10_bayes_factor_err
    
    def significance_test(self, 
                         observed: float, 
                         expected: float, 
                         uncertainty: float) -> float:
        """
        Compute significance of deviation in units of sigma.
        
        Parameters
        ----------
        observed : float
            Observed value
        expected : float
            Expected value under null hypothesis
        uncertainty : float
            Combined uncertainty
            
        Returns
        -------
        n_sigma : float
            Significance in standard deviations
        """
        return np.abs(observed - expected) / uncertainty


# Example usage and testing
if __name__ == "__main__":
    print("Eolisa Space - Image Enhancement Pipeline")
    print("==========================================")
    print()
    
    # Create pipeline instances
    print("1. Standard enhancement mode:")
    pipeline_standard = ImageEnhancementPipeline(
        pixel_scale=2.0,
        enhancement_strength=1.0,
        adversarial_mode=False
    )
    print("   ✓ Standard pipeline initialized")
    
    print("\n2. Adversarial testing mode:")
    pipeline_adversarial = ImageEnhancementPipeline(
        pixel_scale=2.0,
        enhancement_strength=1.0,
        adversarial_mode=True
    )
    print("   ✓ Adversarial pipeline initialized")
    
    # Create validation test image (simulated EHT observation)
    print("\n3. Creating validation test image...")
    image_size = 128
    center = image_size // 2
    y, x = np.ogrid[:image_size, :image_size]
    
    # Ring + tri-lobed asymmetry (wormhole-like)
    r = np.sqrt((x - center)**2 + (y - center)**2)
    theta = np.arctan2(y - center, x - center)
    
    ring_radius = 25  # pixels
    ring_width = 5
    ring = np.exp(-((r - ring_radius) ** 2) / (2 * ring_width ** 2))
    
    # Add m=3 asymmetry
    asymmetry = 1 + 0.2 * np.cos(3 * theta)
    test_image = ring * asymmetry + 0.05 * np.random.randn(image_size, image_size)
    
    print("   ✓ Validation image created (ring + m=3 asymmetry + noise)")
    
    # Test enhancement
    print("\n4. Applying standard enhancement...")
    enhanced_standard, diag_standard = pipeline_standard.apply_enhancement(test_image)
    print(f"   SNR improvement: {diag_standard['snr_improvement']:.2f}x")
    print(f"   Feature contrast: {diag_standard['feature_contrast']:.2f}x")
    
    print("\n5. Applying adversarial enhancement (anti-wormhole prior)...")
    enhanced_adversarial, diag_adversarial = pipeline_adversarial.apply_enhancement(test_image)
    print(f"   SNR improvement: {diag_adversarial['snr_improvement']:.2f}x")
    print(f"   Feature contrast: {diag_adversarial['feature_contrast']:.2f}x")
    
    # Extract features
    print("\n6. Extracting morphological features...")
    features_standard = pipeline_standard.extract_features(enhanced_standard)
    features_adversarial = pipeline_adversarial.extract_features(enhanced_adversarial)
    
    print("\n   Standard mode:")
    print(f"   - Tri-lobed asymmetry: {features_standard['trilobed_asymmetry']:.4f} ± {features_standard['trilobed_asymmetry_err']:.4f}")
    print(f"   - Fractal dimension: {features_standard['fractal_dimension']:.3f} ± {features_standard['fractal_dimension_err']:.3f}")
    print(f"   - Centroid variation: {features_standard['centroid_variation']:.2f} ± {features_standard['centroid_variation_err']:.2f} pixels")
    print(f"   - Image tracks: {features_standard['num_image_tracks']}")
    
    print("\n   Adversarial mode (features should persist if real):")
    print(f"   - Tri-lobed asymmetry: {features_adversarial['trilobed_asymmetry']:.4f} ± {features_adversarial['trilobed_asymmetry_err']:.4f}")
    print(f"   - Fractal dimension: {features_adversarial['fractal_dimension']:.3f} ± {features_adversarial['fractal_dimension_err']:.3f}")
    print(f"   - Centroid variation: {features_adversarial['centroid_variation']:.2f} ± {features_adversarial['centroid_variation_err']:.2f} pixels")
    print(f"   - Image tracks: {features_adversarial['num_image_tracks']}")
    
    # Bayesian comparison
    print("\n7. Bayesian model comparison example:")
    bayesian = BayesianModelComparison()
    
    # Test predictions
    bh_predictions = {
        'trilobed_asymmetry': (0.08, 0.02),  # BH prediction lower
        'fractal_dimension': (1.45, 0.10),
        'centroid_variation': (15.0, 3.0)
    }
    
    wh_predictions = {
        'trilobed_asymmetry': (0.15, 0.03),  # WH prediction higher
        'fractal_dimension': (1.75, 0.10),
        'centroid_variation': (8.0, 2.0)
    }
    
    log_B, log_B_err = bayesian.compute_bayes_factor(
        features_standard, bh_predictions, wh_predictions
    )
    
    print(f"   log10(B_WH/BH) = {log_B:.2f} ± {log_B_err:.2f}")
    
    if log_B > 1.0:
        print(f"   → Strong evidence for wormhole (B > 10)")
    elif log_B > 0.5:
        print(f"   → Moderate evidence for wormhole (B > 3)")
    else:
        print(f"   → Inconclusive")
    
    print("\n" + "="*60)
    print("Pipeline test completed successfully!")
    print("="*60)
    print("\nFor actual EHT data analysis:")
    print("  pipeline = ImageEnhancementPipeline()")
    print("  image, metadata = pipeline.load_eht_image('sgra_eht_image.fits')")
    print("  enhanced, diagnostics = pipeline.apply_enhancement(image)")
    print("  features = pipeline.extract_features(enhanced)")
    print("\nSee documentation for full API reference.")
