import glob
import os
from typing import Optional, List, Dict, Tuple
import yaml
from enum import Enum
import json


class ExecutionMode(Enum):
    """SCOPE-RL notebook execution modes."""
    FULL_RUN = "full_run"              # No artifacts, run everything including training
    TRAINING_BYPASS = "training_bypass"  # Datasets + models exist, bypass training and start from OPE


def find_datasets(notebook_path: str) -> Dict[str, str]:
    """
    Find dataset artifacts for SCOPE-RL notebook by searching for train/test pickle files in the artifacts directory.
    """
    try:
        notebook_base_name = os.path.basename(notebook_path).replace(".ipynb", "")
        env_type = os.path.basename(os.path.dirname(notebook_path))

        # Path to the logs directory
        artifact_dir = os.path.join("artifacts", env_type, notebook_base_name, "logs")
        
        if not os.path.isdir(artifact_dir):
            return {}

        datasets = {}
        
        # Search for train dataset with more patterns
        train_patterns = [
            "train_dataset_*.pkl", 
            "train_logged_dataset_*.pkl",
            "*_train_dataset.pkl",
            "train_*.pkl"
        ]
        for pattern in train_patterns:
            train_files = glob.glob(os.path.join(artifact_dir, pattern))
            if train_files:
                # Pick the most recent file if multiple exist
                train_files.sort(key=os.path.getmtime, reverse=True)
                datasets['train'] = os.path.abspath(train_files[0])
                break
        
        # Search for test dataset with more patterns
        test_patterns = [
            "test_dataset_*.pkl", 
            "test_logged_dataset_*.pkl",
            "*_test_dataset.pkl",
            "test_*.pkl"
        ]
        for pattern in test_patterns:
            test_files = glob.glob(os.path.join(artifact_dir, pattern))
            if test_files:
                # Pick the most recent file if multiple exist
                test_files.sort(key=os.path.getmtime, reverse=True)
                datasets['test'] = os.path.abspath(test_files[0])
                break
        
        # Validate that both datasets exist and are readable
        if 'train' in datasets and 'test' in datasets:
            for key in ['train', 'test']:
                if not os.path.exists(datasets[key]):
                    print(f"Warning: {key} dataset path exists but file not found: {datasets[key]}")
                    return {}
                if os.path.getsize(datasets[key]) == 0:
                    print(f"Warning: {key} dataset file is empty: {datasets[key]}")
                    return {}
        else:
            # Both datasets are required for bypass mode
            return {}
                
        return datasets
    except Exception as e:
        print(f"Error finding datasets: {e}")
        return {}


def find_models(notebook_path: str) -> Dict[str, str]:
    """
    Find model artifacts for SCOPE-RL notebook by searching for .pt files in the d3rlpy_logs directory with validation.
    """
    try:
        notebook_base_name = os.path.basename(notebook_path).replace(".ipynb", "")
        env_type = os.path.basename(os.path.dirname(notebook_path))

        # Path to the models directory
        models_dir = os.path.join("artifacts", env_type, notebook_base_name, "d3rlpy_logs")
        
        models = {}
        
        if os.path.isdir(models_dir):
            model_files = glob.glob(os.path.join(models_dir, "*.pt"))
            
            for model_file in model_files:
                model_name = os.path.splitext(os.path.basename(model_file))[0]
                
                # Validate model file
                if os.path.getsize(model_file) == 0:
                    print(f"Warning: Model file is empty: {model_file}")
                    continue
                
                models[model_name] = os.path.abspath(model_file)
            
        return models
    except Exception as e:
        print(f"Error finding models: {e}")
        return {}


def validate_artifacts(datasets: Dict[str, str], models: Dict[str, str]) -> Tuple[bool, List[str]]:
    """
    Validate that artifacts are complete and compatible for bypass mode, checking minimum requirements for datasets and models.
    """
    errors = []
    
    # Check datasets
    if not datasets:
        errors.append("No datasets found")
    elif 'train' not in datasets:
        errors.append("Training dataset not found")
    elif 'test' not in datasets:
        errors.append("Test dataset not found")
    
    # Check models
    if not models:
        errors.append("No models found")
    
    # Check for minimum expected models based on common patterns
    if models and len(models) < 2:
        errors.append(f"Only {len(models)} model(s) found, expected at least 2")
    
    return len(errors) == 0, errors


def determine_execution_mode(notebook_path: str, verbose: bool = True) -> ExecutionMode:
    """
    Determine optimal execution mode based on available artifacts with validation and return TRAINING_BYPASS or FULL_RUN.
    """
    datasets = find_datasets(notebook_path)
    models = find_models(notebook_path)
    
    if verbose:
        print(f"Checking artifacts for: {os.path.basename(notebook_path)}")
        print(f"  Found {len(datasets)} dataset(s)")
        print(f"  Found {len(models)} model(s)")
    
    # Validate artifacts
    is_valid, errors = validate_artifacts(datasets, models)
    
    if is_valid:
        if verbose:
            print(" All artifacts valid for TRAINING_BYPASS mode")
        return ExecutionMode.TRAINING_BYPASS
    else:
        if verbose:
            print(" Missing/invalid artifacts for bypass mode:")
            for error in errors:
                print(f"    - {error}")
            print("  → Using FULL_RUN mode")
        return ExecutionMode.FULL_RUN


def get_artifact_info(notebook_path: str) -> Dict[str, any]:
    """
    Get detailed information about available artifacts including sizes, paths, metadata, and execution mode recommendations.
    """
    info = {
        'notebook': os.path.basename(notebook_path),
        'execution_mode': determine_execution_mode(notebook_path, verbose=False).value,
        'datasets': {},
        'models': {}
    }
    
    # Get dataset info
    datasets = find_datasets(notebook_path)
    for key, path in datasets.items():
        if os.path.exists(path):
            info['datasets'][key] = {
                'path': path,
                'size_mb': os.path.getsize(path) / (1024 * 1024),
                'modified': os.path.getmtime(path)
            }
    
    # Get model info with metadata
    models = find_models(notebook_path)
    for name, path in models.items():
        if os.path.exists(path):
            model_info = {
                'path': path,
                'size_mb': os.path.getsize(path) / (1024 * 1024),
                'modified': os.path.getmtime(path)
            }
            
            # Check for metadata
            meta_path = path.replace('.pt', '_meta.json')
            if os.path.exists(meta_path):
                try:
                    with open(meta_path, 'r') as f:
                        model_info['metadata'] = json.load(f)
                except Exception:
                    pass
            
            info['models'][name] = model_info
    
    return info


def cleanup_artifacts(notebook_path: str, dry_run: bool = True) -> List[str]:
    """
    Clean up artifact files for a notebook with dry run option to preview deletions before executing.
    """
    notebook_base_name = os.path.basename(notebook_path).replace(".ipynb", "")
    env_type = os.path.basename(os.path.dirname(notebook_path))
    artifact_base = os.path.join("artifacts", env_type, notebook_base_name)
    
    files_to_delete = []
    
    if os.path.exists(artifact_base):
        for root, dirs, files in os.walk(artifact_base):
            for file in files:
                file_path = os.path.join(root, file)
                files_to_delete.append(file_path)
                if not dry_run:
                    os.remove(file_path)
        
        # Remove empty directories if not dry run
        if not dry_run:
            for root, dirs, files in os.walk(artifact_base, topdown=False):
                for dir in dirs:
                    dir_path = os.path.join(root, dir)
                    try:
                        os.rmdir(dir_path)
                    except OSError:
                        pass  # Directory not empty
            try:
                os.rmdir(artifact_base)
            except OSError:
                pass
    
    return files_to_delete 