import openai
from typing import Callable
import time
import json

def split_gene_string(input_str, chunk_size):
    input_list = input_str.split(",")
    output_sublists = [input_list[i:i + chunk_size] for i in range(0, len(input_list), chunk_size)]
    output_list = [",".join(x) for x in output_sublists]
    return output_list

def basic_zero_shot_prompt(trait: str,gene_str: str) -> tuple[str, str]:
    context_zero_shot = """
You are an expert in biology and genetics.
Your task is to identify likely causal genes within a locus for a given GWAS phenotype based on literature evidence.
From the list, provide the likely causal gene (matching one of the given genes), confidence (0: very unsure to 1: very confident), and a brief reason (50 words or less) for your choice.
Return your response in JSON format, excluding the GWAS phenotype name and gene list in the locus. JSON keys should be 'causal_gene','confidence','reason'. 
Your response must start with '{' and end with '}'.
"""
    query_str = f"Identify the causal gene.\nGWAS phenotype: {trait}\nGenes in locus: {gene_str}\n"
    return context_zero_shot, query_str

def zero_shot_prompt_minimal(trait: str,gene_str: str) -> tuple[str, str]:
    context_zero_shot = """
From the list, provide the likely causal gene (matching one of the given genes), confidence (0: very unsure to 1: very confident), and a brief reason (50 words or less) for your choice.
Return your response in JSON format, excluding the GWAS phenotype name and gene list in the locus. JSON keys should be 'causal_gene','confidence','reason'.
Your response must start with '{' and end with '}'.
"""
    query_str = f"Identify the causal gene.\nGWAS phenotype: {trait}\nGenes in locus: {gene_str}\n"
    return context_zero_shot, query_str

def get_basic_zero_shot_response(trait: str,gene_str: str,MODEL: str, prompt_style_fn: Callable[...,tuple[str,str]], kwargs: dict = {}) -> str:
    context_zero_shot, query_str = prompt_style_fn(trait, gene_str)
    response = openai.chat.completions.create(
      model=MODEL,
      messages=[
          {"role": "system", "content": context_zero_shot},
          {"role": "user", "content": query_str},
      ],
      temperature=0,
    )
    answer = response.choices[0].message.content
    return answer

def get_basic_zero_shot_gene_only_response(trait: str,gene_str: str,MODEL: str, prompt_style_fn: Callable[...,tuple[str,str]], kwargs: dict = {}) -> str:
    context_zero_shot, query_str = prompt_style_fn(trait, gene_str)
    response = openai.chat.completions.create(
      model=MODEL,
      messages=[
          {"role": "system", "content": context_zero_shot},
          {"role": "user", "content": query_str},
      ],
      temperature=0,
    )
    answer = response.choices[0].message.content
    answer = '{"causal_gene":"'+answer+'"}'
    return answer

def basic_zero_shot_gene_only_prompt(trait: str,gene_str: str) -> tuple[str, str]:
    context_zero_shot = """
You are an expert in biology and genetics.
Your task is to identify likely causal genes within a locus for a given GWAS phenotype based on literature evidence.
From the list, provide only the likely causal gene (matching one of the given genes).
Return your response in JSON format, excluding the GWAS phenotype name and gene list in the locus. JSON key should be 'causal_gene'.
Your response must start with '{' and end with '}'.
"""
    query_str = f"Identify the causal gene.\nGWAS phenotype: {trait}\nGenes in locus: {gene_str}\n"
    return context_zero_shot, query_str

def parse_response_to_json(response):
    response = response.strip("`").strip("json").strip()
    try:
        answer = json.loads(response)
    except json.decoder.JSONDecodeError as e:
        print(response)
        print(e)
    return answer
