#' Plot Category Means
#'
#' This function takes a data frame and creates a ggplot object that displays
#'  the mean and standard error of a specified y variable grouped by a specified
#'   x categorical variable. The plot can be customized with optional argument
#'   s for color, ordering of x categories, and ggplot2 theme and labels.
#'
#' @param df A data frame containing the data to be plotted.
#' @param compound A string specifying the value of the compound_column to filter the data by.
#' @param compound_column A string specifying the column name in df to filter by compound.
#' @param x_cat A string specifying the column name in df to use as the x categorical variable.
#' @param y A string specifying the column name in df to use as the y variable.
#' @param min_ref An integer specifying the minimum number of references
#' @param min_obs An integer specifying the minimum number of observations at the lower level (if exists)
#' @param color An optional string specifying the column name in df to use for coloring the plot. Default is NULL.
#' @param order_x A logical value indicating whether or not to order the x categories by mean y value. Default is TRUE.
#' @param ylab An optional string specifying the Y axis label.
#' @param guess_y_unit An optional logical searching the y unit from column
#' \code{conc_units}
#' @param theme An optional ggplot2::theme object for customizing plot theme. Default is ggplot2::theme_classic().
#'
#' @return A ggplot object displaying the mean and standard error of y grouped by x_cat.
plot_category_means <- function(df, compound, compound_column, x_cat, y,
                                min_obs = 1,
                                color=NULL,
                                order_x = TRUE,
                                ylab = "LC50 mean ± SE",
                                guess_y_unit = "conc_units",
                                nrColor = "black",
                                uniqueCounts = "ref_number",
                                legend.position = "bottom",
                                facet = NULL,
                                color_palette  = NULL,
                                theme = ggplot2::theme_bw()) {
  
  if(is.null(color_palette) && !is.null(color) && color %in% colnames(df)){
    # Get unique values of your variable
    unique_values <- sort(unique(df[[color]]))
    
    #get the colors
    nr2get <- ifelse(length(unique_values) > 9, 8, 
                     ifelse(length(unique_values)<3, 3, length(unique_values)))
    color_vars <- RColorBrewer::brewer.pal(nr2get, "Set1")
    
    # Create a color palette if enough colors are available
    if(length(unique_values) > length(color_vars)){
      color_palette <- setNames(color_vars[1:length(unique_values)], unique_values)
    } else{
      #repeat the color vector enough times to match the number of unique values
      color_vars <- rep(color_vars, length.out = length(unique_values))
      color_palette <- setNames(color_vars, unique_values)
      
    }
  }
  
  
  
  if(!is.null(facet) && facet %in% colnames(df)){
    items <- unique(df[[facet]])
    if(length(items) > 1){
      plots <- lapply(items, function(x){
        plot_category_means(df[df[[facet]] == x,],
                            compound=x,
                            compound_column = compound_column,
                            x_cat = x_cat,
                            y = y,
                            min_obs = min_obs,
                            color=color,
                            order_x = order_x,
                            ylab = ylab,
                            guess_y_unit = guess_y_unit,
                            nrColor = nrColor,
                            uniqueCounts = uniqueCounts,
                            legend.position = legend.position,
                            color_palette  = color_palette,
                            facet = NULL,
                            theme = theme)
      })
      #filter Null plots
      plots <- plots[!sapply(plots, is.null)]
      return(plots)
    }
  }
  
  
  theme <- theme + 
    ggplot2::theme(axis.text.x = ggplot2::element_text(angle=45, hjust = 1),
                   legend.position = legend.position,
                   axis.title.x = ggplot2::element_blank())
  
 
  pos <- ggplot2::position_dodge(width = 0.9)
  

  # Function to add number of observations to the top of each bar
  ggn <- function(x, ymax) {
    shifter <- 0.05
    if(ymax < 10){shifter <- 0.1}
    if(ymax < 1){shifter <- 0.2}
    if(ymax < 0.5){shifter <- 0.3}
    if(ymax < 0){shifter <- 1}
    return(c(y = ymax+abs(ymax*shifter), label = length(x)))
  }
  if(!is.null(compound) && !is.null(compound_column)){
    if(compound %in% df[[compound_column]]){
      cdf <- df[df[[compound_column]] == compound,]
    } else {
      warning(paste0("Compound ", compound, " not found in data frame.\n"), call.=FALSE)
      return(NULL)
    }
  } else {
    if(is.null(compound)){
      compound <- "All" #set a default name if missing
    }
    cdf <- df
  }

  
  
  if(!is.null(color) && color %in% colnames(cdf)){
    byCats <- unique(c(x_cat, color))
  } else{
    byCats <- c(x_cat)
  }
  
  #fucntion to get maxSE for a vector of values
  maxse <- function(x){
    se <- function(x){
      x <-x[!is.na(x)]
      return(sd(x)/sqrt(length(x)))
    }
    x <-x[!is.na(x)]
    return(mean(x)+se(x))
  }
  
  
  #keep the previous level order for x_cat 
  cdf[[x_cat]] <-refactor(cdf[[x_cat]])

  #filter out  values that have less than min observations in y
  cdf <- filter_rare_combinations(cdf, y, byCats, min_obs)
  #keep the previous level order for x_cat 
  cdf[[x_cat]] <-refactor(cdf[[x_cat]])
  
  if(length(cdf[[x_cat]])==0){
    warning(paste0("No data for ", compound,
                   " if minimum observations (min_obs = ", min_obs,
                   ") is used.\n"), call.=FALSE)
    return(NULL)
  }
  if(length(unique(cdf[[x_cat]])) == 1){
    warning(paste0("Only one value for ", compound, "(x=",
                   unique(cdf[[x_cat]]),"),",
                   " if minimum observations (min_obs = ", min_obs,
                   ") is used. Nothing to compare\n"), call.=FALSE)
    return(NULL)
  }
  
  #reorder x_cat by mean y value
  if(order_x){
    cdf[[x_cat]] <- reorder(factor(cdf[[x_cat]]), cdf[[y]], FUN=mean)
    
  }
  
  #get unique references
  if(uniqueCounts %in% colnames(cdf)){
    refs <- aggregate(cdf[[uniqueCounts]], cdf[byCats],
                      FUN = function(x) length(unique(x)))
    colnames(refs) <- c(byCats, uniqueCounts)
  }
  
  #setup plot labels
  if(guess_y_unit %in% colnames(cdf)){
    unit <- paste0(unlist(unique(as.character(cdf[[guess_y_unit]]))),
                   collapse = ", ")
    ylab <- paste(unit, ylab)
  } else {
    ylab <- paste(guess_y_unit, ylab) 
  }
  labs <- ggplot2::labs(subtitle = compound, y=ylab)
  if(!is.null(color) && color %in% colnames(cdf)){
      maxC <- max(aggregate(cdf[[y]],
                            by=cdf[c(x_cat, color)]
                            , FUN=maxse)[,3], na.rm=T)
  } else{
    maxC <- max(aggregate(cdf[[y]], by=cdf[x_cat], FUN=maxse)[,2], na.rm=T)
  }
  
  # Create plot
  p <- ggplot2::ggplot(cdf,
              ggplot2::aes(x=.data[[x_cat]],
                           y=.data[[y]])) + 
    ggplot2::stat_summary(fun.data = mean_se, position = pos)+
    ggplot2::stat_summary(fun = mean, geom = "point", shape = 19, size = 4,
                          position = pos) +
    scale_color_manual(values = color_palette) +
    theme + labs
  
  if(!is.null(color) && color %in% colnames(cdf)){
    p <- p + ggplot2::aes(color = .data[[color]])
    p <- p + ggplot2::stat_summary(fun.data = ggn, geom = "text",
                            fun.args=list(ymax=maxC), position = pos)
  } else{
    p <- p + ggplot2::stat_summary(fun.data = ggn, geom = "text", color=nrColor,
                          fun.args=list(ymax=maxC))
  }
  adjustRefLab <- 0.1
  if(maxC < 10){adjustRefLab <- 0.2} 
  if(maxC < 1){adjustRefLab <- 0.4}
  if(maxC < 0.5){adjustRefLab <- 0.6}
  if(maxC < 0){adjustRefLab <- 2}

  
  #add number of unique references
  if(uniqueCounts %in% colnames(cdf)){
    p <- p + ggplot2::geom_text(data = refs, aes(x = .data[[x_cat]],
                                                  y = maxC + abs(maxC*adjustRefLab),
                                                  label = .data[[uniqueCounts]]),
                                position = pos)
  }
  
  
  if(grepl("%", guess_y_unit, fixed = TRUE)){
    p <- p + ggplot2::scale_y_continuous(labels = scales::percent) +
         ggplot2::geom_hline(yintercept = 1, linetype = "dashed")
  }
  
  if(grepl("SD", guess_y_unit, fixed = TRUE)){
    p <- p +  ggplot2::geom_hline(yintercept = 0, linetype = "dashed")
  }
    
  return(p)
}

#' Plot Categories Means
#'
#' This function takes a data frame and creates a list of ggplot objects that
#'  display the mean and standard error of a specified y variable grouped by a
#'  specified x categorical variable for each unique value in the specified
#'  compound_column. The plots can be customized with optional arguments for
#'   color, ordering of x categories, ggplot2 theme, and legend position.
#'
#' @param df A data frame containing the data to be plotted.
#' @param compound_column A string specifying the column name in df to create
#'  separate plots for each unique value.
#' @param x_cat A string specifying the column name in df to use as the 
#' x categorical variable.
#' @param min_obs An integer specifying the minimum number of observations
#' @param min_ref An integer specifying the minimum number of references
#' @param y A string specifying the column name in df to use as the y variable.
#' @param color An optional string specifying the column name in df to use for
#'  coloring the plot. Default is NULL.
#' @param order_x A logical value indicating whether or not to order the x 
#' categories by mean y value. Default is TRUE.
#' @param theme An optional ggplot2::theme object for customizing plot theme.
#'  Default is ggplot2::theme().
#' @param ylab An optional string specifying the Y axis label
#' @param guess_y_unit An optional logical searching the y unit from column
#' \code{conc_units}
#' @param legend.position An optional string specifying the position of the
#'  legend on the plot. Default is "bottom".
#'
#' @return A list of ggplot objects displaying the mean and standard error of
#'  y grouped by x_cat for each unique value in compound_column.
plot_categories_means <- function(df, compound_column, x_cat, y,
                                 min_obs = 1,
                                 color=NULL,
                                 order_x = TRUE,
                                 theme = ggplot2::theme(),
                                 ylab = "LC50 mean ± SE",
                                 guess_y_unit = "conc_units",
                                 legend.position = "bottom",
                                 nrColor = "black",
                                 ...) {
  compounds <- unique(df[[compound_column]])
  plots <- list()
  
  for(compound in compounds){
    if(!is.na(compound)){
      lowerCompounds <- get_unique_lowest_level_compounds(df, compound_column)
      if(lowerCompounds[["name"]] != compound_column){
        #filter out higher lever plots that have less than min Observations
        # in the lowest level column
        if(length(lowerCompounds[["items"]]) < min_obs){next}
      }
      
      
      
    plots[[compound]] <- (plot_category_means(df[!is.na(df[[compound_column]]),], 
                                              compound, compound_column,
                                              x_cat,
                                              y,
                                              min_obs = min_obs,
                                              color = color, 
                                              order_x = order_x,
                                              ylab = ylab,
                                              guess_y_unit = guess_y_unit,
                                              legend.position = legend.position,
                                              nrColor = nrColor,
                                              theme=theme,
                                              ...))
    }
  }
  
  return(plots)
}


#' Get unique lowest level compounds
#' 
get_unique_lowest_level_compounds <- function(df, compound_column){
  items <- unique(df[, compound_column])
  itemname <- compound_column
  #the compound column is chemicals
  if(compound_column == "chemical_class"){
    items <- unique(df[, "chem_name"])
    itemname <- "chem_name"
  }
  #the compound column is organisms
  if(compound_column %in% c("Order", "Family", "Genus")){
        items <- unique(df[, "species_name"])
        itemname <- "species_name"
  }
  
  return(list(name = itemname, items = items))
}


#' Plot BRMS Model probabilities
#' 
plot_bms_prob <- function(mod1, var, color_labs = NULL, exp=FALSE){
  # Extract the posterior samples
  posterior_samples <- as_draws_df(mod1, variable = "^b_", regex = TRUE)
  
  #convert to long format
  varNames <- grep("b_", colnames(posterior_samples), value=T)
  posterior_samples <- as.data.frame(posterior_samples)
  #adjust all by b_intercept
  if("b_Intercept" %in% colnames(posterior_samples)){
    posterior_samples$reference <- posterior_samples$b_Intercept
    posterior_samples[,varNames] <- posterior_samples[,varNames] + posterior_samples[,"b_Intercept"]
    # remove intercept 
    varNames <- varNames[varNames != "b_Intercept"]
    varNames <- c("reference", varNames)
  }
  
  
  
  ps_long <- reshape2::melt(posterior_samples, measure.vars =varNames, value = TRUE)
  ps_long$variable <- as.character(ps_long$variable)
  
  if(exp){
    ps_long$value <- exp(ps_long$value)
  }
  
  if(is.null(color_labs) || length(color_labs) != length(varNames)){
    
    color_labs <- varNames
    groupVals <- paste0("b_",var, unique(mod1$data[[var]]))
    # Find the name in varNames that is not in groupVals
    missingVar <- setdiff(varNames, groupVals)
    # Find the value in groupVals that is not in varNames
    replacementVal <- setdiff(groupVals, varNames)
    
    # Replace missingVar with replacementVal in varNames
    color_labs[color_labs == missingVar] <- replacementVal
    color_labs <- gsub(paste0("b_", var), "", color_labs)
  }
  ps_long[[var]] <- factor(ps_long$variable,
                                 levels = varNames,
                                 labels = color_labs)
  
  # draw the posterior distribution of the coefficients with ggplot
  ggplot(ps_long, aes(x = value, color=.data[[var]])) +
    geom_density() +
    theme_bw()  
}
