library(data.table)
library(ggplot2)
library(ggrepel)
library(plyr)
library(cowplot)
library(jsonlite)
library(Rtsne)
set.seed(23)


base_dir <- "."
source(file.path(base_dir,"scripts/evaluation/utils.R"))
datasetdir <- file.path(base_dir,'data/benchmark_datasets/')
predictiondir <- file.path(base_dir,'results/predictions')

publication_count <- get_publication_count_by_gene(file.path(base_dir,"data/helper_datasets/publication_count_by_gene.txt"))

# OpenTargets
dataset_prefix <- 'opentargets'
methodlist <- c('nearest_gene','L2G','gpt3_zero_shot','gpt4_0613_zero_shot','text_mining_gene')
opentargets_result_df <- get_dataset_results(datasetdir, dataset_prefix, predictiondir, methodlist, publication_count)

# Pharmaprojects
dataset_prefix <- 'pharmaprojects'
methodlist <- c('nearest_gene','text_mining_gene','gpt3_zero_shot','gpt4_0613_zero_shot')
pharmaprojects_result_df <- get_dataset_results(datasetdir, dataset_prefix, predictiondir, methodlist, publication_count)

# Weeks et al
dataset_prefix <- 'weeks_et_al'
methodlist <- c('nearest_gene','pops','gpt3_zero_shot','gpt4_0613_zero_shot')
weeks_et_al_result_df <- get_dataset_results(datasetdir, dataset_prefix, predictiondir, methodlist, publication_count)

# GWAS catalog
dataset_prefix <- 'gwas_catalog'
methodlist <- c('nearest_gene','gpt3_zero_shot','gpt4_0613_zero_shot')
gwas_catalog_result_df <- get_dataset_results(datasetdir, dataset_prefix, predictiondir, methodlist, publication_count)

all_result_df_main <- rbind(opentargets_result_df, pharmaprojects_result_df, 
                       weeks_et_al_result_df, gwas_catalog_result_df)
all_result_df_main_copy <- copy(all_result_df_main)

all_result_df_main$method <- revalue(all_result_df_main$method,
                                c("gpt3_zero_shot"="LLM-GPT-3.5", 
                                  "gpt4_0613_zero_shot"="LLM-GPT-4", 
                                  "text_mining_gene"="Text mining",
                                  "pops"="PoPS",
                                  "nearest_gene"="Nearest gene"))

all_result_df_main$dataset <- revalue(all_result_df_main$dataset,
                                 c('gwas_catalog'="GWAS catalog",
                                   'opentargets'="OpenTargets",
                                   'pharmaprojects'="Pharmaprojects",
                                   "weeks_et_al"="Weeks et al."))
outfile <- "results/manuscript/supp_table_metrics.csv"
write.table(all_result_df_main,outfile,sep=",",row.names=F)

p1 <- ggplot(all_result_df_main,aes(x=dataset,y=F1,fill=method)) 
p1 <- p1 + geom_bar(stat='identity',
                  position = position_dodge(0.9))
p1 <- p1 + geom_errorbar(aes(x=dataset,ymin=F1.lower,ymax=F1.upper),
                       position = position_dodge(0.9),width=0.5)
p1 <- p1 + xlab("Dataset") + ylab("F-score") + theme_bw()
outfile <- "results/manuscript/f1_score_figure.pdf"
pdf(outfile,width=10,height=6)
print(prettify(p1,xangle=0))
dev.off()

p2 <- ggplot(all_result_df_main,aes(x=precision,y=recall,fill=dataset,colour=dataset,
                              shape=method)) 
p2 <- p2 + geom_point(size=4,alpha=0.7,position=position_jitter(width=0.01,height=0.01))
p2 <- p2 + xlab("Precision") + ylab("Recall")
p2 <- p2 + theme_bw() + xlim(0,1) + ylim(0,1)
outfile <- "results/manuscript/precision_recall_figure.pdf"
pdf(outfile,width=10,height=6)
print(prettify(p2,xangle=0))
dev.off()


p3 <- ggplot(all_result_df_main,aes(x=dataset,y=gene_count_cor,fill=method)) 
p3 <- p3 + geom_bar(stat='identity',position = position_dodge(0.9))
p3 <- p3 + geom_errorbar(aes(x=dataset,ymin=gene_count_cor.lower,ymax=gene_count_cor.upper),
                       position = position_dodge(0.9),width=0.5)
p3 <- p3 + ylab("Correlation with gene count") + xlab("Dataset") + theme_bw()
outfile <- "results/manuscript/gene_count_correlation.pdf"
pdf(outfile,width=10,height=6)
print(prettify(p3,xangle=0))
dev.off()

p4 <- ggplot(all_result_df_main,aes(x=dataset,y=gene_pub_cor,fill=method)) 
p4 <- p4 + geom_bar(stat='identity',position = position_dodge(0.9))
p4 <- p4 + geom_errorbar(aes(x=dataset,ymin=gene_pub_cor.lower,ymax=gene_pub_cor.upper),
                       position = position_dodge(0.9),width=0.5)
p4 <- p4 + ylab("Correlation with publication count") + xlab("Dataset") + theme_bw()
outfile <- "results/manuscript/gene_publication_correlation.pdf"
pdf(outfile,width=10,height=6)
print(prettify(p4,xangle=0))
dev.off()

# Ablation
# OpenTargets
methodlist <- c('gpt3_zero_shot','gpt3_zero_shot_minimal')

dataset_prefix <- 'opentargets'
opentargets_result_df <- get_dataset_results(datasetdir, dataset_prefix, predictiondir, methodlist, publication_count)

# Pharmaprojects
dataset_prefix <- 'pharmaprojects'
pharmaprojects_result_df <- get_dataset_results(datasetdir, dataset_prefix, predictiondir, methodlist, publication_count)
# Weeks et al
dataset_prefix <- 'weeks_et_al'
weeks_et_al_result_df <- get_dataset_results(datasetdir, dataset_prefix, predictiondir, methodlist, publication_count)

# GWAS catalog
dataset_prefix <- 'gwas_catalog'
gwas_catalog_result_df <- get_dataset_results(datasetdir, dataset_prefix, predictiondir, methodlist, publication_count)

all_result_df <- rbind(opentargets_result_df, pharmaprojects_result_df, 
                       weeks_et_al_result_df, gwas_catalog_result_df)

all_result_df$method <- revalue(all_result_df$method,
                                c("gpt3_zero_shot"="Standard", 
                                "gpt3_zero_shot_minimal"="Minimal"))
all_result_df$dataset <- revalue(all_result_df$dataset,
                                 c('gwas_catalog'="GWAS catalog",
                                   'opentargets'="OpenTargets",
                                   'pharmaprojects'="Pharmaprojects",
                                   "weeks_et_al"="Weeks et al."))

p5 <- ggplot(all_result_df,aes(x=dataset,y=F1,fill=method)) 
p5 <- p5 + geom_bar(stat='identity',position = position_dodge(0.9))
p5 <- p5 + geom_errorbar(aes(x=dataset,ymin=F1.lower,ymax=F1.upper),
                       position = position_dodge(0.9),width=0.5)
p5 <- p5 + xlab("Dataset") + ylab("F-score") + theme_bw()
p5 <- p5 + scale_fill_discrete(name="Prompt style")
outfile <- "results/manuscript/ablation_figure.pdf"
pdf(outfile,width=10,height=6)
print(prettify(p5,xangle=0))
dev.off()

# Embedding
# OpenTargets
methodlist <- c('gpt3_zero_shot','zero_shot_embedding')

dataset_prefix <- 'opentargets'
opentargets_result_df <- get_dataset_results(datasetdir, dataset_prefix, predictiondir, methodlist, publication_count)

# Pharmaprojects
dataset_prefix <- 'pharmaprojects'
pharmaprojects_result_df <- get_dataset_results(datasetdir, dataset_prefix, predictiondir, methodlist, publication_count)
# Weeks et al
dataset_prefix <- 'weeks_et_al'
weeks_et_al_result_df <- get_dataset_results(datasetdir, dataset_prefix, predictiondir, methodlist, publication_count)

# GWAS catalog
dataset_prefix <- 'gwas_catalog'
gwas_catalog_result_df <- get_dataset_results(datasetdir, dataset_prefix, predictiondir, methodlist, publication_count)

all_result_df <- rbind(opentargets_result_df, pharmaprojects_result_df, 
                       weeks_et_al_result_df, gwas_catalog_result_df)

all_result_df$method <- revalue(all_result_df$method,
                                c("gpt3_zero_shot"="LLM-GPT3", 
                                  "zero_shot_embedding"="Embedding"))
all_result_df$dataset <- revalue(all_result_df$dataset,
                                 c('gwas_catalog'="GWAS catalog",
                                   'opentargets'="OpenTargets",
                                   'pharmaprojects'="Pharmaprojects",
                                   "weeks_et_al"="Weeks et al."))


p6 <- ggplot(all_result_df,aes(x=dataset,y=F1,fill=method)) 
p6 <- p6 + geom_bar(stat='identity',position = position_dodge(0.9))
p6 <- p6 + geom_errorbar(aes(x=dataset,ymin=F1.lower,ymax=F1.upper),
                       position = position_dodge(0.9),width=0.5)
p6 <- p6 + xlab("Dataset") + ylab("F-score") + theme_bw()
outfile <- "results/manuscript/embedding_figure.pdf"
pdf(outfile,width=10,height=6)
print(prettify(p6,xangle=0))
dev.off()

# GPT-4 version comparisons
# OpenTargets
methodlist <- c('gpt4_1106_zero_shot','gpt4_0613_zero_shot')

dataset_prefix <- 'opentargets'
opentargets_result_df <- get_dataset_results(datasetdir, dataset_prefix, predictiondir, methodlist, publication_count)

# Pharmaprojects
dataset_prefix <- 'pharmaprojects'
pharmaprojects_result_df <- get_dataset_results(datasetdir, dataset_prefix, predictiondir, methodlist, publication_count)
# Weeks et al
dataset_prefix <- 'weeks_et_al'
weeks_et_al_result_df <- get_dataset_results(datasetdir, dataset_prefix, predictiondir, methodlist, publication_count)

# GWAS catalog
dataset_prefix <- 'gwas_catalog'
gwas_catalog_result_df <- get_dataset_results(datasetdir, dataset_prefix, predictiondir, methodlist, publication_count)

all_result_df <- rbind(opentargets_result_df, pharmaprojects_result_df, 
                       weeks_et_al_result_df, gwas_catalog_result_df)

all_result_df$method <- revalue(all_result_df$method,
                                c("gpt4_1106_zero_shot"="LLM-GPT4-Turbo", 
                                  "gpt4_0613_zero_shot"="LLM-GPT4"))
all_result_df$dataset <- revalue(all_result_df$dataset,
                                 c('gwas_catalog'="GWAS catalog",
                                   'opentargets'="OpenTargets",
                                   'pharmaprojects'="Pharmaprojects",
                                   "weeks_et_al"="Weeks et al."))
outfile <- "results/manuscript/supp_table_gpt4_version.csv"
write.table(all_result_df,outfile,sep=",",row.names=F)


calibration_plots <- function(publication_count){
  datasetdir <- file.path(base_dir,'data/benchmark_datasets/')
  dataset_list <- c('opentargets','pharmaprojects','weeks_et_al','gwas_catalog')
  cal_df <- data.frame()
  for(dataset_prefix in dataset_list){
    dataset <- create_evaluation_dataset(datasetdir,dataset_prefix, publication_count)
    methodlist <- c('gpt3_zero_shot','gpt4_0613_zero_shot')
    resultlist <- list()
    for(method in methodlist){
      prediction_file <- file.path(predictiondir,
                                   paste(dataset_prefix,method,'csv',sep="."))
      result_df <- get_prediction_df(dataset,prediction_file,collapse_duplicates=FALSE)
      strat_precision <- result_df[,.(precision=mean(total,na.rm=T),count=.N,
                                      countNA=sum(is.na(total))),
                                   by=confidence]
      strat_precision[,se:=sqrt(precision*(1-precision)/(count-countNA))]
      resultlist[[method]] <- strat_precision
    }
    result_df <- rbindlist(resultlist,idcol = TRUE)
    result_df$dataset <- dataset_prefix
    setnames(result_df,'.id','method')
    cal_df <- rbind(cal_df, result_df)
  }
  cal_df
}


cal <- calibration_plots(publication_count)
cal$dataset <- revalue(cal$dataset,
                                 c('gwas_catalog'="GWAS catalog",
                                   'opentargets'="OpenTargets",
                                   'pharmaprojects'="Pharmaprojects",
                                   "weeks_et_al"="Weeks et al."))
cal$method <- revalue(cal$method,
                                c("gpt3_zero_shot"="LLM-GPT-3.5", 
                                  "gpt4_0613_zero_shot"="LLM-GPT-4"))

outfile_cal <- "results/manuscript/supp_table_calibration.csv"
write.table(cal,outfile_cal,row.names=F,sep=",")
p7 <- ggplot(subset(cal,count-countNA>=5),aes(x=confidence,y=precision,group=method,
                    fill=method,colour=method)) +
   geom_point() + geom_errorbar(aes(x=confidence,
                                                 ymin=precision-1.96*se,
                                                 ymax=precision+1.96*se)) + 
  facet_wrap(~dataset)
p7 <- p7 + geom_abline(slope = 1, intercept = 0, linetype='dashed') 
p7 <- p7 + theme_bw() + ylim(0,1)  + xlim(0,1)
p7 <- p7 + xlab("Predicted confidence") + ylab("Precision")
p7 <- p7 + theme(strip.text.x = element_text(size=12))
outfile <- "results/manuscript/calibration_figure.pdf"
pdf(outfile,width=10,height=6)
print(prettify(p7,xangle=45))
dev.off()


pairwise_comparison <- function(publication_count){
  datasetdir <- file.path(base_dir,'data/benchmark_datasets/')
  dataset_list <- c('opentargets','pharmaprojects','weeks_et_al','gwas_catalog')
  best_non_llm_method_list <- c('nearest_gene','nearest_gene','pops','nearest_gene')
  cal_df <- data.frame()
  method1 <- 'gpt4_0613_zero_shot'
  ans_df <- data.frame()
  for(idx in seq_along(dataset_list)){
    dataset_prefix <- dataset_list[idx]
    method2 <- best_non_llm_method_list[idx]
    dataset <- create_evaluation_dataset(datasetdir,dataset_prefix, publication_count)
    methodlist <- c(method1,method2)
    resultlist <- list()
    for(method in methodlist){
      prediction_file <- file.path(predictiondir,
                                   paste(dataset_prefix,method,'csv',sep="."))
      result_df <- get_prediction_df(dataset,prediction_file,collapse_duplicates=FALSE)
      resultlist[[method]] <- result_df
    }
    pred1 <- as.numeric(resultlist[[method1]]$total)
    pred2 <- as.numeric(resultlist[[method2]]$total)
    wilcox.result <- wilcox.test(pred1,pred2,paired=TRUE,correct=TRUE,alternative='greater')
    mcnemar.result <- mcnemar.test(pred1,pred2,correct=TRUE)
    ans_df <- rbind(ans_df,data.frame(dataset=dataset_prefix,
                                      wilcox.pvalue=wilcox.result$p.value,
                                      mcnemar.pvalue=mcnemar.result$p.value,
                                      compared_to=method2,llm_method=method1))
  }
  ans_df
}

pairwise_results <- pairwise_comparison(publication_count)
pairwise_results$dataset <- revalue(pairwise_results$dataset,
                       c('gwas_catalog'="GWAS catalog",
                         'opentargets'="OpenTargets",
                         'pharmaprojects'="Pharmaprojects",
                         "weeks_et_al"="Weeks et al."))
pairwise_results$compared_to <- revalue(pairwise_results$compared_to,
                      c("nearest_gene"="Nearest gene","pops"="PoPS"))
pairwise_results$llm_method <- revalue(pairwise_results$llm_method,
                                   c("gpt4_0613_zero_shot"="LLM-GPT-4")) 
                                     

outfile <- "results/manuscript/supp_table_pairwise_comparison.csv"
write.table(pairwise_results,outfile,row.names=F,sep=",")

get_rank_plots <- function(){
  dataset_list <- c('opentargets','pharmaprojects',
                    'weeks_et_al','gwas_catalog')
  answer_df <- data.frame()
  for(dataset in dataset_list){
    print(dataset)
    filename <- file.path("results","others",paste0(dataset,".embedding_info.csv"))
    counts <- read.table(filename,header=T)[,1] + 1
    for(k in seq(1,5)){
      prop <- sum(counts<=k)/length(counts)
      se <- sqrt(prop * (1-prop) / length(counts))  
      result_df <- data.frame(dataset=dataset,K=k,
                              prop=prop,se=se)
      answer_df <- rbind(answer_df, result_df)
    }
  }
  answer_df
}

rank_df <- get_rank_plots()
rank_df$dataset <- revalue(rank_df$dataset,
                                 c('gwas_catalog'="GWAS catalog",
                                   'opentargets'="OpenTargets",
                                   'pharmaprojects'="Pharmaprojects",
                                   "weeks_et_al"="Weeks et al."))
rank_df_K_1 <- subset(rank_df,K==1)
p8 <- ggplot(rank_df_K_1,aes(x=dataset,y=prop)) +
  geom_bar(stat='identity',position=position_dodge(0.9),colour="#CE0F69",fill="#CE0F69") +
  geom_errorbar(aes(x=dataset,ymin=prop-1.96*se,
                    ymax=prop+1.96*se),
                    position=position_dodge(0.9),width=0.5)
p8 <- p8 + theme_bw()
p8 <- p8 + xlab("Dataset")
p8 <- p8 + ylab("Proportion of loci with \ncausal gene most \nsimilar to phenotype")
outfile <- "results/manuscript/embedding_top_gene_figure.pdf"
pdf(outfile,width=10,height=6)
print(prettify(p8,xangle=0))
dev.off()

p9 <- ggplot(rank_df,aes(x=K,y=prop)) +
  geom_bar(stat='identity',position=position_dodge(0.9),colour="#CE0F69",fill="#CE0F69") +
  geom_errorbar(aes(x=K,ymin=prop-1.96*se,
                    ymax=prop+1.96*se),
                position=position_dodge(0.9),width=0.5)
p9 <- p9 + facet_wrap(~dataset)
p9 <- p9 + theme_bw() + xlab("K")
p9 <- p9 + ylab("Proportion of loci")
outfile <- "results/manuscript/embedding_top_K_figure.pdf"
pdf(outfile,width=10,height=6)
print(prettify(p9,xangle=0))
dev.off()


get_precision_by_phenotype <- function(publication_count){
  datasetdir <- file.path(base_dir,'data/benchmark_datasets/')
  dataset_list <- c('opentargets','pharmaprojects',
                    'weeks_et_al','gwas_catalog')
  method_list <- c('nearest_gene','pops','text_mining_gene',
                   'gpt3_zero_shot','gpt4_0613_zero_shot',
                   'zero_shot_embedding')
  final_df <- data.frame()
  for(dataset_prefix in dataset_list){
    dataset <- create_evaluation_dataset(datasetdir,dataset_prefix, publication_count)
    resultlist <- list()
    for(method in method_list){
      prediction_file <- file.path(predictiondir,
                                   paste(dataset_prefix,method,'csv',sep="."))
      if(!file.exists(prediction_file)) next
      result_df <- get_prediction_df(dataset,prediction_file,collapse_duplicates=FALSE)
      strat_result <- result_df[,.(precision=mean(total,na.rm=T),
                                   recall=sum(total,na.rm=T)/.N,
                                      count=.N,
                                      countNA=sum(is.na(total))),
                                   by=description]
      resultlist[[method]] <- strat_result
    }
    result_df <- rbindlist(resultlist,idcol = TRUE)
    result_df$dataset <- dataset_prefix
    setnames(result_df,'.id','method')
    final_df <- rbind(final_df, result_df)
  }
  final_df
}

prec_by_pheno <- get_precision_by_phenotype(publication_count)
library(dplyr)
library(tidyr)
reshaped_df <- prec_by_pheno %>%
  pivot_wider(names_from = method, values_from = c(precision, recall, count, countNA))

reshaped_df$dataset <- revalue(reshaped_df$dataset,
                                 c('gwas_catalog'="GWAS catalog",
                                   'opentargets'="OpenTargets",
                                   'pharmaprojects'="Pharmaprojects",
                                   "weeks_et_al"="Weeks et al."))
outfile <- "results/manuscript/supp_table_prec_by_pheno.csv"
write.table(reshaped_df,outfile,sep=",",row.names=F)

weeks_step1 = fread("data/benchmark_datasets/weeks_et_al_step1.tsv",header=T,sep="\t")
weeks_gpt4_preds = fread("results/predictions/weeks_et_al.gpt4_0613_zero_shot.csv",header=T,sep=",")
total_protein_idx = which(weeks_step1$description == "Total protein")
gpt4_total_protein_reasons = weeks_gpt4_preds$reason[total_protein_idx]
outfile <- "results/manuscript/supp_table_total_protein_reasons.csv"
reason <- "GPT-4 provided reasons for examples about the phenotype 'Total protein' from the Weeks et al. dataset"
write.table(gpt4_total_protein_reasons,outfile,quote = T,row.names=F,col.names = reason)

correct_reason_summary <- function(publication_count){
  datasetdir <- file.path(base_dir,'data/benchmark_datasets/')
  dataset_list <- c('opentargets','pharmaprojects','weeks_et_al','gwas_catalog')
  reason_df <- data.frame()
  for(dataset_prefix in dataset_list){
    dataset <- create_evaluation_dataset(datasetdir,dataset_prefix, publication_count)
    methodlist <- c('gpt4_0613_zero_shot')
    resultlist <- list()
    for(method in methodlist){
      prediction_file <- file.path(predictiondir,
                                   paste(dataset_prefix,method,'csv',sep="."))
      result_df <- get_prediction_df(dataset,prediction_file,collapse_duplicates=FALSE)
      correct_df <- result_df[total==1]
      resultlist[[method]] <- correct_df
    }
    result_df <- rbindlist(resultlist,idcol = TRUE)
    result_df$dataset <- dataset_prefix
    setnames(result_df,'.id','method')
    reason_df <- rbind(reason_df, result_df)
  }
  reason_df
}

library(ngram)
correct_reasons_df <- correct_reason_summary(publication_count)
correct_reasons <- correct_reasons_df$reason
trigrams <- ngram(concatenate(correct_reasons),n=3)
trigram_counts <- get.phrasetable(trigrams)
trigram_counts$prop <- NULL
outfile <- "results/manuscript/supp_table_correct_reason_ngrams_top_200.csv"
write.table(head(trigram_counts,200),outfile,row.names=F,quote=T,sep=",")

get_scrambled_results <- function(publication_count){
  datasetdir <- file.path(base_dir,'data/benchmark_datasets/')
  dataset_list <- c('opentargets','pharmaprojects','weeks_et_al','gwas_catalog')
  method <- 'gpt4_0613'
  resultlist <- list()
    for(dataset_prefix in dataset_list){
    dataset <- create_evaluation_dataset(datasetdir,dataset_prefix, publication_count)
    label_file <- sprintf("data/benchmark_datasets/%s_step2.labels",dataset_prefix)
    labels = read.table(label_file,header=T)
    prediction_file <- file.path(predictiondir,
                                 paste("scrambled",dataset_prefix,method,'csv',sep="."))
    result_df <- get_evaluation_metrics(dataset,prediction_file,collapse_duplicates=FALSE)
    resultlist[[dataset_prefix]] <- result_df
  }
  scrambled_results_df <- rbindlist(resultlist,idcol = TRUE)
  setnames(scrambled_results_df,'.id','method')
  scrambled_results_df
}

scrambled_results_df <- get_scrambled_results(publication_count)
outfile <- "results/manuscript/supp_table_scrambled_results.csv"
write.table(scrambled_results_df,outfile,sep=",",row.names=F)

calibration_confidence <- function(publication_count){
  datasetdir <- file.path(base_dir,'data/benchmark_datasets/')
  dataset_list <- c('opentargets','pharmaprojects','weeks_et_al','gwas_catalog')
  cal_df <- data.frame()
  for(dataset_prefix in dataset_list){
    dataset <- create_evaluation_dataset(datasetdir,dataset_prefix, publication_count)
    methodlist <- c('gpt3_zero_shot','gpt4_0613_zero_shot')
    resultlist <- list()
    for(method in methodlist){
      prediction_file <- file.path(predictiondir,
                                   paste(dataset_prefix,method,'csv',sep="."))
      result_df <- get_prediction_df(dataset,prediction_file,collapse_duplicates=FALSE)
      corobj <- cor(result_df$confidence,sqrt(result_df$gene_pub_count+1),
                    method='spearman')
      result_df[,high_confidence:=(confidence>=0.8)]
      strat_precision <- result_df[,.(precision=mean(total,na.rm=T),count=.N,
                                      countNA=sum(is.na(total)),
                                      recall=sum(total,na.rm=T)/nrow(result_df)),
                                   by=high_confidence]
      strat_precision <- strat_precision[high_confidence==1]
      strat_precision[,high_confidence:=NULL]
      resultlist[[method]] <- strat_precision
    }
    result_df <- rbindlist(resultlist,idcol = TRUE)
    result_df$dataset <- dataset_prefix
    setnames(result_df,'.id','method')
    cal_df <- rbind(cal_df, result_df)
  }
  cal_df
}


high_conf_preds <- calibration_confidence(publication_count)
metricsfile <- "results/manuscript/supp_table_metrics.csv"
select_cols <- c('method','dataset','precision','recall')
hc_preds_df <- high_conf_preds[,..select_cols]
hc_preds_df[,subset:="High confidence (>=0.8)"]
all_preds_df <- all_result_df_main_copy[grepl('gpt',method),..select_cols] # Dataframe from first evaluation of all methods 
all_preds_df[,subset:="All"]
plot_df <- rbind(hc_preds_df,all_preds_df)
plot_df <- plot_df[method=='gpt4_0613_zero_shot']

plot_df$dataset <- revalue(plot_df$dataset,
                                 c('gwas_catalog'="GWAS catalog",
                                   'opentargets'="OpenTargets",
                                   'pharmaprojects'="Pharmaprojects",
                                   "weeks_et_al"="Weeks et al."))


p10 <- ggplot(plot_df,aes(x=precision,y=recall,fill=dataset,colour=dataset,
                              shape=subset,group=dataset)) 
p10 <- p10 + geom_point(size=4,alpha=0.7)#,position=position_jitter(width=0.01,height=0.01))
p10 <- p10 + geom_line()
p10 <- p10 + xlab("Precision") + ylab("Recall")
p10 <- p10 + theme_bw() + xlim(0,1) + ylim(0,1) 
outfile <- "results/manuscript/confidence_precision_recall_figure.pdf"
pdf(outfile,width=10,height=6)
print(prettify(p10,xangle=0))
dev.off()

locus_results <- function(gene_embedding_file,pheno_embedding_file,
                      pheno_name,gene_str){
  gene_emb <- fread(gene_embedding_file,header=T)
  pheno_emb <- fread(pheno_embedding_file,header=T)
  gene_str_clean <- str_remove_all(str_remove_all(gene_str,"\\{"),"\\}")
  gene_names = str_split_1(gene_str_clean,",")
  pheno_df <- pheno_emb[get('0')==pheno_name]
  gene_df <- gene_emb[get('0') %in% gene_names]
  comb_df <- rbind(pheno_df,gene_df,use.names=FALSE)
  item_names <- comb_df[,get('0')]
  item_embs <- t(sapply(comb_df$embedding, fromJSON))
  rownames(item_embs) <- item_names
  tsne_ob <- Rtsne(item_embs,perplexity = 4)
  points <- data.frame(tsne_ob$Y)
  points$labels <- item_names
  points$is_pheno <- points$labels == pheno_name
  list(raw_data=item_embs,tsne_data=points)
}

pheno_embedding_file <- "data/helper_datasets/phenotype_embeddings.csv"
gene_embedding_file <- "data/helper_datasets/gene_embeddings.csv"
pheno_name = "Low density lipoprotein cholesterol"
gene_str="{ACOT11},{BSND},{C1orf177},{DHCR24},{FAM151A},{MROH7},{PARS2},{PCSK9},{TMEM61},{TTC22},{TTC4},{USP24}"
pcsk9_info <- locus_results(gene_embedding_file,pheno_embedding_file,
                    pheno_name,gene_str)
points <- pcsk9_info$tsne_data
p11 <- ggplot(points,aes(x=X1,y=X2,fill=is_pheno,
                       colour=is_pheno)) + geom_point()
p11 <- p11 + geom_text_repel(aes(label=labels))
p11 <- p11 + theme_bw() + theme(legend.position="none")
p11 <- p11 + xlab("T-SNE component 1")
p11 <- p11 + ylab("T-SNE component 2")
outfile <- "results/manuscript/tsne_plot_pcsk9.pdf"
pdf(outfile,width=10,height=6)
print(prettify(p11,xangle=0))
dev.off()

emb_data <- pcsk9_info$raw_data
sim_info <- apply(emb_data[2:nrow(emb_data),],1,function(x) sum(x*emb_data[1,]))
sim_df <- data.frame(genes=rownames(emb_data)[2:nrow(emb_data)],
                     similarity=sim_info)
sim_df <- sim_df[order(sim_df$similarity,decreasing = T),]
sim_df$genes <- factor(sim_df$genes,levels = sim_df$genes)
p12 <- ggplot(sim_df,aes(x=genes,y=similarity))
p12 <- p12 + geom_bar(stat='identity',colour="#CE0F69",fill="#CE0F69")
p12 <- p12 + xlab("Genes") + ylab("Similarity to phenotype")
p12 <- p12 + theme_bw()
outfile <- "results/manuscript/similarity_barplot_pcsk9.pdf"
pdf(outfile,width=10,height=6)
print(prettify(p12,xangle=90))
dev.off()

outfile <- "results/manuscript/figure2.pdf"
pdf(outfile,width=21,height=12)
p <- plot_grid(prettify(p1), NULL, prettify(p2), NULL, NULL, NULL,
          prettify(p3), NULL, prettify(p4), rel_widths = c(1, 0.05, 1),
          nrow = 3,ncol=3,hjust = -1, vjust = 0, rel_heights = c(1, 0.1, 1),
          labels = c("a","","b","","","","c","","d"), label_size=28)
p <- p + theme(plot.margin = unit(c(1,0.5,0.5,0.5), "cm")) 
print(p)
dev.off()

outfile <- "results/manuscript/figure3.pdf"
pdf(outfile,width=21,height=6)
p <- plot_grid(prettify(p11), NULL, prettify(p8), rel_widths = c(1, 0.05, 1),
          nrow = 1, hjust = -1.5, vjust = 0 , label_x = 0, label_y = 1, 
          labels = c("a","","b"), label_size=28)
p <- p + theme(plot.margin = unit(c(1,0.5,0.5,0.5), "cm")) 
print(p)
dev.off()

halluciations_df <- subset(all_result_df_main,grepl("LLM",method),
                           select=c('method','dataset',
                                                    'num_preds','num_hallucinations'))
halluciations_df$proportion_hallucinations <- halluciations_df$num_hallucinations/
  (halluciations_df$num_preds + halluciations_df$num_hallucinations)
outfile <- "results/manuscript/supp_table_hallucinations.csv"
write.table(halluciations_df,outfile,sep=",",row.names=F)
