#!/usr/bin/env Rscript

hlp = " Compare presence of genes to predict individual phenotypes.
        Use pathway information to narrow down the list further.
      "


# Read arguments
args = commandArgs(trailingOnly = TRUE)
if(length(args) < 1){
  message(hlp)
  message(sprintf("Wrong input"))
  message(sprintf("Usage: figure_msps_pathway_multi pathway"))
  q(1, save = "no")
}
target = args[1]
pathway = gsub("map", "ko", target)

# Libs
require(ggplot2)
require(ggpubr)
require(stringr)
require(pheatmap)
require(reshape2)

# Paths
in.data = "./data/data.Robj"
in.kgml = "./output/pathways/kgml/"
out.dir = file.path("./output/pathways/enzymes/")
img.dir = file.path(out.dir, "img")
dir.create(img.dir, recursive = T, showWarnings = F)

# Parameters
alpha = 0.05
genus = c("Bifidobacterium", "Roseburia",  "Ruminococcus", "Eubacterium", "Coprococcus")
COLORS = c("#F8766D","#00AEF3","#E76BF3","#00BF7D","#A3A502")
names(COLORS) = genus
focus = "Roseburia"

# Read pathway graph
in.ko = file.path(in.kgml, sprintf("%s.graph.csv", target))
kf = read.csv(in.ko, stringsAsFactors = F)

# Load data object
load(in.data)
obj = .GlobalEnv$obj

# Load MSPs
S = obj$assays$MSPCore$data
sm = obj$assays$MSPCore$metadata
msps = row.names(sm)[which(sm$genus %in% genus)]
message(sprintf("Found %d MPSs for %s", length(msps), genus))

# Load genes and filter by specie
G = obj$assays$MSPGenes$data
gm = obj$assays$MSPGenes$metadata
gmp = obj$assays$MSPGenes$mapping
gmp$MSPCore = sprintf("%s.core", gmp$MSP)
gmp = gmp[gmp$MSPCore %in% msps,]
genes = unique(gmp$Feature[gmp$MSPCore %in% msps])
message(sprintf("Found %d genes for %s", length(genes), genus))

# Plot aggregated genus abundance in terms of samples and RPK
keep = sm$genus %in% genus
Sg = S[,keep]
smg = sm[keep,]
message(sprintf("Keeping %d MSPs belonging to the genera", ncol(Sg)))
dsg = melt(Sg)
colnames(dsg) = c("Sample", "MSP", "Abundance")
dsg$Sample = as.character(dsg$Sample)
dsg$MSP = as.character(dsg$MSP)
dsg[,colnames(smg)] = smg[dsg$MSP,]
# Compute total genus prevalence
dsa = aggregate(dsg$Abundance, by=list(Sample=dsg$Sample, Genus=dsg$genus), sum)
# Compute total genus prevalence
dsc = aggregate(dsg$Abundance > 0, by=list(Sample=dsg$Sample, Genus=dsg$genus), sum)
dsp = aggregate(dsc$x > 0, by=list(Genus=dsc$Genus), mean)
row.names(dsp) = dsp$Genus
dsa$Label = sprintf("%s (%.2f %%)", dsa$Genus, 100 * dsp[dsa$Genus, "x"])
comparisons <- list(c("Roseburia", "Ruminococcus"),
                    c("Roseburia", "Bifidobacterium"))
fname = file.path(out.dir, "genus_total_abundance.pdf")
ggplot(data=dsa, mapping = aes(x=reorder(Genus, -x), y=log(1+x), col=Genus)) + geom_violin() +
  geom_jitter(alpha=0.1, size=1) + ylab("Log Abundance") + 
  theme(text = element_text(size=7)) + xlab("") + 
  theme(legend.position = "none") + 
  theme(axis.text.x = element_text(angle=90, hjust=1, vjust=.5)) + 
  stat_compare_means(comparisons = comparisons, method = "wilcox", size=2)
ggsave(fname, width = 1.2, height = 2.5)
message(sprintf("Written %s", fname))

# Filter by KO and rename
gm = gm[genes,]
keep = unlist(lapply(strsplit(gm$KEGG_KOs, ","), function(ks) length(intersect(ks, kf$EnzymeID)) > 0))
gm = gm[keep,]
kset = intersect(unique(unlist(strsplit(gm$KEGG_KOs, ","))), kf$EnzymeID)
message(sprintf("Found %d/%d enzymes for %s", length(kset), length(unique(kf$EnzymeID)), pathway))

# Map KO to unique
gm$KO = unlist(lapply(strsplit(gm$KEGG_KOs, ","), function(ks) intersect(ks, kf$EnzymeID)[1]))

# Set gene names
gnames = kf[,c("EnzymeID", "Enzyme")]
gnames = gnames[!duplicated(gnames$EnzymeID),]
row.names(gnames) = gnames$EnzymeID

# Filter RPK matrix and phenotype
G = G[,row.names(gm)]

# Plot difference by genus
gf = gmp[gmp$Feature %in% colnames(G), c("Feature", "MSPCore")]
gf$Genus = sm[gf$MSPCore, "genus"]
gf$KO = gm[gf$Feature, ]$KO
gf$GenusFactor = factor(gf$Genus, levels = genus)
gf = gf[order(gf$GenusFactor, gf$MSPCore),]
gf = gf[!duplicated(gf$Feature),]
row.names(gf) = gf$Feature
G = G[,row.names(gf)]
gff = gf[,c("Genus"), drop=F]

# Plot heatmaps
H = G + 0
H[H>0] = 1
cf = as.data.frame(C)
labels_col = sprintf("%s (%s)", gm$predicted_gene_name, gm$KEGG_KOs)
gaps_col = which(gff$Genus[1:nrow(gff)-1] != gff$Genus[2:nrow(gff)])
ann_colors = list(Genus=COLORS)
fname = file.path(out.dir, "heatmap_genes.pdf")
pheatmap(H, cluster_rows = T, cluster_cols = F,
         treeheight_row = 0,
         show_colnames = F,
         labels_col = labels_col, 
         annotation_col = gff, 
         gaps_col = gaps_col,
         annotation_legend = F,
         annotation_colors = ann_colors,
         color = c("white", "black"),
         legend = F,
         show_rownames = F, filename = fname, width = 1.5, height = 1.5)
message(sprintf("Written %s", fname))




# Pathway completeness RPKM
dg = melt(G)
dh = melt(H)
colnames(dg) = c("Sample", "Gene", "RPK")
colnames(dh) = c("Sample", "Gene", "Detected")
dh$RPK = dg$RPK
dh$Sample = as.character(dh$Sample)
dh$Gene = as.character(dh$Gene)
dh[,colnames(gf)] = gf[dh$Gene,]
dh$Enzyme = gnames[dh$KO, ]$Enzyme

# Compute data across all samples samples
agg = aggregate(t(G), by=list(KO=gf[colnames(G), "KO"], Genus=gf[colnames(G), "Genus"]), sum)
dj = melt(agg, id.vars = colnames(agg)[1:2])
colnames(dj) = c("KO", "Genus", "Sample", "RPK")
dj$KO = as.character(dj$KO)
dj$Genus = as.character(dj$Genus)
dj$Sample = as.character(dj$Sample)
dj$Enzyme = gnames[dj$KO, ]$Enzyme
dj$Label = sprintf("%s %s", dj$Enzyme, dj$KO)


# Compute exhaustive tests per Roseburia for each gene
others = setdiff(genus, focus)
dp = data.frame(Enzyme=unique(dj$Enzyme), Genus=focus, Pvalue=0, N=0, Other="", Mark="", stringsAsFactors = F)
row.names(dp) = dp$Enzyme
for(e in dp$Enzyme){
  x = dj[dj$Enzyme == e & dj$Genus == target, "RPK"]
  dp[e, "N"] = length(x)
  if(length(x) < 10) next;
  pval = 0
  other = ""
  for(o in others){
    y = dj[dj$Enzyme == e & dj$Genus == o, "RPK"]
    if(length(y) < 10) next;
    wt = wilcox.test(x, y, alternative = "greater")
    if(wt$p.value > pval){
      pval = wt$p.value
      other = o
    }
  }
  dp[e, "Pvalue"] = pval
  dp[e, "Other"] = other
}
dpf = dp[dp$N > 0,]
dpf = dpf[dpf$Pvalue < alpha,]
dpf$PvalueFormat = sprintf("%.2e", dpf$Pvalue)
agg = aggregate(dj$Enzyme, by=list(Label=dj$Label), unique)
row.names(agg) = agg$x
dpf$Label = agg[dpf$Enzyme, "Label"]
dpf$RPK = 0
positions = list("Bifidobacterium"=2.4, "Ruminococcus"=2.5, "Eubacterium"=3.5, "Coprococcus"=3 )
dpf$Pos = unlist(lapply(dpf$Other, function(o) positions[[o]]))

# Plot RPKM box
fname = file.path(out.dir, "rpk_genes_unique_ko_violin.pdf")
p = ggplot(data=dj, mapping = aes(x=reorder(Genus, -RPK), fill=Genus, col=Genus, y=log(1+RPK))) + 
  geom_jitter(alpha=.7, size=.2) + 
  geom_pointrange(col="black", size=0.07,
                  stat = "summary",
                  fun.min = function(z) {quantile(z,0.25)},
                  fun.max = function(z) {quantile(z,0.75)},
                  fun = median) + 
  geom_text(data=dpf, aes(x=Pos, y=9.5, label=PvalueFormat, group= NULL), col="black", size=2) +
  geom_segment(data=dpf, aes(x=Genus, xend=Other, y=8, yend=8), col="black", size=.5) + 
  facet_wrap(~reorder(Label, -RPK)) + 
  theme(axis.text.x = element_text(angle = 90, hjust = 1, vjust = 0.5)) +
  theme(legend.position = "none", text = element_text(size=7)) + 
  xlab("") + ylab("Log Abundance") + ylim(-0.1, 10)
ggsave(fname, width = 3.6, height = 4, plot=p)
message(sprintf("Written %s", fname))
fname = file.path(out.dir, "rpk_genes_unique_ko_violin.csv")
write.csv(dj, fname)
message(sprintf("Written %s", fname))

# Plot Roseburia specific genes
drb = dh[dh$Genus == focus,]
drb$Label = sprintf("%s %s", gnames[drb$KO, "Enzyme"], drb$KO)
drg = aggregate(drb$RPK, by=list(MSP=drb$MSPCore, Gene=drb$Label), sum)
drg$LogRPK = log(drg$x)
drg$Label = sm[drg$MSP, "label"]

# Reorder genes
levels = c(
  "hisC K00817", "mhpE K01666", "paaH K00074",
  "yhdR K11358", "paaF K01692", "hipO K01451",
  "enr K10797", "paaI K02614", "HPD K00457", 
  "paaK K01912", "aaaT K03825", "padE K18357",
  "aspB K00812")
drg$Gene = factor(as.character(drg$Gene), levels=rev(levels))
fname = file.path(out.dir, sprintf("detected_genes_%s.pdf", focus))
ggplot(data=drg, mapping = aes(x=reorder(Label, x), y=Gene, fill=LogRPK)) + 
  geom_tile() + theme(axis.text.x = element_text(angle = 90, hjust=1, vjust=.5)) +
  theme(legend.position = "none") + 
  xlab("") + ylab("") + 
  theme(text = element_text(size=7))
ggsave(fname, width = 2, height = 2.7)
message(sprintf("Written %s", fname))

# Figure with legend
# Rename enzymes
fname = file.path(out.dir, sprintf("detected_genes_%s_legend.pdf", focus))
ggplot(data=drg, mapping = aes(x=reorder(Label, x), y=Gene, fill=LogRPK)) +
  geom_tile() + theme(axis.text.x = element_text(angle = 90, hjust=1, vjust=.5)) +
  xlab("") + ylab("") +
  theme(text = element_text(size=7))
ggsave(fname, width = 2, height = 2.7)
message(sprintf("Written %s", fname))

# Write sumary data
fname = file.path(out.dir, sprintf("supp_table_ENZ_per_feature.csv"))
write.csv(dh, fname, row.names = F)
message(sprintf("Written %s", fname))
agg = aggregate(cbind(dh$RPK, dh$RPK > 0), by=list(Genus=dh$Genus, Feature=dh$Feature, KO=dh$KO, Enzyme=dh$Enzyme), mean)
colnames(agg) = gsub("V1", "MeanRPK", colnames(agg))
colnames(agg) = gsub("V2", "Prevalence", colnames(agg))
fname = file.path(out.dir, sprintf("supp_table_ENZ_per_feature_agg.csv"))
write.csv(agg, fname, row.names = F)
message(sprintf("Written %s", fname))

# Write summary data - per KO 
fname = file.path(out.dir, sprintf("supp_table_ENZ_per_ko.csv"))
write.csv(dj, fname, row.names = F)
message(sprintf("Written %s", fname))
agg = aggregate(cbind(dj$RPK, dj$RPK > 0), by=list(Genus=dj$Genus, KO=dj$KO, Enzyme=dj$Enzyme), mean)
colnames(agg) = gsub("V1", "MeanRPK", colnames(agg))
colnames(agg) = gsub("V2", "Prevalence", colnames(agg))
fname = file.path(out.dir, sprintf("supp_table_ENZ_per_ko_agg.csv"))
write.csv(agg, fname, row.names = F)
message(sprintf("Written %s", fname))
