library(tidyverse)
library(hues)
library(ggforce)
library(patchwork)
library(ggh4x)
library(ggsankey)
library(ggstance)
# export all session info
sessionInfo()

bacdive_results <- read_csv("benchmarks/playground_output/bacdive_lpsn_comparison.csv")
ncbi_results <- read_csv("benchmarks/playground_output/ncbi_lpsn_comparison.csv")
lpsn_results <- read_csv("benchmarks/playground_output/lpsn_lpsn_comparison.csv")
ratatoskr_benchmarks <- read_csv("benchmarks/playground_output/ratatoskr_lpsn_comparison.csv")
datasets <- list(bacdive = bacdive_results, ncbi = ncbi_results, lpsn = lpsn_results, ratatoskr = ratatoskr_benchmarks)
graphs <- list()
for (name in names(datasets)) {
    # print(paste0("Creating sankey plot for ", name))
    # print(datasets[[name]])

  df <- datasets[[name]] %>%
    make_long(domain,kingdom,phylum,class,order,family,genus,species,subspecies,strain)
  # add a percentage column that shows the percentage of TRUE values at each level if the value of node is TRUE, else show 100 - percentage
  df <- df %>%
mutate(node_label = case_when(
    is.na(node) ~ "N/A",
    node == TRUE ~ "TRUE",
    node == FALSE ~ "FALSE"
)) %>%
add_count(x, next_x, node_label, wt = as.integer(node_label %in% c("TRUE","FALSE")), name = "count_label") %>%
add_count(x, next_x, wt = as.integer(node_label %in% c("TRUE","FALSE")), name = "count_group") %>%
mutate(percentage = paste0(round(100 * count_label / count_group, 1), "%")) %>%
filter(!is.na(node)) %>%
ungroup()

p <- ggplot(df, aes(x = x, next_x = next_x, node = node, next_node = next_node, fill = factor(node), label = percentage)) +

  geom_alluvial(flow.alpha = .6) +
  geom_alluvial_text(size = 2) +
  scale_fill_manual(values = c('TRUE'='#014701', 'FALSE'='#990101'), name = "Correct retrieval" ) +
  theme_sankey(base_size = 8) +
  scale_x_discrete(guide = guide_axis(n.dodge=3)) +
  labs(x = NULL)
ggsave(paste0("benchmarks/figures/sankey_", name, ".png"), p, width=7, height=4, dpi=300)
graphs[[name]] <- p
}

combined_plot <- (graphs$lpsn + graphs$bacdive) / (graphs$ncbi + graphs$ratatoskr) +
  plot_annotation(tag_levels = list(c("LPSN", "BacDive", "NCBI", "Ratatoskr"))) + plot_layout(guides = "collect") &
  theme(plot.title = element_text(size=16, face="bold"))
ggsave("benchmarks/figures/sankey_all_databases.png", width=9, height=4, dpi=300)

benchmark_times = read_csv("benchmarks/playground_output/benchmark_times.csv") %>%
mutate(level=factor(level, levels=c('domain','kingdom', 'phylum', 'class', 'order', 'family', 'genus')))
b1 <-ggplot(benchmark_times, aes(x=members, y=time_seconds)) +
  geom_smooth(method='lm', colour = "gray40") +
  geom_point(aes(color=level)) +
  labs(x="Number of members", y="Time (seconds)", title="Runtimes") +
  scale_color_iwanthue(cmin=30, cmax=80, lmin=35, lmax=80) +
  guides(color=guide_legend(nrow=1,byrow=TRUE)) +
  theme_bw()
ggsave("benchmarks/figures/benchmark_times.png", b1, width=7, height=4, dpi=300)
benchmark_times = read_csv("benchmarks/playground_output/benchmark_download_results.csv")%>%
mutate(level=factor(level, levels=c('domain','kingdom', 'phylum', 'class', 'order', 'family', 'genus')))
b2 <-ggplot(benchmark_times, aes(x=members, y=time_seconds)) +
  geom_smooth(method='lm', colour = "gray40") +
  geom_point(aes(color=level)) +

  labs(x="Number of members", y="Time (seconds)", title="Runtimes including downloads") +
  scale_color_iwanthue(cmin=30, cmax=80, lmin=35, lmax=80, drop=FALSE, guide = 'none') +
  theme_bw()
ggsave("benchmarks/figures/benchmark_download_times.png", b2, width=7, height=4, dpi=300)

patchwork_times <- b1 + b2 + plot_annotation(tag_levels = 'A') + plot_layout(guides = "collect", axes='collect') & theme(legend.position = 'bottom', ) 

ggsave("benchmarks/figures/benchmark_times_combined.png", patchwork_times, width=8, height=4, dpi=300)

ncbi_results %>% mutate(source="NCBI",
genome=ifelse(genome_seq & strain, 'Correct', ifelse(genome_seq, 'Incorrect', 'Missing')),
rRNA=ifelse(rRNA_seq & strain, 'Correct', ifelse(rRNA_seq, 'Incorrect', 'Missing'))) -> ncbi_results
bacdive_results %>% mutate(source="BacDive",
genome=ifelse(genome_seq & strain, 'Correct', ifelse(genome_seq, 'Incorrect', 'Missing')),
rRNA=ifelse(rRNA_seq & strain, 'Correct', ifelse(rRNA_seq, 'Incorrect', 'Missing'))) -> bacdive_results
lpsn_results %>% mutate(source="LPSN",
# genome=ifelse(genome_seq & strain, 'Correct', ifelse(genome_seq, 'Incorrect', 'Missing')),
genome='N/A',
rRNA=ifelse(rRNA_seq & strain, 'Correct', ifelse(rRNA_seq, 'Incorrect', 'Missing'))) -> lpsn_results
ratatoskr_benchmarks %>% mutate(source="Ratatoskr",
genome=ifelse(genome_seq & strain, 'Correct', ifelse(genome_seq, 'Incorrect', 'Missing')),
rRNA=ifelse(rRNA_seq & strain, 'Correct', ifelse(rRNA_seq, 'Incorrect', 'Missing'))) -> ratatoskr


all_results <- bind_rows(ncbi_results, bacdive_results, lpsn_results, ratatoskr) %>% select(genome,rRNA, source) %>% pivot_longer(cols=c(genome, rRNA), names_to="type", values_to="status")
ggplot(all_results, aes(x=source, fill=status)) +
  geom_bar(position='fill') +
  scale_y_continuous(labels=scales::percent_format()) +
  facet_nested(~type + source, scales = "free_x", space = "free") +
  labs(x="Data Source", y="Proportion of strains", fill="Sequence status", title="Sequence availability across data sources") +
  scale_fill_manual(values=c("Correct"="#1b9e77", "Incorrect"="#c20202", "Missing"="#7570b3", "N/A"="gray20")) +
  theme_bw()-> g1
ggsave("benchmarks/figures/genome_sequence_availability.png", g1, width=8, height=4, dpi=300)

# print table of percentages of rRNA and genome status per source
all_results %>% group_by(source, type, status) %>%
  summarise(count=n()) %>%
  group_by(source, type) %>%
  mutate(percentage = round(100 * count / sum(count), 2)) %>%
  ungroup() %>%
  pivot_wider(names_from=status, values_from=c(count, percentage)) %>%
  arrange(source, type) %>% select(source, type,  starts_with("percentage_")) 

# ggplot(all_results, aes(x=source, fill=rRNA_status)) +
#   geom_bar(position='fill') +
#   scale_y_continuous(labels=scales::percent_format()) +
#   labs(x="Data Source", y="Proportion of strains", fill="rRNA sequence status", title="rRNA sequence availability across data sources") +
#   theme_bw() -> g2
# ggsave("benchmarks/figures/rRNA_sequence_availability.png", g2, width=7, height=4, dpi=300)