import argparse
import ollama
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
from scripts.utils import (
    apply_diff,
    parse
)

def main(input_path, model):
    print(f"Running {model}...")
    
    with open(input_path, "r") as f:
        notebook = f.read()
        
    response = ollama.chat(model=model, messages=[{
    'role': 'user',
    'content': "The following code performs offline policy evaluation and calculates the relative error estimation with various estimators at the end. Modify the hyperparameters of the file to decrease the relative error estimation. Generate a unified diff of the changes you suggest: " + notebook,
  },
  ])        
    print(response['message']['content'])
    parsed_diff = parse(response['message']['content'])
    print("parsed diff", parsed_diff)
    apply_diff(input_path, parsed_diff)
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Modify hyperparameters of a notebook to decrease relative error estimation.")
    parser.add_argument("input_path", type=str, help="Path to the input notebook file")
    parser.add_argument("model", type=str, help="Specify the model to use")
    args = parser.parse_args()

    main(args.input_path, args.model)