#!/usr/bin/env Rscript

# Defining color palettes

conditions_palette <- c("#CC0033", "#EBAE09", "#004080", "#95D6F6")
d_palette1 <- c("#0073C2FF", "#EFC000FF", "#CD534CFF", "#868686FF")
d_palette2 <- c("#E69F00", "#56B4E9", "#009E73", "#999999")

# Functions

format_data <- function(df) {
  # Creating nested indicator variables
  # (circuit within condition within study within lab)
  update_df <- df %>%
    mutate(lab_id = case_when(
      lab == "TissUse" ~ 1,
      lab == "AstraZeneca" ~ 2
    )) %>%
    mutate(study_id = case_when(
      study == "Study 1" ~ 1,
      study == "Study 2" ~ 2,
      study == "Study 3" ~ 1
    )) %>%
    mutate(cond_id = case_when(
      condition == "High HCT-HG" ~ 1,
      condition == "High HCT-NG" ~ 2,
      condition == "Low HCT-HG" ~ 3,
      condition == "Low HCT-NG" ~ 4
    )) %>%
    mutate(circuit_id = rep) %>%
    # Creating unique indicator variables
    mutate(lab_id_uniq = lab_id) %>%
    mutate(study_id_uniq = paste(lab_id_uniq, study_id, sep = ".")) %>%
    mutate(cond_id_uniq = paste(study_id_uniq, cond_id, sep = ".")) %>%
    mutate(circuit_id_uniq = paste(cond_id_uniq, circuit_id, sep = ".")) %>%
    # Ensuring factor levels are right
    mutate(
      lab = factor(lab, levels = c("TissUse", "AstraZeneca")),
      study = factor(study, levels = c("Study 1", "Study 2", "Study 3")),
      condition = factor(condition),
      rep = factor(rep),
      day = factor(day),
      hour = factor(hour),
      lab_id_uniq = factor(lab_id_uniq),
      study_id_uniq = factor(study_id_uniq),
      cond_id_uniq = factor(cond_id_uniq),
      circuit_id_uniq = factor(circuit_id_uniq)
    ) %>%
    filter(!is.na(value))

  return(update_df)
}

plot_var_comps <- function(my_mixed_mod, my_palette, my_title = "") {
  my_var_comps <- as.data.frame(VarCorr(my_mixed_mod)) %>%
    mutate(grp = str_remove(grp, "_id_uniq")) %>%
    mutate(grp = tolower(grp))

  my_var_comps <- my_var_comps %>%
    mutate(pct = vcov / sum(my_var_comps$vcov)) %>%
    mutate(labels = scales::percent(pct, accuracy = 0.1)) %>%
    mutate(grp = factor(grp, levels = my_var_comps$grp)) %>%
    # add a tiny value just so that all parts are shown in pie chart
    mutate(pct_to_plot = ifelse(pct == 0, 0.001, pct))

  p <- ggplot(my_var_comps, aes(x = "", y = pct_to_plot, fill = grp)) +
    geom_bar(width = 1, stat = "identity") +
    geom_text(aes(label = labels),
      position = position_stack(vjust = 0.4),
      size = 4, color = "white", fontface = "bold"
    ) +
    coord_polar(theta = "y", start = 0) +
    scale_fill_manual(values = my_palette) +
    labs(
      fill = "Factor",
      x = "",
      y = "Percent of total variation",
      title = my_title
    ) +
    theme_bw() +
    theme(
      axis.ticks = element_blank(),
      axis.text = element_blank(),
      legend.text = element_text(size = 10),
      axis.title.x = element_text(size = 12)
    )

  return(p)
}

plot_test_stats <- function(contrast_study1 = NULL,
                            contrast_study2 = NULL,
                            contrast_study3 = NULL,
                            my_study_names,
                            my_palette,
                            my_title) {
  if (!is.null(contrast_study1)) {
    test_stat_study1 <- summary(contrast_study1) %>% as.data.frame()
  } else {
    test_stat_study1 <- data.frame(estimate = NA, SE = NA)
  }

  if (!is.null(contrast_study2)) {
    test_stat_study2 <- summary(contrast_study2) %>% as.data.frame()
  } else {
    test_stat_study2 <- data.frame(estimate = NA, SE = NA)
  }

  if (!is.null(contrast_study3)) {
    test_stat_study3 <- summary(contrast_study3) %>% as.data.frame()
  } else {
    test_stat_study3 <- data.frame(estimate = NA, SE = NA)
  }

  test_stat_df <- data.frame(study = my_study_names) %>%
    mutate(
      est = c(
        test_stat_study1$estimate,
        test_stat_study2$estimate,
        test_stat_study3$estimate
      ),
      se = c(
        test_stat_study1$SE,
        test_stat_study2$SE,
        test_stat_study3$SE
      )
    ) %>%
    filter(!is.na(est))

  if (nrow(test_stat_df) == 3) {
    nudge_x_val <- 0.31
  } else if (nrow(test_stat_df) == 2) {
    nudge_x_val <- 0.24
  }

  p <- ggplot(test_stat_df, aes(x = study, color = study)) +
    geom_point(aes(y = est), size = 3) +
    geom_hline(
      yintercept = 0, linewidth = 0.5,
      linetype = "dashed", color = "#999999"
    ) +
    geom_errorbar(aes(ymin = est - se, ymax = est + se),
      width = 0.08, linewidth = 0.75
    ) +
    geom_text(
      aes(
        y = est,
        label = paste(round(est, 3), "±", round(se, 3))
      ),
      nudge_x = nudge_x_val, nudge_y = -0.02, size = 4,
      color = "black", fontface = "bold"
    ) +
    scale_color_manual(values = my_palette) +
    theme_bw() +
    theme(
      legend.position = "none",
      axis.text.x = element_text(size = 12),
      axis.text.y = element_text(size = 10),
      axis.title.y = element_text(size = 12)
    ) +
    labs(
      x = "",
      color = "",
      y = "Estimated effect",
      title = my_title
    )

  return(p)
}

plot_percent_cvs <- function(contrast_study1 = NULL,
                             contrast_study2 = NULL,
                             contrast_study3 = NULL,
                             inter_study = FALSE,
                             my_study_names,
                             my_palette,
                             my_title) {
  ## NOTE: This function returns the CV_estimate (i.e. SE/mean, not SD/mean)

  if (!is.null(contrast_study1)) {
    test_stat_study1 <- summary(contrast_study1) %>% as.data.frame()
  } else {
    test_stat_study1 <- data.frame(estimate = NA, SE = NA)
  }

  if (!is.null(contrast_study2)) {
    test_stat_study2 <- summary(contrast_study2) %>% as.data.frame()
  } else {
    test_stat_study2 <- data.frame(estimate = NA, SE = NA)
  }

  if (!is.null(contrast_study3)) {
    test_stat_study3 <- summary(contrast_study3) %>% as.data.frame()
  } else {
    test_stat_study3 <- data.frame(estimate = NA, SE = NA)
  }

  test_stat_df <- data.frame(study = my_study_names) %>%
    mutate(
      est = c(
        test_stat_study1$estimate,
        test_stat_study2$estimate,
        test_stat_study3$estimate
      ),
      se = c(
        test_stat_study1$SE,
        test_stat_study2$SE,
        test_stat_study3$SE
      )
    ) %>%
    filter(!is.na(est)) %>%
    mutate(pct_cv_est = abs(round(se / est * 100, 3)))

  ## add the inter-study variability if required
  if (inter_study) {
    to_append <- data.frame(
      study = "between studies",
      est = mean(test_stat_df$est),
      se = std.error(test_stat_df$est)
    ) %>%
      mutate(pct_cv_est = abs(round(se / est * 100, 3)))
    test_stat_df <- rbind(test_stat_df, to_append)
  }

  ## use a factor to retain this order
  test_stat_df$study <- factor(
    test_stat_df$study,
    levels = test_stat_df$study
  )

  p <- ggplot(test_stat_df, aes(x = study, y = pct_cv_est, fill = study)) +
    geom_bar(width = 0.8, stat = "identity") +
    scale_fill_manual(values = my_palette) +
    theme_bw() +
    theme(
      legend.position = "none",
      axis.text.x = element_text(size = 12),
      axis.text.y = element_text(size = 10),
      axis.title.y = element_text(size = 12)
    ) +
    labs(
      x = "",
      color = "",
      y = "%CV of estimate",
      title = my_title
    ) +
    geom_text(aes(label = pct_cv_est), vjust = -0.2)

  return(p)
}


get_inter_study_statistics <- function(contrast_study1 = NULL,
                                       contrast_study2 = NULL,
                                       contrast_study3 = NULL,
                                       my_study_names) {
  if (!is.null(contrast_study1)) {
    test_stat_study1 <- summary(contrast_study1) %>% as.data.frame()
  } else {
    test_stat_study1 <- data.frame(estimate = NA, SE = NA)
  }

  if (!is.null(contrast_study2)) {
    test_stat_study2 <- summary(contrast_study2) %>% as.data.frame()
  } else {
    test_stat_study2 <- data.frame(estimate = NA, SE = NA)
  }

  if (!is.null(contrast_study3)) {
    test_stat_study3 <- summary(contrast_study3) %>% as.data.frame()
  } else {
    test_stat_study3 <- data.frame(estimate = NA, SE = NA)
  }

  test_stat_df <- data.frame(study = my_study_names) %>%
    mutate(
      est = c(
        test_stat_study1$estimate,
        test_stat_study2$estimate,
        test_stat_study3$estimate
      ),
      se = c(
        test_stat_study1$SE,
        test_stat_study2$SE,
        test_stat_study3$SE
      )
    ) %>%
    filter(!is.na(est)) %>%
    mutate(pct_cv_est = abs(round(se / est * 100, 3)))

  output <- data.frame(est = NA)

  ## get the CV
  output$est <- mean(test_stat_df$est)
  output$sd <- sd(test_stat_df$est)
  output$cv <- with(output, sd / est)

  ## get the CV_mean
  output$se <- std.error(test_stat_df$est)
  output$cv_mean <- with(output, se / est)

  return(output)
}

sig_stars <- function(p) {
  ifelse(p <= 0.001, "***",
    ifelse(p <= 0.01, "**",
      ifelse(p <= 0.05, "*",
        ifelse(p <= 0.1, ".", "")
      )
    )
  )
}

output_tidy_mod_to_file <- function(my_mixed_mod, my_endpoint) {
  my_mod_name_txt <- deparse(substitute(my_mixed_mod))
  my_file_name <- paste("tidy_", my_mod_name_txt, ".csv", sep = "")

  my_call <- my_mixed_mod@call$formula
  my_call_txt <- paste(as.character(my_call)[2],
    str_remove(as.character(my_call)[3], "_id_uniq"),
    sep = " ~ "
  )

  my_output <- as.data.frame(broom.mixed::tidy(my_mixed_mod)) %>%
    mutate(endpoint = my_endpoint) %>%
    mutate(call = my_call_txt) %>%
    mutate(estim_orig_scale = ifelse(effect == "fixed", exp(estimate), NA)) %>%
    mutate(stars = sig_stars(p.value))

  dir.create(here::here("output", "models"), showWarnings = FALSE)
  write.csv(as.data.frame(my_output),
    file = here::here("output", "models", my_file_name)
  )

  return(my_output)
}


get_max_CV <- function(x) {
  ## x = matrix, rows are timepoints, columns are experimental units (circuits)
  temp <- apply(x, 1, sd, na.rm = TRUE) /
    apply(x, 1, mean, na.rm = TRUE)
  max(temp)
}
