#!/usr/bin/env Rscript

hlp = "MDS with Shannon-Jensen divergence (genus) and dominant phyla. "

# Paths
in.data = "./data/data.Robj"
out.dir = file.path("./output/composition/")
dir.create(out.dir, recursive = TRUE, showWarnings = FALSE)

# Colors
COLORS = c("#52A675", "#1982C4", "#FF595E", "#6A4C93", "#FFCA3A", "#45355A", "#8AC926")
names(COLORS) = c("Firmicutes", "Bacteroidetes", "Actinobacteria", "Protebacteria",
                  "Verrucomicrobia",  "Euryarchaeota", "Lentisphaerae")

# Libs 
require(philentropy)
require(ggplot2)

# Load data object
load(in.data)
obj = .GlobalEnv$obj

# Load phylum level
Sx = obj$assays$MSPCore$data
phylum = obj$assays$MSPCore$metadata$phylum

# Average by Phyla
agg = aggregate(t(Sx), by=list(Phylum=phylum), sum)
row.names(agg) = agg$Phylum
Px = t(agg[,2:ncol(agg)])
colnames(Px) = row.names(agg)

# Plot metadata
meta = obj$metadata
keep = intersect(row.names(meta), row.names(Px))
meta = meta[keep,]
Px = Px[keep,]
meta[,colnames(Px)] = Px[row.names(meta),]

# Count samples per cluster
agg = aggregate(meta$Cluster, by=list(Cluster=meta$Cluster), length)
colnames(agg) = c("Cluster", "N")
agg$Fraction = agg$N / sum(agg$N)
agg = agg[order(-agg$N),]
colnames(agg) = c("Cluster", "N", "Fraction")
row.names(agg) = agg$Cluster
meta$Fraction = agg[meta$Cluster, "Fraction"]
meta$N = agg[as.character(meta$Cluster), "N"]
meta$Label = sprintf("%s (%d)", meta$Cluster, meta$N)

# Bounds
mi = c(1.05 * min(meta$MDS1), 1.05 * max(meta$MDS1))
ma = c(1.05 * min(meta$MDS2), 1.05 * max(meta$MDS2))

# Plot overall cluster figure
fname = file.path(out.dir, sprintf("clusters.pdf"))
ggplot(data=meta, aes(x=MDS1, y=MDS2, color=Label)) +
  geom_point(alpha=0.5) + xlab("MDS 1") + ylab("MDS 2") +
  ggtitle(sprintf("")) +
  theme(legend.title = element_blank(),
        legend.position = "none") +
    theme(text = element_text(size=7)) + 
  scale_color_manual(values = c(as.vector(COLORS[1:3]))) +
  ggtitle("Clusters of samples")
ggsave(fname, width = 1.9, height = 1.8)
message(sprintf("Written %s", fname))

# Plot phyla
for(ent in colnames(Px)){
  fname = file.path(out.dir, sprintf("enterotypes_%s.pdf", ent))
  meta$Color = log(meta[,ent])
  mnz = meta[is.finite(meta$Color),]
  ggplot(data=mnz) + 
    geom_point(aes(x=MDS1, y=MDS2, color=Color), 
               show.legend = T,
               alpha=0.7) + 
    xlim(mi) + ylim(ma) + 
    xlab("MDS 1") + ylab("MDS 2") + 
    ggtitle(sprintf("%s", ent)) + 
    theme(legend.title = element_blank(),
          legend.position = "right",
          legend.key.size = unit(.5, "line"),
          legend.margin = margin(1,1,1,0)) +
    theme(text = element_text(size=7))
  ggsave(fname, width = 2.3, height = 1.8)
  message(sprintf("Written %s", fname))
}
