#!/usr/bin/env python3
"""
Enhanced Test Runner for Comprehensive Notebook Testing

Features:
- Phased execution (regular notebooks first, advanced last)
- Automatic result organization within runs/ folder
- Resume capability with checkpoint system
- Intelligent error analysis and classification
- Real-time terminal output analysis
- Enhanced CSV logging with detailed metrics
- Robust error handling and test isolation
"""

import subprocess
import datetime
import csv
import os
import json
import time
import sys
import platform
import glob
import shutil
import re
from pathlib import Path
from typing import List, Dict, Tuple, Optional, Set
import pandas as pd
import argparse
from dataclasses import dataclass, asdict
from enum import Enum

# --- CONFIGURATION ---

# Enhanced notebook batch definitions with phased execution
NOTEBOOK_BATCHES = [
    {
        "name": "OBP Notebooks",
        "category": "OBP",
        "library": "OBP",
        "phase": "regular",
        "notebooks": [
            "notebooks/synthetic.ipynb",
            "notebooks/obd.ipynb",
            "notebooks/multiclass.ipynb",
        ]
    },
    {
        "name": "Basic and RTB Notebooks",
        "category": "Regular",
        "library": "Scope-RL",
        "phase": "regular",
        "notebooks": [
            "notebooks/basic/basic_synthetic_continuous_basic.ipynb",
            "notebooks/basic/basic_synthetic_discrete_basic.ipynb",
            "notebooks/rec/rec_synthetic_discrete_basic.ipynb",
            "notebooks/rtb/rtb_synthetic_continuous_basic.ipynb",
            "notebooks/rtb/rtb_synthetic_discrete_basic.ipynb",
        ]
    },
    {
        "name": "Zoo Notebooks",
        "category": "Regular",
        "library": "Scope-RL",
        "phase": "regular",
        "notebooks": [
            "notebooks/basic/basic_synthetic_continuous_zoo.ipynb",
            "notebooks/basic/basic_synthetic_discrete_zoo.ipynb",
            "notebooks/rec/rec_synthetic_discrete_zoo.ipynb",
            "notebooks/rtb/rtb_synthetic_continuous_zoo.ipynb",
            "notebooks/rtb/rtb_synthetic_discrete_zoo.ipynb",
        ]
    },
    {
        "name": "Advanced Notebooks - RUN LAST",
        "category": "Advanced",
        "library": "Scope-RL",
        "phase": "advanced",
        "run_last": True,
        "notebooks": [
            "notebooks/basic/basic_synthetic_continuous_advanced.ipynb",
            "notebooks/basic/basic_synthetic_discrete_advanced.ipynb",
            "notebooks/rec/rec_synthetic_discrete_advanced.ipynb",
            "notebooks/rtb/rtb_synthetic_continuous_advanced.ipynb",
            "notebooks/rtb/rtb_synthetic_discrete_advanced.ipynb",
        ]
    }
]

# Test configuration
FRAMEWORKS = [None, "autogen", "crewai", "two_agent"]
OPTIONS = ["whole_code", "manual_patch", "agent_applies"]
ITERATIONS = [1, 2, 3]
MODEL = "gemini-1.5-flash"
TIMEOUT_SECONDS = 7200

# Output configuration
OUTPUT_DIR = "tests"
RUNS_DIR = "runs"
ORGANIZED_DIR = os.path.join(RUNS_DIR, "organized_results")
CHECKPOINT_FILE = os.path.join(OUTPUT_DIR, "testing_checkpoint.json")
CSV_FILENAME = f"comprehensive_test_log_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.csv"
OUTPUT_CSV_PATH = os.path.join(OUTPUT_DIR, CSV_FILENAME)

# --- DATA STRUCTURES ---

class TestStatus(Enum):
    """Test execution status enumeration."""
    PENDING = "pending"
    RUNNING = "running"
    COMPLETE = "complete"
    COMPLETE_WITH_WARNINGS = "complete_with_warnings"
    FAILED = "failed"
    TIMEOUT = "timeout"
    ERROR = "error"
    CRITICAL_ERROR = "critical_error"
    SKIPPED = "skipped"

class ErrorCategory(Enum):
    """Error classification categories."""
    SUCCESS = "success"
    IMPORT_ERROR = "import_error"
    API_RATE_LIMIT = "api_rate_limit"
    METRIC_PARSING = "metric_parsing"
    FRAMEWORK_ERROR = "framework_error"
    TIMEOUT_ERROR = "timeout_error"
    ENVIRONMENT_ERROR = "environment_error"
    UNKNOWN_ERROR = "unknown_error"
    NO_PLOTS_SAVED = "no_plots_saved"
    SCOPE_RL_SPECIFIC = "scope_rl_specific"

@dataclass
class TestCombination:
    """Represents a single test combination."""
    notebook: str
    framework: Optional[str]
    option: str
    iteration: int
    batch_name: str
    category: str
    library: str
    phase: str
    
    def to_key(self) -> str:
        """Generate unique key for this test combination."""
        fw = self.framework or "default"
        return f"{self.notebook}#{fw}#{self.option}#{self.iteration}"
    
    def to_dict(self) -> Dict:
        """Convert to dictionary for JSON serialization."""
        return asdict(self)

@dataclass
class TestResult:
    """Comprehensive test result data structure."""
    test_combination: TestCombination
    status: TestStatus
    error_category: ErrorCategory
    start_time: datetime.datetime
    end_time: Optional[datetime.datetime]
    duration: Optional[float]
    run_id: str
    original_run_path: str
    organized_path: str
    command: str
    return_code: Optional[int]
    plots_saved_marker_found: bool
    error_message: str
    terminal_output_summary: str
    csv_validation_passed: bool
    performance_metrics: Dict[str, str]
    
    def to_csv_row(self) -> Dict[str, str]:
        """Convert to CSV row format."""
        duration_str = str(datetime.timedelta(seconds=int(self.duration))) if self.duration else "00:00:00"
        start_str = self.start_time.strftime("%Y-%m-%d %H:%M:%S")
        end_str = self.end_time.strftime("%Y-%m-%d %H:%M:%S") if self.end_time else start_str
        
        return {
            "notebook_name": self.test_combination.notebook,
            "framework": self.test_combination.framework or "default",
            "option": self.test_combination.option,
            "iteration": str(self.test_combination.iteration),
            "batch_name": self.test_combination.batch_name,
            "category": self.test_combination.category,
            "library": self.test_combination.library,
            "phase": self.test_combination.phase,
            "command": self.command,
            "timestamp": self.start_time.strftime("%H:%M"),
            "time_started": start_str,
            "time_ended": end_str,
            "time_taken": duration_str,
            "status": self.status.value,
            "error_category": self.error_category.value,
            "return_code": str(self.return_code) if self.return_code is not None else "",
            "plots_saved_marker_found": str(self.plots_saved_marker_found),
            "csv_validation_passed": str(self.csv_validation_passed),
            "runs": self.run_id,
            "original_run_path": self.original_run_path,
            "organized_path": self.organized_path,
            "error_message": self.error_message[:500],  # Truncate for CSV
            "terminal_output_summary": self.terminal_output_summary[:300],
            **self.performance_metrics
        }

class TestingSession:
    """Manages testing session state and resume capability."""
    
    def __init__(self, checkpoint_file: str):
        self.checkpoint_file = checkpoint_file
        self.completed_tests: Set[str] = set()
        self.failed_tests: Set[str] = set()
        self.session_start_time = datetime.datetime.now()
        self.total_tests = 0
        self.current_test_index = 0
        
    def load_checkpoint(self) -> bool:
        """Load previous session state."""
        if not os.path.exists(self.checkpoint_file):
            return False
            
        try:
            with open(self.checkpoint_file, 'r') as f:
                data = json.load(f)
                self.completed_tests = set(data.get('completed_tests', []))
                self.failed_tests = set(data.get('failed_tests', []))
                self.current_test_index = data.get('current_test_index', 0)
                return True
        except Exception as e:
            print(f"[WARN] Failed to load checkpoint: {e}")
            return False
    
    def save_checkpoint(self):
        """Save current session state."""
        try:
            os.makedirs(os.path.dirname(self.checkpoint_file), exist_ok=True)
            with open(self.checkpoint_file, 'w') as f:
                json.dump({
                    'completed_tests': list(self.completed_tests),
                    'failed_tests': list(self.failed_tests),
                    'current_test_index': self.current_test_index,
                    'session_start_time': self.session_start_time.isoformat(),
                    'last_updated': datetime.datetime.now().isoformat()
                }, f, indent=2)
        except Exception as e:
            print(f"[WARN] Failed to save checkpoint: {e}")
    
    def mark_test_completed(self, test_key: str):
        """Mark a test as completed."""
        self.completed_tests.add(test_key)
        self.failed_tests.discard(test_key)  # Remove from failed if it was there
        self.current_test_index += 1
        self.save_checkpoint()
    
    def mark_test_failed(self, test_key: str):
        """Mark a test as failed."""
        self.failed_tests.add(test_key)
        self.current_test_index += 1
        self.save_checkpoint()
    
    def is_test_completed(self, test_key: str) -> bool:
        """Check if test was already completed."""
        return test_key in self.completed_tests
    
    def is_test_failed(self, test_key: str) -> bool:
        """Check if test was previously failed."""
        return test_key in self.failed_tests
    
    def get_progress_summary(self) -> Dict[str, int]:
        """Get current progress statistics."""
        return {
            'total_tests': self.total_tests,
            'completed': len(self.completed_tests),
            'failed': len(self.failed_tests),
            'remaining': self.total_tests - len(self.completed_tests) - len(self.failed_tests),
            'current_index': self.current_test_index
        }

# --- UTILITY FUNCTIONS ---

def get_notebook_category_info(notebook_path: str) -> Tuple[str, str, str]:
    """
    Extract category information from notebook path.
    Returns: (main_category, sub_category, notebook_name)
    """
    path_parts = Path(notebook_path).parts
    
    if 'notebooks' not in path_parts:
        return "unknown", "unknown", Path(notebook_path).stem
    
    notebook_idx = path_parts.index('notebooks')
    notebook_name = Path(notebook_path).stem
    
    if notebook_idx + 1 >= len(path_parts):
        # Direct notebook in notebooks/ folder (OBP case)
        return "OBP", notebook_name, notebook_name
    
    # Scope-RL case with subdirectories
    sub_dir = path_parts[notebook_idx + 1]
    
    if "advanced" in notebook_name:
        return "Scope", f"{sub_dir}_advanced", notebook_name
    else:
        return "Scope", sub_dir, notebook_name

def get_organized_path(test_combination: TestCombination, run_id: str) -> str:
    """
    Generate organized path within runs folder.
    Format: runs/organized_results/[Library]/[Category]/[NotebookName]/[framework_option_iteration_timestamp]/
    """
    main_category, sub_category, notebook_name = get_notebook_category_info(test_combination.notebook)
    framework_str = test_combination.framework or "default"
    
    organized_path = os.path.join(
        ORGANIZED_DIR,
        test_combination.library,
        main_category,
        sub_category,
        notebook_name,
        f"{framework_str}_{test_combination.option}_{test_combination.iteration}_{run_id}"
    )
    
    return organized_path

def analyze_terminal_output(output_lines: List[str]) -> Tuple[ErrorCategory, str, Dict[str, str]]:
    """
    Analyze terminal output for error patterns and performance metrics.
    Returns: (error_category, summary, performance_metrics)
    """
    full_output = "\n".join(output_lines)
    full_output_lower = full_output.lower()
    
    performance_metrics = {}
    error_category = ErrorCategory.SUCCESS
    summary_parts = []
    
    # Error pattern detection
    error_patterns = {
        ErrorCategory.IMPORT_ERROR: [
            r"importerror", r"modulenotfounderror", r"no module named",
            r"cannot import name", r"import.*failed"
        ],
        ErrorCategory.API_RATE_LIMIT: [
            r"rate limit", r"quota.*exceeded", r"too many requests",
            r"api.*limit", r"429", r"rate.*exceeded"
        ],
        ErrorCategory.METRIC_PARSING: [
            r"could not determine best iteration", r"parsing.*error",
            r"invalid.*metric", r"metric.*not found", r"failed.*parse.*metric"
        ],
        ErrorCategory.FRAMEWORK_ERROR: [
            r"autogen.*error", r"crewai.*error", r"framework.*failed",
            r"agent.*error", r"conversation.*failed"
        ],
        ErrorCategory.SCOPE_RL_SPECIFIC: [
            r"scope.*rl.*error", r"ope.*error", r"baseline.*error",
            r"estimator.*error", r"policy.*error"
        ],
        ErrorCategory.ENVIRONMENT_ERROR: [
            r"environment.*error", r"conda.*error", r"pip.*error",
            r"virtual.*environment", r"python.*path"
        ]
    }
    
    # Check for error patterns
    for category, patterns in error_patterns.items():
        for pattern in patterns:
            if re.search(pattern, full_output_lower):
                error_category = category
                summary_parts.append(f"Detected {category.value}")
                break
        if error_category != ErrorCategory.SUCCESS:
            break
    
    # Extract performance metrics where possible
    metric_patterns = {
        'execution_time': r"execution.*time[:\s]+([0-9:]+)",
        'iterations_completed': r"completed.*iteration[s]?[:\s]+(\d+)",
        'plots_generated': r"plots?.*generated[:\s]+(\d+)",
        'memory_usage': r"memory.*usage[:\s]+([0-9.]+[gmk]?b)",
    }
    
    for metric_name, pattern in metric_patterns.items():
        match = re.search(pattern, full_output_lower)
        if match:
            performance_metrics[metric_name] = match.group(1)
    
    # Generate summary
    if not summary_parts:
        if "all plots saved in directory:" in full_output_lower:
            summary_parts.append("Success marker found")
        else:
            summary_parts.append("No clear success/error indicators")
    
    # Add key information to summary
    if "warning" in full_output_lower:
        summary_parts.append("Warnings detected")
    
    summary = "; ".join(summary_parts[:3])  # Limit summary length
    
    return error_category, summary, performance_metrics

def verify_results_csv(run_dir: str, expected_iterations: int) -> Tuple[bool, str]:
    """Enhanced CSV verification with detailed analysis."""
    if not os.path.exists(run_dir):
        return False, "Run directory does not exist"
    
    results_csv_files = glob.glob(os.path.join(run_dir, "*_results.csv"))
    
    if not results_csv_files:
        return False, "No results CSV file found"
    
    try:
        df = pd.read_csv(results_csv_files[0])
        if df.empty:
            return False, "Results CSV is empty"
        
        # Check iteration coverage
        unique_iterations = df['iteration'].unique() if 'iteration' in df.columns else []
        if len(unique_iterations) < expected_iterations:
            return False, f"Expected {expected_iterations} iterations, found {len(unique_iterations)}"
        
        # Check for meaningful data
        if len(df) < expected_iterations:
            return False, f"Insufficient data rows: {len(df)} < {expected_iterations}"
        
        return True, f"CSV validation passed: {len(df)} rows, {len(unique_iterations)} iterations"
        
    except Exception as e:
        return False, f"CSV validation error: {str(e)}"

def organize_test_results(test_combination: TestCombination, original_run_path: str, run_id: str) -> str:
    """
    Move run directory to organized structure and return new path.
    """
    if not os.path.exists(original_run_path):
        print(f"[WARN] Original run path does not exist: {original_run_path}")
        return ""
    
    organized_path = get_organized_path(test_combination, run_id)
    
    try:
        # Create organized directory structure
        os.makedirs(os.path.dirname(organized_path), exist_ok=True)
        
        # Move the entire run directory
        shutil.move(original_run_path, organized_path)
        
        print(f"[INFO] Organized results: {organized_path}")
        return organized_path
        
    except Exception as e:
        print(f"[ERROR] Failed to organize results: {e}")
        return original_run_path

def detect_execution_issues(
    test_combination: TestCombination,
    run_id: str,
    run_dir: str,
    return_code: int,
    plots_saved_marker_found: bool,
    full_output: str,
    expected_iterations: int
) -> Tuple[TestStatus, ErrorCategory, str]:
    """
    Enhanced execution issue detection with comprehensive analysis.
    """
    # Check return code first
    if return_code != 0:
        error_category, summary, _ = analyze_terminal_output(full_output.splitlines())
        return TestStatus.FAILED, error_category, f"Non-zero exit code ({return_code}): {summary}"
    
    # Check for success marker
    if not plots_saved_marker_found:
        error_category, summary, _ = analyze_terminal_output(full_output.splitlines())
        if error_category == ErrorCategory.SUCCESS:
            error_category = ErrorCategory.NO_PLOTS_SAVED
        return TestStatus.FAILED, error_category, f"Success marker not found: {summary}"
    
    # Verify CSV results
    csv_ok, csv_message = verify_results_csv(run_dir, expected_iterations)
    if not csv_ok:
        return TestStatus.COMPLETE_WITH_WARNINGS, ErrorCategory.SUCCESS, f"Plots saved but CSV issues: {csv_message}"
    
    # Check for warnings
    if "warning" in full_output.lower():
        return TestStatus.COMPLETE_WITH_WARNINGS, ErrorCategory.SUCCESS, "Completed with warnings"
    
    return TestStatus.COMPLETE, ErrorCategory.SUCCESS, "Test completed successfully"

# --- MAIN EXECUTION FUNCTIONS ---

def run_single_test_with_resilience(
    test_combination: TestCombination,
    dry_run: bool = False,
    silent: bool = False
) -> TestResult:
    """
    Run individual test with complete error isolation and detailed result tracking.
    """
    start_time = datetime.datetime.now()
    cmd_parts = ["python", "start.py", test_combination.notebook, MODEL]
    
    if test_combination.framework:
        cmd_parts.extend(["-fw", test_combination.framework])
    # Do not pass option flag for two_agent; it runs whole code implicitly
    if test_combination.framework != "two_agent":
        cmd_parts.extend(["-opt", test_combination.option])
    # Always pass iteration; two_agent must be invoked with -n 1
    cmd_parts.extend(["-n", str(test_combination.iteration)])
    cmd_str = " ".join(cmd_parts)
    
    print(f"\n{'='*80}")
    print(f"[{start_time.strftime('%H:%M:%S')}] Running: {cmd_str}")
    print(f"Category: {test_combination.library} > {test_combination.category}")
    
    # Initialize result structure
    result = TestResult(
        test_combination=test_combination,
        status=TestStatus.PENDING,
        error_category=ErrorCategory.SUCCESS,
        start_time=start_time,
        end_time=None,
        duration=None,
        run_id="",
        original_run_path="",
        organized_path="",
        command=cmd_str,
        return_code=None,
        plots_saved_marker_found=False,
        error_message="",
        terminal_output_summary="",
        csv_validation_passed=False,
        performance_metrics={}
    )
    
    if dry_run:
        print("--- DRY RUN MODE ---")
        result.status = TestStatus.COMPLETE
        result.end_time = start_time
        result.duration = 1.0
        result.run_id = "dry_run"
        result.error_message = "Dry run mode"
        return result
    
    # Track directories to detect new run directory
    dirs_before = set(os.listdir(RUNS_DIR)) if os.path.exists(RUNS_DIR) else set()
    
    try:
        result.status = TestStatus.RUNNING
        
        # Execute test with real-time output streaming
        output_lines = []
        plots_saved_marker_found = False
        
        process = subprocess.Popen(
            cmd_parts,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            text=True,
            encoding='utf-8',
            errors='replace',
            cwd=os.getcwd()
        )
        
        # Stream output and detect markers
        while True:
            line = process.stdout.readline()
            if not line:
                break
            
            if not silent:
                sys.stdout.write(line)
                sys.stdout.flush()
            
            output_lines.append(line)
            if "All plots saved in directory:" in line:
                plots_saved_marker_found = True
        
        process.wait()
        return_code = process.returncode
        
        # Calculate duration
        end_time = datetime.datetime.now()
        duration = (end_time - start_time).total_seconds()
        
        # Detect run directory
        time.sleep(1)  # Allow filesystem to settle
        dirs_after = set(os.listdir(RUNS_DIR)) if os.path.exists(RUNS_DIR) else set()
        new_dirs = dirs_after - dirs_before
        
        if len(new_dirs) == 1:
            run_id = new_dirs.pop()
        elif len(new_dirs) == 0:
            # Fallback: find by timestamp
            prefix = start_time.strftime("%Y%m%d_%H%M%S")
            matching_dirs = [d for d in dirs_after if d.startswith(prefix[:10])]  # Match date part
            run_id = matching_dirs[0] if matching_dirs else "unknown"
        else:
            # Multiple directories, take most recent
            run_id = max(new_dirs) if new_dirs else "unknown"
        
        original_run_path = os.path.join(RUNS_DIR, run_id) if run_id != "unknown" else ""
        
        # Analyze execution results
        full_output = "".join(output_lines)
        status, error_category, error_message = detect_execution_issues(
            test_combination, run_id, original_run_path, return_code,
            plots_saved_marker_found, full_output, test_combination.iteration
        )
        
        # Analyze terminal output for additional insights
        _, terminal_summary, performance_metrics = analyze_terminal_output(output_lines)
        
        # Verify CSV
        csv_ok, _ = verify_results_csv(original_run_path, test_combination.iteration)
        
        # Organize results
        organized_path = ""
        if original_run_path and os.path.exists(original_run_path):
            organized_path = organize_test_results(test_combination, original_run_path, run_id)
        
        # Update result
        result.status = status
        result.error_category = error_category
        result.end_time = end_time
        result.duration = duration
        result.run_id = run_id
        result.original_run_path = original_run_path
        result.organized_path = organized_path
        result.return_code = return_code
        result.plots_saved_marker_found = plots_saved_marker_found
        result.error_message = error_message
        result.terminal_output_summary = terminal_summary
        result.csv_validation_passed = csv_ok
        result.performance_metrics = performance_metrics
        
    except subprocess.TimeoutExpired:
        result.status = TestStatus.TIMEOUT
        result.error_category = ErrorCategory.TIMEOUT_ERROR
        result.end_time = start_time + datetime.timedelta(seconds=TIMEOUT_SECONDS)
        result.duration = TIMEOUT_SECONDS
        result.error_message = f"Test exceeded {TIMEOUT_SECONDS//60} minute timeout"
        
    except Exception as e:
        result.status = TestStatus.CRITICAL_ERROR
        result.error_category = ErrorCategory.UNKNOWN_ERROR
        result.end_time = datetime.datetime.now()
        result.duration = (result.end_time - start_time).total_seconds()
        result.error_message = f"Critical error: {str(e)[:200]}"
    
    # Status summary
    duration_str = str(datetime.timedelta(seconds=int(result.duration))) if result.duration else "00:00:00"
    print(f"[INFO] Test completed in {duration_str} - Status: {result.status.value}")
    if result.status != TestStatus.COMPLETE:
        print(f"[INFO] Error: {result.error_message[:100]}...")
    
    return result

def generate_test_combinations() -> List[TestCombination]:
    """Generate all test combinations with proper phasing."""
    combinations = []
    
    for batch in NOTEBOOK_BATCHES:
        existing_notebooks = [nb for nb in batch["notebooks"] if os.path.exists(nb)]
        
        for notebook in existing_notebooks:
            for framework in FRAMEWORKS:
                # two_agent runs once (-n 1) and as whole_code only
                options_list = ["whole_code"] if framework == "two_agent" else OPTIONS
                iterations_list = [1] if framework == "two_agent" else ITERATIONS
                for option in options_list:
                    for iteration in iterations_list:
                        combination = TestCombination(
                            notebook=notebook,
                            framework=framework,
                            option=option,
                            iteration=iteration,
                            batch_name=batch["name"],
                            category=batch["category"],
                            library=batch["library"],
                            phase=batch["phase"]
                        )
                        combinations.append(combination)
    
    # Sort by phase (regular first, advanced last)
    combinations.sort(key=lambda x: (x.phase, x.notebook, x.framework or "", x.option, x.iteration))
    
    return combinations

def write_csv_header(filename: str):
    """Write enhanced CSV header."""
    fieldnames = [
        "notebook_name", "framework", "option", "iteration", "batch_name", 
        "category", "library", "phase", "command", "timestamp", "time_started", 
        "time_ended", "time_taken", "status", "error_category", "return_code",
        "plots_saved_marker_found", "csv_validation_passed", "runs", 
        "original_run_path", "organized_path", "error_message", 
        "terminal_output_summary", "execution_time", "iterations_completed",
        "plots_generated", "memory_usage"
    ]
    
    with open(filename, 'w', newline='', encoding='utf-8') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()

def write_test_result_to_csv(result: TestResult, filename: str):
    """Write test result to CSV."""
    file_exists = os.path.exists(filename)
    
    if not file_exists:
        write_csv_header(filename)
    
    with open(filename, 'a', newline='', encoding='utf-8') as csvfile:
        fieldnames = [
            "notebook_name", "framework", "option", "iteration", "batch_name",
            "category", "library", "phase", "command", "timestamp", "time_started",
            "time_ended", "time_taken", "status", "error_category", "return_code",
            "plots_saved_marker_found", "csv_validation_passed", "runs",
            "original_run_path", "organized_path", "error_message",
            "terminal_output_summary", "execution_time", "iterations_completed",
            "plots_generated", "memory_usage"
        ]
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writerow(result.to_csv_row())

def handle_resume_options(session: TestingSession) -> str:
    """
    Handle resume options when previous session is detected.
    Returns: 'resume', 'retry_failed', 'skip_failed', 'fresh'
    """
    progress = session.get_progress_summary()
    
    print(f"\n{'='*60}")
    print("PREVIOUS TESTING SESSION DETECTED")
    print(f"{'='*60}")
    print(f"Progress Summary:")
    print(f"  - Total tests: {progress['total_tests']}")
    print(f"  - Completed: {progress['completed']}")
    print(f"  - Failed: {progress['failed']}")
    print(f"  - Remaining: {progress['remaining']}")
    print(f"  - Current index: {progress['current_index']}")
    
    if progress['failed'] > 0:
        print(f"\nFailed tests detected: {progress['failed']} tests")
    
    print(f"\nResume Options:")
    print("  1. Resume from where left off (skip completed tests)")
    print("  2. Retry all failed tests + continue with remaining")
    print("  3. Skip failed tests and continue with remaining only")
    print("  4. Start fresh (ignore previous progress)")
    
    while True:
        try:
            choice = input("\nSelect option (1-4): ").strip()
            if choice == "1":
                return "resume"
            elif choice == "2":
                return "retry_failed"
            elif choice == "3":
                return "skip_failed"
            elif choice == "4":
                return "fresh"
            else:
                print("Invalid choice. Please enter 1, 2, 3, or 4.")
        except KeyboardInterrupt:
            print("\nExiting...")
            sys.exit(0)

def execute_comprehensive_testing(
    dry_run: bool = False,
    auto_yes: bool = False,
    silent: bool = False,
    resume_mode: Optional[str] = None
) -> None:
    """
    Execute comprehensive testing with phased approach and resume capability.
    """
    print("Starting Enhanced Comprehensive Notebook Testing Suite")
    print(f"Success Detection: 'All plots saved in directory:' marker + CSV validation")
    print(f"Auto-Organization: Results organized in {ORGANIZED_DIR}")
    
    if dry_run:
        print("[WARN] DRY RUN MODE ENABLED")
    if silent:
        print("[INFO] SILENT MODE: Suppressing real-time test output")
    
    print("="*80)
    
    # Initialize session and load checkpoint
    session = TestingSession(CHECKPOINT_FILE)
    checkpoint_loaded = session.load_checkpoint()
    
    # Generate all test combinations
    all_combinations = generate_test_combinations()
    session.total_tests = len(all_combinations)
    
    print(f"[INFO] Total tests configured: {session.total_tests}")
    
    # Handle resume logic
    if checkpoint_loaded and not auto_yes and resume_mode is None:
        resume_mode = handle_resume_options(session)
    elif checkpoint_loaded and resume_mode is None:
        resume_mode = "resume"  # Default for auto mode
    elif not checkpoint_loaded:
        resume_mode = "fresh"
    
    print(f"[INFO] Resume mode: {resume_mode}")
    
    # Create output directories
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    os.makedirs(ORGANIZED_DIR, exist_ok=True)
    
    # Initialize CSV
    if not os.path.exists(OUTPUT_CSV_PATH):
        write_csv_header(OUTPUT_CSV_PATH)
    
    # Filter tests based on resume mode
    tests_to_run = []
    
    for combination in all_combinations:
        test_key = combination.to_key()
        
        if resume_mode == "fresh":
            tests_to_run.append(combination)
        elif resume_mode == "resume":
            if not session.is_test_completed(test_key):
                tests_to_run.append(combination)
        elif resume_mode == "retry_failed":
            if not session.is_test_completed(test_key):
                tests_to_run.append(combination)
            # Reset failed status for retry
            session.failed_tests.discard(test_key)
        elif resume_mode == "skip_failed":
            if not session.is_test_completed(test_key) and not session.is_test_failed(test_key):
                tests_to_run.append(combination)
    
    print(f"[INFO] Tests to execute: {len(tests_to_run)}")
    
    if resume_mode == "fresh":
        # Reset session for fresh start
        session.completed_tests.clear()
        session.failed_tests.clear()
        session.current_test_index = 0
    
    # Group tests by phase for execution
    regular_tests = [t for t in tests_to_run if t.phase == "regular"]
    advanced_tests = [t for t in tests_to_run if t.phase == "advanced"]
    
    print(f"[INFO] Regular phase tests: {len(regular_tests)}")
    print(f"[INFO] Advanced phase tests: {len(advanced_tests)}")
    
    # Execution statistics
    start_time = datetime.datetime.now()
    stats = {
        'total_run': 0,
        'completed': 0,
        'completed_with_warnings': 0,
        'failed': 0,
        'errors': 0,
        'timeouts': 0
    }
    
    def run_phase(phase_tests: List[TestCombination], phase_name: str):
        """Execute a phase of tests."""
        if not phase_tests:
            print(f"\n[INFO] No tests to run in {phase_name} phase")
            return
        
        print(f"\n{'#'*80}")
        print(f"### EXECUTING {phase_name.upper()} PHASE ###")
        print(f"### {len(phase_tests)} tests ###")
        print(f"{'#'*80}")
        
        if not auto_yes and not dry_run:
            response = input(f"Continue with {phase_name} testing phase? (y/n): ").strip().lower()
            if response != 'y':
                print(f"[CANCEL] {phase_name} phase cancelled by user")
                return
        
        for i, combination in enumerate(phase_tests):
            test_key = combination.to_key()
            stats['total_run'] += 1
            
            print(f"\n[{stats['total_run']}/{len(tests_to_run)}] {phase_name.upper()} PHASE TEST:")
            print(f"   - Notebook:  {combination.notebook}")
            print(f"   - Framework: {combination.framework or 'default'}")
            print(f"   - Option:    {combination.option}")
            print(f"   - Iteration: {combination.iteration}")
            
            # Execute test with full error isolation
            try:
                result = run_single_test_with_resilience(combination, dry_run, silent)
                
                # Update statistics
                if result.status == TestStatus.COMPLETE:
                    stats['completed'] += 1
                    session.mark_test_completed(test_key)
                elif result.status == TestStatus.COMPLETE_WITH_WARNINGS:
                    stats['completed_with_warnings'] += 1
                    session.mark_test_completed(test_key)
                elif result.status == TestStatus.TIMEOUT:
                    stats['timeouts'] += 1
                    session.mark_test_failed(test_key)
                elif result.status in [TestStatus.ERROR, TestStatus.CRITICAL_ERROR]:
                    stats['errors'] += 1
                    session.mark_test_failed(test_key)
                else:
                    stats['failed'] += 1
                    session.mark_test_failed(test_key)
                
                # Write result to CSV
                if not dry_run:
                    write_test_result_to_csv(result, OUTPUT_CSV_PATH)
                
                # Progress update
                success_rate = ((stats['completed'] + stats['completed_with_warnings']) / stats['total_run'] * 100)
                print(f"   - [PROGRESS] {stats['total_run']}/{len(tests_to_run)} | Success Rate: {success_rate:.1f}%")
                
            except Exception as e:
                print(f"   - [CRITICAL ERROR] Test runner exception: {e}")
                stats['errors'] += 1
                session.mark_test_failed(test_key)
                continue
    
    # Execute phases
    run_phase(regular_tests, "regular")
    run_phase(advanced_tests, "advanced")
    
    # Final summary
    end_time = datetime.datetime.now()
    total_duration = end_time - start_time
    
    print(f"\n{'='*80}")
    print("COMPREHENSIVE TESTING COMPLETE")
    print(f"{'='*80}")
    print(f"Execution Summary:")
    print(f"  - Total tests run: {stats['total_run']}")
    print(f"  - Successful: {stats['completed']}")
    print(f"  - Successful with warnings: {stats['completed_with_warnings']}")
    print(f"  - Failed: {stats['failed']}")
    print(f"  - Errors: {stats['errors']}")
    print(f"  - Timeouts: {stats['timeouts']}")
    
    total_success = stats['completed'] + stats['completed_with_warnings']
    if stats['total_run'] > 0:
        success_rate = (total_success / stats['total_run'] * 100)
        print(f"  - Success Rate: {success_rate:.1f}%")
    
    print(f"  - Total Duration: {str(total_duration).split('.')[0]}")
    print(f"  - Results CSV: {OUTPUT_CSV_PATH}")
    print(f"  - Organized Results: {ORGANIZED_DIR}")
    print(f"  - Checkpoint: {CHECKPOINT_FILE}")
    print("="*80)

def validate_environment() -> bool:
    """Enhanced environment validation."""
    print("[INFO] Validating testing environment...")
    
    required_files = ["start.py", "notebooks"]
    for item in required_files:
        if not os.path.exists(item):
            print(f"[ERROR] Required item not found: {item}")
            return False
    
    # Check notebook availability
    all_notebooks = [nb for batch in NOTEBOOK_BATCHES for nb in batch["notebooks"]]
    existing_notebooks = [nb for nb in all_notebooks if os.path.exists(nb)]
    missing_count = len(all_notebooks) - len(existing_notebooks)
    
    print(f"[INFO] Notebooks: {len(existing_notebooks)}/{len(all_notebooks)} available")
    if missing_count > 0:
        print(f"[WARN] {missing_count} configured notebooks not found")
    
    # Check runs directory
    if not os.path.exists(RUNS_DIR):
        print(f"[WARN] Runs directory not found: {RUNS_DIR}")
        print("[INFO] Will be created automatically")
    
    print("[SUCCESS] Environment validation complete")
    return len(existing_notebooks) > 0

def main():
    """Main function with enhanced argument parsing."""
    parser = argparse.ArgumentParser(
        description="Enhanced Comprehensive Notebook Testing Suite",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Resume Options:
  --resume-mode fresh       Start fresh (ignore previous progress)  
  --resume-mode resume      Resume from where left off
  --resume-mode retry       Retry failed tests + continue
  --resume-mode skip        Skip failed tests and continue

Examples:
  python enhanced_test_runner.py                    # Interactive mode
  python enhanced_test_runner.py --dry-run          # Test configuration
  python enhanced_test_runner.py -y --silent        # Automated silent mode
  python enhanced_test_runner.py --resume-mode retry # Retry failed tests
        """
    )
    
    parser.add_argument("--dry-run", action="store_true", 
                       help="Print test plan without executing")
    parser.add_argument("-y", "--yes", action="store_true", 
                       help="Auto-answer yes to all prompts")
    parser.add_argument("--silent", action="store_true", 
                       help="Suppress real-time test output")
    parser.add_argument("--resume-mode", choices=["fresh", "resume", "retry", "skip"],
                       help="Specify resume behavior (skips interactive prompt)")
    
    args = parser.parse_args()
    
    if not validate_environment():
        print("[ERROR] Environment validation failed. Exiting.")
        sys.exit(1)
    
    try:
        execute_comprehensive_testing(
            dry_run=args.dry_run,
            auto_yes=args.yes,
            silent=args.silent,
            resume_mode=args.resume_mode
        )
    except KeyboardInterrupt:
        print("\n[INFO] Testing interrupted by user")
        print("[INFO] Progress saved to checkpoint file")
        sys.exit(0)
    except Exception as e:
        print(f"\n[CRITICAL ERROR] Testing suite failure: {e}")
        sys.exit(1)

if __name__ == "__main__":
    main() 