#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
BASIC FCE USAGE EXAMPLES

This file demonstrates basic usage patterns for the Fractal Correction Engine.
These examples show how to integrate FCE into your own research projects.
"""

import numpy as np
import matplotlib.pyplot as plt
import sys
import os

# Add src directory to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))

from quick_demo_fce import QuickFractalCorrectionEngine


def example_1_basic_prediction():
    """Example 1: Basic orbital prediction with FCE."""

    print("Example 1: Basic FCE Prediction")
    print("-" * 40)

    # Create a simple elliptical orbit
    t = np.linspace(0, 2*np.pi, 100)
    a, b = 3.0, 2.0  # Semi-major and semi-minor axes
    x = a * np.cos(t)
    y = b * np.sin(t)

    # Use first half for observation, predict second half
    split = len(t) // 2
    x_obs = x[:split]
    y_obs = y[:split]
    x_true = x[split:]
    y_true = y[split:]

    # Initialize FCE and make prediction
    fce = QuickFractalCorrectionEngine()
    x_pred, y_pred, signature = fce.predict_with_fce(x_obs, y_obs, len(x_true))

    # Calculate error
    rmse = np.sqrt(np.mean((x_pred - x_true)**2 + (y_pred - y_true)**2))

    print(f"Prediction RMSE: {rmse:.4f}")
    print(f"Harmonics found: {len(signature['harmonics'])}")
    print(f"Stability index: {signature['stability']:.6f}")

    # Create visualization
    plt.figure(figsize=(10, 8))
    plt.plot(x_obs, y_obs, 'b-', linewidth=2, label='Observed')
    plt.plot(x_pred, y_pred, 'r--', linewidth=2, label='FCE Prediction')
    plt.plot(x_true, y_true, 'k:', linewidth=1.5, label='Ground Truth')
    plt.xlabel('X Position')
    plt.ylabel('Y Position')
    plt.title('Example 1: Elliptical Orbit Prediction')
    plt.legend()
    plt.axis('equal')
    plt.grid(True, alpha=0.3)
    plt.show()

    return rmse, signature


def example_2_curvature_analysis():
    """Example 2: Direct curvature analysis."""

    print("\nExample 2: Curvature Analysis")
    print("-" * 40)

    # Create a figure-8 trajectory
    t = np.linspace(0, 2*np.pi, 200)
    x = np.sin(t)
    y = np.sin(2*t)

    # Initialize FCE and analyze curvature
    fce = QuickFractalCorrectionEngine()
    curvature = fce.calculate_curvature(x, y)
    signature = fce.extract_pi_signature(curvature)

    print(f"Average curvature: {np.mean(curvature):.4f}")
    print(f"Maximum curvature: {np.max(curvature):.4f}")
    print(f"Curvature variation: {np.std(curvature):.4f}")

    # Plot trajectory and curvature
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

    # Trajectory plot
    ax1.plot(x, y, 'b-', linewidth=2)
    ax1.set_xlabel('X Position')
    ax1.set_ylabel('Y Position')
    ax1.set_title('Figure-8 Trajectory')
    ax1.axis('equal')
    ax1.grid(True, alpha=0.3)

    # Curvature plot
    ax2.plot(curvature, 'r-', linewidth=2)
    ax2.set_xlabel('Path Point')
    ax2.set_ylabel('Curvature')
    ax2.set_title('Curvature Profile')
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    return curvature, signature


def example_3_parameter_sensitivity():
    """Example 3: Parameter sensitivity analysis."""

    print("\nExample 3: Parameter Sensitivity")
    print("-" * 40)

    # Generate test data
    t = np.linspace(0, 4*np.pi, 300)
    radius = 2.0 + 0.3 * np.sin(3*t)  # Perturbed circular orbit
    x = radius * np.cos(t)
    y = radius * np.sin(t)

    # Test different prediction lengths
    split = 200
    x_obs = x[:split]
    y_obs = y[:split]

    prediction_lengths = [20, 50, 100, 200]
    results = []

    fce = QuickFractalCorrectionEngine()

    for length in prediction_lengths:
        x_true = x[split:split+length]
        y_true = y[split:split+length]

        x_pred, y_pred, signature = fce.predict_with_fce(x_obs, y_obs, length)

        # Adjust arrays to same length for comparison
        min_len = min(len(x_pred), len(x_true))
        rmse = np.sqrt(np.mean((x_pred[:min_len] - x_true[:min_len])**2 +
                              (y_pred[:min_len] - y_true[:min_len])**2))

        results.append((length, rmse, signature['stability']))
        print(f"Prediction length {length:3d}: RMSE = {rmse:.4f}, Stability = {signature['stability']:.6f}")

    # Plot sensitivity analysis
    lengths = [r[0] for r in results]
    rmse_values = [r[1] for r in results]
    stability_values = [r[2] for r in results]

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

    ax1.plot(lengths, rmse_values, 'ro-', linewidth=2, markersize=8)
    ax1.set_xlabel('Prediction Length')
    ax1.set_ylabel('RMSE')
    ax1.set_title('Prediction Accuracy vs Length')
    ax1.grid(True, alpha=0.3)

    ax2.plot(lengths, stability_values, 'bo-', linewidth=2, markersize=8)
    ax2.set_xlabel('Prediction Length')
    ax2.set_ylabel('Stability Index')
    ax2.set_title('Pattern Stability vs Length')
    ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    return results


def example_4_comparison_study():
    """Example 4: Comprehensive method comparison."""

    print("\nExample 4: Method Comparison")
    print("-" * 40)

    # Generate various orbit types
    orbit_types = {
        'circular': lambda t: (2*np.cos(t), 2*np.sin(t)),
        'elliptical': lambda t: (3*np.cos(t), 1.5*np.sin(t)),
        'perturbed': lambda t: ((2+0.2*np.sin(5*t))*np.cos(t), (2+0.2*np.sin(5*t))*np.sin(t))
    }

    fce = QuickFractalCorrectionEngine()
    results = {}

    for orbit_name, orbit_func in orbit_types.items():
        print(f"\nTesting {orbit_name} orbit:")

        # Generate trajectory
        t = np.linspace(0, 2*np.pi, 200)
        x, y = orbit_func(t)

        # Split data
        split = 100
        x_obs, y_obs = x[:split], y[:split]
        x_true, y_true = x[split:], y[split:]

        # FCE prediction
        x_fce, y_fce, signature = fce.predict_with_fce(x_obs, y_obs, len(x_true))
        fce_rmse = np.sqrt(np.mean((x_fce - x_true)**2 + (y_fce - y_true)**2))

        # Traditional prediction
        x_trad, y_trad = fce.predict_traditional(x_obs, y_obs, len(x_true))
        trad_rmse = np.sqrt(np.mean((x_trad - x_true)**2 + (y_trad - y_true)**2))

        # Store results
        improvement = (trad_rmse - fce_rmse) / trad_rmse * 100 if trad_rmse > 0 else 0
        results[orbit_name] = {
            'fce_rmse': fce_rmse,
            'traditional_rmse': trad_rmse,
            'improvement': improvement,
            'stability': signature['stability']
        }

        print(f"  FCE RMSE: {fce_rmse:.4f}")
        print(f"  Traditional RMSE: {trad_rmse:.4f}")
        print(f"  Improvement: {improvement:.1f}%")

    # Summary plot
    orbit_names = list(results.keys())
    fce_rmse_list = [results[name]['fce_rmse'] for name in orbit_names]
    trad_rmse_list = [results[name]['traditional_rmse'] for name in orbit_names]

    x_pos = np.arange(len(orbit_names))
    width = 0.35

    plt.figure(figsize=(10, 6))
    plt.bar(x_pos - width/2, fce_rmse_list, width, label='FCE', color='red', alpha=0.7)
    plt.bar(x_pos + width/2, trad_rmse_list, width, label='Traditional', color='orange', alpha=0.7)

    plt.xlabel('Orbit Type')
    plt.ylabel('RMSE')
    plt.title('FCE vs Traditional Methods - Comparison Study')
    plt.xticks(x_pos, orbit_names)
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.yscale('log')  # Log scale due to potentially large differences
    plt.show()

    return results


def main():
    """Run all examples."""

    print("FRACTAL CORRECTION ENGINE - USAGE EXAMPLES")
    print("=" * 50)

    try:
        # Run examples
        rmse1, sig1 = example_1_basic_prediction()
        curv2, sig2 = example_2_curvature_analysis()
        results3 = example_3_parameter_sensitivity()
        results4 = example_4_comparison_study()

        print("\n" + "=" * 50)
        print("ALL EXAMPLES COMPLETED SUCCESSFULLY")
        print("=" * 50)

        # Summary
        print("\nSUMMARY:")
        print(f"Example 1 - Basic prediction RMSE: {rmse1:.4f}")
        print(f"Example 2 - Average curvature: {np.mean(curv2):.4f}")
        print(f"Example 3 - Best prediction length: {min(results3, key=lambda x: x[1])[0]} points")
        print(f"Example 4 - Average FCE improvement: {np.mean([r['improvement'] for r in results4.values()]):.1f}%")

    except Exception as e:
        print(f"Error running examples: {e}")
        print("Please ensure all dependencies are installed:")
        print("  pip install -r requirements.txt")


if __name__ == "__main__":
    main()