#!/usr/bin/env Rscript

hlp = " Null distribution of pathway activities. "

# Libs
require(ggplot2)
require(stringr)
require(pheatmap)
require(ggpubr)

# Paths
in.data = "./data/data.Robj"
in.kgml = "./output/pathways/kgml/"
out.dir = file.path("./output/pathways/phenylalanine/")
img.dir = file.path(out.dir, "img")
dir.create(img.dir, recursive = T, showWarnings = F)

# Parameters
genus = "Roseburia"
target = "map00360"
pathway = "ko00360"

# Read pathway graph
in.ko = file.path(in.kgml, sprintf("%s.graph.csv", target))
kf = read.csv(in.ko, stringsAsFactors = F)

# Load data object
load(in.data)
obj = .GlobalEnv$obj

# Load MSPs
S = obj$assays$MSPCore$data
sm = obj$assays$MSPCore$metadata
if(genus == "All"){
  msps = row.names(sm)
} else {
  msps = row.names(sm)[which(sm$genus == genus)]  
}
message(sprintf("Found %d MPSs for %s", length(msps), genus))

# Load genes and filter by specie
G = obj$assays$MSPGenes$data
gm = obj$assays$MSPGenes$metadata
gmp = obj$assays$MSPGenes$mapping
gmp$MSPCore = sprintf("%s.core", gmp$MSP)
gmp = gmp[gmp$MSPCore %in% msps,]
genes = unique(gmp$Feature[gmp$MSPCore %in% msps])
gm = gm[gm$Feature %in% gmp$Feature,]
message(sprintf("Found %d genes for %s", length(genes), genus))
G = G[,row.names(gm)]

# Filter by KO
gm = gm[genes,]
keep = unlist(lapply(strsplit(gm$KEGG_KOs, ","), function(ks) length(intersect(ks, kf$EnzymeID)) > 0))
gme = gm[keep,]
kset = intersect(unique(unlist(strsplit(gm$KEGG_KOs, ","))), kf$EnzymeID)
message(sprintf("Found %d/%d enzymes for %s", length(kset), length(unique(kf$EnzymeID)), pathway))
Ge = G[,row.names(gme)]

# Extract all genes on the pathway
dk = data.frame(stringsAsFactors = F)
for(k in unique(kf$EnzymeID)){
  if(k  == "") next;
  rows = grep(k, gm$KEGG_KOs)
  feats = row.names(gm)[rows]
  name = kf[kf$EnzymeID == k, "Enzyme"][1]
  if(length(feats))
    dk = rbind(dk, data.frame(EnzymeID=k, Feature=feats, Enzyme=name,stringsAsFactors = F))
}
message(sprintf("Found %d/%d features for %d enzymes",
                nrow(dk), ncol(G), length(unique(dk$EnzymeID))))
agg = aggregate(t(G[,dk$Feature]), by=list(EnzymeID=dk$EnzymeID), sum)
Ge = t(agg[,2:ncol(agg)])
colnames(Ge) = agg[,1]

# Remaining genes
dif = setdiff(colnames(G), dk$Feature)
Gn = G[,dif]
message(sprintf("Keeping %d remaining genes", ncol(Gn)))

# Filter metabolites by pathway
M = obj$assays$Metabolites$data
mm = obj$assays$Metabolites$metadata
P = obj$assays$Metabolites$pathway
mbx = names(which(P[,target] > 0))
M = M[,mbx]
mm = mm[mbx,]
M = scale(M)
message(sprintf("Keeping %d compounds for %s", length(mbx), target))

# Filter and order samples
keep = intersect(row.names(G), row.names(M))
G = G[keep,]
Ge = Ge[keep,]
Gn = Gn[keep,]
M = M[keep,]
message(sprintf("Keeping %d samples", length(keep)))

# Prefix selected molecules 
mol1 = c("mb_0157", "mb_0161", "mb_0213")
# mol2 = c("mb_0535", "mb_0120", "mb_0090", "mb_0208")
mol2 = c("mb_0208", "mb_0535", "mb_0120")
en1 = c("K00817", "K01666", "K11358", "K01692", "K00074", "K01451")
en2 = c("K02614", "K00457", "K01912", "K03825", "K03825")
molecules = list(
  mol1 = mol1,
  mol2 = mol2,
  mol3 = mol1,
  mol4 = mol2
)
enzymes = list(
  en1 = en1,
  en2 = en2,
  en3 = en2,
  en4 = en1
)

# Colors
green = "#F89646"
blue = "#4F81BD"
colors = c(blue, green, blue, green)

# Compute null distribution
for(b in 1:length(enzymes)){
  mols = molecules[[b]]
  ens = enzymes[[b]]

  # Compute observed correlation
  n = length(ens)
  pa = rowSums(M[,mols])
  ea = rowSums(log(1+Ge[,ens]))
  ct = cor.test(ea, pa, method = "spearman")
  et = ct$estimate
  prev = mean(Ge[,ens] > 0)

  # Null distribution
  null = c()
  null_prev = c()
  boot = 10000
  for(i in 1:boot){
    ngs = sample(colnames(Gn), size = n)
    ean = rowSums(log(1 + Gn[,ngs]))
    ctn = cor.test(ean, pa, method = "spearman")
    e = ctn$estimate
    p = mean(Gn[,ngs] > 0)
    null = c(null, e)  
    null_prev = c(null_prev, p)
  }
  
  # Correlation
  col = colors[b]
  pval = ct$p.value
  mark = paste0(rep("*", floor(-log10(pval))), collapse = "")
  fname = file.path(out.dir, sprintf("correlation_branch_%d.pdf", b))
  ggplot(mapping = aes(x=ea, y=pa)) + geom_point(alpha=.2, size=1, col=col) + 
    geom_smooth(method = "lm", col="black") +
    xlab("") + ylab("") + 
    stat_cor(method="spearman",size=2, aes(label = ..r.label..)) + 
    theme(text=element_text(size=7),
          plot.margin=unit(c(1,.5,-2,-2), "mm"),
          panel.background = element_blank())
  ggsave(fname, width = 0.72, height = 0.72)
  message(sprintf("Written %s", fname))
  
  # Null distribution of effects
  if(b <= 2){
    pval = mean(null > et)
    hjust = 1.1
  } else{
    pval = mean(null < et)
    hjust = -0.1
  } 
  v = 3
  if(b %in% c(1, 3)) v=6
  mark = paste0(rep("*", floor(-log10(pval))), collapse = "")
  label = sprintf("P = %.2f%s", pval, mark)
  fname = file.path(out.dir, sprintf("ep_pa_null_branch_%d.pdf", b))
  ggplot(data=NULL, aes(x=null)) + geom_density(col="gray") + geom_vline(xintercept = et, col="black") + 
    xlab("") + ylab("") + 
    theme(text=element_text(size=7),
          plot.margin=unit(c(1,1.2,-2,-2), "mm"),
          panel.background = element_blank()) + 
    annotate(x=et, y=v, label=label, geom = "text",size=2, hjust=hjust)
  ggsave(fname, width = 0.72, height = 0.72)
  message(sprintf("Written %s", fname))
  
  fname = file.path(out.dir, sprintf("prevalence_branch_%d.pdf", b))
  ggplot(data=NULL, aes(x=null_prev)) + geom_density() + geom_vline(xintercept = prev) + 
    xlab("Prevalence") + ylab("Density") + 
    ggtitle(" ")
  ggsave(fname, width = 2, height = 2)
  message(sprintf("Written %s", fname))
}