from argparse import ArgumentParser
import openai
import os
import json
from dotenv import load_dotenv
import pandas as pd
import mlflow
from prompt_utils import *
from utils import log_arguments

PROMPT_STYLE_DICT = {'basic_zero_shot': (get_basic_zero_shot_response, basic_zero_shot_prompt),
                     'zero_shot_minimal': (get_basic_zero_shot_response,zero_shot_prompt_minimal),
                     'basic_zero_shot_gene_only': (get_basic_zero_shot_response, basic_zero_shot_gene_only_prompt),} 
                      
PROMPT_STYLE_LIST = list(PROMPT_STYLE_DICT.keys())

def main():
    parser = ArgumentParser("Provide an input file with (trait,gene_string) pairs, optional openai model and save API responses to file")
    parser.add_argument("inputfn")
    parser.add_argument("outputfn")
    parser.add_argument("--model",default="gpt-3.5-turbo-0125")
    parser.add_argument("--max_examples",default=-1,type=int)
    parser.add_argument("--prompt_style",choices=PROMPT_STYLE_LIST,default="basic_zero_shot")

    args = parser.parse_args()
    mlflow.start_run()
    log_arguments(args)

    inputfn = args.inputfn
    outputfn = args.outputfn
    MODEL = args.model
    max_examples = args.max_examples
    prompt_style = args.prompt_style

    kwargs = {}

    load_dotenv()

    openai.api_key = os.getenv("OPENAI_API_KEY")
    openai.organization = os.getenv("OPENAI_ORGANIZATION")

    prompt_function, prompt_template_function = PROMPT_STYLE_DICT[prompt_style]

    system_prompt_eg, user_prompt_eg = prompt_template_function('<Trait>','<Gene String>')
    mlflow.log_param('system_prompt_example',system_prompt_eg)
    mlflow.log_param('user_prompt_example',user_prompt_eg)

    input_df = pd.read_csv(inputfn,sep="\t")

    if max_examples < 0:
        max_examples = len(input_df)

    query_df = input_df[['description','symbol_gene_string']]

    v2g_result_list = []
    for i in range(max_examples):
        print(i)
        trait, gene_str = query_df.iloc[i]
        answer = prompt_function(trait, gene_str, MODEL, prompt_template_function, kwargs)
        answer = parse_response_to_json(answer)
        v2g_result_list.append(answer)
    
    #v2g_result_dicts = [json.loads(x) for x in v2g_result_list]
    v2g_result_df = pd.DataFrame(v2g_result_list)
    v2g_result_df.to_csv(outputfn)

    mlflow.log_artifact(outputfn)
    mlflow.end_run()

if __name__ == "__main__":
    main()
