library(reticulate)
library(tidyverse)
## ── Attaching packages ───────────────────────────────────────────────────────────────────────────────────────────── tidyverse 1.3.0 ──
## ✔ ggplot2 3.2.1     ✔ purrr   0.3.3
## ✔ tibble  2.1.3     ✔ dplyr   0.8.3
## ✔ tidyr   1.0.0     ✔ stringr 1.4.0
## ✔ readr   1.3.1     ✔ forcats 0.4.0
## ── Conflicts ──────────────────────────────────────────────────────────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
use_condaenv('idg-dream', required = T)
synapse <- import("synapseclient")
syn <- synapse$Synapse()
synutils <- synapse$utils
syn$login()
source_python('https://raw.githubusercontent.com/Sage-Bionetworks/IDG-DREAM-Drug-Kinase-Challenge/master/round1b/score/bin/evaluation_metrics_python2.py')

gold <- syn$get("syn18421225")$path %>% read_csv %>%
  mutate(comp = paste0(Compound_InchiKeys,"_",UniProt_Id,"_",DiscoveRx_Gene_Symbol)) 
## Parsed with column specification:
## cols(
##   Compound_SMILES = col_character(),
##   Compound_InchiKeys = col_character(),
##   Compound_Name = col_character(),
##   UniProt_Id = col_character(),
##   Entrez_Gene_Symbol = col_character(),
##   DiscoveRx_Gene_Symbol = col_character(),
##   `pKd_[M]` = col_double()
## )
spearman_py <- function(gold, pred){
  gold_py <- gold %>% np_array()
  pred_py <- pred %>% np_array()
  spearman(gold_py, pred_py)
}

rmse_py <- function(gold, pred){
  gold_py <- gold %>% np_array()
  pred_py <- pred %>% np_array()
  rmse(gold_py, pred_py)
}

theme_set(theme_bw() +
  theme(text = element_text(size = 15)))
ids <- tibble::tribble(
  ~id, ~threshold, ~random_iteration, ~is_random, ~exp_label, ~no_of_removed,
  "syn21363501",0.2,1,"Random Control","0.2",949995,
  "syn21363502",0.2,2,"Random Control","0.2",949995,
  "syn21363503",0.2,NA,"Dropout Test","0.2",949995,
  "syn21363504",0.4,1,"Random Control","0.4",797394,
  "syn21363505",0.4,2,"Random Control","0.4",797394,
  "syn21363506",0.4,NA,"Dropout Test","0.4",797394,
  "syn21363507",0.6,1,"Random Control","0.6",17538,
  "syn21363508",0.6,2,"Random Control","0.6",17538,
  "syn21363509",0.6,NA,"Dropout Test","0.6",17538,
  "syn21363510",0.8,1,"Random Control","0.8",1034,
  "syn21363511",0.8,2,"Random Control","0.8",1034,
  "syn21363512",0.8,NA,"Dropout Test","0.8",1034,
  "syn21209595",1,NA,"All Training Data", "AD",0)

data <- lapply(ids$id, function(x){
  syn$get(x)$path %>% read_csv()
})
## Parsed with column specification:
## cols(
##   Compound_SMILES = col_character(),
##   Compound_InchiKeys = col_character(),
##   Compound_Name = col_character(),
##   UniProt_Id = col_character(),
##   Entrez_Gene_Symbol = col_character(),
##   DiscoveRx_Gene_Symbol = col_character(),
##   `pKd_[M]_pred` = col_double()
## )
## Parsed with column specification:
## cols(
##   Compound_SMILES = col_character(),
##   Compound_InchiKeys = col_character(),
##   Compound_Name = col_character(),
##   UniProt_Id = col_character(),
##   Entrez_Gene_Symbol = col_character(),
##   DiscoveRx_Gene_Symbol = col_character(),
##   `pKd_[M]_pred` = col_double()
## )
## Parsed with column specification:
## cols(
##   Compound_SMILES = col_character(),
##   Compound_InchiKeys = col_character(),
##   Compound_Name = col_character(),
##   UniProt_Id = col_character(),
##   Entrez_Gene_Symbol = col_character(),
##   DiscoveRx_Gene_Symbol = col_character(),
##   `pKd_[M]_pred` = col_double()
## )
## Parsed with column specification:
## cols(
##   Compound_SMILES = col_character(),
##   Compound_InchiKeys = col_character(),
##   Compound_Name = col_character(),
##   UniProt_Id = col_character(),
##   Entrez_Gene_Symbol = col_character(),
##   DiscoveRx_Gene_Symbol = col_character(),
##   `pKd_[M]_pred` = col_double()
## )
## Parsed with column specification:
## cols(
##   Compound_SMILES = col_character(),
##   Compound_InchiKeys = col_character(),
##   Compound_Name = col_character(),
##   UniProt_Id = col_character(),
##   Entrez_Gene_Symbol = col_character(),
##   DiscoveRx_Gene_Symbol = col_character(),
##   `pKd_[M]_pred` = col_double()
## )
## Parsed with column specification:
## cols(
##   Compound_SMILES = col_character(),
##   Compound_InchiKeys = col_character(),
##   Compound_Name = col_character(),
##   UniProt_Id = col_character(),
##   Entrez_Gene_Symbol = col_character(),
##   DiscoveRx_Gene_Symbol = col_character(),
##   `pKd_[M]_pred` = col_double()
## )
## Parsed with column specification:
## cols(
##   Compound_SMILES = col_character(),
##   Compound_InchiKeys = col_character(),
##   Compound_Name = col_character(),
##   UniProt_Id = col_character(),
##   Entrez_Gene_Symbol = col_character(),
##   DiscoveRx_Gene_Symbol = col_character(),
##   `pKd_[M]_pred` = col_double()
## )
## Parsed with column specification:
## cols(
##   Compound_SMILES = col_character(),
##   Compound_InchiKeys = col_character(),
##   Compound_Name = col_character(),
##   UniProt_Id = col_character(),
##   Entrez_Gene_Symbol = col_character(),
##   DiscoveRx_Gene_Symbol = col_character(),
##   `pKd_[M]_pred` = col_double()
## )
## Parsed with column specification:
## cols(
##   Compound_SMILES = col_character(),
##   Compound_InchiKeys = col_character(),
##   Compound_Name = col_character(),
##   UniProt_Id = col_character(),
##   Entrez_Gene_Symbol = col_character(),
##   DiscoveRx_Gene_Symbol = col_character(),
##   `pKd_[M]_pred` = col_double()
## )
## Parsed with column specification:
## cols(
##   Compound_SMILES = col_character(),
##   Compound_InchiKeys = col_character(),
##   Compound_Name = col_character(),
##   UniProt_Id = col_character(),
##   Entrez_Gene_Symbol = col_character(),
##   DiscoveRx_Gene_Symbol = col_character(),
##   `pKd_[M]_pred` = col_double()
## )
## Parsed with column specification:
## cols(
##   Compound_SMILES = col_character(),
##   Compound_InchiKeys = col_character(),
##   Compound_Name = col_character(),
##   UniProt_Id = col_character(),
##   Entrez_Gene_Symbol = col_character(),
##   DiscoveRx_Gene_Symbol = col_character(),
##   `pKd_[M]_pred` = col_double()
## )
## Parsed with column specification:
## cols(
##   Compound_SMILES = col_character(),
##   Compound_InchiKeys = col_character(),
##   Compound_Name = col_character(),
##   UniProt_Id = col_character(),
##   Entrez_Gene_Symbol = col_character(),
##   DiscoveRx_Gene_Symbol = col_character(),
##   `pKd_[M]_pred` = col_double()
## )
## Parsed with column specification:
## cols(
##   Compound_SMILES = col_character(),
##   Compound_InchiKeys = col_character(),
##   Compound_Name = col_character(),
##   UniProt_Id = col_character(),
##   Entrez_Gene_Symbol = col_character(),
##   DiscoveRx_Gene_Symbol = col_character(),
##   `pKd_[M]_pred` = col_double()
## )
names(data) <- ids$id

data_df <- bind_rows(data, .id = "id") %>%
  mutate(comp = paste0(Compound_InchiKeys,"_",UniProt_Id,"_",DiscoveRx_Gene_Symbol)) %>% 
  filter(comp %in% gold$comp)

data_df <- data_df %>% 
  left_join(ids)
## Joining, by = "id"
data_df_summary <- data_df %>% 
  # group_by(threshold, random_iteration, is_random, Compound_InchiKeys, UniProt_Id, DiscoveRx_Gene_Symbol) %>% 
  # ungroup() %>% 
  left_join(gold %>% select(Compound_InchiKeys, UniProt_Id, DiscoveRx_Gene_Symbol, `pKd_[M]`)) %>% 
  group_by(exp_label, threshold, is_random, random_iteration, no_of_removed) %>% 
  summarize(spearman = spearman_py(`pKd_[M]`,`pKd_[M]_pred`),rmse = rmse_py(`pKd_[M]`,`pKd_[M]_pred`)) %>% 
  arrange(threshold)
## Joining, by = c("Compound_InchiKeys", "UniProt_Id", "DiscoveRx_Gene_Symbol")
ggplot(data_df_summary)+
  geom_point(aes(x =  spearman, y =  rmse, color = as.character(threshold), shape = is_random)) +
  labs(x = "Spearman Correlation", y = "RMSE") +
  scale_color_manual(name="Tanimoto Threshold", values =c("0.1" = "#A80000",
                                                          "0.2" = "#C60968",
                                                          "0.3" = "#C60968",
                                                          "0.4" = "#C61D72",
                                                          "0.5" ="#C64787", 
                                                          "0.6" = "#C884A6",
                                                          "0.7" = "#E8CEE5", 
                                                          "0.8" = "#CCD5FF",
                                                          "0.9" = "#8EA3FF", 
                                                          "1" = "#637FFF",
                                                          "R2" = "#00BFFF")) +
  scale_shape_manual(name = "Random Control", values= c("Random Control"= 1,"Dropout Test"= 19,"All Training Data"=20))

ggplot(data_df_summary %>% group_by(exp_label, is_random, threshold, no_of_removed) %>% summarize(mean_rmse = mean(rmse), sd_rmse = sd(rmse)))+
  geom_col(aes(x = exp_label, y = mean_rmse), fill = "#75B3CE", stat = 'identity') +
  labs(y= "RMSE", x = "Tanimoto Threshold") +
  theme(axis.text.x = element_text( size=15, angle = 45, hjust = 1, vjust = 1)) +
  facet_grid(cols = vars(is_random), space = "free", scales = "free") +
  geom_text(aes(x = exp_label, y = mean_rmse, label = signif(no_of_removed/950732,2)), position = position_stack(vjust = 0.5), angle = 90, color = 'white') +
  geom_errorbar(aes(x = exp_label, y = mean_rmse, ymax = mean_rmse+sd_rmse, ymin = mean_rmse-sd_rmse))
## Warning: Ignoring unknown parameters: stat
## Warning: Ignoring unknown aesthetics: y
## Warning: Removed 5 rows containing missing values (geom_errorbar).

ggplot(data_df_summary %>% group_by(exp_label, is_random, threshold, no_of_removed) %>% summarize(mean_spearman = mean(spearman), sd_spearman = sd(spearman)))+
  geom_col(aes(x = exp_label, y = mean_spearman), fill = "#E5A467", stat = 'identity') +
  labs(y= "Spearman Correlation", x = "Tanimoto Threshold") +
  theme(axis.text.x = element_text(size=15, angle = 45, hjust = 1, vjust = 1)) +
  facet_grid(cols = vars(is_random), space = "free", scales = "free") +
  geom_text(aes(x = exp_label, y = mean_spearman, label = signif(no_of_removed/950732,3)), position = position_stack(vjust = 0.5), angle = 90, color = 'white') +
    geom_errorbar(aes(x = exp_label, y = mean_spearman, ymax = mean_spearman+sd_spearman, ymin = mean_spearman-sd_spearman))
## Warning: Ignoring unknown parameters: stat
## Warning: Ignoring unknown aesthetics: y
## Warning: Removed 5 rows containing missing values (geom_errorbar).

ggplot(data_df_summary %>% group_by(exp_label, is_random, threshold, no_of_removed) %>% summarize(mean_rmse = mean(rmse), sd_rmse = sd(rmse)) %>% filter(threshold > 0.1))+
  geom_col(aes(x = exp_label, y = mean_rmse), fill = "#75B3CE", stat = 'identity') +
  labs(y= "RMSE", x = "Tanimoto Threshold") +
  theme(axis.text.x = element_text(size=15, angle = 45, hjust = 1, vjust = 1)) +
  facet_grid(cols = vars(is_random), space = "free", scales = "free") +
  geom_text(aes(x = exp_label, y = mean_rmse, label = no_of_removed), position = position_stack(vjust = 0.5), angle = 90, color = 'white') +
  geom_errorbar(aes(x = exp_label, y = mean_rmse, ymax = mean_rmse+sd_rmse, ymin = mean_rmse-sd_rmse))
## Warning: Ignoring unknown parameters: stat
## Warning: Ignoring unknown aesthetics: y
## Warning: Removed 5 rows containing missing values (geom_errorbar).

ggplot(data_df_summary %>% group_by(exp_label, is_random, threshold, no_of_removed) %>% summarize(mean_spearman = mean(spearman), sd_spearman = sd(spearman))  %>% filter(threshold > 0.1))+
  geom_col(aes(x = exp_label, y = mean_spearman), fill = "#E5A467", stat = 'identity') +
  labs(y= "Spearman Correlation", x = "Tanimoto Threshold") +
  theme(axis.text.x = element_text(size=15, angle = 45, hjust = 1, vjust = 1)) +
  facet_grid(cols = vars(is_random), space = "free", scales = "free") +
  geom_text(aes(x = exp_label, y = mean_spearman, label = no_of_removed), position = position_stack(vjust = 0.5), angle = 90, color = 'white') +
    geom_errorbar(aes(x = exp_label, y = mean_spearman, ymax = mean_spearman+sd_spearman, ymin = mean_spearman-sd_spearman))
## Warning: Ignoring unknown parameters: stat
## Warning: Ignoring unknown aesthetics: y
## Warning: Removed 5 rows containing missing values (geom_errorbar).

spearman_scal_fact <- 950732*2
rmse_scal_fact <- 950732
  
ggplot(data_df_summary %>% group_by(exp_label, is_random, threshold, no_of_removed) %>% summarize(mean_rmse = mean(rmse), sd_rmse = sd(rmse)) %>% filter(threshold > 0.1))+
  geom_col(aes(x = exp_label, y = mean_rmse), fill = "#75B3CE", stat = 'identity') +
  labs(y= "RMSE", x = "Tanimoto Threshold") +
  theme(axis.text.x = element_text(size=15, angle = 45, hjust = 1, vjust = 1)) +
  facet_grid(cols = vars(is_random), space = "free", scales = "free") +
  geom_line(aes(x = exp_label, y = no_of_removed/rmse_scal_fact, group = 1)) +
  geom_errorbar(aes(x = exp_label, y = mean_rmse, ymax = mean_rmse+sd_rmse, ymin = mean_rmse-sd_rmse)) +
  scale_y_continuous(sec.axis = sec_axis(~., name = "Proportion of removed pairs")) 
## Warning: Ignoring unknown parameters: stat
## Warning: Ignoring unknown aesthetics: y
## geom_path: Each group consists of only one observation. Do you need to adjust
## the group aesthetic?
## Warning: Removed 5 rows containing missing values (geom_errorbar).

ggplot(data_df_summary %>% group_by(exp_label, is_random, threshold, no_of_removed) %>% summarize(mean_spearman = mean(spearman), sd_spearman = sd(spearman))  %>% filter(threshold > 0.1))+
  geom_col(aes(x = exp_label, y = mean_spearman), fill = "#E5A467", stat = 'identity') +
  labs(y= "Spearman Correlation", x = "Tanimoto Threshold") +
  theme(axis.text.x = element_text(size=15, angle = 45, hjust = 1, vjust = 1)) +
  facet_grid(cols = vars(is_random), space = "free", scales = "free") +
    geom_line(aes(x = exp_label, y = no_of_removed/spearman_scal_fact, group = 1)) +
    geom_errorbar(aes(x = exp_label, y = mean_spearman, ymax = mean_spearman+sd_spearman, ymin = mean_spearman-sd_spearman)) + 
  scale_y_continuous(sec.axis = sec_axis(~.*2, name = "Proportion of removed pairs")) 
## Warning: Ignoring unknown parameters: stat

## Warning: Ignoring unknown aesthetics: y
## geom_path: Each group consists of only one observation. Do you need to adjust
## the group aesthetic?
## Warning: Removed 5 rows containing missing values (geom_errorbar).

library(Cairo)

spearman_scal_fact <- 950732*2 ##total no of pairs to convert to ratio
rmse_scal_fact <- 950732*0.6
  
ggplot(data_df_summary %>% group_by(exp_label, is_random, threshold, no_of_removed) %>% summarize(mean_rmse = mean(rmse), sd_rmse = sd(rmse)) %>% filter(threshold > 0.1))+
  geom_bar(aes(x = exp_label, y = mean_rmse, fill = is_random), stat = 'identity', position = position_dodge(preserve = "single")) +
  labs(y= "RMSE", x = "Tanimoto Threshold") +
  theme(axis.text.x = element_text(size=15, angle = 45, hjust = 1, vjust = 1)) +
  geom_line(aes(x = exp_label, y = no_of_removed/rmse_scal_fact, group = 1)) +
  geom_point(aes(x = exp_label, y = no_of_removed/rmse_scal_fact, group = 1)) +
  geom_errorbar(aes(x = exp_label, y = mean_rmse, ymax = mean_rmse+sd_rmse, ymin = mean_rmse-sd_rmse, color = is_random), position = 'dodge') +
  scale_y_continuous(sec.axis = sec_axis(~.*0.6, name = "Proportion of removed pairs")) +
  scale_fill_manual(name = "Condition", values= c("Random Control"= "#BFBFBF","Dropout Test"= "#66666E","All Training Data"="#75B3CE")) +
  scale_color_manual(name = "Condition", values= c("Random Control"= "#BFBFBF","Dropout Test"= "#66666E","All Training Data"="#75B3CE"))
## Warning: Ignoring unknown aesthetics: y
## Warning: Removed 5 rows containing missing values (geom_errorbar).

ggsave("figure_3_rmse_r2_dmis_sim.pdf",  device = cairo_pdf,
              width = 9, height = 4.135, units = "in")
## Warning: Removed 5 rows containing missing values (geom_errorbar).
ggplot(data_df_summary %>% group_by(exp_label, is_random, threshold, no_of_removed) %>% summarize(mean_spearman = mean(spearman), sd_spearman = sd(spearman)) %>% filter(threshold > 0.1))+
  geom_bar(aes(x = exp_label, y = mean_spearman, fill = is_random), stat = 'identity', position = position_dodge(preserve = "single")) +
  labs(y= "Spearman correlation", x = "Tanimoto Threshold") +
  theme(axis.text.x = element_text(size=15, angle = 45, hjust = 1, vjust = 1)) +
  geom_line(aes(x = exp_label, y = no_of_removed/spearman_scal_fact, group = 1)) +
  geom_point(aes(x = exp_label, y = no_of_removed/spearman_scal_fact, group = 1)) +
  geom_errorbar(aes(x = exp_label, y = mean_spearman, ymax = mean_spearman+sd_spearman, ymin = mean_spearman-sd_spearman, color = is_random), position = 'dodge') +
  scale_y_continuous(sec.axis = sec_axis(~.*2, name = "Proportion of removed pairs")) +
  scale_fill_manual(name = "Condition", values= c("Random Control"= "#BFBFBF","Dropout Test"= "#66666E","All Training Data"="#E5A467")) +
  scale_color_manual(name = "Condition", values= c("Random Control"= "#BFBFBF","Dropout Test"= "#66666E","All Training Data"="#E5A467")) 
## Warning: Ignoring unknown aesthetics: y

## Warning: Removed 5 rows containing missing values (geom_errorbar).

ggsave("figure_3_spearman_r2_dmis_sim.pdf",  device = cairo_pdf,
              width = 9, height = 4.135, units = "in")
## Warning: Removed 5 rows containing missing values (geom_errorbar).

Make Similarity Histogram

sim_matrix_path<-syn$get('syn21363592')$path

sim_tidy <-read_csv(sim_matrix_path)
## Parsed with column specification:
## cols(
##   training_data_compound = col_character(),
##   test_data_compound = col_character(),
##   tanimoto_similarity = col_double()
## )
sim_subset <- filter(sim_tidy, test_data_compound %in% gold$Compound_SMILES) 
  # group_by(training_data_compound) %>% 
  # top_n(1, sim) %>%
  # sample_n(1) %>%  ##in case there are some with identical similarities to test set. we just need to pick one 
  # ungroup() %>% 
  # rename(cmpd = train_compounds) %>% 
  # inner_join(qed_pkd_tidy) %>% 
  # filter(!is.na(pKd))

bars<- tibble::tribble(
  ~max, ~cutoff, ~height,
  1, 0.15, -100000,
  1, 0.2, -150000,
  1, 0.3, -200000,
  1, 0.4, -250000,
  1, 0.5, -300000,
  1, 0.6, -350000,
  1, 0.7, -400000,
  1, 0.8, -450000,
  1, 0.9, -500000,
)

ggplot(sim_subset) + 
  geom_histogram(aes(x=tanimoto_similarity), boundary = 1, bins = 30)+
  geom_segment(data = bars, aes(x = max, y = height, xend = cutoff, yend = height), arrow = arrow(angle = 90, length = unit(0.025, "inches"),
       ends = "both", type = "open"), size = 1) +
  labs(x = "Minimum Tanimoto Similarity to Test Compounds", y = "Count")  +
  scale_x_continuous(breaks = seq(0,1, by = 0.2))

ggsave("figure_3_sim_hist_dmis.pdf",  device = cairo_pdf,
              width = 6, height = 4.135, units = "in")