#!/usr/bin/env Rscript

hlp = " Compute CCA between MSPs and metabolite features.
        Emphasize immunomodulatory species.
      "

# Libs
require(ggplot2)
require(ggrepel)
require(reshape2)
require(Rtsne)

# Paths
in.data = "./data/data.Robj"
out.dir = file.path("./output/metabolomics/cca/")
dir.create(out.dir, recursive = TRUE, showWarnings = FALSE)
img.dir = file.path(out.dir, "img")
dir.create(img.dir, recursive = TRUE, showWarnings = FALSE)

# Load trained immunity and specific response results
in.results.train = "./output/response/trained/results.csv"
in.results.spec = "./output/response/specific/results.csv"

# Parameters
min.prev = 0.05
rank.data = T

# Unified color scheme
TRAINED = "#0dafdb"
SPECIFIC = "#db7b14"


# CCA between EnzymeID and metabolites, robust to rank deficieny and 
# compute p-values
cc.comp <- function(X, Y, nshuf=1000, rank=2){
  
  # Compute CCA
  Cov = t(X) %*% Y
  udv = svd(Cov, nu = rank, nv=rank)
  U = udv$u; V = udv$v; D=udv$d[1:rank]

  # Compute with shuffled data
  Nd = matrix(0, ncol = rank, nrow = nshuf)
  Var = matrix(0, ncol = rank, nrow = nshuf)
  for(n in 1:nshuf){
    Xi = X[sample(row.names(X)),]
    udv = svd(t(Xi) %*% Y, nu = rank, nv=rank)
    Dt = udv$d[1:rank]
    Nd[n,] = Dt > D
    Var[n,] = Dt
  }
  pvalue = colMeans(Nd) 
  row.names(U) = colnames(X)
  row.names(V) = colnames(Y)
  colnames(U) = sprintf("CCA%d", 1:rank)
  colnames(V) = sprintf("CCA%d", 1:rank)
  list(U=U, V=V, D=D, Var=Var, pvalue=pvalue)
}

# Load data object
load(in.data)
obj = .GlobalEnv$obj

# Load results
df.spec = read.csv(in.results.spec, row.names = 1, stringsAsFactors = F)
df.train = read.csv(in.results.train, row.names = 1, stringsAsFactors = F)
df.spec = df.spec[df.spec$Keep,]
df.train = df.train[df.train$Keep,]

# Load MSPs
S = obj$assays$MSPCore$data
sm = obj$assays$MSPCore$metadata
M = obj$assays$Metabolites$data
mm = obj$assays$Metabolites$metadata
M = scale(M)

# Set labels
sm$Label = sm$label

# Filter by prevalence
keep = colMeans(S > 0) > min.prev
S = S[,keep]
sm = sm[keep,]
message(sprintf("Keeping %d MSPs on prevalence", nrow(sm)))
if(rank.data){
  message(sprintf("Using ranked values"))
  S = apply(S, 2, rank, ties.method = "min") - 1  
}

# Intersect
keep = intersect(row.names(M), row.names(S))
M = M[keep,]
S = S[keep,]
message(sprintf("Keeping %d intersecting samples", nrow(S)))


# Compute projection and select components
results = cc.comp(M, S, rank = 40, nshuf = 1000)

# Store significance values
d = log(results$D)
V = log(results$Var)
dv = melt(V)
colnames(dv) = c("Replicate", "Component", "LogVariance")
dvm = aggregate(dv$LogVariance, by=list(Component=dv$Component), mean)
dvs = aggregate(dv$LogVariance, by=list(Component=dv$Component), sd)
dvv = data.frame(Replicate=1, Component=dvm$Component, Mean=dvm$x, Sd=dvs$x)
dvv$Type = "Random"
do = data.frame(Replicate=1, Component=1:length(d), Mean=d, Sd=0)
do$Type = "Observed"
dm = rbind(do, dvv)

# Plot significance
marks = unlist(lapply(pmin(3, floor(-log10(results$pvalue))), function(t) paste0(rep("*", t), collapse = "")))
fname = file.path(out.dir, "cca_plot_components.pdf")
ggplot(dm, aes(x=as.factor(Component), y=Mean, group=Type, color=Type)) + 
  geom_line() +
  geom_point()+
  geom_errorbar(aes(ymin=Mean-Sd, ymax=Mean+Sd), width=.2,
                position=position_dodge(0.05)) + 
  annotate(geom="text", x=do$Component+0.3, y=do$Mean, label=marks, hjust=0) + 
  xlab("Canonical Component") + ylab("Log Variance") + 
  theme(legend.title = element_blank())
ggsave(fname, width = 8.5, height = 3)
message(sprintf("Written %s", fname))


# Select final number of components
results = cc.comp(M, S, rank = 25, nshuf = 2)

# Create results
comps = colnames(results$U)
du = data.frame(scale(results$U)); du$Type = "Metabolite"
du$Class = mm[row.names(du), "Majority_class"]
du$Mode = NA
dv = data.frame(scale(results$V)); dv$Type = "MSP"
dv$Mode = "N/A"
dv[row.names(dv) %in% df.spec$Feature, "Mode"] = "Specific"
dv[row.names(dv) %in% df.train$Feature, "Mode"] = "Trained"
dv$Class = NA

# Save significances
fname = file.path(out.dir, "cca_plot_M_S.txt")
write.table(results$pvalue, fname)
message(sprintf("Written %s", fname))






# Compute projections
set.seed(42)
df = rbind(du, dv)
tsne = Rtsne(df[,comps])
df[,c("TSNE1", "TSNE2")] = tsne$Y

# Fill in type
df$Mode[is.na(df$Mode)] = "N/A"
df$Mode = factor(df$Mode)



# Main figures:  Set labels
inxs = row.names(df)[df$Mode != "N/A"]
df$Label = ""
df[inxs, "Label"] = sm[inxs, "Label"]


# Large figures
fname = file.path(out.dir, "cca_plot_mbx_XL.pdf")
ggplot(data=df) + geom_point(mapping = aes(x=TSNE1, y=TSNE2, col=Class, shape=Type)) + 
  xlab("t-SNE / CCA 1") + ylab("t-SNE / CCA 2") 
ggsave(fname, width = 9, height = 6)
message(sprintf("Written %s", fname))

fname = file.path(out.dir, "cca_plot_msp_XL.pdf")
ggplot(data=df) + geom_point(mapping = aes(x=TSNE1, y=TSNE2, col=Mode, shape=Type),
                             show.legend = F) + 
  xlab("t-SNE / CCA 1") + ylab("t-SNE / CCA 2") +
  geom_text_repel(x=df$TSNE1, y=df$TSNE2, label=df$Label, show.legend = F,
                  force = 1) + 
  scale_color_manual(values=c("white", "red", "blue"))
ggsave(fname, width = 6, height = 6)
message(sprintf("Written %s", fname))



# Main figures
focus = c(grep("MSP 112", df$Label), 
          grep("MSP 091", df$Label),
          grep("MSP 181", df$Label))
zout = setdiff(1:nrow(df), focus)
df[zout,]$Label = ""
fname = file.path(out.dir, "cca_plot_mbx.pdf")
ggplot(data=df) + geom_point(mapping = aes(x=TSNE1, y=TSNE2, col=Class, shape=Type),
                             alpha=.7, size=1) + 
  xlab("t-SNE / CCA 1") + ylab("t-SNE / CCA 2") + theme(text = element_text(size=7)) +
  theme(legend.margin = margin(1,1,1,1), legend.title = element_blank()) + 
  theme(legend.key.size = unit(1, "line"))
ggsave(fname, width = 4.5, height = 2.7)
message(sprintf("Written %s", fname))

fname = file.path(out.dir, "cca_plot_msp.pdf")
ggplot(data=df) + geom_point(mapping = aes(x=TSNE1, y=TSNE2, col=Mode, shape=Type),
                             show.legend = F, alpha=0.7, size=1) + 
  xlab("t-SNE / CCA 1") + ylab("t-SNE / CCA 2") +
  geom_text_repel(x=df$TSNE1, y=df$TSNE2, label=df$Label, show.legend = F, force=4.4, size=2) + 
  scale_color_manual(values=c("white", SPECIFIC, TRAINED)) +
  theme(text = element_text(size=7))
ggsave(fname, width = 2.7, height = 2.7)
message(sprintf("Written %s", fname))
