# we ran the following docker commants with the listed container to get these results: 

##actual dropouts:

 docker run -it --rm -v ${PWD}/input:/input -v ${PWD}/output:/output -v ${PWD}/SW_based_prediction:/SW_based_prediction  -e LOW="5" -e HIGH="9" docker.synapse.org/syn18519352/qed-paff

 docker run -it --rm -v ${PWD}/input:/input -v ${PWD}/output:/output -v ${PWD}/SW_based_prediction:/SW_based_prediction  -e LOW="5.5" -e HIGH="8.5" docker.synapse.org/syn18519352/qed-paff

 docker run -it --rm -v ${PWD}/input:/input -v ${PWD}/output:/output -v ${PWD}/SW_based_prediction:/SW_based_prediction  -e LOW="6" -e HIGH="8" docker.synapse.org/syn18519352/qed-paff

 docker run -it --rm -v ${PWD}/input:/input -v ${PWD}/output:/output -v ${PWD}/SW_based_prediction:/SW_based_prediction  -e LOW="1" -e HIGH="14" docker.synapse.org/syn18519352/qed-paff

 docker run -it --rm -v ${PWD}/input:/input -v ${PWD}/output:/output -v ${PWD}/SW_based_prediction:/SW_based_prediction  -e LOW="5" -e HIGH="14" docker.synapse.org/syn18519352/qed-paff

 docker run -it --rm -v ${PWD}/input:/input -v ${PWD}/output:/output -v ${PWD}/SW_based_prediction:/SW_based_prediction  -e LOW="1" -e HIGH="8" docker.synapse.org/syn18519352/qed-paff

##control runs:

docker run -it --rm -v ${PWD}/input:/input -v ${PWD}/output:/output -v ${PWD}/SW_based_prediction:/SW_based_prediction  -e LOW="1" -e HIGH="8" -e CONTROL_MODE="1" -e CONTROL_NUM="5" docker.synapse.org/syn18519352/qed-paff:v2

docker run -it --rm -v ${PWD}/input:/input -v ${PWD}/output:/output -v ${PWD}/SW_based_prediction:/SW_based_prediction  -e LOW="5" -e HIGH="14" -e CONTROL_MODE="1" -e CONTROL_NUM="5" docker.synapse.org/syn18519352/qed-paff:v2

docker run -it --rm -v ${PWD}/input:/input -v ${PWD}/output:/output -v ${PWD}/SW_based_prediction:/SW_based_prediction  -e LOW="5.5" -e HIGH="8.5" -e CONTROL_MODE="1" -e CONTROL_NUM="5" docker.synapse.org/syn18519352/qed-paff:v2

docker run -it --rm -v ${PWD}/input:/input -v ${PWD}/output:/output -v ${PWD}/SW_based_prediction:/SW_based_prediction  -e LOW="5" -e HIGH="9" -e CONTROL_MODE="1" -e CONTROL_NUM="5" docker.synapse.org/syn18519352/qed-paff:v2

docker run -it --rm -v ${PWD}/input:/input -v ${PWD}/output:/output -v ${PWD}/SW_based_prediction:/SW_based_prediction  -e LOW="6" -e HIGH="8" -e CONTROL_MODE="1" -e CONTROL_NUM="5" docker.synapse.org/syn18519352/qed-paff:v2

docker run -it --rm -v ${PWD}/input:/input -v ${PWD}/output:/output -v ${PWD}/SW_based_prediction:/SW_based_prediction  -e LOW="1" -e HIGH="14" -e CONTROL_MODE="1" -e CONTROL_NUM="5" docker.synapse.org/syn18519352/qed-paff:v2
library(reticulate)
library(tidyverse)
library(Cairo)
library(svglite)

use_python("/usr/local/bin/python2")
synapse <- import("synapseclient")
syn <- synapse$Synapse()
synutils <- synapse$utils
syn$login()
source_python('https://raw.githubusercontent.com/Sage-Bionetworks/IDG-DREAM-Drug-Kinase-Challenge/master/round1b/score/bin/evaluation_metrics_python2.py')

gold <- syn$get("syn18421225")$path %>% read_csv %>%
  mutate(comp = paste0(Compound_InchiKeys,"_",UniProt_Id,"_",DiscoveRx_Gene_Symbol)) 

spearman_py <- function(gold, pred){
  gold_py <- gold %>% np_array()
  pred_py <- pred %>% np_array()
  spearman(gold_py, pred_py)
}

rmse_py <- function(gold, pred){
  gold_py <- gold %>% np_array()
  pred_py <- pred %>% np_array()
  rmse(gold_py, pred_py)
}

qed_pkd <- syn$get("syn20690323")$path %>% 
  read_delim(delim = " ", col_names = F) 
qed_cmpd <- syn$get("syn20690840")$path %>% 
  read_delim(delim = "\t", ,col_names = F)
qed_prot <- syn$get("syn20690841")$path %>% 
  read_delim(delim = "\t", col_names = F)

qed_pkd_tidy <- qed_pkd %>% 
  magrittr::set_colnames(qed_prot$X1) %>% 
  magrittr::set_rownames(qed_cmpd$X1) %>% 
  rownames_to_column("cmpd") %>% 
  gather(key = "prot", value = "pKd", -cmpd) %>% 
  filter(pKd < 100)


theme_set(theme_bw() +
  theme(text = element_text(size = 15)))
## We are describing the cutoffs of LOW = 1 and HIGH= 14 as "no limit" as there were no training data points that were outside this range.

ids <- tibble::tribble(
  ~id, ~lower_limit, ~upper_limit, ~is_random, ~exp_label, ~no_of_removed, ~random_iteration,
  "syn20825281", 1, 14, F, "no limit", 0, 0,
  "syn20825274", 1, 8,  F, "Kd≤8", 6726, 0,
  "syn20825278", 5.5, 8.5, F, "5.5≤Kd≤8.5", 16126, 0,
  "syn20825275", 5, 14, F, "5≤Kd", 3096, 0,
  "syn20825279", 5, 9, F, "5≤Kd≤9", 4774, 0,
  "syn20825280", 6, 8, F, "6≤Kd≤8", 29933, 0,
  "syn20834496", 1, 14, T, "no limit", 0, 0, 
  "syn20834497", 1, 14, T, "no limit", 0, 1,
  "syn20834498", 1, 14, T, "no limit", 0, 2,
  "syn20834499", 1, 14, T, "no limit", 0, 3, 
  "syn20834500", 1, 14, T, "no limit", 0, 4, 
  "syn20834470", 1, 8, T, "Kd≤8", 6726, 0, 
  "syn20834471", 1, 8, T, "Kd≤8", 6726, 1,
  "syn20834472", 1, 8, T, "Kd≤8", 6726, 2,
  "syn20834473", 1, 8, T, "Kd≤8", 6726, 3, 
  "syn20834474", 1, 8, T, "Kd≤8", 6726, 4, 
  "syn20834480", 5.5, 8.5, T, "5.5≤Kd≤8.5", 16126,  0, 
  "syn20834481", 5.5, 8.5, T, "5.5≤Kd≤8.5", 16126,  1,
  "syn20834482", 5.5, 8.5, T, "5.5≤Kd≤8.5", 16126,  2,
  "syn20834483", 5.5, 8.5, T, "5.5≤Kd≤8.5", 16126,  3, 
  "syn20834484", 5.5, 8.5, T, "5.5≤Kd≤8.5", 16126,  4, 
  "syn20834475", 5, 14, T, "5≤Kd", 3096, 0, 
  "syn20834476", 5, 14, T, "5≤Kd", 3096, 1,
  "syn20834477", 5, 14, T, "5≤Kd", 3096, 2,
  "syn20834478", 5, 14, T, "5≤Kd", 3096, 3, 
  "syn20834479", 5, 14, T, "5≤Kd", 3096, 4, 
  "syn20834486", 5, 9, T, "5≤Kd≤9", 4774, 0,  
  "syn20834487", 5, 9, T, "5≤Kd≤9", 4774, 1, 
  "syn20834488", 5, 9, T, "5≤Kd≤9", 4774, 2, 
  "syn20834489", 5, 9, T, "5≤Kd≤9", 4774, 3,  
  "syn20834490", 5, 9, T, "5≤Kd≤9", 4774, 4,  
  "syn20834491", 6, 8, T, "6≤Kd≤8", 29933, 0, 
  "syn20834492", 6, 8, T, "6≤Kd≤8", 29933, 1,
  "syn20834493", 6, 8, T, "6≤Kd≤8", 29933, 2,
  "syn20834494", 6, 8, T, "6≤Kd≤8", 29933, 3, 
  "syn20834495", 6, 8, T, "6≤Kd≤8", 29933, 4)


data <- lapply(ids$id, function(x){
  syn$get(x)$path %>% read_csv()
})


names(data) <- ids$id

data_df <- bind_rows(data, .id = "id") %>%
  mutate(comp = paste0(Compound_InchiKeys,"_",UniProt_Id,"_",DiscoveRx_Gene_Symbol)) %>% 
  filter(comp %in% gold$comp)

data_df <- data_df %>% 
  left_join(ids)

data_df_summary <- data_df %>% 
  # group_by(threshold, random_iteration, is_random, Compound_InchiKeys, UniProt_Id, DiscoveRx_Gene_Symbol) %>% 
  # ungroup() %>% 
  left_join(gold %>% select(Compound_InchiKeys, UniProt_Id, DiscoveRx_Gene_Symbol, `pKd_[M]`)) %>% 
  group_by(exp_label,lower_limit, upper_limit, is_random, no_of_removed, random_iteration) %>% 
  summarize(spearman = spearman_py(`pKd_[M]`,`pKd_[M]_pred`),rmse = rmse_py(`pKd_[M]`,`pKd_[M]_pred`)) %>% 
  arrange(no_of_removed) %>% 
  ungroup()

spearman_scal_fact <- 60462*2
rmse_scal_fact <- 60462

# levs <- c("no limit", "5≤Kd", "5≤Kd≤9", "5.5≤Kd≤8.5", "6≤Kd≤8",  "Kd≤8")
# data_df_summary <- mutate(data_df_summary, exp_label = factor(exp_label, levels = levs, ordered=T))



data_df_summary <- data_df_summary %>% mutate(is_random_label = case_when(is_random == T ~ "random control",
                                     is_random == F ~ "pKd dropout")) 
  
p_1 <- ggplot(data_df_summary %>% group_by(exp_label, is_random_label, no_of_removed) %>% summarize(mean_rmse = mean(rmse), sd_rmse = sd(rmse)))+
  geom_col(aes(x = reorder(exp_label, no_of_removed), y = mean_rmse, fill = is_random_label), stat = 'identity', position = position_dodge(preserve = "single")) +
  labs(y= "RMSE", x = "Tanimoto Threshold") +
  theme(axis.text.x = element_text(size=10)) +
  # facet_grid(cols = vars(is_random_label), space = "free", scales = "free") +
  geom_errorbar(aes(x = exp_label, y = mean_rmse, ymax = mean_rmse+sd_rmse, ymin = mean_rmse-sd_rmse, group = is_random_label, color = is_random_label), position = "dodge") +
  geom_line(aes(x = exp_label, y = no_of_removed/rmse_scal_fact, group = 1)) +
    geom_point(aes(x = exp_label, y = no_of_removed/rmse_scal_fact, group = 1)) +
  # scale_fill_manual(name = "pKd threshold",
  #                   values = c("no limit" = "#FDE725FF", 
  #                              "5≤Kd" = "#7AD151FF",
  #                              "5≤Kd≤9" = "#22A884FF", 
  #                              "5.5≤Kd≤8.5" = "#2A788EFF", 
  #                              "6≤Kd≤8" = "#414487FF",  
  #                              "Kd≤8" = "#440154FF")) +
  #   scale_color_manual(name = "pKd threshold",
  #                   values = c("no limit" = "#FDE725FF", 
  #                              "5≤Kd" = "#7AD151FF",
  #                              "5≤Kd≤9" = "#22A884FF", 
  #                              "5.5≤Kd≤8.5" = "#2A788EFF", 
  #                              "6≤Kd≤8" = "#414487FF",  
  #                              "Kd≤8" = "#440154FF")) +
      scale_color_manual(name = "Condition", values= c("random control"= "#BFBFBF",'pKd dropout' = "#66666E")) +
  scale_fill_manual(name = "Condition", values= c("random control"= "#BFBFBF",'pKd dropout' = "#66666E")) +
    scale_y_continuous(sec.axis = sec_axis(~., name = "Proportion of removed pairs")) 


p_2 <-ggplot(data_df_summary %>% group_by(exp_label, is_random_label, no_of_removed) %>% summarize(mean_spearman = mean(spearman), sd_spearman = sd(spearman)))+
  geom_col(aes(x = reorder(exp_label, no_of_removed), y = mean_spearman, fill = is_random_label), stat = 'identity', position = position_dodge(preserve = "single")) +
  labs(y= "Spearman correlation", x = "Tanimoto Threshold") +
  theme(axis.text.x = element_text(size = 10)) +
  # facet_grid(cols = vars(is_random_label), space = "free", scales = "free") +
  geom_errorbar(aes(x = exp_label, y = mean_spearman, ymax = mean_spearman+sd_spearman, ymin = mean_spearman-sd_spearman, group = is_random_label, color = is_random_label), position = "dodge") +
  geom_line(aes(x = exp_label, y = no_of_removed/spearman_scal_fact, group = 1)) +
  geom_point(aes(x = exp_label, y = no_of_removed/spearman_scal_fact, group = 1)) +
  # scale_fill_manual(name = "pKd threshold",
  #                   values = c("no limit" = "#FDE725FF", 
  #                              "5≤Kd" = "#7AD151FF",
  #                              "5≤Kd≤9" = "#22A884FF", 
  #                              "5.5≤Kd≤8.5" = "#2A788EFF", 
  #                              "6≤Kd≤8" = "#414487FF",  
  #                              "Kd≤8" = "#440154FF")) +
  #   scale_color_manual(name = "pKd threshold",
  #                   values = c("no limit" = "#FDE725FF", 
  #                              "5≤Kd" = "#7AD151FF",
  #                              "5≤Kd≤9" = "#22A884FF", 
  #                              "5.5≤Kd≤8.5" = "#2A788EFF", 
  #                              "6≤Kd≤8" = "#414487FF",  
  #                              "Kd≤8" = "#440154FF")) +
    scale_fill_manual(name = "Condition", values= c("random control"= "#BFBFBF",'pKd dropout' = "#66666E")) +
  scale_color_manual(name = "Condition", values= c("random control"= "#BFBFBF",'pKd dropout' = "#66666E")) +
    scale_y_continuous(sec.axis = sec_axis(~., name = "Proportion of removed pairs")) 


p_1

ggsave("figure_5_qed_sim_rmse.pdf", device = cairo_pdf,
              width = 9, height = 4.135, units = "in")


p_2

ggsave("figure_5_qed_sim_spearman.pdf", device = cairo_pdf,
              width = 9, height = 4.135, units = "in")
all_preds <- data_df %>% left_join(gold) %>% group_by(`pKd_[M]`, exp_label, is_random, no_of_removed) %>% summarize(mean_pred = mean(`pKd_[M]_pred`)) %>% ungroup()

levs <- c("no limit", "5≤Kd", "5≤Kd≤9", "5.5≤Kd≤8.5", "6≤Kd≤8",  "Kd≤8")

all_preds <- mutate(all_preds, exp_label = factor(exp_label, 
                                      levels = rev(levs),
                                      ordered = T)) %>% 
  mutate(is_random_label = case_when(is_random == T ~ "random control",
                                     is_random == F ~ "pKd dropout")) 


# library(cowplot)

p1 <- ggplot(all_preds) + 
  geom_smooth(aes(y = abs(`pKd_[M]`-`mean_pred`),x =`pKd_[M]`, color = exp_label, fill = exp_label), alpha = 0.2) +
  facet_wrap(~is_random_label, nrow = 2) +
  labs(x = "measured pKd", y = "absolute error of prediction") +
  scale_fill_viridis_d(name = "pKd threshold") +
  scale_color_viridis_d(name = "pKd threshold")

sz <- 1
end <- 'butt'
join <- 'mitre'

p2 <- ggplot(qed_pkd_tidy) + 
  geom_histogram(aes(x=pKd))+
  # geom_segment(aes(x = 1, y = -750, xend = 14, yend = -750), 
  #              lineend = end, linejoin = join,
  #              arrow = arrow(angle = 90, length = unit(0.05, "inches"),
  #                            ends = "both", type = "open"), 
  #              # color = "#FDE725FF", 
  #              size = sz) +
  # geom_segment(aes(x = 5, y = -4500, xend = 14, yend = -4500), 
  #              lineend = end, linejoin = join,
  #              arrow = arrow(angle = 90, length = unit(0.05, "inches"),
  #                            ends = "both", type = "open"), 
  #              # color = "#7AD151FF", 
  #              size = sz) +
  # geom_segment(aes(x = 5, y = -3750, xend = 9, yend = -3750), 
  #              lineend = end, linejoin = join, 
  #              arrow = arrow(angle = 90, length = unit(0.05, "inches"),
  #                            ends = "both", type = "open"),
  #              # color = "#22A884FF", 
  #              size = sz,) +
  # geom_segment(aes(x = 5.5, y = -3000, xend = 8.5, yend = -3000), 
  #              lineend = end, linejoin = join,
  #              arrow = arrow(angle = 90, length = unit(0.05, "inches"),
  #                            ends = "both", type = "open"),
  #              # color = "#2A788EFF", 
  #              size = sz) +
  # geom_segment(aes(x = 6, y = -2250, xend = 8, yend = -2250), 
  #              lineend = end, linejoin = join, 
  #              arrow = arrow(angle = 90, length = unit(0.05, "inches"),
  #                            ends = "both", type = "open"),
  #              # color = "#414487FF", 
  #              size = sz) +
  # geom_segment(aes(x = 1, y = -1500, xend = 8, yend = -1500), 
  #              lineend = end, linejoin = join,
  #              arrow = arrow(angle = 90, length = unit(0.05, "inches"),
  #                            ends = "both", type = "open"),
  #                             # color = "#440154FF", 
  #              size = sz) +
  ylim(-5000,12500) +
  labs(x = "measured pKd", y = "number of compound-kinase training pairs")


# ggplot(all_preds) + 
#   geom_line(aes(y = abs(`pKd_[M]`-`mean_pred`),x =`pKd_[M]`, color = exp_label)) +
#   facet_wrap(~is_random, nrow = 2) +
#   labs(x = "measured pKd", y = "absolute error of prediction")
# 
# ggplot(all_preds) + 
#   geom_point(aes(y = abs(`pKd_[M]`-`mean_pred`),x =`pKd_[M]`, color = exp_label), stat = "identity") +
#   facet_wrap(~is_random, nrow = 2) +
#   labs(x = "measured pKd", y = "absolute error of prediction")

p1

ggsave("figure_5_qed_sim_traces.pdf",  device = cairo_pdf,
              width = 9, height = 8.5, units = "in")


p2

#for some reason this file is really giving me problems in Graphic as an svg, so lets save it as a fixed file 
ggsave("figure_5_qed_sim_hist.pdf", device = cairo_pdf, 
  width = 6, height = 4.135, units = "in")


cowplot::plot_grid(p_1, p_2, p1, p2, ncol = 1, align = 'v', axis = 'rl', rel_heights = c(0.16,0.16,0.25,0.25))

ggsave("pkd_threshold_experiment_qed.pdf",  device = cairo_pdf,
               width = 10, height = 20, units = "in")