#!/usr/bin/env Rscript

hlp = " Visualization of cohorts in a common MDS plot. "


# Paths
in.data = "./data/data.Robj"
out.dir = file.path("./output/composition/")
dir.create(out.dir, recursive = TRUE, showWarnings = FALSE)

# Parameters
assay = "Species"

# Libs 
require(philentropy)
require(ggplot2)
require(Rtsne)
require(Maaslin2)
require(scales)

# Colors
COLORS = c("#52A675", "#1982C4", "#FF595E", "#6A4C93", 
           "#FFCA3A", "#45355A", "#8AC926")

# List available cohorts
cohorts = list(FG500="./data/data.500fg.Robj",
               BCG300="./data/data.Robj")


# Load data on cohorts
M = NULL
meta = NULL
for(coh in names(cohorts)){
  in.data = cohorts[[coh]]
  load(in.data)
  obj = .GlobalEnv$obj
  Sx = obj$assays[[assay]]$data
  Sx = Sx / rowSums(Sx)
  df = data.frame(row.names = row.names(Sx), 
                  Sample = row.names(Sx),
                  Cohort = coh)
  if(is.null(M)){
    M = Sx
    meta = df
  } else {
    keep = intersect(colnames(M), colnames(Sx))
    M = rbind(M[,keep], Sx[,keep])
    meta = rbind(meta, df)
  }
  message(sprintf("Adding cohort %s, keeping %d features in %s",
                  coh, ncol(M), assay))
}

# Re-normalize to kepth species
message(sprintf("Renormalizing %d features", ncol(M)))
M = M / rowSums(M)

# Add cluster from BCG
obj$metadata$Cluster[is.na(obj$metadata$Cluster)] = 1
obj$metadata$ClusterBCG = sprintf("BCG300 C%d", obj$metadata$Cluster)
meta$Cluster = "FG500"
keep = intersect(row.names(meta), row.names(obj$metadata))
meta[keep,"Cluster"] = obj$metadata[keep, "ClusterBCG"]

# Remove shared samples
shared.bcg = as.vector(row.names(obj$metadata[!is.na(obj$metadata$ID500FG),]))
shared.fg = obj$metadata[shared.bcg,]$ID500FG
shared.bcg = intersect(shared.bcg, row.names(meta))
shared.fg = intersect(shared.fg, row.names(meta))
shared = union(shared.bcg, shared.fg)
meta[shared, "Cluster"] = "Shared"
meta = meta[meta$Cluster != "Shared",]
M = M[row.names(meta),]

# Convert
meta$Cluster = factor(meta$Cluster)

# Compute distance between rows
set.seed(41)
D = JSD(M)
mds <- cmdscale(D, eig=FALSE, k=2)
meta$MDS1 = mds[,1]
meta$MDS2 = mds[,2]

# Compute t-SNE
set.seed(41)
tsne = Rtsne(D)
meta$TSNE1 = tsne$Y[,1]
meta$TSNE2 = tsne$Y[,2]

# Plot selected enterotype
fname = file.path(out.dir, sprintf("jsd_mds.pdf"))
ggplot(data=meta, aes(x=MDS1, y=MDS2, color=Cluster)) + 
  geom_point(alpha=.7) + xlab("MDS 1") + ylab("MDS 2") + 
  ggtitle(sprintf("")) + 
  theme(legend.title = element_blank(), text = element_text(size=7)) + 
  scale_color_manual(values = c(COLORS[1:3], "#FFDDD2")) + 
  theme(panel.background = element_blank(),
        axis.line = element_line(colour = "black"))
ggsave(fname, width = 2.7, height = 1.8)
message(sprintf("Written %s", fname))
dev.off()

# Plot selected enterotype
fname = file.path(out.dir, sprintf("jsd_tsne.pdf"))
ggplot(data=meta, aes(x=TSNE1, y=TSNE2, color=Cluster)) + 
  geom_point(alpha=0.7) + xlab("t-SNE 1") + ylab("t-SNE 2") + 
  ggtitle(sprintf("")) + 
  theme(legend.title = element_blank(), text = element_text(size=7)) + 
  scale_color_manual(values = c(COLORS[1:3], "#FFDDD2")) + 
  theme(panel.background = element_blank(),
  axis.line = element_line(colour = "black"))
ggsave(fname, width = 2.7, height = 1.8)
message(sprintf("Written %s", fname))
dev.off()
