import subprocess
import datetime
import csv
import os
import time
import sys
import platform
import glob
from pathlib import Path
from typing import List, Dict, Tuple, Optional
import pandas as pd
import argparse

# --- CONFIGURATION ---

# Notebook batch definitions for phased execution
NOTEBOOK_BATCHES = [
    {
        "name": "OBP Notebooks",
        "notebooks": [
            "notebooks/synthetic.ipynb",
            "notebooks/obd.ipynb",
            "notebooks/multiclass.ipynb",
        ]
    },
    {
        "name": "Basic and RTB Notebooks",
        "notebooks": [
            "notebooks/basic/basic_synthetic_continuous_basic.ipynb",
            "notebooks/basic/basic_synthetic_discrete_basic.ipynb",
            "notebooks/rec/rec_synthetic_discrete_basic.ipynb",
            "notebooks/rtb/rtb_synthetic_continuous_basic.ipynb",
            "notebooks/rtb/rtb_synthetic_discrete_basic.ipynb",
        ]
    },
    {
        "name": "Zoo Notebooks",
        "notebooks": [
            "notebooks/basic/basic_synthetic_continuous_zoo.ipynb",
            "notebooks/basic/basic_synthetic_discrete_zoo.ipynb",
            "notebooks/rec/rec_synthetic_discrete_zoo.ipynb",
            "notebooks/rtb/rtb_synthetic_continuous_zoo.ipynb",
            "notebooks/rtb/rtb_synthetic_discrete_zoo.ipynb",
        ]
    },
    {
        "name": "Advanced Notebooks",
        "notebooks": [
            "notebooks/basic/basic_synthetic_continuous_advanced.ipynb",
            "notebooks/basic/basic_synthetic_discrete_advanced.ipynb",
            "notebooks/rec/rec_synthetic_discrete_advanced.ipynb",
            "notebooks/rtb/rtb_synthetic_continuous_advanced.ipynb",
            "notebooks/rtb/rtb_synthetic_discrete_advanced.ipynb",
        ]
    }
]

# Frameworks - None means default (no -fw flag)
FRAMEWORKS = [None, "autogen", "crewai"]
OPTIONS = ["whole_code", "manual_patch", "agent_applies"]
ITERATIONS = [1, 2, 3]
MODEL = "gemini-1.5-flash"
TIMEOUT_SECONDS = 7200

# Output Configuration
OUTPUT_DIR = "tests"
CSV_FILENAME = f"comprehensive_test_log_{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}.csv"
OUTPUT_CSV_PATH = os.path.join(OUTPUT_DIR, CSV_FILENAME)


# --- ENVIRONMENT VALIDATION ---

def validate_environment() -> bool:
    """Validate that the testing environment is properly set up."""
    print("[INFO] Validating testing environment...")
    
    if not os.path.exists("start.py"):
        print("[ERROR] start.py not found. Please run this script from the 'version-1-stable' root directory.")
        return False
    
    if not os.path.exists("notebooks"):
        print("[ERROR] 'notebooks' directory not found.")
        return False
    
    # Check for notebook existence
    all_configured_notebooks = [nb for batch in NOTEBOOK_BATCHES for nb in batch["notebooks"]]
    missing_notebooks = [nb for nb in all_configured_notebooks if not os.path.exists(nb)]
    
    print(f"[INFO] Found {len(all_configured_notebooks) - len(missing_notebooks)}/{len(all_configured_notebooks)} configured notebooks.")
    if missing_notebooks:
        print(f"[WARN] Missing {len(missing_notebooks)} notebooks:")
        for nb in missing_notebooks[:5]:
            print(f"       - {nb}")
        if len(missing_notebooks) > 5:
            print(f"       ... and {len(missing_notebooks) - 5} more")
    
    print("[SUCCESS] Environment validation complete.")
    return True

# --- SUCCESS DETECTION FUNCTIONS ---

def verify_results_csv(run_dir: str, expected_iterations: int) -> Tuple[bool, str]:
    """
    Verify that the results CSV contains data for all expected iterations.
    This provides additional validation beyond the plots saved marker.
    """
    results_csv_path = next(iter(glob.glob(os.path.join(run_dir, "*_results.csv"))), None)

    if not results_csv_path:
        return False, "Results CSV not found."

    try:
        df = pd.read_csv(results_csv_path)
        if df.empty:
            return False, "Results CSV is empty."
        
        # Check for iterations beyond the initial one
        iterations_df = df[df['iteration'] > 0]
        if iterations_df.empty and expected_iterations > 1:
            return False, "No improvement iterations found in results."

        # Check for meaningful data in the results
        if len(df) > 0:
            return True, "Results CSV validation passed."
        else:
            return False, "Results CSV contains no data."
            
    except Exception as e:
        return False, f"Error reading results CSV: {e}"

def detect_execution_issues(
    run_id: str, 
    run_dir: str, 
    return_code: int, 
    plots_saved_marker_found: bool,
    full_output: str,
    expected_iterations: int,
    framework: Optional[str]
) -> Tuple[str, str]:
    """
    Comprehensive, context-aware success/failure detection system based on the plots saved marker.
    Returns a tuple of (status, error_message).
    """
    # 1. Check return code first
    if return_code != 0:
        return "failed", f"Command returned non-zero exit code: {return_code}"

    # 2. Check for the critical success marker: "All plots saved in directory:"
    if not plots_saved_marker_found:
        # Search for specific error messages in the output
        error_message = "Success marker 'All plots saved in directory:' not found in output."
        lines = full_output.splitlines()
        
        # Look for specific errors in the output
        for i in range(len(lines) - 1, -1, -1):
            line = lines[i]
            if any(pattern in line for pattern in [
                "Traceback (most recent call last):",
                "NameError:",
                "Error:",
                "Exception:",
                "Failed to run notebook"
            ]):
                # Found a specific error, capture it with context
                error_context = "\n".join(lines[max(0, i-2):i+3])
                error_message = f"Found error in output:\n---\n{error_context.strip()}\n---"
                break
        
        return "failed", error_message

    # 3. Additional validation: Check results CSV if plots marker was found
    csv_ok, csv_message = verify_results_csv(run_dir, expected_iterations)
    if not csv_ok:
        return "complete_with_warnings", f"Plots saved successfully but CSV validation failed: {csv_message}"

    # 4. Check for framework-specific safe patterns that shouldn't be treated as errors
    FRAMEWORK_SAFE_PATTERNS = {
        'autogen': ['terminate']
    }
    
    # Clean output for error detection
    cleaned_output = full_output.lower()
    safe_patterns = FRAMEWORK_SAFE_PATTERNS.get(framework, [])
    for pattern in safe_patterns:
        cleaned_output = cleaned_output.replace(pattern, '')

    # 5. Check for suspicious patterns (warnings, not failures)
    suspicious_patterns = ["warning:", "deprecation"]
    for pattern in suspicious_patterns:
        if pattern in cleaned_output:
            return "complete_with_warnings", f"Completed successfully with warnings: {pattern} detected"

    return "complete", "no_errors"


def run_test(notebook: str, framework: Optional[str], option: str, iteration: int, dry_run: bool = False, silent: bool = False) -> Dict[str, str]:
    """Run a single test with real-time output streaming and success detection."""
    cmd_parts = ["python", "start.py", notebook, MODEL]
    if framework:
        cmd_parts.extend(["-fw", framework])
    cmd_parts.extend(["-opt", option, "-n", str(iteration)])
    cmd_str = " ".join(cmd_parts)
    
    time_started = datetime.datetime.now()
    time_started_str = time_started.strftime("%Y-%m-%d %H:%M:%S")
    
    print(f"\n{'='*80}")
    print(f"Running test: {cmd_str}")
    print(f"Started at:   {time_started_str}")
    
    if dry_run:
        print("--- DRY RUN MODE ---")
        print("   - Command execution skipped.")
        print("   - Simulating successful run for flow demonstration.")
        print("="*80)
        return {
            "notebook_name": cmd_str, "timestamp": time_started.strftime("%H:%M"),
            "time_started": time_started_str, "time_ended": time_started_str,
            "time_taken": "00:00:01", "status": "complete", "runs": "dry_run_id",
            "error_message": "no_errors"
        }

    start_time = time.time()
    run_id_prefix = time_started.strftime("%Y%m%d_%H%M%S")
    
    # Track directories before execution to detect new run directory
    dirs_before = set(os.listdir('runs')) if os.path.exists('runs') else set()
    
    # Real-time output streaming with plots saved marker detection
    plots_saved_marker_found = False
    output_lines = []
    run_dir = ""
    
    try:
        # Use Popen for real-time output streaming
        process = subprocess.Popen(
            cmd_parts,
            stdout=subprocess.PIPE,
            stderr=subprocess.STDOUT,
            text=True,
            encoding='utf-8',
            errors='replace',
            cwd=os.getcwd()
        )

        # Stream output in real-time and detect success marker
        while True:
            line = process.stdout.readline()
            if not line:
                break
            # Print to console for visibility (can be disabled with --silent flag)
            if not silent:
                sys.stdout.write(line)
                sys.stdout.flush()
            output_lines.append(line)
            if "All plots saved in directory:" in line:
                plots_saved_marker_found = True

        process.wait()
        return_code = process.returncode
        
        end_time = time.time()
        duration_seconds = int(end_time - start_time)
        
        # Detect the run directory created by start.py
        time.sleep(1)  # Allow filesystem a moment
        dirs_after = set(os.listdir('runs')) if os.path.exists('runs') else set()
        new_dirs = dirs_after - dirs_before
        
        if len(new_dirs) == 1:
            run_id = new_dirs.pop()
            run_dir = os.path.join('runs', run_id)
        elif len(new_dirs) == 0:
            # Fallback: try to find directory by timestamp prefix
            all_dirs = [d for d in os.listdir('runs') if d.startswith(run_id_prefix)] if os.path.exists('runs') else []
            if all_dirs:
                run_id = all_dirs[0]
                run_dir = os.path.join('runs', run_id)
            else:
                run_id = "unknown"
                run_dir = ""
        else:
            # Multiple directories found, take the most recent
            run_id = max(new_dirs) if new_dirs else "unknown"
            run_dir = os.path.join('runs', run_id) if run_id != "unknown" else ""
        
        full_output = "".join(output_lines)
        
        # Determine final status using the new detection logic
        status, error_message = detect_execution_issues(
            run_id, run_dir, return_code, plots_saved_marker_found, 
            full_output, iteration, framework
        )

    except subprocess.TimeoutExpired:
        duration_seconds = TIMEOUT_SECONDS
        status = "timeout"
        error_message = f"Test exceeded {TIMEOUT_SECONDS//60} minute timeout"
        run_id = run_id_prefix
    
    except Exception as e:
        duration_seconds = int(time.time() - start_time) if 'start_time' in locals() else 0
        status = "error"
        error_message = str(e)[:300]
        run_id = run_id_prefix
    
    time_ended_str = (time_started + datetime.timedelta(seconds=duration_seconds)).strftime("%Y-%m-%d %H:%M:%S")
    time_taken = str(datetime.timedelta(seconds=duration_seconds))

    print(f"[INFO] Test completed in {time_taken} (Final Status: {status})")
    if status != "complete":
        print(f"[INFO] Error details: {error_message}")
    
    return {
        "notebook_name": cmd_str, "timestamp": time_started.strftime("%H:%M"),
        "time_started": time_started_str, "time_ended": time_ended_str,
        "time_taken": time_taken, "status": status, "runs": run_id,
        "error_message": error_message
    }


def write_to_csv(results: Dict[str, str], filename: str):
    """Append results to a CSV file."""
    file_exists = os.path.exists(filename)
    with open(filename, 'a', newline='', encoding='utf-8') as csvfile:
        fieldnames = ["notebook_name", "timestamp", "time_started", "time_ended", 
                      "time_taken", "status", "runs", "error_message"]
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        if not file_exists:
            writer.writeheader()
        writer.writerow(results)

# --- MAIN EXECUTION ---

def main():
    """Main function to run the batched, interactive testing suite."""
    parser = argparse.ArgumentParser(description="Run comprehensive notebook testing suite in batches.")
    parser.add_argument("--dry-run", action="store_true", help="Print commands without executing them.")
    parser.add_argument("-y", "--yes", action="store_true", help="Automatically answer yes to all batch prompts.")
    parser.add_argument("--silent", action="store_true", help="Suppress real-time test output for cleaner logs.")
    args = parser.parse_args()

    print("Starting Comprehensive Notebook Testing Suite")
    print(f"Success Detection: Tests are considered successful when 'All plots saved in directory:' appears in output")
    if args.dry_run:
        print("[WARN] DRY RUN MODE ENABLED: No tests will be executed.")
    if args.silent:
        print("[INFO] SILENT MODE: Real-time test output will be suppressed.")
    print("="*80)

    if not validate_environment():
        print("[ERROR] Environment validation failed. Exiting.")
        return

    # Create output directory if it doesn't exist
    if not os.path.exists(OUTPUT_DIR):
        os.makedirs(OUTPUT_DIR)
        print(f"[INFO] Created output directory: {OUTPUT_DIR}")

    # Calculate total tests
    all_notebooks = [nb for batch in NOTEBOOK_BATCHES for nb in batch["notebooks"]]
    existing_notebooks = [nb for nb in all_notebooks if os.path.exists(nb)]
    total_tests = len(existing_notebooks) * len(FRAMEWORKS) * len(OPTIONS) * len(ITERATIONS)
    
    print(f"[INFO] Total potential tests: {total_tests}")

    total_tests_run = 0
    total_failed = 0
    overall_start_time = datetime.datetime.now()

    for i, batch_info in enumerate(NOTEBOOK_BATCHES):
        batch_name = batch_info["name"]
        notebooks_in_batch = [nb for nb in batch_info["notebooks"] if os.path.exists(nb)]

        if not notebooks_in_batch:
            print(f"\n[INFO] Skipping batch '{batch_name}' as no notebooks were found.")
            continue

        print(f"\n{'#'*80}")
        print(f"### PHASE {i+1}/{len(NOTEBOOK_BATCHES)}: {batch_name} ###")
        print(f"{'#'*80}")

        batch_test_count = len(notebooks_in_batch) * len(FRAMEWORKS) * len(OPTIONS) * len(ITERATIONS)
        print(f"This phase involves {batch_test_count} tests.")
        
        if not args.yes and not args.dry_run:
            response = input("Continue with this testing phase? (y/n): ").strip().lower()
            if response != 'y':
                print("[CANCEL] Testing phase cancelled by user.")
                continue

        phase_completed = 0
        phase_failed = 0
        phase_start_time = datetime.datetime.now()

        for notebook in notebooks_in_batch:
            for framework in FRAMEWORKS:
                for option in OPTIONS:
                    for iteration in ITERATIONS:
                        total_tests_run += 1
                        print(f"\n[{total_tests_run}/{total_tests}] Testing combination:")
                        print(f"   - Notebook:  {notebook}")
                        print(f"   - Framework: {framework or 'default'}")
                        print(f"   - Option:    {option}")
                        print(f"   - Iteration: {iteration}")

                        results = run_test(notebook, framework, option, iteration, args.dry_run, args.silent)
                        
                        if not args.dry_run:
                            write_to_csv(results, OUTPUT_CSV_PATH)
                        
                        if results["status"] not in ["complete", "complete_with_warnings"]:
                            phase_failed += 1
                            total_failed += 1
                            print(f"   - [STATUS-FAIL] {results['status']}")
                            if results.get('error_message') != 'no_errors':
                                print(f"   - [ERROR-MSG] {results['error_message'][:120]}...")
                        else:
                            print(f"   - [STATUS-OK] {results['status']}")
                        print(f"   - [TIME] {results['time_taken']}")
        
        print(f"\n--- PHASE {i+1} SUMMARY: {batch_name} ---")
        success_rate = ((batch_test_count - phase_failed) / batch_test_count * 100) if batch_test_count > 0 else 0
        print(f"   - Tests in phase: {batch_test_count}")
        print(f"   - Success Rate: {success_rate:.1f}%")
        print(f"   - Duration: {str(datetime.datetime.now() - phase_start_time).split('.')[0]}")
        print(f"   - Results saved to: {OUTPUT_CSV_PATH}")

    print(f"\n{'='*80}")
    print("ALL TESTING PHASES COMPLETE")
    print(f"{'='*80}")
    print(f"   - Total tests run: {total_tests_run}")
    print(f"   - Successful (incl. warnings): {total_tests_run - total_failed}")
    print(f"   - Failed/Timeout/Error: {total_failed}")
    overall_success_rate = ((total_tests_run - total_failed) / total_tests_run * 100) if total_tests_run > 0 else 0
    print(f"   - Overall Success Rate: {overall_success_rate:.1f}%")
    print(f"   - Total duration: {str(datetime.datetime.now() - overall_start_time).split('.')[0]}")
    print(f"   - Final log file: {OUTPUT_CSV_PATH}")
    print("="*80)

if __name__ == "__main__":
    main()