#!/usr/bin/env Rscript

hlp = " Do significant species have a larger effect on metabolism? "

# Paths
in.data = "./data/data.Robj"
out.dir = file.path("./output/metabolomics/combined/")
dir.create(out.dir, recursive = TRUE, showWarnings = FALSE)

# Load previous results; immunomodulatory species and metabolomics correlations
in.mbx = "./output/metabolomics/correlations/results.csv"
in.anova.train = "./output/response/trained/results.csv"
in.anova.spec = "./output/response/specific/results.csv"

# Thresholds
padj.mbx = .05     # Adjusted p-value for metabolomics correlations
min.prev = 0.2     # Minimum prevalence

# Load data object
load(in.data)
obj = .GlobalEnv$obj

# Colors
TRAINED = "#0dafdb"
SPECIFIC = "#db7b14"
THRESHOLDS = c("white", "#8CC1E2", "#1982C4", "#314A7A")
names(THRESHOLDS) = c("N/A", "P < 0.20", "P < 0.10", "P < 0.05")
IMMUNOMODULATORY = c("white", SPECIFIC, TRAINED)
names(IMMUNOMODULATORY) = c("N/A", "Specific", "Trained")

# Libs
require(ggplot2)
require(ggpubr)
require(pheatmap)

# Read trained and specific response significant species
ant = read.csv(in.anova.train, row.names = 1)
ans = read.csv(in.anova.spec, row.names = 1)
row.names(ant) = ant$Feature
row.names(ans) = ans$Feature

# Load metabolites
mbx = read.csv(in.mbx, stringsAsFactors = F)
mbx$Significant = mbx$padj < padj.mbx
mbx$Mode = "N/A"
mbx$Threshold = "N/A"

# Make comparusisons with different thresholds
anv = data.frame(Feature=row.names(ant), row.names = row.names(ant),
                 Mode="N/A", Threshold="N/A", Pvalue = 1, 
                 Immunomodulation="N/A", stringsAsFactors = F)
for(p in c(0.2, 0.1, 0.05)){
  
  # Reset mode
  anv$Mode = "N/A"  
  
  # trained
  inxs = row.names(ant)[ant$Pvalue < p & ant$Prev > min.prev]
  anv[inxs, "Mode"] = sprintf("P < %.2f", p)
  anv[inxs, "Threshold"] = sprintf("P < %.2f", p)
  anv[inxs, "Pvalue"] = pmin(anv[inxs, "Pvalue"], ant[inxs, "Pvalue"])
  anv[inxs, "Immunomodulation"] = "Trained"

  # specific
  inxs = row.names(ans)[ans$Pvalue < p & ans$Prev > min.prev]
  anv[inxs, "Mode"] = sprintf("P < %.2f", p)
  anv[inxs, "Threshold"] = sprintf("P < %.2f", p)
  anv[inxs, "Pvalue"] = pmin(anv[inxs, "Pvalue"], ans[inxs, "Pvalue"])
  anv[inxs, "Immunomodulation"] = "Specific"
  
  # Combine two results files
  mbx$Mode = anv[mbx$Column, "Mode"]
  mbx$Threshold = anv[mbx$Column, "Threshold"]
  mbx$PvalueImm = anv[mbx$Column, "Pvalue"]
  mbx$Immunomodulation = anv[mbx$Column, "Immunomodulation"]
  mbx$Effect = abs(mbx$spearman)
  
  # Write test results
  ctab = table(mbx$Significant, mbx$Mode)
  ct = chisq.test(mbx$Significant, mbx$Mode)
  fname = file.path(out.dir, sprintf("MSP_effects_metabolites_unadj_%02d.txt", 100 * p))
  sink(fname)
  print(ct)
  sink()
  message(sprintf("Written %s", fname))
  fname = file.path(out.dir, sprintf("MSP_effects_metabolites_unadj_%02d.csv", 100 * p))
  write.csv(ctab, fname)
  message(sprintf("Written %s", fname))
  
  # Plot distribution of unadjusted effects
  fname = file.path(out.dir, sprintf("MSP_effects_metabolites_unadj_%02d.pdf", 100 * p))
  agg = aggregate(mbx$Effect, by=list(MSP=mbx$Column, Mode=mbx$Mode), max)
  agg$Mode = factor(as.character(agg$Mode), levels = names(THRESHOLDS))
  agg$Effect = agg$x
  ggplot(agg, aes(x=Effect, fill=Mode)) +
    geom_density(alpha=0.7) +
    theme(legend.title = element_blank()) +
    geom_vline(aes(xintercept=min(mbx[mbx$Significant, "Effect"])),
               linetype="dashed", color="gray") + 
    ylab("Normalized density") + xlab("Abs. Spearman coef.") + ylim(0, 10) +
    annotate(x=0.45, y=9.5, geom="text", 
             label=sprintf("P(Chi-Sq)\n < %.2e",ct$p.value), hjust=1, size=2) +
    theme(text = element_text(size = 7)) +
    scale_fill_manual(values = THRESHOLDS, drop=FALSE)
  ggsave(fname, width = 2.6, height = 1.6)
  message(sprintf("Written %s", fname))
}


####################
### Heatmap plot ###
####################


# Output significant species and metabolites attributed to significant species
rf = mbx[mbx$Significant,]
fname = file.path(out.dir, "MSP_mbx_filtered.csv")
write.csv(rf, fname)
message(sprintf("Written %d rows, %d MSPs, %d metabolites to %s", 
                nrow(rf), length(unique(rf$Column)),
                length(unique(rf$Row)), fname))

# Plot heatmaps with effects on the cytokines
comps = unique(rf$Row)
msps = unique(rf$Column)
X = matrix(NA, ncol = length(msps), nrow=length(comps))
row.names(X) = comps
colnames(X) = msps
X[cbind(rf$Row, rf$Column)] = rf$spearman
ar = rf[!duplicated(rf$Row),]; row.names(ar) = ar$Row
ac = rf[!duplicated(rf$Column),]; row.names(ac) = ac$Column
ar = ar[row.names(X),]
ac = ac[colnames(X),]
names_row = unlist(lapply(strsplit(ar$Top.annotation.name, ";"), function(t) t[1]))
labels_row = sprintf("%.3f %s", ar$ionMz, names_row)

# Label MSP
ac$mark = ""
ac[ac$Column == "msp_112.core", "mark"] = "*"
ac[ac$Column == "msp_181.core", "mark"] = "*"
labels_col = sprintf("%s %s", ant[ac$Column, "label"], ac$mark)

# Marks
# rff = mbx[prefilter,]
M = matrix("", ncol = length(msps), nrow=length(comps))
row.names(M) = comps; colnames(M) = msps
M[cbind(rf$Row, rf$Column)] = "o"

# Define annotations
arf = ar[,c("average_molecular_weight", "Majority_class")]
colnames(arf) = c("Molecular weight (Da)", "Majority class")

# Add pathway figure
pm = obj$assays$PathwaysKEGG$metadata
P = obj$assays$Metabolites$pathway
P = P[row.names(arf), ]
P = P[,colSums(P) > 2]
pm$Pathway = gsub("Biosynthesis of", "Biosynth.", pm$Pathway)
pm$Pathway = gsub("Degradation of", "Degr.", pm$Pathway)
colnames(P) = pm[colnames(P), "Pathway"]
arf[,colnames(P)] = P
ann_colors = list()
for(cn in colnames(P)){
  arf[[cn]] = factor(arf[[cn]])
  levels(arf[[cn]]) = c("N", "Y")
  ann_colors[[cn]] = c(N=0, Y="red")
}

# Add custom colors
ann_colors[["Immunomodulation"]] = IMMUNOMODULATORY
ann_colors[["Threshold"]] = THRESHOLDS

# Simplify annotations
acf = ac[,c("Threshold", "PvalueImm", "Immunomodulation"), drop=F]
colnames(acf) = c("Threshold", "P-value", "Immunomodulation")
acf$Threshold = factor(acf$Threshold, levels=c("N/A", "P < 0.20", "P < 0.10", "P < 0.05"))
ord = order(-acf$`P-value`)
acf = acf[ord, , drop=F]
X = X[,row.names(acf)]
M = M[,row.names(acf)]
labels_col = labels_col[ord]
gaps_col = which(acf$Threshold  != "N/A")[1]-1


# Compute clustering on imputed data
Y = X + 0
Y[is.na(Y)] = 0
hcr = hclust(dist(Y))
hcc = hclust(dist(t(Y)))

# Plot a heatmap
width = 7.5
height = 4.8
fname = file.path(out.dir, "MSP_mbx_heatmap.pdf")
pheatmap(X, 
         treeheight_row = 0,
         na_col = "white",
         annotation_row = arf, 
         annotation_col = acf,
         annotation_colors = ann_colors,
         labels_col = labels_col,
         labels_row = labels_row,
         cluster_cols = F,
         cluster_rows = hcr,
         gaps_col = gaps_col,
         legend = F,
         filename = fname,
         fontsize = 6,
         width=width, height = height)
message(sprintf("Written %s", fname))
