import argparse
from mistralai import Mistral
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
from scripts.utils import (
    load_prompts,
    apply_diff,
    get_api_key,
    parse
)

def main(input_path, model, option):
    print(f"Running {model}... with option={option}")
    
    prompts = load_prompts()
    
    if option=="whole_code":
        prompt = prompts['whole_code']
    else:
        prompt = prompts['diff']
    
    with open(input_path, "r") as f:
        notebook = f.read()
        
    client = Mistral(api_key=get_api_key(model))

    chat_response = client.chat.complete(
        model= model,
        messages = [
            {
                "role": "user", "content": 
                f"{prompts['system_message']} {prompt}: {notebook}",
            },
        ]
    )
    
    parsed_response = parse((chat_response.choices[0].message.content))
    print("parsed_response:", parsed_response)
    
    if option=="whole_code":
        with open(input_path, "w") as f:
            f.write(parsed_response)
    else:
        apply_diff(input_path, parsed_response)
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Modify hyperparameters of a notebook to decrease relative error estimation.")
    parser.add_argument("input_path", type=str, help="Path to the input notebook file")
    parser.add_argument("model", type=str, help="Specify the model to use")
    parser.add_argument("-opt", "--option", type=str, default="manual_patch", choices=['manual_patch', 'whole_code', 'agent_applies'], help="Choose how to apply the agent's changes")
    args = parser.parse_args()

    main(args.input_path, args.model, args.option)