# Predictions run by DMIS-DK
library(reticulate)
library(tidyverse)
library(Cairo)

use_python("/usr/local/bin/python2")
synapse <- import("synapseclient")
syn <- synapse$Synapse()
synutils <- synapse$utils
syn$login()
source_python('https://raw.githubusercontent.com/Sage-Bionetworks/IDG-DREAM-Drug-Kinase-Challenge/master/round1b/score/bin/evaluation_metrics_python2.py')

gold <- syn$get("syn18421225")$path %>% read_csv %>%
  mutate(comp = paste0(Compound_InchiKeys,"_",UniProt_Id,"_",DiscoveRx_Gene_Symbol)) 

spearman_py <- function(gold, pred){
  gold_py <- gold %>% np_array()
  pred_py <- pred %>% np_array()
  spearman(gold_py, pred_py)
}

rmse_py <- function(gold, pred){
  gold_py <- gold %>% np_array()
  pred_py <- pred %>% np_array()
  rmse(gold_py, pred_py)
}


dmis_pkd_tidy <- syn$get("syn21445732")$path %>%
  read_csv() %>% 
  select(target_id, standard_inchi_key, standard_value_pkd) %>% 
  rename(prot = target_id, cmpd = standard_inchi_key, pKd = standard_value_pkd)

theme_set(theme_bw() +
  theme(text = element_text(size = 15)))
math <- tibble::tribble(
  ~exp_label, ~no_of_used,
  "1≤Kd≤14", 950698,
  "5≤Kd≤14", 950698,
  "1≤Kd≤8", 746609,
  "5.5≤Kd≤8.5", 608206,
  "5≤Kd≤9", 890134,
  "6≤Kd≤8", 430972
) %>% mutate(total = 950732) %>% 
  mutate(no_of_removed = total-no_of_used) %>% 
  select(-total, -no_of_used)


ids <- tibble::tribble(
  ~id, ~lower_limit, ~upper_limit, ~is_random, ~exp_label, ~no_of_removed, ~random_iteration,
# "syn21217148"
# syn21217149
# syn21217150
# syn21217151
# syn21217152
# syn21217153
"syn21209595",1,14,F,"1≤Kd≤14",0,0,
"syn21209598",1,14,T,"1≤Kd≤14",0,1,
"syn21209599",1,14,T,"1≤Kd≤14",0,2,
"syn21209600",1,14,T,"1≤Kd≤14",0,3,
"syn21209585",1,8,F,"1≤Kd≤8",0,0,
"syn21209590",1,8,T,"1≤Kd≤8",0,1,
"syn21209591",1,8,T,"1≤Kd≤8",0,2,
"syn21209594",1,8,T,"1≤Kd≤8",0,3,
"syn21209601",5.5,8.5,F,"5.5≤Kd≤8.5",0,0,
"syn21209602",5.5,8.5,T,"5.5≤Kd≤8.5",0,1,
"syn21209603",5.5,8.5,T,"5.5≤Kd≤8.5",0,2,
"syn21209605",5.5,8.5,T,"5.5≤Kd≤8.5",0,3,
"syn21209612",5,14,F,"5≤Kd≤14",0,0,
"syn21209613",5,14,T,"5≤Kd≤14",0,1,
"syn21209614",5,14,T,"5≤Kd≤14",0,2,
"syn21209615",5,14,T,"5≤Kd≤14",0,3,
"syn21209606",5,9,F,"5≤Kd≤9",0,0,
"syn21209607",5,9,T,"5≤Kd≤9",0,1,
"syn21209610",5,9,T,"5≤Kd≤9",0,2,
"syn21209611",5,9,T,"5≤Kd≤9",0,3,
"syn21209616",6,8,F,"6≤Kd≤8",0,0,
"syn21209617",6,8,T,"6≤Kd≤8",0,1,
"syn21209619",6,8,T,"6≤Kd≤8",0,2,
"syn21209621",6,8,T,"6≤Kd≤8",0,3) %>% 
  select(-no_of_removed) %>%  #placeholder
  left_join(math)


data <- lapply(ids$id, function(x){
  syn$get(x)$path %>% read_csv()
})

names(data) <- ids$id

data_df <- bind_rows(data, .id = "id") %>%
  mutate(comp = paste0(Compound_InchiKeys,"_",UniProt_Id,"_",DiscoveRx_Gene_Symbol)) %>% 
  filter(comp %in% gold$comp)

data_df <- data_df %>% 
  left_join(ids)

data_df_summary <- data_df %>% 
  # group_by(threshold, random_iteration, is_random, Compound_InchiKeys, UniProt_Id, DiscoveRx_Gene_Symbol) %>% 
  # ungroup() %>% 
  left_join(gold %>% select(Compound_InchiKeys, UniProt_Id, DiscoveRx_Gene_Symbol, `pKd_[M]`)) %>% 
  group_by(exp_label,lower_limit, upper_limit, is_random, no_of_removed, random_iteration) %>% 
  summarize(spearman = spearman_py(`pKd_[M]`,`pKd_[M]_pred`),rmse = rmse_py(`pKd_[M]`,`pKd_[M]_pred`)) %>% 
  arrange(no_of_removed) %>% 
  ungroup()

spearman_scal_fact <- 950732*2
rmse_scal_fact <- 950732

# levs <- c("no limit", "5≤Kd", "5≤Kd≤9", "5.5≤Kd≤8.5", "6≤Kd≤8",  "Kd≤8")
# data_df_summary <- mutate(data_df_summary, exp_label = factor(exp_label, levels = levs, ordered=T))



data_df_summary <- data_df_summary %>% mutate(is_random_label = case_when(is_random == T ~ "random control",
                                     is_random == F ~ "pKd dropout")) 
  
p_1 <- ggplot(data_df_summary %>% group_by(exp_label, is_random_label, no_of_removed) %>% summarize(mean_rmse = mean(rmse), sd_rmse = sd(rmse)))+
  geom_col(aes(x = reorder(exp_label, no_of_removed), y = mean_rmse, fill = is_random_label), stat = 'identity', position = position_dodge(preserve = "single")) +
  labs(y= "RMSE", x = "Tanimoto Threshold") +
  theme(axis.text.x = element_text(size =10)) +
  # facet_grid(cols = vars(is_random_label), space = "free", scales = "free") +
  geom_errorbar(aes(x = exp_label, y = mean_rmse, ymax = mean_rmse+sd_rmse, ymin = mean_rmse-sd_rmse, group = is_random_label, color = is_random_label), position = "dodge") +
  geom_line(aes(x = exp_label, y = no_of_removed/rmse_scal_fact, group = 1)) +
    geom_point(aes(x = exp_label, y = no_of_removed/rmse_scal_fact, group = 1)) +
  # scale_fill_manual(name = "pKd threshold",
  #                   values = c("no limit" = "#FDE725FF", 
  #                              "5≤Kd" = "#7AD151FF",
  #                              "5≤Kd≤9" = "#22A884FF", 
  #                              "5.5≤Kd≤8.5" = "#2A788EFF", 
  #                              "6≤Kd≤8" = "#414487FF",  
  #                              "Kd≤8" = "#440154FF")) +
  #   scale_color_manual(name = "pKd threshold",
  #                   values = c("no limit" = "#FDE725FF", 
  #                              "5≤Kd" = "#7AD151FF",
  #                              "5≤Kd≤9" = "#22A884FF", 
  #                              "5.5≤Kd≤8.5" = "#2A788EFF", 
  #                              "6≤Kd≤8" = "#414487FF",  
  #                              "Kd≤8" = "#440154FF")) +
      scale_color_manual(name = "Condition", values= c("random control"= "#BFBFBF",'pKd dropout' = "#66666E")) +
  scale_fill_manual(name = "Condition", values= c("random control"= "#BFBFBF",'pKd dropout' = "#66666E")) +
    scale_y_continuous(sec.axis = sec_axis(~., name = "Proportion of removed pairs")) 


p_2 <-ggplot(data_df_summary %>% group_by(exp_label, is_random_label, no_of_removed) %>% summarize(mean_spearman = mean(spearman), sd_spearman = sd(spearman)))+
  geom_col(aes(x = reorder(exp_label, no_of_removed), y = mean_spearman, fill = is_random_label), stat = 'identity', position = position_dodge(preserve = "single")) +
  labs(y= "Spearman correlation", x = "Tanimoto Threshold") +
  theme(axis.text.x = element_text(size = 10)) +
  # facet_grid(cols = vars(is_random_label), space = "free", scales = "free") +
  geom_errorbar(aes(x = exp_label, y = mean_spearman, ymax = mean_spearman+sd_spearman, ymin = mean_spearman-sd_spearman, group = is_random_label, color = is_random_label), position = "dodge") +
  geom_line(aes(x = exp_label, y = no_of_removed/spearman_scal_fact, group = 1)) +
  geom_point(aes(x = exp_label, y = no_of_removed/spearman_scal_fact, group = 1)) +
  # scale_fill_manual(name = "pKd threshold",
  #                   values = c("no limit" = "#FDE725FF", 
  #                              "5≤Kd" = "#7AD151FF",
  #                              "5≤Kd≤9" = "#22A884FF", 
  #                              "5.5≤Kd≤8.5" = "#2A788EFF", 
  #                              "6≤Kd≤8" = "#414487FF",  
  #                              "Kd≤8" = "#440154FF")) +
  #   scale_color_manual(name = "pKd threshold",
  #                   values = c("no limit" = "#FDE725FF", 
  #                              "5≤Kd" = "#7AD151FF",
  #                              "5≤Kd≤9" = "#22A884FF", 
  #                              "5.5≤Kd≤8.5" = "#2A788EFF", 
  #                              "6≤Kd≤8" = "#414487FF",  
  #                              "Kd≤8" = "#440154FF")) +
    scale_fill_manual(name = "Condition", values= c("random control"= "#BFBFBF",'pKd dropout' = "#66666E")) +
  scale_color_manual(name = "Condition", values= c("random control"= "#BFBFBF",'pKd dropout' = "#66666E")) +
    scale_y_continuous(sec.axis = sec_axis(~., name = "Proportion of removed pairs")) 


p_1

ggsave("figure_5_dmis_pkd_rmse.pdf", device = cairo_pdf,
              width = 9, height = 4.135, units = "in")


p_2

ggsave("figure_5_dmis_pkd_spearman.pdf", device = cairo_pdf,
              width = 9, height = 4.135, units = "in")
all_preds <- data_df %>% left_join(gold) %>% group_by(`pKd_[M]`, exp_label, is_random, no_of_removed) %>% summarize(mean_pred = mean(`pKd_[M]_pred`)) %>% ungroup()

levs <- c("no limit", "5≤Kd", "5≤Kd≤9", "5.5≤Kd≤8.5", "6≤Kd≤8",  "Kd≤8")

all_preds <- mutate(all_preds, exp_label = factor(exp_label, 
                                      levels = rev(levs),
                                      ordered = T)) %>% 
  mutate(is_random_label = case_when(is_random == T ~ "random control",
                                     is_random == F ~ "pKd dropout")) 


# library(cowplot)

p1 <- ggplot(all_preds) + 
  geom_smooth(aes(y = abs(`pKd_[M]`-`mean_pred`),x =`pKd_[M]`, color = exp_label, fill = exp_label), alpha = 0.2) +
  facet_wrap(~is_random_label, nrow = 2) +
  labs(x = "measured pKd", y = "absolute error of prediction") +
  scale_fill_viridis_d(name = "pKd threshold") +
  scale_color_viridis_d(name = "pKd threshold") 

sz <- 1
end <- 'butt'
join <- 'mitre'

p2 <- ggplot(dmis_pkd_tidy) + 
  geom_histogram(aes(x=pKd))+
  # geom_segment(aes(x = 1, y = -5000, xend = 14, yend = -5000), 
  #              lineend = end, linejoin = join,
  #              arrow = arrow(angle = 90, length = unit(0.05, "inches"),
  #                            ends = "both", type = "open"), 
  #              # color = "#FDE725FF", 
  #              size = sz) +
  # geom_segment(aes(x = 5, y = -20000, xend = 14, yend = -20000), 
  #              lineend = end, linejoin = join,
  #              arrow = arrow(angle = 90, length = unit(0.05, "inches"),
  #                            ends = "both", type = "open"), 
  #              # color = "#7AD151FF", 
  #              size = sz) +
  # geom_segment(aes(x = 5, y = -17000, xend = 9, yend = -17000), 
  #              lineend = end, linejoin = join, 
  #              arrow = arrow(angle = 90, length = unit(0.05, "inches"),
  #                            ends = "both", type = "open"),
  #              # color = "#22A884FF", 
  #              size = sz,) +
  # geom_segment(aes(x = 5.5, y = -14000, xend = 8.5, yend = -14000), 
  #              lineend = end, linejoin = join,
  #              arrow = arrow(angle = 90, length = unit(0.05, "inches"),
  #                            ends = "both", type = "open"),
  #              # color = "#2A788EFF", 
  #              size = sz) +
  # geom_segment(aes(x = 6, y = -11000, xend = 8, yend = -11000), 
  #              lineend = end, linejoin = join, 
  #              arrow = arrow(angle = 90, length = unit(0.05, "inches"),
  #                            ends = "both", type = "open"),
  #              # color = "#414487FF", 
  #              size = sz) +
  # geom_segment(aes(x = 1, y = -8000, xend = 8, yend = -8000), 
  #              lineend = end, linejoin = join,
  #              arrow = arrow(angle = 90, length = unit(0.05, "inches"),
  #                            ends = "both", type = "open"),
  #                             # color = "#440154FF", 
  #              size = sz) +
  ylim(-100000,200000) +
  labs(x = "measured pKd", y = "number of compound-kinase training pairs")
# ggplot(all_preds) + 
#   geom_line(aes(y = abs(`pKd_[M]`-`mean_pred`),x =`pKd_[M]`, color = exp_label)) +
#   facet_wrap(~is_random, nrow = 2) +
#   labs(x = "measured pKd", y = "absolute error of prediction")
# 
# ggplot(all_preds) + 
#   geom_point(aes(y = abs(`pKd_[M]`-`mean_pred`),x =`pKd_[M]`, color = exp_label), stat = "identity") +
#   facet_wrap(~is_random, nrow = 2) +
#   labs(x = "measured pKd", y = "absolute error of prediction")

p1

ggsave("figure_5_dmis_sim_traces.pdf",  device = cairo_pdf,
              width = 9, height = 8.5, units = "in")


p2

#for some reason this file is really giving me problems in Graphic as an svg, so lets save it as a fixed file 
ggsave("figure_5_dmis_sim_hist.pdf",  device = cairo_pdf,
  width = 6, height = 4.135, units = "in")


cowplot::plot_grid(p_1, p_2, p1, p2, ncol = 1, align = 'v', axis = 'rl', rel_heights = c(0.16,0.16,0.25,0.25))

ggsave("pkd_threshold_experiment_dmis.pdf",  device = cairo_pdf,
               width = 10, height = 20, units = "in")