#!/usr/bin/env python3
"""
System Resource Monitor and Safety Wrapper
==========================================

This module provides utilities to monitor system resources and prevent crashes
during intensive computational operations. It can be imported and used to wrap
existing code with safety checks.

Usage:
    from system_monitor import ResourceSafetyWrapper

    safety = ResourceSafetyWrapper()

    # Wrap dangerous operations
    with safety.safe_operation("intensive_calculation"):
        result = your_intensive_function()

    # Safe array allocation
    array = safety.safe_array_allocation((1000, 1000, 1000))

    # Safe nested loops
    results = safety.safe_nested_loop_execution(
        ranges=[100, 100, 100],
        operation=lambda i, j, k: expensive_calculation(i, j, k)
    )
"""

import psutil
import numpy as np
import time
import gc
import signal
import logging
import threading
from typing import Dict, List, Tuple, Optional, Any, Callable
from contextlib import contextmanager
from dataclasses import dataclass

logger = logging.getLogger(__name__)

@dataclass
class SafetyConfig:
    """Configuration for resource safety limits"""
    max_memory_gb: float = 3.0
    max_cpu_percent: float = 80.0
    max_execution_time: int = 180
    memory_check_frequency: int = 100
    cpu_yield_threshold: float = 70.0
    auto_gc_threshold: float = 2.5  # Trigger GC when memory exceeds this
    chunk_size_limit: int = 1000  # Maximum chunk size for operations

class ResourceSafetyWrapper:
    """
    Wrapper class that adds resource monitoring and safety to existing code
    """

    def __init__(self, config: SafetyConfig = None):
        self.config = config or SafetyConfig()
        self.start_time = time.time()
        self.operation_count = 0
        self.interrupted = False
        self.process = psutil.Process()

        # Set up interrupt handling
        signal.signal(signal.SIGINT, self._handle_interrupt)

        logger.info(f"ResourceSafetyWrapper initialized with limits: "
                   f"Memory={self.config.max_memory_gb}GB, "
                   f"CPU={self.config.max_cpu_percent}%, "
                   f"Time={self.config.max_execution_time}s")

    def _handle_interrupt(self, signum, frame):
        """Handle interrupt signals"""
        logger.warning("Interrupt signal received, setting stop flag")
        self.interrupted = True

    def get_current_resource_usage(self) -> Dict[str, Any]:
        """Get current system resource usage"""
        try:
            memory_info = self.process.memory_info()
            memory_gb = memory_info.rss / (1024**3)
            cpu_percent = self.process.cpu_percent()
            elapsed_time = time.time() - self.start_time

            return {
                'memory_gb': memory_gb,
                'cpu_percent': cpu_percent,
                'elapsed_time': elapsed_time,
                'memory_exceeded': memory_gb > self.config.max_memory_gb,
                'cpu_exceeded': cpu_percent > self.config.max_cpu_percent,
                'time_exceeded': elapsed_time > self.config.max_execution_time,
                'should_stop': self.interrupted
            }
        except Exception as e:
            logger.error(f"Error getting resource usage: {e}")
            return {'error': str(e)}

    def should_continue(self) -> bool:
        """Check if operation should continue based on resources"""
        self.operation_count += 1

        if self.operation_count % self.config.memory_check_frequency == 0:
            usage = self.get_current_resource_usage()

            if usage.get('should_stop', False):
                logger.warning("Stop requested")
                return False

            if usage.get('memory_exceeded', False):
                logger.warning(f"Memory limit exceeded: {usage['memory_gb']:.2f}GB")
                return False

            if usage.get('time_exceeded', False):
                logger.warning(f"Time limit exceeded: {usage['elapsed_time']:.1f}s")
                return False

            # Auto garbage collection
            if usage.get('memory_gb', 0) > self.config.auto_gc_threshold:
                logger.info("Triggering automatic garbage collection")
                gc.collect()

            # CPU yielding
            if usage.get('cpu_percent', 0) > self.config.cpu_yield_threshold:
                time.sleep(0.01)

        return True

    @contextmanager
    def safe_operation(self, operation_name: str):
        """Context manager for safe operations with monitoring"""
        start_time = time.time()
        logger.info(f"Starting safe operation: {operation_name}")

        try:
            yield self
            elapsed = time.time() - start_time
            logger.info(f"Completed operation '{operation_name}' in {elapsed:.2f}s")

        except KeyboardInterrupt:
            logger.warning(f"Operation '{operation_name}' interrupted by user")
            self.interrupted = True
            raise

        except MemoryError as e:
            logger.error(f"Memory error in operation '{operation_name}': {e}")
            gc.collect()  # Emergency cleanup
            raise

        except Exception as e:
            logger.error(f"Error in operation '{operation_name}': {e}")
            raise

        finally:
            # Check final resource state
            usage = self.get_current_resource_usage()
            if usage.get('memory_gb', 0) > self.config.max_memory_gb * 0.8:
                gc.collect()

    def safe_array_allocation(self, shape: Tuple[int, ...], dtype=np.complex128,
                            fill_value=None) -> Optional[np.ndarray]:
        """Safely allocate numpy array with memory checking"""
        try:
            # Check memory requirement
            element_size = np.dtype(dtype).itemsize
            total_elements = np.prod(shape)
            required_gb = (total_elements * element_size) / (1024**3)

            current_usage = self.get_current_resource_usage()
            current_memory = current_usage.get('memory_gb', 0)
            available = self.config.max_memory_gb - current_memory

            if required_gb > available:
                logger.warning(f"Array allocation blocked: need {required_gb:.3f}GB, "
                             f"only {available:.3f}GB available")
                return None

            if required_gb > 1.0:  # Large allocation warning
                logger.info(f"Allocating large array: {shape} ({required_gb:.3f}GB)")

            # Allocate array
            if fill_value is not None:
                array = np.full(shape, fill_value, dtype=dtype)
            else:
                array = np.zeros(shape, dtype=dtype)

            return array

        except MemoryError as e:
            logger.error(f"Memory allocation failed for shape {shape}: {e}")
            gc.collect()
            return None
        except Exception as e:
            logger.error(f"Array allocation error: {e}")
            return None

    def safe_nested_loop_execution(self, ranges: List[int], operation: Callable,
                                 max_total_iterations: int = 1000000,
                                 chunk_size: int = None) -> List[Any]:
        """Execute nested loops safely with automatic chunking and monitoring"""

        total_iterations = np.prod(ranges)
        chunk_size = chunk_size or min(self.config.chunk_size_limit,
                                     max_total_iterations // 10 if total_iterations > 10 else 1)

        if total_iterations > max_total_iterations:
            logger.warning(f"Limiting iterations from {total_iterations} to {max_total_iterations}")
            # Scale down ranges proportionally
            scale = (max_total_iterations / total_iterations) ** (1.0 / len(ranges))
            ranges = [max(1, int(r * scale)) for r in ranges]
            total_iterations = np.prod(ranges)

        logger.info(f"Executing nested loops: {ranges} ({total_iterations} total iterations)")

        results = []
        processed = 0

        try:
            if len(ranges) == 1:
                for i in range(ranges[0]):
                    if not self.should_continue():
                        logger.warning(f"Loop stopped early at iteration {processed}")
                        break

                    try:
                        result = operation(i)
                        results.append(result)
                    except Exception as e:
                        logger.error(f"Error in operation at iteration {i}: {e}")
                        continue

                    processed += 1
                    if processed % chunk_size == 0:
                        time.sleep(0.001)  # Brief yield

            elif len(ranges) == 2:
                for i in range(ranges[0]):
                    for j in range(ranges[1]):
                        if not self.should_continue():
                            logger.warning(f"Loop stopped early at iteration {processed}")
                            return results

                        try:
                            result = operation(i, j)
                            results.append(result)
                        except Exception as e:
                            logger.error(f"Error in operation at iteration ({i},{j}): {e}")
                            continue

                        processed += 1
                        if processed % chunk_size == 0:
                            time.sleep(0.001)

            elif len(ranges) == 3:
                for i in range(ranges[0]):
                    for j in range(ranges[1]):
                        for k in range(ranges[2]):
                            if not self.should_continue():
                                logger.warning(f"Loop stopped early at iteration {processed}")
                                return results

                            try:
                                result = operation(i, j, k)
                                results.append(result)
                            except Exception as e:
                                logger.error(f"Error in operation at iteration ({i},{j},{k}): {e}")
                                continue

                            processed += 1
                            if processed % chunk_size == 0:
                                time.sleep(0.001)

            else:
                # Handle higher dimensions with flattened iteration
                indices = np.ndindex(tuple(ranges))
                for idx in indices:
                    if not self.should_continue():
                        logger.warning(f"Loop stopped early at iteration {processed}")
                        break

                    try:
                        result = operation(*idx)
                        results.append(result)
                    except Exception as e:
                        logger.error(f"Error in operation at iteration {idx}: {e}")
                        continue

                    processed += 1
                    if processed % chunk_size == 0:
                        time.sleep(0.001)

        except Exception as e:
            logger.error(f"Critical error in nested loop execution: {e}")
            raise

        logger.info(f"Nested loop completed: {processed}/{total_iterations} iterations processed")
        return results

    def safe_mathematical_operation(self, operation: Callable, *args,
                                  handle_overflow: bool = True, **kwargs) -> Any:
        """Execute mathematical operations with error handling"""
        try:
            result = operation(*args, **kwargs)

            # Check for problematic results
            if isinstance(result, np.ndarray):
                if np.any(np.isnan(result)):
                    logger.warning("Operation produced NaN values")
                    if handle_overflow:
                        result = np.nan_to_num(result, nan=0.0)

                if np.any(np.isinf(result)):
                    logger.warning("Operation produced infinite values")
                    if handle_overflow:
                        result = np.nan_to_num(result, posinf=1e10, neginf=-1e10)

            elif isinstance(result, (int, float, complex)):
                if np.isnan(result) or np.isinf(result):
                    logger.warning(f"Operation produced invalid result: {result}")
                    if handle_overflow:
                        result = 0.0

            return result

        except (OverflowError, ZeroDivisionError, FloatingPointError) as e:
            logger.warning(f"Mathematical error: {e}")
            if handle_overflow:
                return 0.0
            else:
                raise
        except Exception as e:
            logger.error(f"Unexpected error in mathematical operation: {e}")
            raise

    def get_safe_grid_size(self, desired_size: int, dtype=np.complex128,
                          dimensions: int = 3) -> int:
        """Calculate safe grid size based on available memory"""
        element_size = np.dtype(dtype).itemsize
        current_usage = self.get_current_resource_usage()
        available_gb = self.config.max_memory_gb - current_usage.get('memory_gb', 0)

        # Leave some margin
        usable_gb = available_gb * 0.7

        # Calculate maximum elements that fit in available memory
        max_elements = int((usable_gb * 1024**3) / element_size)

        # Calculate maximum grid size for given dimensions
        max_grid_size = int(max_elements ** (1.0 / dimensions))

        safe_size = min(desired_size, max_grid_size)

        if safe_size < desired_size:
            logger.warning(f"Reducing grid size from {desired_size} to {safe_size} "
                         f"(available memory: {available_gb:.2f}GB)")

        return max(8, safe_size)  # Minimum useful size

    def status_report(self) -> Dict[str, Any]:
        """Generate a comprehensive status report"""
        usage = self.get_current_resource_usage()

        report = {
            'resource_usage': usage,
            'operation_count': self.operation_count,
            'config': {
                'max_memory_gb': self.config.max_memory_gb,
                'max_cpu_percent': self.config.max_cpu_percent,
                'max_execution_time': self.config.max_execution_time
            },
            'status': 'running' if self.should_continue() else 'stopped'
        }

        return report

# Convenience functions for quick integration
def safe_array(shape, dtype=np.complex128, max_memory_gb=2.0):
    """Quick function to safely allocate array"""
    config = SafetyConfig(max_memory_gb=max_memory_gb)
    wrapper = ResourceSafetyWrapper(config)
    return wrapper.safe_array_allocation(shape, dtype)

def safe_loops(ranges, operation, max_memory_gb=2.0, max_iterations=100000):
    """Quick function to safely execute nested loops"""
    config = SafetyConfig(max_memory_gb=max_memory_gb)
    wrapper = ResourceSafetyWrapper(config)
    return wrapper.safe_nested_loop_execution(ranges, operation, max_iterations)

def monitor_resources():
    """Quick function to get current resource usage"""
    wrapper = ResourceSafetyWrapper()
    return wrapper.get_current_resource_usage()

if __name__ == "__main__":
    # Demo of the safety wrapper
    print("=== System Resource Monitor Demo ===")

    config = SafetyConfig(max_memory_gb=1.0, max_execution_time=30)
    safety = ResourceSafetyWrapper(config)

    print("Current resources:", safety.get_current_resource_usage())

    # Test safe array allocation
    print("\nTesting safe array allocation...")
    array = safety.safe_array_allocation((100, 100, 100))
    if array is not None:
        print(f"Successfully allocated array: {array.shape}")
    else:
        print("Array allocation blocked by safety limits")

    # Test safe loop execution
    print("\nTesting safe loop execution...")
    def test_operation(i, j):
        return i * j

    results = safety.safe_nested_loop_execution([10, 10], test_operation)
    print(f"Loop completed with {len(results)} results")

    print("\nFinal status:", safety.status_report())