library(stringr)
library(boot)

create_evaluation_dataset <- function(basedir, dataset_prefix, publication_count){
  dataset_step1_fname <- file.path(basedir,paste(dataset_prefix,'step1.tsv',sep="_"))
  dataset_step2_fname <- file.path(basedir,paste(dataset_prefix,'step2.for_llm.tsv',sep="_"))
  dataset_step1 <- fread(dataset_step1_fname)
  dataset_step2 <- fread(dataset_step2_fname)
  stopifnot(all(dataset_step1$row_number == dataset_step2$row_number))
  dataset_step2[,`:=`(row_number=NULL,description=NULL)]
  dataset <- data.table(cbind(dataset_step1,dataset_step2))
  dataset[,locus_ngenes:=str_count(symbol_gene_string,",")+1]
  dataset[,causal_in_window:=mapply(grepl,symbol,symbol_gene_string)]
  dataset[,gene_pub_count:=publication_count$count[match(gene,publication_count$gene)]]
  dataset[is.na(gene_pub_count),gene_pub_count:=0]
  dataset
}

get_evaluation_metrics <- function(gold_dataset,prediction_file,collapse_duplicates=FALSE){
  predictions <- fread(prediction_file,header=T,sep=",",)
  stopifnot(nrow(predictions)==nrow(gold_dataset))
  if('causal_gene' %in% colnames(predictions)){
    pred_col <- 'causal_gene'
    gold_col <- 'symbol'
    gene_string_col <- 'symbol_gene_string'
  }else{
    pred_col <- 'causal_gene_id'
    gold_col <- 'gene'
    gene_string_col <- 'ensembl_gene_string'
  }
  gold_dataset[,use_prediction:=predictions[,get(pred_col)]]
  gold_dataset[,is_hallucination:= !mapply(grepl,use_prediction,get(gene_string_col))]
  gold_dataset[is_hallucination==TRUE,use_prediction:=NA]
  gold_dataset[,iscorrect:=(use_prediction==gold_dataset[,get(gold_col)])]
  select_cols <- c('efo','chromosome','position',gold_col)
  if(!('gene' %in% select_cols)) select_cols <- c(select_cols,'gene')
  if(collapse_duplicates){
    final <- gold_dataset[,.(total=max(iscorrect),
                             row_number=row_number[which.max(iscorrect)],
                             is_halluciation=max(is_hallucination)),
                          by=select_cols]
  }else{
    final <- copy(gold_dataset)
    setnames(final,'iscorrect','total')
  }
  select_cols <- c(gold_col,'use_prediction')
#  print(final[is_hallucination==TRUE,..select_cols])
  final[is.infinite(total),total:=NA]
  
  select_cols <- c('total','locus_ngenes','gene_pub_count','gene')
  results <- get_bootstrap_results(final[,..select_cols])
  results$num_hallucinations <- sum(gold_dataset$is_hallucination,na.rm = T)
  results  
}

get_publication_count_by_gene <- function(fname){
  publication_count <- fread(fname)
  publication_count
}

result_function <- function(df,idx){
  df_sample <- df[idx,]
  ntotal <- nrow(df_sample)
  npreds <- sum(!is.na(df_sample$total))
  ncorrect <- sum(!is.na(df_sample$total) & (df_sample$total==1))
  precision <- ncorrect/npreds
  recall <- ncorrect/ntotal
  F1 <- 2*precision*recall/(precision+recall)
#  df_sample <- na.omit(df_sample)
  gene_count_cor <- cor(df_sample$total,df_sample$locus_ngenes,use = 'complete.obs')  
  gene_pub_df <- df_sample[,.(accuracy=mean(total),pub_count=mean(gene_pub_count)),by=gene]
  gene_pub_cor <- cor(gene_pub_df$accuracy,gene_pub_df$pub_count,use = 'complete.obs')
  c(precision=precision,recall=recall,F1=F1,
    gene_count_cor=gene_count_cor,gene_pub_cor=gene_pub_cor,
    num_preds=npreds,num_correct=ncorrect)
}

get_bootstrap_results <-function(df){
  df.boot <- boot(df, result_function, R = 1000, sim = "ordinary")  
  result_list <- list()
  for(i in seq_along(df.boot$t0)){
    metric_name <- names(df.boot$t0)[i]
    print(metric_name)
    if(!grepl('num',metric_name)){
      #    print(metric_name)
      upper_name <- paste(metric_name,"upper",sep=".")
      lower_name <- paste(metric_name,"lower",sep=".")
      df.ci <- boot.ci(df.boot, conf = 0.95,type='perc',index=i)
      result_list[[metric_name]] <- df.ci$t0
      result_list[[lower_name]] <- df.ci$percent[4]
      result_list[[upper_name]] <- df.ci$percent[5]
    }else{
      result_list[[metric_name]] <- df.boot$t0[i]
    }
  }
  result_df <- data.frame(result_list)
  result_df
}

get_dataset_results <- function(datasetdir, dataset_prefix, predictiondir, methodlist, publication_count){
  dataset <- create_evaluation_dataset(datasetdir,dataset_prefix, publication_count)
  print(dataset_prefix)
  resultlist <- list()
  for(method in methodlist){
    print(paste(".    ",method))
    prediction_file <- file.path(predictiondir,
                                 paste(dataset_prefix,method,'csv',sep="."))
    result <- get_evaluation_metrics(dataset,
                                     prediction_file)
    resultlist[[method]] <- result
  }
  result_df <- rbindlist(resultlist,idcol = TRUE)
  result_df$dataset <- dataset_prefix
  setnames(result_df,'.id','method')
  result_df
}

prettify <- function(p,xangle=0){
  p<-p+theme(axis.title.x = element_text(face="bold", size=20),
             axis.text.x  = element_text(size=16,angle=xangle,vjust=0.5,colour='black'))
  p<-p+theme(axis.title.y = element_text(face="bold", size=20),
             axis.text.y  = element_text(size=16,colour='black'))
  p<-p+ theme(legend.title = element_text(size=16),
              legend.text = element_text(size = 16))
  p
}

get_prediction_df <- function(dataset,prediction_file,collapse_duplicates=FALSE){
  gold_dataset <- copy(dataset)
  predictions <- fread(prediction_file,header=T,sep=",",)
  stopifnot(nrow(predictions)==nrow(gold_dataset))
  if('causal_gene' %in% colnames(predictions)){
    pred_col <- 'causal_gene'
    gold_col <- 'symbol'
    gene_string_col <- 'symbol_gene_string'
  }else{
    pred_col <- 'causal_gene_id'
    gold_col <- 'gene'
    gene_string_col <- 'ensembl_gene_string'
  }
  gold_dataset[,use_prediction:=predictions[,get(pred_col)]]
  if('confidence' %in% colnames(predictions))gold_dataset[,confidence:=predictions$confidence]
  if('reason' %in% colnames(predictions))gold_dataset[,reason:=predictions$reason]
  gold_dataset[,is_hallucination:= !mapply(grepl,use_prediction,get(gene_string_col))]
  gold_dataset[is_hallucination==TRUE,use_prediction:=NA]
  gold_dataset[,iscorrect:=(use_prediction==gold_dataset[,get(gold_col)])]
  select_cols <- c('efo','chromosome','position',gold_col)
  if(!('gene' %in% select_cols)) select_cols <- c(select_cols,'gene')
  if(collapse_duplicates){
    final <- gold_dataset[,.(total=max(iscorrect),
                             row_number=row_number[which.max(iscorrect)],
                             is_halluciation=max(is_hallucination),
                             confidence=confidence[which.max(iscorrect)]),
                          by=select_cols]
  }else{
    final <- copy(gold_dataset)
    setnames(final,'iscorrect','total')
  }
  select_cols <- c(gold_col,'use_prediction')
  #  print(final[is_hallucination==TRUE,..select_cols])
  final[is.infinite(total),total:=NA]
  
  select_cols <- c('row_number','total','locus_ngenes','gene_pub_count',
                   'gene','is_hallucination','use_prediction','description')
  if('confidence' %in% colnames(predictions)) select_cols <- c(select_cols,'confidence')
  if('reason' %in% colnames(predictions)) select_cols <- c(select_cols,'reason')
  evaluation_df <- final[,..select_cols]
  evaluation_df  
}
