import os
import shutil
import argparse
import subprocess
import re
import nbformat
from datetime import datetime
from scripts.utils import (
    convert_notebook_to_script,
    plot_results,
    write_code_diff,
    save_results,
    validate_agent,
    validate_framework,
    combine_results,
    run_and_log,
    load_config
)
from scripts.utils.artifact_utils import determine_execution_mode, ExecutionMode

def detect_notebook_framework(nb_path):
    """
    Detect framework based on import signatures and content analysis.
    
    Args:
        nb_path: Path to notebook file.
        
    Returns:
        Framework identifier ('scope_rl', 'obp', 'unknown').
    """
    try:
        # Read the notebook file
        with open(nb_path, 'r', encoding='utf-8') as f:
            nb = nbformat.read(f, as_version=4)
        
        # Framework signatures with import patterns and confidence scores
        framework_signatures = {
            'scope_rl': {
                'imports': [
                    'import scope_rl',
                    'from scope_rl',
                    'scope_rl.',
                    'from basicgym',
                    'import basicgym',
                    'SyntheticDataset',
                    'ContinuousEvalHead',
                    'TruncatedGaussianHead'
                ],
                'confidence_threshold': 1  # Need at least 1 match for SCOPE-RL
            },
            'obp': {
                'imports': [
                    'import obp',
                    'from obp',
                    'obp.',
                    'OpenBanditDataset',
                    'BernoulliTS',
                    'LinearUCB',
                    'ReplayMethod'
                ],
                'confidence_threshold': 1  # Need at least 1 match for OBP
            }
        }
        
        # Extract all code content from notebook cells
        code_content = ""
        for cell in nb.cells:
            if cell.cell_type == 'code':
                code_content += cell.source + "\n"
        
        # Score each framework based on import signatures
        framework_scores = {}
        for framework, config in framework_signatures.items():
            score = 0
            for import_pattern in config['imports']:
                if import_pattern.lower() in code_content.lower():
                    score += 1
            
            framework_scores[framework] = score
        
        # Determine the framework with highest confidence
        max_score = max(framework_scores.values()) if framework_scores.values() else 0
        
        if max_score == 0:
            return 'unknown'
        
        # Return framework with highest score that meets threshold
        for framework, score in framework_scores.items():
            if score == max_score and score >= framework_signatures[framework]['confidence_threshold']:
                return framework
        
        return 'unknown'
        
    except Exception as e:
        print(f"Warning: Failed to detect framework for {nb_path}: {e}")
        return 'unknown'

def pick_interpreter(nb_path, cfg):
    """
    Select the appropriate interpreter based on sophisticated framework detection.
    
    Args:
        nb_path: Path to notebook file
        cfg: Configuration dictionary
        
    Returns:
        str: Path to appropriate Python interpreter
    """
    # Detect framework using content analysis
    framework = detect_notebook_framework(nb_path)
    
    print(f"Framework detection for {os.path.basename(nb_path)}: {framework}")
    
    # Framework to interpreter mapping
    framework_interpreters = {
        'scope_rl': cfg.get('interpreter_map', {}).get('scope_rl', cfg.get('notebook_interpreter')),
        'obp': cfg.get('interpreter_map', {}).get('obp', cfg.get('notebook_interpreter')),
        'unknown': cfg.get('notebook_interpreter')
    }
    
    # Get interpreter or fall back to regex patterns if framework detection fails
    interpreter = framework_interpreters.get(framework)
    
    if not interpreter or framework == 'unknown':
        print(f"Warning: Using fallback regex detection for {os.path.basename(nb_path)}")
        # Fallback to original regex-based detection
        patterns = cfg.get("interpreter_map", {})
        for pattern, exe in patterns.items():
            if re.search(pattern, nb_path):
                return exe
        interpreter = cfg.get("notebook_interpreter")
    
    return interpreter

def main(input_path, model, option, framework=None, iterations=1, all=False, time=None):
    # Load configuration
    config = load_config()
    config_settings = config['settings'] # Store settings for pick_interpreter
    
    # Create a directory to store the results of the run(s)
    if time is None:
        time = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    if all:
        dir_name = f'{model}_{framework}'
        run_dir = os.path.join('runs', f'{time}-all', dir_name)
    else:
        dir_name = time
        run_dir = os.path.join('runs', dir_name)
    os.makedirs(run_dir, exist_ok=True)
    
    if os.path.isfile(input_path) and input_path.endswith('.ipynb'):
        files = [input_path]
    else:
        files = [f for f in os.listdir(input_path) if os.path.isfile(os.path.join(input_path, f)) and f.endswith('.ipynb')]
    for file in files:
        file = os.path.join(input_path, file) if os.path.isdir(input_path) else file
        
        # Select appropriate interpreter for this notebook
        notebook_interpreter = pick_interpreter(file, config_settings)
        
        # Check for SCOPE-RL artifacts and determine execution mode
        if detect_notebook_framework(file) == 'scope_rl':
            execution_mode = determine_execution_mode(file)
            
            if execution_mode == ExecutionMode.TRAINING_BYPASS:
                print("-" * 60)
                print(f"TRAINING BYPASS MODE: Detected datasets and models for {os.path.basename(file)}")
                print("Will skip expensive training and load pre-computed artifacts")
                print("-" * 60)

        # exec_path can be either an llm or a framework
        exec_path = ""
        
        if framework:
            exec_path = validate_framework(framework)
        else:
            try:
                exec_path = validate_agent(model)
            except ValueError as e:
                print(f"Error: {e}")
                return    

        # Notebooks are stored in the given input folder. To run them, we
        # are converting them to .py files and them copying them to a unique
        # run folder. Since we are going to run a notebook once before using 
        # the agent, that file will be the 0th file (initial results).
        filename = os.path.basename(file).replace(".ipynb", ".py")
        new_filename = f'0-{filename}'
        file_to_run = os.path.join(run_dir, new_filename)
        
        # Convert our Jupyter notebook to .py file so it can be ran externally
        convert_notebook_to_script(
            file, 
            file_to_run
        )
        
        # Run the code once for initial set of results
        print("Running notebook for initial results...")
        subprocess.run(f"{notebook_interpreter} {file_to_run}", shell=True)
        # Save results from generated file out.csv and deleted it afterwards
        print("Saving initial results...")
        save_results("out.csv", f'{os.path.join(run_dir, filename[:-3])}_results.csv')
        
        # Run the code a given number of times
        for i in range(iterations):
            print("///////////////////////////////////////////////////////////")
            print("Run no: ", i+1)
            print("///////////////////////////////////////////////////////////")
            # Create a copy of the file on which to apply the agent changes
            prev_file = file_to_run
            file_to_run = file_to_run.replace(f'{i}-{filename}', f'{i+1}-{filename}')
            shutil.copy(prev_file, file_to_run)
            print("Running agent...")
            log_file_path = os.path.join(run_dir, "agent_output.log")
            command = f'python {exec_path} {file_to_run} {model} -opt {option}'
            return_code = run_and_log(command, log_file_path)
            if return_code != 0:
                print(f"Error: Command failed with return code {return_code}")
                return
            
            # Compare the previous iteration with the current iteration
            write_code_diff(prev_file, file_to_run)
            print("Running notebook")
            try:
                subprocess.run(f"{notebook_interpreter} {file_to_run}", shell=True, check=True)
            except Exception as e:
                print("Error: Failed to run notebook", e)
                return
            save_results("out.csv", f'{os.path.join(run_dir, filename[:-3])}_results.csv')
            
        # Visualize the recorded results  
        plot_results(f'{os.path.join(run_dir, filename[:-3])}_results.csv')
    
def run_all_combinations(input_path, option):
    MODELS = ['gemini-1.5-pro', 'gemini-1.5-flash', 
              'mistral-large-latest', 'codestral-latest',
              'gpt-4o', 'gpt-4o-mini']
    FRAMEWORKS = [None, 'autogen', 'crewai']
    
    time = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    for model in MODELS:
        for framework in FRAMEWORKS:
            print(f"Running with model: {model}, framework: {framework}")
            main(input_path, model, framework, all=True, time=time, option=option)
    
    combine_results(time)
        
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Use an agent to optimize hyperparamters a notebook.")
    parser.add_argument("input_path", type=str, help="Path to the input notebook")
    parser.add_argument("model", type=str, nargs='?', help="Name of the model to run")
    parser.add_argument("-opt", "--option", type=str, default="whole_code", choices=['manual_patch', 'whole_code', 'agent_applies'], help="Choose how to apply the agent's changes")
    parser.add_argument("-fw", "--framework", type=str, help="Name of the agent framework (crewai, autogen)")
    parser.add_argument("-n", "--iterations", type=int, default=1, help="Number of iterations to run the agent (default: 1)")
    parser.add_argument("-a", "--all", action="store_true", help="Run all combinations of models and frameworks")

    args = parser.parse_args()

    if args.all:
        run_all_combinations(args.input_path, args.option)
    else:
        main(args.input_path, args.model, args.option, args.framework, args.iterations)