"""
End-to-End Experimental Simulation (Monte Carlo)
=================================================

Objective: Simulate the COMPLETE torsion balance experiment from start to finish,
including all noise sources, systematic errors, and data analysis pipeline.

This proves you can actually detect the signal and quantifies false positive/negative rates.

Key Features:
- Realistic noise realizations (thermal, seismic, electronic)
- Systematic errors (fiber drift, temperature fluctuations, field inhomogeneity)
- Data acquisition artifacts (sampling, quantization)
- Statistical analysis (signal extraction, confidence intervals)
- Monte Carlo: 1000+ experimental runs to get probability distributions
"""

import numpy as np
import json
from dataclasses import dataclass
from typing import List, Tuple, Dict
import scipy.signal as signal
from scipy import stats

# Physical Constants
K_B = 1.380649e-23  # Boltzmann constant (J/K)

# Experimental Parameters (from previous optimizations)
T_CRYO = 4.0  # K
M_TEST = 1e-3  # kg
R_TEST = 0.01  # m
FIBER_KAPPA = 1.2e-11  # N·m/rad (quartz, 50 μm)
Q_FACTOR = 1e7
F0 = 0.055  # Hz (resonant frequency)

# Signal Parameters
A_TARGET = 1e-10  # m/s² (target acceleration to detect)
A_SIGNAL = 9e-11  # m/s² (90% of threshold, from acceleration profile)
MEASUREMENT_DURATION = 3600  # seconds (1 hour)
SAMPLING_RATE = 10  # Hz (data acquisition rate)

# Noise Parameters
THETA_NOISE_DENSITY = 1e-15  # rad/√Hz (total noise floor)
SEISMIC_AMPLITUDE = 1e-7  # m/s²/√Hz at 1 Hz
THERMAL_FLUCTUATION = 0.01  # K (temperature stability)

# Systematic Error Parameters
FIBER_KAPPA_UNCERTAINTY = 0.05  # 5% uncertainty in torsion constant
MASS_POSITION_UNCERTAINTY = 1e-4  # m (0.1 mm)
FIELD_GRADIENT_UNCERTAINTY = 0.02  # 2% uncertainty in magnetic gradient
TEMPERATURE_DRIFT = 0.001  # K/hour


@dataclass
class ExperimentalRun:
    """Results from a single experimental run."""
    measured_acceleration: float
    uncertainty: float
    snr: float
    detection_confidence: float
    is_true_positive: bool
    is_false_positive: bool
    raw_data: np.ndarray
    filtered_data: np.ndarray


class TorsionBalanceExperiment:
    """Full end-to-end simulation of the torsion balance experiment."""
    
    def __init__(self, true_signal_present: bool = True):
        """
        Args:
            true_signal_present: If True, inject real signal. If False, null test.
        """
        self.true_signal_present = true_signal_present
        self.kappa = FIBER_KAPPA
        self.Q = Q_FACTOR
        self.f0 = F0
        self.omega_0 = 2 * np.pi * self.f0
        
        # Calculate moment of inertia
        self.I = M_TEST * R_TEST**2
        
    def generate_signal(self, time_array: np.ndarray) -> np.ndarray:
        """
        Generate the true signal (if present).
        
        Uses trapezoidal acceleration profile from previous optimization.
        """
        if not self.true_signal_present:
            return np.zeros_like(time_array)
        
        # Trapezoidal profile
        ramp_time = MEASUREMENT_DURATION * 0.2
        hold_time = MEASUREMENT_DURATION * 0.6
        
        acceleration = np.zeros_like(time_array)
        
        for i, t in enumerate(time_array):
            if t < ramp_time:
                acceleration[i] = A_SIGNAL * (t / ramp_time)
            elif t < ramp_time + hold_time:
                acceleration[i] = A_SIGNAL
            else:
                t_down = t - (ramp_time + hold_time)
                acceleration[i] = A_SIGNAL * (1 - t_down / ramp_time)
        
        # Convert acceleration to angular displacement
        # θ = (m * r * a) / κ
        torque = M_TEST * R_TEST * acceleration
        theta_signal = torque / self.kappa
        
        return theta_signal
    
    def generate_thermal_noise(self, time_array: np.ndarray) -> np.ndarray:
        """
        Generate thermal (Brownian) noise.
        
        Uses Lorentzian response around resonant frequency.
        """
        dt = time_array[1] - time_array[0]
        n_samples = len(time_array)
        
        # Generate white noise in frequency domain
        freq = np.fft.rfftfreq(n_samples, dt)
        
        # Lorentzian response
        omega = 2 * np.pi * freq
        lorentzian = 1 / np.sqrt((omega**2 - self.omega_0**2)**2 + (self.omega_0**2 / self.Q**2))
        
        # Thermal noise PSD
        S_thermal = (4 * K_B * T_CRYO * self.omega_0) / (self.kappa * self.Q)
        
        # Generate noise in frequency domain
        noise_fft = np.random.normal(0, 1, len(freq)) + 1j * np.random.normal(0, 1, len(freq))
        noise_fft *= np.sqrt(S_thermal * lorentzian * n_samples / (2 * dt))
        
        # Transform to time domain
        noise_time = np.fft.irfft(noise_fft, n_samples)
        
        return noise_time
    
    def generate_seismic_noise(self, time_array: np.ndarray) -> np.ndarray:
        """
        Generate seismic noise.
        
        Follows 1/f² power law at low frequencies.
        """
        dt = time_array[1] - time_array[0]
        n_samples = len(time_array)
        
        freq = np.fft.rfftfreq(n_samples, dt)
        
        # Seismic noise PSD: S ∝ 1/f⁴
        S_seismic = (SEISMIC_AMPLITUDE**2) * (1.0 / np.maximum(freq, 0.001))**4
        
        # Convert to angular displacement
        # Coupling: θ ≈ (m * r / κ) * a_seismic
        coupling = (M_TEST * R_TEST / self.kappa)**2
        S_theta_seismic = coupling * S_seismic
        
        # Generate noise
        noise_fft = np.random.normal(0, 1, len(freq)) + 1j * np.random.normal(0, 1, len(freq))
        noise_fft *= np.sqrt(S_theta_seismic * n_samples / (2 * dt))
        
        noise_time = np.fft.irfft(noise_fft, n_samples)
        
        return noise_time
    
    def generate_electronic_noise(self, time_array: np.ndarray) -> np.ndarray:
        """Generate white electronic noise from readout system."""
        # White noise
        theta_noise_rms = THETA_NOISE_DENSITY * np.sqrt(SAMPLING_RATE / 2)
        noise = np.random.normal(0, theta_noise_rms, len(time_array))
        return noise
    
    def apply_systematic_errors(self, theta_signal: np.ndarray, time_array: np.ndarray) -> np.ndarray:
        """
        Apply systematic errors to the signal.
        
        Errors:
        1. Fiber torsion constant uncertainty
        2. Temperature drift
        3. Magnetic field gradient drift
        """
        # Fiber κ uncertainty (random offset for this run)
        kappa_error = np.random.normal(0, FIBER_KAPPA_UNCERTAINTY)
        theta_signal *= (1 + kappa_error)
        
        # Temperature drift (linear drift over measurement)
        temp_drift = TEMPERATURE_DRIFT * (time_array / 3600)  # K
        # Temperature affects κ slightly
        kappa_temp_coeff = 1e-4  # 1/K (typical for quartz)
        theta_signal *= (1 + kappa_temp_coeff * temp_drift)
        
        # Magnetic field gradient drift (slow random walk)
        field_drift = np.cumsum(np.random.normal(0, FIELD_GRADIENT_UNCERTAINTY / np.sqrt(len(time_array)), len(time_array)))
        theta_signal *= (1 + field_drift)
        
        return theta_signal
    
    def acquire_data(self, time_array: np.ndarray) -> np.ndarray:
        """
        Simulate the complete data acquisition process.
        
        Returns raw angular displacement data with all noise and errors.
        """
        # Generate signal
        theta_signal = self.generate_signal(time_array)
        
        # Apply systematic errors
        theta_signal = self.apply_systematic_errors(theta_signal, time_array)
        
        # Add noise sources
        theta_thermal = self.generate_thermal_noise(time_array)
        theta_seismic = self.generate_seismic_noise(time_array)
        theta_electronic = self.generate_electronic_noise(time_array)
        
        # Total measured signal
        theta_measured = theta_signal + theta_thermal + theta_seismic + theta_electronic
        
        # ADC quantization (16-bit, ±1 μrad range)
        adc_resolution = 2e-6 / (2**16)  # rad
        theta_measured = np.round(theta_measured / adc_resolution) * adc_resolution
        
        return theta_measured
    
    def analyze_data(self, raw_data: np.ndarray, time_array: np.ndarray) -> Tuple[float, float, float]:
        """
        Analyze the raw data to extract the signal.
        
        Returns:
            (measured_acceleration, uncertainty, snr)
        """
        # Apply low-pass filter to remove high-frequency noise
        # Cutoff at 0.1 Hz (well below resonance, above DC drift)
        dt = time_array[1] - time_array[0]
        nyquist = 0.5 / dt
        cutoff = 0.1 / nyquist
        
        b, a = signal.butter(4, cutoff, btype='low')
        filtered_data = signal.filtfilt(b, a, raw_data)
        
        # Extract signal during hold period (middle 60% of measurement)
        start_idx = int(0.2 * len(time_array))
        end_idx = int(0.8 * len(time_array))
        hold_data = filtered_data[start_idx:end_idx]
        
        # Measured angular displacement (mean during hold)
        theta_measured = np.mean(hold_data)
        theta_std = np.std(hold_data)
        
        # Convert to acceleration
        # a = (κ * θ) / (m * r)
        a_measured = (self.kappa * theta_measured) / (M_TEST * R_TEST)
        a_uncertainty = (self.kappa * theta_std) / (M_TEST * R_TEST)
        
        # Calculate SNR
        if a_uncertainty > 0:
            snr = abs(a_measured) / a_uncertainty
        else:
            snr = 0.0
        
        return a_measured, a_uncertainty, snr
    
    def run_experiment(self) -> ExperimentalRun:
        """Run a single experimental measurement."""
        # Generate time array
        time_array = np.arange(0, MEASUREMENT_DURATION, 1.0 / SAMPLING_RATE)
        
        # Acquire data
        raw_data = self.acquire_data(time_array)
        
        # Analyze data
        a_measured, a_uncertainty, snr = self.analyze_data(raw_data, time_array)
        
        # Calculate detection confidence (using SNR)
        # For Gaussian noise, confidence = erf(SNR / sqrt(2))
        detection_confidence = stats.norm.cdf(snr) - stats.norm.cdf(-snr)
        
        # Determine if this is a true/false positive
        # Detection threshold: SNR > 3 (3σ)
        detected = snr > 3.0
        
        is_true_positive = detected and self.true_signal_present
        is_false_positive = detected and not self.true_signal_present
        
        # Filter data for visualization
        dt = time_array[1] - time_array[0]
        nyquist = 0.5 / dt
        cutoff = 0.1 / nyquist
        b, a = signal.butter(4, cutoff, btype='low')
        filtered_data = signal.filtfilt(b, a, raw_data)
        
        return ExperimentalRun(
            measured_acceleration=a_measured,
            uncertainty=a_uncertainty,
            snr=snr,
            detection_confidence=detection_confidence,
            is_true_positive=is_true_positive,
            is_false_positive=is_false_positive,
            raw_data=raw_data,
            filtered_data=filtered_data
        )


def run_monte_carlo_simulation(n_runs: int = 1000) -> Dict:
    """
    Run Monte Carlo simulation with many experimental runs.
    
    Args:
        n_runs: Number of simulated experiments
        
    Returns:
        Dictionary with statistical results
    """
    print("=" * 80)
    print("END-TO-END EXPERIMENTAL SIMULATION (MONTE CARLO)")
    print("=" * 80)
    print(f"\nRunning {n_runs} simulated experiments...")
    print(f"Signal present: Testing detection capability")
    print(f"Target acceleration: {A_SIGNAL:.2e} m/s²")
    print(f"Measurement duration: {MEASUREMENT_DURATION / 3600:.1f} hours\n")
    
    # Run experiments with signal present
    print("Phase 1: Signal Present (True Positive Test)")
    print("-" * 80)
    
    experiment_signal = TorsionBalanceExperiment(true_signal_present=True)
    results_signal = []
    
    for i in range(n_runs):
        if (i + 1) % 100 == 0:
            print(f"  Run {i + 1}/{n_runs}...")
        result = experiment_signal.run_experiment()
        results_signal.append(result)
    
    # Run experiments with NO signal (null test)
    print(f"\nPhase 2: No Signal (False Positive Test)")
    print("-" * 80)
    
    experiment_null = TorsionBalanceExperiment(true_signal_present=False)
    results_null = []
    
    for i in range(n_runs):
        if (i + 1) % 100 == 0:
            print(f"  Run {i + 1}/{n_runs}...")
        result = experiment_null.run_experiment()
        results_null.append(result)
    
    # Analyze results
    print("\n" + "=" * 80)
    print("RESULTS")
    print("=" * 80)
    
    # Signal present statistics
    accelerations_signal = [r.measured_acceleration for r in results_signal]
    snrs_signal = [r.snr for r in results_signal]
    true_positives = sum(r.is_true_positive for r in results_signal)
    false_negatives = n_runs - true_positives
    
    print(f"\n📊 SIGNAL PRESENT (n={n_runs}):")
    print(f"  Mean Measured Acceleration: {np.mean(accelerations_signal):.2e} m/s²")
    print(f"  True Acceleration: {A_SIGNAL:.2e} m/s²")
    print(f"  Measurement Bias: {(np.mean(accelerations_signal) - A_SIGNAL) / A_SIGNAL * 100:.2f}%")
    print(f"  Mean SNR: {np.mean(snrs_signal):.2f}")
    print(f"  Median SNR: {np.median(snrs_signal):.2f}")
    print(f"  True Positive Rate: {true_positives / n_runs * 100:.2f}% ({true_positives}/{n_runs})")
    print(f"  False Negative Rate: {false_negatives / n_runs * 100:.2f}% ({false_negatives}/{n_runs})")
    
    # Null test statistics
    accelerations_null = [r.measured_acceleration for r in results_null]
    snrs_null = [r.snr for r in results_null]
    false_positives = sum(r.is_false_positive for r in results_null)
    
    print(f"\n📊 NO SIGNAL (NULL TEST, n={n_runs}):")
    print(f"  Mean Measured Acceleration: {np.mean(accelerations_null):.2e} m/s²")
    print(f"  Expected: ~0 m/s²")
    print(f"  Mean SNR: {np.mean(snrs_null):.2f}")
    print(f"  False Positive Rate: {false_positives / n_runs * 100:.2f}% ({false_positives}/{n_runs})")
    
    # Overall assessment
    print(f"\n" + "=" * 80)
    print("ASSESSMENT")
    print("=" * 80)
    
    detection_power = true_positives / n_runs
    false_positive_rate = false_positives / n_runs
    
    print(f"\n✓ Detection Power: {detection_power * 100:.2f}%")
    print(f"✓ False Positive Rate: {false_positive_rate * 100:.4f}%")
    print(f"✓ Statistical Significance: {1 - false_positive_rate:.6f} (1 - FPR)")
    
    if detection_power >= 0.95:
        print(f"\n✅ EXCELLENT: >95% detection probability")
    elif detection_power >= 0.80:
        print(f"\n✅ GOOD: >80% detection probability")
    else:
        print(f"\n⚠️  MARGINAL: <80% detection probability")
    
    if false_positive_rate <= 0.001:
        print(f"✅ EXCELLENT: <0.1% false positive rate (>3σ)")
    elif false_positive_rate <= 0.05:
        print(f"✅ GOOD: <5% false positive rate")
    else:
        print(f"⚠️  HIGH: >5% false positive rate")
    
    # Save results
    output = {
        'n_runs': n_runs,
        'signal_present': {
            'mean_acceleration': float(np.mean(accelerations_signal)),
            'std_acceleration': float(np.std(accelerations_signal)),
            'true_acceleration': A_SIGNAL,
            'mean_snr': float(np.mean(snrs_signal)),
            'median_snr': float(np.median(snrs_signal)),
            'true_positive_rate': float(detection_power),
            'false_negative_rate': float(1 - detection_power)
        },
        'null_test': {
            'mean_acceleration': float(np.mean(accelerations_null)),
            'std_acceleration': float(np.std(accelerations_null)),
            'mean_snr': float(np.mean(snrs_null)),
            'false_positive_rate': float(false_positive_rate)
        },
        'overall_assessment': {
            'detection_power': float(detection_power),
            'false_positive_rate': float(false_positive_rate),
            'statistical_significance': float(1 - false_positive_rate)
        }
    }
    
    output_path = '/home/shri/Desktop/Tortion Balance/simulations/end_to_end_simulation_results.json'
    with open(output_path, 'w') as f:
        json.dump(output, f, indent=2)
    
    print(f"\n✓ Results saved to: {output_path}")
    
    return output


if __name__ == "__main__":
    results = run_monte_carlo_simulation(n_runs=1000)
