#!/usr/bin/env Rscript

hlp = "MDS with Shannon-Jensen divergence in species level to determine number of clusters."


# Paths
in.data = "./data/data.Robj"
out.dir = file.path("./output/composition/")
dir.create(out.dir, recursive = TRUE, showWarnings = FALSE)

# Parameters
method.hclust = "ward.D"  # Method for clustering
nboot = 100               # Number of boostrap iterations

# Libs
require(philentropy)
require(ggplot2)
require(pvclust)
require(pheatmap)
require(entropy)
require(reshape2)
require(ggtree)

# Load data object
load(in.data)
obj = .GlobalEnv$obj

# Load assays
Sx = obj$assays$MSPCore$data
meta = obj$metadata

# Prepare df
df = data.frame(sample=row.names(Sx), 
                row.names = row.names(Sx), 
                Cluster="",
                stringsAsFactors = F)

# Compute diversity
df$DiversityChao = rowSums(Sx > 0)
df$DiversityChaoShen = unlist(apply(Sx, 1, entropy.ChaoShen))
df$DiversityShannon =  unlist(apply(Sx, 1, entropy.empirical))

# Ad-hoc distance
# Compute distance at assays
hc = pvclust(Sx, 
             iseed = 42,
             method.dist = function(X) as.dist(JSD(X)),
             method.hclust = method.hclust, nboot = nboot)

# Store dendrogram
fname = file.path(out.dir, "jsd_dendrogram.pdf")
pdf(fname, width = 14, height = 5)
plot(hc, labels=FALSE)
dev.off()
message(sprintf("Written %s", fname))

# Determine cutoff

# Plot MDs
# Compute distance at assays
D = JSD(Sx)
mds <- cmdscale(D,eig=FALSE, k=2)
df$MDS1 = mds[,1]
df$MDS2 = mds[,2]

# Compute between and within cluster distance
dst <- function(D, cl){
  dm = melt(D)
  dm$Var1 = as.character(dm$Var1)
  dm$Var2 = as.character(dm$Var2)
  dm$C1 = cl[dm$Var1]
  dm$C2 = cl[dm$Var2]
  dm = dm[dm$Var1 < dm$Var2,]
  dm$Same = dm$C1 == dm$C2
  wi = mean(dm[dm$Same, "value"])
  be = mean(dm[!dm$Same, "value"])
  list(Between=be, Within=wi)
}

# Decide on the number of clusters
crange = 2:5
be = c()
wi = c()
for(nc in crange){
  cl = cutree(hc$hclust, nc)
  dl = dst(D, cl)
  be = c(be, dl$Between)
  wi = c(wi, dl$Within)
}
fname = file.path(out.dir, sprintf("jsd_distance.pdf"))
ggplot(mapping=aes(x=crange, y=be-wi)) + geom_point() + 
  geom_line(mapping=aes(x=crange, y=be-wi), linetype="dashed") + 
  xlab("Number of clusters") + ylab("Distance difference") +
  theme(text = element_text(size=7))
ggsave(fname, width = 1.8, height = 1.4)
message(sprintf("Written %s", fname))
nc = crange[which(max(be-wi) == (be-wi))]
message(sprintf("Decided for %d clusters", nc))
cl = cutree(hc$hclust, nc)

# Store dendrogram
fname = file.path(out.dir, "jsd_tree.pdf")
ggtree(hc$hclust) + geom_vline(xintercept = 300, linetype="dashed")
ggsave(fname, width = 1.2, height = 2.3)
message(sprintf("Written %s", fname))

# Count percentage of samples
df$Cluster = cl
agg = aggregate(df$Cluster, by=list(Cluster=df$Cluster), length)
colnames(agg) = c("Cluster", "N")
agg$Fraction = agg$N / sum(agg$N)
agg = agg[order(-agg$N),]
colnames(agg) = c("Cluster", "N", "Fraction")
row.names(agg) = agg$Cluster
df$Fraction = agg[df$Cluster, "Fraction"]
df$N = agg[as.character(df$Cluster), "N"]
df$Label = sprintf("%s (%d)", df$Cluster, df$N)

# Plot selected enterotype
fname = file.path(out.dir, sprintf("jsd_mds_clusters_%s.pdf", nc))
ggplot(data=df, aes(x=MDS1, y=MDS2, color=Label)) + 
  geom_point() + xlab("MDS 1") + ylab("MDS 2") + 
  ggtitle(sprintf("")) + 
  theme(legend.title = element_blank()) +
    theme(text = element_text(size=7))
ggsave(fname, width = 4.2, height = 3.1)
message(sprintf("Written %s", fname))
dev.off()

# Plot heatmap ordered by the clustering
fname = file.path(out.dir, sprintf("jsd_mds_heatmap_%s.pdf", nc))
o = hc$hclust$order
pdf(fname, width = 10, height = 10)
pheatmap(Sx[o,], 
         annotation_row = df[o, c("Label", "DiversityShannon")],
         show_rownames = F, show_colnames = F,
         cluster_rows = F)
dev.off()
message(sprintf("Written %s", fname))
