from argparse import ArgumentParser
import openai
import os
import json
from dotenv import load_dotenv
import pandas as pd
import mlflow
from utils import log_arguments
pd.options.mode.chained_assignment = None 

def get_embedding(text, model="text-embedding-3-large"):
   text = text.replace("\n", " ")
   return openai.embeddings.create(input = [text], model=model).data[0].embedding

def get_description(text, model="gpt-3.5-turbo-0125", num_words=300):
    print(text)
    query_str = "You are an expert in biology and genetics.\n"
    query_str += "Your task is to provide biologically relevant information about the query below "
    query_str += f"in {num_words} words or less.\n\nQuery: {text}."
    response = openai.chat.completions.create(
      model=model,
      messages=[
          {"role": "user", "content": query_str},
      ],
      temperature=0,
    )
    answer = response.choices[0].message.content
    return answer

def main():
    parser = ArgumentParser("Provide an input file with text to embed, one entry per line, optional openai model and save API responses to file")
    parser.add_argument("inputfn")
    parser.add_argument("outputfn")
    parser.add_argument("--model",default="text-embedding-3-large")
    parser.add_argument("--description_model",default="gpt-3.5-turbo-0125")
    parser.add_argument("--max_examples",default=-1,type=int)
    parser.add_argument("--include_gpt_description",action='store_true')
    parser.add_argument("--query_type",default="gene")
    parser.add_argument("--num_words",default=300,type=int,help="Number of words to generate in GPT description")

    args = parser.parse_args()
    mlflow.start_run()
    log_arguments(args)

    inputfn = args.inputfn
    outputfn = args.outputfn
    MODEL = args.model
    max_examples = args.max_examples
    include_gpt_description = args.include_gpt_description
    description_model = args.description_model
    query_type = args.query_type
    num_words = args.num_words

    load_dotenv()

    openai.api_key = os.getenv("OPENAI_API_KEY")
    openai.organization = os.getenv("OPENAI_ORGANIZATION")

    input_df = pd.read_csv(inputfn,sep="\t",header=None)
    assert input_df.shape[1]==1,"Input has multiple columns"

    if max_examples < 0:
        max_examples = len(input_df)

    query_df = input_df.iloc[:max_examples]
    column = query_df.columns[0]
    if include_gpt_description:
        query_df["gpt_description"] = query_df[column].apply(lambda x: get_description(query_type + " "+ x, description_model, num_words))
        column = "gpt_description"
    query_df["embedding"] = query_df[column].apply(lambda x: get_embedding(x, model=MODEL))
    query_df.to_csv(outputfn)

    mlflow.log_artifact(outputfn)
    mlflow.end_run()

if __name__ == "__main__":
    main()
