#!/usr/bin/env python3
"""
System Resource Monitor and Safety Wrapper
==========================================

This module provides utilities to monitor system resources and prevent crashes
during intensive computational operations. It can be imported and used to wrap
existing code with safety checks.

Usage:
    from system_monitor import ResourceSafetyWrapper

    safety = ResourceSafetyWrapper()

    # Wrap dangerous operations
    with safety.safe_operation("intensive_calculation"):
        result = your_intensive_function()

    # Safe array allocation
    array = safety.safe_array_allocation((1000, 1000, 1000))

    # Safe nested loops
    results = safety.safe_nested_loop_execution(
        ranges=[100, 100, 100],
        operation=lambda i, j, k: expensive_calculation(i, j, k)
    )
"""

import psutil
import numpy as np
import time
import gc
import signal
import logging
import threading
import glob as globmod
from typing import Dict, List, Tuple, Optional, Any, Callable
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path

logger = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# Thermal state definitions
# ---------------------------------------------------------------------------

class ThermalState(Enum):
    """CPU thermal states with associated thresholds."""
    OK = "ok"       # < 75°C — full speed
    WARM = "warm"   # 75-85°C — throttle (halve workers)
    HOT = "hot"     # > 85°C — pause all heavy work


class SystemHealthMonitor:
    """
    Monitors CPU temperature via /sys/class/thermal/ and provides
    thermal-aware recommendations for parallel workloads.

    Thermal thresholds (°C):
        ok   : temp < 75   — no restrictions
        warm : 75 <= temp < 85 — halve parallel workers, insert cooldown pauses
        hot  : temp >= 85  — pause all heavy computation

    wait_until_cool() blocks until temp drops below 70°C (hysteresis band).
    """

    TEMP_OK_MAX = 75.0        # Below this: full speed
    TEMP_WARM_MAX = 85.0      # Below this (but >= OK_MAX): throttle
    # >= WARM_MAX: hot — pause

    COOL_TARGET = 70.0        # wait_until_cool() target
    COOL_POLL_INTERVAL = 2.0  # Seconds between polls while waiting to cool
    COOL_TIMEOUT = 120.0      # Max seconds to wait before giving up

    COOLDOWN_PAUSE = 2.0      # Seconds to pause between heavy steps when warm

    def __init__(self):
        self._thermal_zones = self._discover_thermal_zones()
        self._last_temp = None
        self._last_state = ThermalState.OK
        self._lock = threading.Lock()

        if self._thermal_zones:
            logger.info(
                f"SystemHealthMonitor: found {len(self._thermal_zones)} "
                f"thermal zone(s): {self._thermal_zones}"
            )
        else:
            logger.info(
                "SystemHealthMonitor: no thermal zones found in "
                "/sys/class/thermal/ — temperature monitoring disabled"
            )

    # ------------------------------------------------------------------
    # Thermal zone discovery
    # ------------------------------------------------------------------

    @staticmethod
    def _discover_thermal_zones() -> List[str]:
        """Find all readable thermal zone paths under /sys/class/thermal/."""
        zones = []
        base = "/sys/class/thermal"
        pattern = f"{base}/thermal_zone*/temp"
        for path in sorted(globmod.glob(pattern)):
            try:
                with open(path, "r") as f:
                    val = f.read().strip()
                    # Sanity check: should be an integer (millidegrees)
                    int(val)
                    zones.append(path)
            except (IOError, ValueError):
                continue
        return zones

    # ------------------------------------------------------------------
    # Temperature reading
    # ------------------------------------------------------------------

    def read_cpu_temp(self) -> Optional[float]:
        """
        Read current CPU temperature in °C.

        Reads all discovered thermal zones and returns the maximum value.
        Returns None if no thermal zones are available.

        The sysfs thermal interface reports temperatures in millidegrees
        Celsius (e.g., 72000 = 72.0°C).
        """
        if not self._thermal_zones:
            return None

        temps = []
        for zone_path in self._thermal_zones:
            try:
                with open(zone_path, "r") as f:
                    raw = f.read().strip()
                    temp_c = int(raw) / 1000.0
                    temps.append(temp_c)
            except (IOError, ValueError, OSError) as e:
                logger.debug(f"Failed to read {zone_path}: {e}")
                continue

        if not temps:
            return None

        max_temp = max(temps)
        with self._lock:
            self._last_temp = max_temp
        return max_temp

    # ------------------------------------------------------------------
    # Thermal state
    # ------------------------------------------------------------------

    @property
    def thermal_state(self) -> ThermalState:
        """
        Current thermal state based on CPU temperature.

        Returns ThermalState.OK if temperature cannot be read.
        """
        temp = self.read_cpu_temp()
        if temp is None:
            return ThermalState.OK

        if temp >= self.TEMP_WARM_MAX:
            state = ThermalState.HOT
        elif temp >= self.TEMP_OK_MAX:
            state = ThermalState.WARM
        else:
            state = ThermalState.OK

        with self._lock:
            self._last_state = state
        return state

    @property
    def last_temp(self) -> Optional[float]:
        """Last temperature reading (°C), or None if never read."""
        with self._lock:
            return self._last_temp

    # ------------------------------------------------------------------
    # Worker scaling recommendation
    # ------------------------------------------------------------------

    def recommended_worker_scale(self) -> float:
        """
        Multiplier for parallel worker count based on thermal state.

        Returns:
            1.0  for OK   — use all workers
            0.5  for WARM — halve workers
            0.0  for HOT  — stop heavy work entirely
        """
        state = self.thermal_state
        if state == ThermalState.HOT:
            return 0.0
        elif state == ThermalState.WARM:
            return 0.5
        return 1.0

    # ------------------------------------------------------------------
    # Blocking cooldown
    # ------------------------------------------------------------------

    def wait_until_cool(self, target_temp: Optional[float] = None,
                        timeout: Optional[float] = None) -> bool:
        """
        Block until CPU temperature drops below `target_temp`.

        Parameters
        ----------
        target_temp : float, optional
            Temperature threshold in °C. Default: COOL_TARGET (70°C).
        timeout : float, optional
            Maximum seconds to wait. Default: COOL_TIMEOUT (120s).

        Returns
        -------
        bool
            True if temperature dropped below target, False if timed out
            or temperature monitoring is unavailable.
        """
        target = target_temp or self.COOL_TARGET
        max_wait = timeout or self.COOL_TIMEOUT

        temp = self.read_cpu_temp()
        if temp is None:
            # No thermal data available — cannot block
            return True

        if temp < target:
            return True

        logger.warning(
            f"CPU temp {temp:.1f}°C exceeds cool target {target:.1f}°C — "
            f"waiting for cooldown (max {max_wait:.0f}s)..."
        )

        start = time.monotonic()
        while True:
            time.sleep(self.COOL_POLL_INTERVAL)
            temp = self.read_cpu_temp()

            if temp is None or temp < target:
                elapsed = time.monotonic() - start
                logger.info(
                    f"CPU cooled to {temp:.1f}°C after {elapsed:.1f}s"
                    if temp is not None else
                    f"Thermal zone unavailable after {elapsed:.1f}s, proceeding"
                )
                return True

            elapsed = time.monotonic() - start
            if elapsed >= max_wait:
                logger.warning(
                    f"Cooldown timeout after {elapsed:.1f}s — "
                    f"CPU still at {temp:.1f}°C (target {target:.1f}°C)"
                )
                return False

            logger.debug(
                f"Cooling: {temp:.1f}°C (target {target:.1f}°C, "
                f"{elapsed:.0f}/{max_wait:.0f}s)"
            )

    # ------------------------------------------------------------------
    # Heavy-step gating
    # ------------------------------------------------------------------

    def pre_heavy_step(self) -> ThermalState:
        """
        Call before a heavy computation step.

        - OK:   returns immediately
        - WARM: inserts a cooldown pause (COOLDOWN_PAUSE seconds)
        - HOT:  calls wait_until_cool() before proceeding

        Returns the thermal state observed.
        """
        state = self.thermal_state
        temp = self.last_temp

        if state == ThermalState.HOT:
            logger.warning(
                f"HOT ({temp:.1f}°C) — blocking until cool before heavy step"
            )
            self.wait_until_cool()

        elif state == ThermalState.WARM:
            logger.info(
                f"WARM ({temp:.1f}°C) — inserting {self.COOLDOWN_PAUSE}s "
                f"cooldown pause before heavy step"
            )
            time.sleep(self.COOLDOWN_PAUSE)

        return state

    # ------------------------------------------------------------------
    # Status
    # ------------------------------------------------------------------

    def status(self) -> Dict[str, Any]:
        """Return current thermal status as a dict."""
        temp = self.read_cpu_temp()
        state = self.thermal_state
        return {
            'cpu_temp_c': temp,
            'thermal_state': state.value,
            'worker_scale': self.recommended_worker_scale(),
            'thermal_zones_count': len(self._thermal_zones),
            'thresholds': {
                'ok_max': self.TEMP_OK_MAX,
                'warm_max': self.TEMP_WARM_MAX,
                'cool_target': self.COOL_TARGET,
            },
        }

@dataclass
class SafetyConfig:
    """Configuration for resource safety limits"""
    max_memory_gb: float = 3.0
    max_cpu_percent: float = 80.0
    max_execution_time: int = 180
    memory_check_frequency: int = 100
    cpu_yield_threshold: float = 70.0
    auto_gc_threshold: float = 2.5  # Trigger GC when memory exceeds this
    chunk_size_limit: int = 1000  # Maximum chunk size for operations
    enable_thermal_monitoring: bool = True  # Enable CPU temperature checks

class ResourceSafetyWrapper:
    """
    Wrapper class that adds resource monitoring and safety to existing code
    """

    def __init__(self, config: SafetyConfig = None):
        self.config = config or SafetyConfig()
        self.start_time = time.time()
        self.operation_count = 0
        self.interrupted = False
        self.process = psutil.Process()

        # Thermal monitoring
        if self.config.enable_thermal_monitoring:
            self.health_monitor = SystemHealthMonitor()
        else:
            self.health_monitor = None

        # Set up interrupt handling
        signal.signal(signal.SIGINT, self._handle_interrupt)

        logger.info(f"ResourceSafetyWrapper initialized with limits: "
                   f"Memory={self.config.max_memory_gb}GB, "
                   f"CPU={self.config.max_cpu_percent}%, "
                   f"Time={self.config.max_execution_time}s, "
                   f"Thermal={'enabled' if self.health_monitor else 'disabled'}")

    def _handle_interrupt(self, signum, frame):
        """Handle interrupt signals"""
        logger.warning("Interrupt signal received, setting stop flag")
        self.interrupted = True

    def get_current_resource_usage(self) -> Dict[str, Any]:
        """Get current system resource usage"""
        try:
            memory_info = self.process.memory_info()
            memory_gb = memory_info.rss / (1024**3)
            cpu_percent = self.process.cpu_percent()
            elapsed_time = time.time() - self.start_time

            return {
                'memory_gb': memory_gb,
                'cpu_percent': cpu_percent,
                'elapsed_time': elapsed_time,
                'memory_exceeded': memory_gb > self.config.max_memory_gb,
                'cpu_exceeded': cpu_percent > self.config.max_cpu_percent,
                'time_exceeded': elapsed_time > self.config.max_execution_time,
                'should_stop': self.interrupted
            }
        except Exception as e:
            logger.error(f"Error getting resource usage: {e}")
            return {'error': str(e)}

    def should_continue(self) -> bool:
        """Check if operation should continue based on resources and temperature."""
        self.operation_count += 1

        if self.operation_count % self.config.memory_check_frequency == 0:
            # Combined memory + thermal check
            pressure = self.check_memory_pressure()

            if pressure.get('should_stop', False):
                return False

            if pressure.get('memory_exceeded', False):
                logger.warning(
                    f"Memory limit exceeded: {pressure['memory_gb']:.2f}GB"
                )
                return False

            if pressure.get('time_exceeded', False):
                logger.warning(
                    f"Time limit exceeded: {pressure['elapsed_time']:.1f}s"
                )
                return False

            # Thermal gating
            thermal = pressure.get('thermal_state')
            if thermal == ThermalState.HOT:
                logger.warning(
                    f"CPU HOT ({pressure.get('cpu_temp_c', '?')}°C) — "
                    f"waiting for cooldown"
                )
                if self.health_monitor:
                    self.health_monitor.wait_until_cool()

            elif thermal == ThermalState.WARM:
                logger.info(
                    f"CPU WARM ({pressure.get('cpu_temp_c', '?')}°C) — "
                    f"inserting cooldown pause"
                )
                time.sleep(SystemHealthMonitor.COOLDOWN_PAUSE)

            # Auto garbage collection
            if pressure.get('memory_gb', 0) > self.config.auto_gc_threshold:
                logger.info("Triggering automatic garbage collection")
                gc.collect()

            # CPU yielding
            if pressure.get('cpu_percent', 0) > self.config.cpu_yield_threshold:
                time.sleep(0.01)

        return True

    def check_memory_pressure(self) -> Dict[str, Any]:
        """
        Combined memory AND temperature health check.

        Every call checks:
        - Process memory usage vs configured limits
        - CPU utilization
        - Elapsed execution time
        - CPU temperature and thermal state

        Returns
        -------
        dict
            Merged resource + thermal status with keys:
            memory_gb, cpu_percent, elapsed_time, memory_exceeded,
            cpu_exceeded, time_exceeded, should_stop,
            cpu_temp_c, thermal_state, worker_scale
        """
        usage = self.get_current_resource_usage()

        # Add thermal data
        if self.health_monitor:
            temp = self.health_monitor.read_cpu_temp()
            state = self.health_monitor.thermal_state
            scale = self.health_monitor.recommended_worker_scale()
            usage['cpu_temp_c'] = temp
            usage['thermal_state'] = state
            usage['worker_scale'] = scale
        else:
            usage['cpu_temp_c'] = None
            usage['thermal_state'] = ThermalState.OK
            usage['worker_scale'] = 1.0

        return usage

    @contextmanager
    def safe_operation(self, operation_name: str):
        """Context manager for safe operations with monitoring"""
        start_time = time.time()
        logger.info(f"Starting safe operation: {operation_name}")

        try:
            yield self
            elapsed = time.time() - start_time
            logger.info(f"Completed operation '{operation_name}' in {elapsed:.2f}s")

        except KeyboardInterrupt:
            logger.warning(f"Operation '{operation_name}' interrupted by user")
            self.interrupted = True
            raise

        except MemoryError as e:
            logger.error(f"Memory error in operation '{operation_name}': {e}")
            gc.collect()  # Emergency cleanup
            raise

        except Exception as e:
            logger.error(f"Error in operation '{operation_name}': {e}")
            raise

        finally:
            # Check final resource state
            usage = self.get_current_resource_usage()
            if usage.get('memory_gb', 0) > self.config.max_memory_gb * 0.8:
                gc.collect()

    def safe_array_allocation(self, shape: Tuple[int, ...], dtype=np.complex128,
                            fill_value=None) -> Optional[np.ndarray]:
        """Safely allocate numpy array with memory checking"""
        try:
            # Check memory requirement
            element_size = np.dtype(dtype).itemsize
            total_elements = np.prod(shape)
            required_gb = (total_elements * element_size) / (1024**3)

            current_usage = self.get_current_resource_usage()
            current_memory = current_usage.get('memory_gb', 0)
            available = self.config.max_memory_gb - current_memory

            if required_gb > available:
                logger.warning(f"Array allocation blocked: need {required_gb:.3f}GB, "
                             f"only {available:.3f}GB available")
                return None

            if required_gb > 1.0:  # Large allocation warning
                logger.info(f"Allocating large array: {shape} ({required_gb:.3f}GB)")

            # Allocate array
            if fill_value is not None:
                array = np.full(shape, fill_value, dtype=dtype)
            else:
                array = np.zeros(shape, dtype=dtype)

            return array

        except MemoryError as e:
            logger.error(f"Memory allocation failed for shape {shape}: {e}")
            gc.collect()
            return None
        except Exception as e:
            logger.error(f"Array allocation error: {e}")
            return None

    def safe_nested_loop_execution(self, ranges: List[int], operation: Callable,
                                 max_total_iterations: int = 1000000,
                                 chunk_size: int = None) -> List[Any]:
        """Execute nested loops safely with automatic chunking and monitoring"""

        total_iterations = np.prod(ranges)
        chunk_size = chunk_size or min(self.config.chunk_size_limit,
                                     max_total_iterations // 10 if total_iterations > 10 else 1)

        if total_iterations > max_total_iterations:
            logger.warning(f"Limiting iterations from {total_iterations} to {max_total_iterations}")
            # Scale down ranges proportionally
            scale = (max_total_iterations / total_iterations) ** (1.0 / len(ranges))
            ranges = [max(1, int(r * scale)) for r in ranges]
            total_iterations = np.prod(ranges)

        logger.info(f"Executing nested loops: {ranges} ({total_iterations} total iterations)")

        results = []
        processed = 0

        try:
            if len(ranges) == 1:
                for i in range(ranges[0]):
                    if not self.should_continue():
                        logger.warning(f"Loop stopped early at iteration {processed}")
                        break

                    try:
                        result = operation(i)
                        results.append(result)
                    except Exception as e:
                        logger.error(f"Error in operation at iteration {i}: {e}")
                        continue

                    processed += 1
                    if processed % chunk_size == 0:
                        time.sleep(0.001)  # Brief yield

            elif len(ranges) == 2:
                for i in range(ranges[0]):
                    for j in range(ranges[1]):
                        if not self.should_continue():
                            logger.warning(f"Loop stopped early at iteration {processed}")
                            return results

                        try:
                            result = operation(i, j)
                            results.append(result)
                        except Exception as e:
                            logger.error(f"Error in operation at iteration ({i},{j}): {e}")
                            continue

                        processed += 1
                        if processed % chunk_size == 0:
                            time.sleep(0.001)

            elif len(ranges) == 3:
                for i in range(ranges[0]):
                    for j in range(ranges[1]):
                        for k in range(ranges[2]):
                            if not self.should_continue():
                                logger.warning(f"Loop stopped early at iteration {processed}")
                                return results

                            try:
                                result = operation(i, j, k)
                                results.append(result)
                            except Exception as e:
                                logger.error(f"Error in operation at iteration ({i},{j},{k}): {e}")
                                continue

                            processed += 1
                            if processed % chunk_size == 0:
                                time.sleep(0.001)

            else:
                # Handle higher dimensions with flattened iteration
                indices = np.ndindex(tuple(ranges))
                for idx in indices:
                    if not self.should_continue():
                        logger.warning(f"Loop stopped early at iteration {processed}")
                        break

                    try:
                        result = operation(*idx)
                        results.append(result)
                    except Exception as e:
                        logger.error(f"Error in operation at iteration {idx}: {e}")
                        continue

                    processed += 1
                    if processed % chunk_size == 0:
                        time.sleep(0.001)

        except Exception as e:
            logger.error(f"Critical error in nested loop execution: {e}")
            raise

        logger.info(f"Nested loop completed: {processed}/{total_iterations} iterations processed")
        return results

    def safe_mathematical_operation(self, operation: Callable, *args,
                                  handle_overflow: bool = True, **kwargs) -> Any:
        """Execute mathematical operations with error handling"""
        try:
            result = operation(*args, **kwargs)

            # Check for problematic results
            if isinstance(result, np.ndarray):
                if np.any(np.isnan(result)):
                    logger.warning("Operation produced NaN values")
                    if handle_overflow:
                        result = np.nan_to_num(result, nan=0.0)

                if np.any(np.isinf(result)):
                    logger.warning("Operation produced infinite values")
                    if handle_overflow:
                        result = np.nan_to_num(result, posinf=1e10, neginf=-1e10)

            elif isinstance(result, (int, float, complex)):
                if np.isnan(result) or np.isinf(result):
                    logger.warning(f"Operation produced invalid result: {result}")
                    if handle_overflow:
                        result = 0.0

            return result

        except (OverflowError, ZeroDivisionError, FloatingPointError) as e:
            logger.warning(f"Mathematical error: {e}")
            if handle_overflow:
                return 0.0
            else:
                raise
        except Exception as e:
            logger.error(f"Unexpected error in mathematical operation: {e}")
            raise

    def get_safe_grid_size(self, desired_size: int, dtype=np.complex128,
                          dimensions: int = 3) -> int:
        """Calculate safe grid size based on available memory"""
        element_size = np.dtype(dtype).itemsize
        current_usage = self.get_current_resource_usage()
        available_gb = self.config.max_memory_gb - current_usage.get('memory_gb', 0)

        # Leave some margin
        usable_gb = available_gb * 0.7

        # Calculate maximum elements that fit in available memory
        max_elements = int((usable_gb * 1024**3) / element_size)

        # Calculate maximum grid size for given dimensions
        max_grid_size = int(max_elements ** (1.0 / dimensions))

        safe_size = min(desired_size, max_grid_size)

        if safe_size < desired_size:
            logger.warning(f"Reducing grid size from {desired_size} to {safe_size} "
                         f"(available memory: {available_gb:.2f}GB)")

        return max(8, safe_size)  # Minimum useful size

    def status_report(self) -> Dict[str, Any]:
        """Generate a comprehensive status report including thermal state."""
        pressure = self.check_memory_pressure()

        report = {
            'resource_usage': pressure,
            'operation_count': self.operation_count,
            'config': {
                'max_memory_gb': self.config.max_memory_gb,
                'max_cpu_percent': self.config.max_cpu_percent,
                'max_execution_time': self.config.max_execution_time,
                'thermal_monitoring': self.config.enable_thermal_monitoring,
            },
            'thermal': (
                self.health_monitor.status() if self.health_monitor else None
            ),
            'status': 'running' if self.should_continue() else 'stopped',
        }

        return report

# Convenience functions for quick integration
def safe_array(shape, dtype=np.complex128, max_memory_gb=2.0):
    """Quick function to safely allocate array"""
    config = SafetyConfig(max_memory_gb=max_memory_gb)
    wrapper = ResourceSafetyWrapper(config)
    return wrapper.safe_array_allocation(shape, dtype)

def safe_loops(ranges, operation, max_memory_gb=2.0, max_iterations=100000):
    """Quick function to safely execute nested loops"""
    config = SafetyConfig(max_memory_gb=max_memory_gb)
    wrapper = ResourceSafetyWrapper(config)
    return wrapper.safe_nested_loop_execution(ranges, operation, max_iterations)

def monitor_resources():
    """Quick function to get current resource usage including thermal state"""
    wrapper = ResourceSafetyWrapper()
    return wrapper.check_memory_pressure()

if __name__ == "__main__":
    # Demo of the safety wrapper + thermal monitoring
    print("=== System Resource Monitor Demo ===")

    # Thermal monitor standalone
    print("\n--- Thermal Monitor ---")
    health = SystemHealthMonitor()
    temp = health.read_cpu_temp()
    if temp is not None:
        print(f"CPU Temperature : {temp:.1f}°C")
        print(f"Thermal State   : {health.thermal_state.value}")
        print(f"Worker Scale    : {health.recommended_worker_scale():.1f}x")
    else:
        print("No thermal zones found — monitoring unavailable")
    print(f"Thermal Status  : {health.status()}")

    # Safety wrapper with thermal integration
    print("\n--- Resource Safety Wrapper ---")
    config = SafetyConfig(max_memory_gb=1.0, max_execution_time=30)
    safety = ResourceSafetyWrapper(config)

    pressure = safety.check_memory_pressure()
    print(f"Memory          : {pressure.get('memory_gb', 0):.2f} GB")
    print(f"CPU Temp        : {pressure.get('cpu_temp_c', 'N/A')}")
    print(f"Thermal State   : {pressure.get('thermal_state', 'N/A')}")
    print(f"Worker Scale    : {pressure.get('worker_scale', 1.0):.1f}x")

    # Test safe array allocation
    print("\nTesting safe array allocation...")
    array = safety.safe_array_allocation((100, 100, 100))
    if array is not None:
        print(f"Successfully allocated array: {array.shape}")
    else:
        print("Array allocation blocked by safety limits")

    # Test safe loop execution
    print("\nTesting safe loop execution...")
    def test_operation(i, j):
        return i * j

    results = safety.safe_nested_loop_execution([10, 10], test_operation)
    print(f"Loop completed with {len(results)} results")

    print("\nFinal status:", safety.status_report())