#!/usr/bin/env Rscript

hlp = "Plot phylophlan tree"

# Paths
in.data = "./data/data.Robj"
out.dir = file.path("./output/phylophlan/")
img.dir = file.path(out.dir, "img")
dir.create(img.dir, recursive = TRUE, showWarnings = FALSE)

# Phylophlan input
bestTree = file.path(out.dir, "RAxML_bestTree.faa_refined.tre")


# Colors
levels = c("Roseburia", "Ruminococcus", "Eubacterium", "Coprococcus",
           "Lachnospira", "Blautia", "Acetatifactor",
            "Clostridium",  "Drancourtella")
colors = c("#00AEF3","#E76BF3","#00BF7D", "#A3A502",
            "#CCCCCC", "#CCCCCC", "#CCCCCC",
            "#CCCCCC","#CCCCCC","#CCCCCC")

# Colors
PHYLUM = c("#52A675", "#1982C4", "#FF595E", "#6A4C93", "#FFCA3A", "#45355A", "#8AC926")
names(PHYLUM) = c("Firmicutes", "Bacteroidetes", "Actinobacteria", "Protebacteria",
                  "Verrucomicrobia",  "Euryarchaeota", "Lentisphaerae")

# Libs
require(ggplot2)
require(ggtree)
require(stringr)
require(ape)

# Load data object for TZ
load(in.data)
obj = .GlobalEnv$obj
sm = obj$assays$MSPCore$metadata
sm$Phylum = factor(sm$phylum,
                   levels = c("Firmicutes", "Bacteroidetes", "Actinobacteria",
                              "Proteobacteria", "Verrucomicrobia", "Euryarchaeota",
                              "Lentisphaerae"))


# Make genera consistent with species level annotation
sm["msp_150.core", "genus"] = "Lachnospira"
sm["msp_328.core", "genus"] = "Clostridium"
sm["msp_329.core", "genus"] = "Clostridium"
sm["msp_039.core", "genus"] = "Ruminococcus"
sm["msp_120.core", "genus"] = "Ruminococcus"
sm["msp_292.core", "genus"] = "Ruminococcus"
sm["msp_081.core", "genus"] = "Ruminococcus"
sm["msp_088.core", "genus"] = "Eubacterium"
sm["msp_046.core", "genus"] = "Eubacterium"
sm["msp_169.core", "genus"] = "Coprococcus"

# Plot tree
if(!file.exists(bestTree)) next
treeC <- read.tree(bestTree)

# Remove
tree_full <- drop.tip(treeC, "msp_337")

# Small version
fname = file.path(out.dir, "tree_plot_phylum_small.pdf")
p <- ggtree(tree_full, layout="rectangular")
p$data$MSP = sprintf("%s.core", p$data$label)
p$data[,colnames(sm)] = sm[p$data$MSP,]
p <- p + geom_tiplab(aes(fill=Phylum, label=label), geom="label", 
                     size=1.3, label.padding = unit(0.1, "lines"),  label.size = 0) +
  ggplot2::scale_fill_manual(values=PHYLUM) + 
  theme(legend.position = "none")
ggsave(fname, width = 10, height = 24, plot = p)
message(sprintf("Written %s", fname))

# Save tree
fname = file.path(out.dir, "tree_plot_phylum.pdf")
p <- ggtree(tree_full, layout="rectangular")
p$data$MSP = sprintf("%s.core", p$data$label)
p$data[,colnames(sm)] = sm[p$data$MSP,]
p <- p + geom_tiplab(aes(fill=Phylum, label=label), geom="label", 
                     size=1.3, label.padding = unit(0.1, "lines"),  label.size = 0) +
    ggplot2::scale_fill_manual(values=PHYLUM) + 
  theme(legend.position = "none")
ggsave(fname, width = 22, height = 32, plot = p)
message(sprintf("Written %s", fname))
# With legend
fname = file.path(out.dir, "tree_plot_phylum_legend.pdf")
p <- ggtree(tree_full, layout="rectangular")
p$data$MSP = sprintf("%s.core", p$data$label)
p$data[,colnames(sm)] = sm[p$data$MSP,]
p <- p + geom_tiplab(aes(fill=Phylum, label=label), geom="label", 
                     size=1.3, label.padding = unit(0.1, "lines"),  label.size = 0) +
  ggplot2::scale_fill_manual(values=PHYLUM)
ggsave(fname, width = 22, height = 32, plot = p)
message(sprintf("Written %s", fname))

# Subset tree for main figure
# Trim tree
keep = sprintf(c("msp_%03d"),
               c(25, 101, 329, 244, 302, 195, 109, 120, 292, 
                 012, 162, 169, 81, 150, 213, 196, 328,
                 39, 155, 046, 088, 148, 2, 137, 51, 112, 032))
drop = setdiff(treeC$tip.label, keep)
treetr <- drop.tip(treeC, drop)
p <- ggtree(treetr, layout="rectangular")
p$data$MSP = sprintf("%s.core", p$data$label)
p$data[,colnames(sm)] = sm[p$data$MSP,]
p$data$genus = factor(p$data$genus, levels=levels)
fname = file.path(out.dir, "tree_plot_genus_mini.pdf")
p <- p + geom_tiplab(aes(fill=genus, label=label), geom="label", 
                     size=2., label.padding = unit(0.1, "lines"),  label.size = 0) + 
  theme(legend.position = "none") +
  theme(plot.margin = ggplot2::margin(0, 0, 0, 0)) + 
  ggplot2::scale_fill_manual(values=colors) + 
  ggplot2::xlim(0, 0.7)
ggsave(fname, width = 2.6, height = 3.5, plot = p)
message(sprintf("Written %s", fname))
