import os
import shutil
import yaml
import argparse
import subprocess
from datetime import datetime
import pandas as pd
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
import scripts.utils.notebook_converter as nc
from scripts.utils.visualize import plot_results
from scripts.utils.diff import write_code_diff
from scripts.utils.save_run_results import save_results
from scripts.utils.parse_agent_output import parse
from scripts.utils.config_loader import load_config
from scripts.utils.get_api_key import get_api_key
from scripts.utils.artifact_utils import determine_execution_mode, ExecutionMode

import google.generativeai as genai
from openai import OpenAI

def detect_notebook_framework(nb_path):
    """
    Detect framework based on import signatures and content analysis.
    (Simplified version of the function from start.py)
    """
    try:
        import nbformat
        with open(nb_path, 'r', encoding='utf-8') as f:
            nb = nbformat.read(f, as_version=4)
        
        code_content = ""
        for cell in nb.cells:
            if cell.cell_type == 'code':
                code_content += cell.source + "\n"
        
        # Check for SCOPE-RL patterns
        scope_patterns = ['import scope_rl', 'from scope_rl', 'scope_rl.', 'SyntheticDataset']
        if any(pattern.lower() in code_content.lower() for pattern in scope_patterns):
            return 'scope_rl'
        
        # Check for OBP patterns
        obp_patterns = ['import obp', 'from obp', 'obp.', 'OpenBanditDataset']
        if any(pattern.lower() in code_content.lower() for pattern in obp_patterns):
            return 'obp'
        
        return 'unknown'
    except Exception as e:
        print(f"Warning: Failed to detect framework for {nb_path}: {e}")
        return 'unknown'

def pick_interpreter(nb_path, cfg):
    """
    Select the appropriate interpreter based on framework detection.
    (Simplified version from start.py)
    """
    framework = detect_notebook_framework(nb_path)
    
    print(f"Framework detection for {os.path.basename(nb_path)}: {framework}")
    
    # Framework to interpreter mapping
    framework_interpreters = {
        'scope_rl': cfg.get('interpreter_map', {}).get('scope_rl', cfg.get('notebook_interpreter')),
        'obp': cfg.get('interpreter_map', {}).get('obp', cfg.get('notebook_interpreter')),
        'unknown': cfg.get('notebook_interpreter')
    }
    
    return framework_interpreters.get(framework, cfg.get('notebook_interpreter'))

def main(input_path, model, iterations=1):
    config = load_config()
    config_settings = config['settings']
    
    if model.startswith("gemini"):
        genai.configure(api_key=get_api_key(model))
        model_client = genai.GenerativeModel(model_name=model)
        chat = model_client.start_chat()
    elif model.startswith("gpt"):
        api_key = get_api_key(model)
        client = OpenAI(api_key=api_key)
    elif model.endswith("latest"):  # mistral-large-latest or codestral-latest
        api_key = get_api_key(model)
        client = Mistral(api_key=api_key)
    else:
        raise ValueError(f"Model {model} not supported")
    
    time = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    dir_name = time
    run_dir = os.path.join('runs', dir_name)
    
    os.makedirs(run_dir, exist_ok=True)

    # Handle both single file and directory inputs
    if os.path.isfile(input_path) and input_path.endswith('.ipynb'):
        files = [input_path]
    else:
        files = [f for f in os.listdir(input_path) if os.path.isfile(os.path.join(input_path, f)) and f.endswith('.ipynb')]

    for file in files:
        file = os.path.join(input_path, file) if os.path.isdir(input_path) else file
        
        # Select appropriate interpreter for this notebook
        notebook_interpreter = pick_interpreter(file, config_settings)
        
        # Check for SCOPE-RL artifacts and determine execution mode
        if detect_notebook_framework(file) == 'scope_rl':
            execution_mode = determine_execution_mode(file)
            
            if execution_mode == ExecutionMode.TRAINING_BYPASS:
                print("-" * 60)
                print(f"TRAINING BYPASS MODE: Detected datasets and models for {os.path.basename(file)}")
                print("Will skip expensive training and load pre-computed artifacts")
                print("-" * 60)
        
        filename = os.path.basename(file).replace(".ipynb", ".py")
        new_filename = f'0-{filename}'
        file_to_run = os.path.join(run_dir, new_filename)

        # Convert our Jupyter notebook to .py file so it can be ran externally
        nc.convert_notebook_to_script(file, file_to_run)

        # Run the code once for initial set of results
        print("Running notebook for initial results...")
        subprocess.run(f"{notebook_interpreter} {file_to_run}", shell=True)
        # Save results from generated file out.csv and deleted it afterwards
        print("Saving initial results...")
        csv_file_path = f'{os.path.join(run_dir, filename[:-3])}_results.csv'
        save_results("out.csv", csv_file_path)
        initial_df = pd.read_csv(csv_file_path)
        print("Initial CSV File Contents:\n", initial_df)

        # Set up the generative model and start chat
        message_history = []
        
        # Initial prompt message
        intial_prompt = """
            The following code performs offline policy evaluation and
            calculates the relative error estimation with various estimators at
            the end. The results that it generates are in the csv file and 
            are also provided. Please modify the hyperparameters of the file to 
            decrease the relative error estimation.
            Please output the entire code I give you and do it in one codeblock. 
            It is absolutely imperative that you output the entire code with the
            changes you've suggested applied to it even if the rest of it remains
            unchanged. Output the entire code even if it is unchanged instead of 
            putting a comment that says 'no changes' or something similar. Do NOT
            say that the rest of the code remains unchanged. Only output the code
            and nothing else. No other information.
        """
        
        # Prompt for all further iterations
        improvement_prompt = """
            Implementing the changes you suggested yields the following results. Keeping in mind 
            what changes you suggested earlier and the results that were generated, make some
            further changes to the code that could help decrease the relative error estimation
            even further. Please output the entire code I give you and do it in one codeblock.
        """
        
        for i in range(iterations):
            print("///////////////////////////////////////////////////////////")
            print("Run no: ", i+1)
            print("///////////////////////////////////////////////////////////")
            prev_file = file_to_run
            file_to_run = file_to_run.replace(f'{i}-{filename}', f'{i+1}-{filename}')
            shutil.copy(prev_file, file_to_run)
            with open(file_to_run, "r") as f:
                notebook_content = f.read()
            df = pd.read_csv(csv_file_path)
            csv_string = df.to_csv(index=False) 
            
            message = intial_prompt if i == 0 else improvement_prompt
            if model.startswith("gemini"):
                results_csv = genai.upload_file(csv_file_path, mime_type="text/csv")
                response = chat.send_message([message, notebook_content, results_csv])
                code=parse(response.candidates[0].content.parts[0].text)
            else:
                message_history.append({"role": "user", "content": 
                    f"""
                    {message}
                    {notebook_content}
                    Run results:
                    {csv_string}
                    """})
                chat_response = client.chat.completions.create(
                    model = model,
                    messages = message_history
                )
                agent_message = chat_response.choices[0].message.content
                print(agent_message)
                message_history.append({"role": "assistant", "content": agent_message})
                
                code = parse(agent_message)
            
            with open(file_to_run, "w") as f:
                f.write(code)
            print("File generated with agent's changes")   
            # Compare the previous iteration with the current iteration
            write_code_diff(prev_file, file_to_run)
            print("Running notebook")
            try:
                subprocess.run(f"{notebook_interpreter} {file_to_run}", shell=True, check=True)
            except Exception as e:
                print("Error: Failed to run notebook", e)
                return
            save_results("out.csv", f'{os.path.join(run_dir, filename[:-3])}_results.csv')
            
        # Visualize the recorded results  
        plot_results(f'{os.path.join(run_dir, filename[:-3])}_results.csv')
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Use a model to optimize hyperparameters in a notebook or py file.")
    parser.add_argument("input_path", type=str, help="Path to the input notebook or .py file")
    parser.add_argument("model", type=str, help="Name of the model to run")
    parser.add_argument("-n", "--iterations", type=int, default=1, help="Number of iterations to run the agent (default: 1)")

    args = parser.parse_args()

    main(args.input_path, args.model, args.iterations)