#!/usr/bin/env Rscript

hlp = " Compare presence of genes to predict individual phenotypes.
        Use pathway information to narrow down the list further.
      "

# Libs
require(ggplot2)
require(stringr)
require(pheatmap)
require(ggpubr)

# Paths
in.data = "./data/data.Robj"
in.kgml = "./output/pathways/kgml/"
out.dir = file.path("./output/pathways/phenylalanine/")
img.dir = file.path(out.dir, "img")
dir.create(img.dir, recursive = T, showWarnings = F)

# Parameters
alpha = 0.05
genus = "Roseburia"
target = "map00360"
pathway = "ko00360"
assay = "CytokineFold2"
cytokines = c("TNF", "IL1b")
stimuli = c("S. aureus")

# Read pathway graph
in.ko = file.path(in.kgml, sprintf("%s.graph.csv", target))
kf = read.csv(in.ko, stringsAsFactors = F)

# Load data object
load(in.data)
obj = .GlobalEnv$obj


# Compound mapping
cps = rbind(cbind(kf$SubstrateID, kf$Substrate),
            cbind(kf$ProductID, kf$Product))
colnames(cps) = c("CompoundID", "Compound")
cps = data.frame(cps, stringsAsFactors = F)
cps = cps[!duplicated(cps[,"CompoundID"]),]
row.names(cps) = cps$CompoundID

# Load MSPs
S = obj$assays$MSPCore$data
sm = obj$assays$MSPCore$metadata
if(genus == "All"){
  msps = row.names(sm)
} else {
  msps = row.names(sm)[which(sm$genus == genus)]  
}
message(sprintf("Found %d MPSs for %s", length(msps), genus))

# Load genes and filter by specie
G = obj$assays$MSPGenes$data
gm = obj$assays$MSPGenes$metadata
gmp = obj$assays$MSPGenes$mapping
gmp$MSPCore = sprintf("%s.core", gmp$MSP)
gmp = gmp[gmp$MSPCore %in% msps,]
genes = unique(gmp$Feature[gmp$MSPCore %in% msps])
gm = gm[gm$Feature %in% gmp$Feature,]
message(sprintf("Found %d genes for %s", length(genes), genus))

# Extract all genes on the pathway
dk = data.frame(stringsAsFactors = F)
for(k in unique(kf$EnzymeID)){
  if(k  == "") next;
  rows = grep(k, gm$KEGG_KOs)
  feats = row.names(gm)[rows]
  name = kf[kf$EnzymeID == k, "Enzyme"][1]
  if(length(feats))
    dk = rbind(dk, data.frame(EnzymeID=k, Feature=feats, Enzyme=name,stringsAsFactors = F))
}
message(sprintf("Found %d/%d features for %d enzymes",
                nrow(dk), ncol(G), length(unique(dk$EnzymeID))))

G = G[,dk$Feature]
agg = aggregate(t(G), by=list(EnzymeID=dk$EnzymeID), sum)
Ge = t(agg[,2:ncol(agg)])
colnames(Ge) = agg[,1]

# Filter by KO
gm = gm[colnames(Ge),]
message(sprintf("Found %d/%d enzymes for %s", ncol(Ge), length(unique(dk$EnzymeID)), genus))



# Load phenotypes
C = obj$assays[[assay]]$data
cm = obj$assays[[assay]]$metadata
cm = cm[cm$Cytokine %in% cytokines & cm$Stimulus %in% stimuli,]
C = C[,row.names(cm)]
keep = !is.na(rowSums(C))
C = C[keep,]

# Filter metabolites by pathway
M = obj$assays$Metabolites$data
mm = obj$assays$Metabolites$metadata
P = obj$assays$Metabolites$pathway
mbx = names(which(P[,target] > 0))
M = M[,mbx]
mm = mm[mbx,]
M = scale(M)
message(sprintf("Keeping %d compounds for %s", length(mbx), target))

# Filter and order samples
keep = intersect(intersect(row.names(G), row.names(C)), row.names(M))
C = C[keep,]
Ge = Ge[keep,]
M = M[keep,]
message(sprintf("Keeping %d samples", length(keep)))

# Reorder by sum
ord = row.names(C)[order(rowSums(C))]
C = C[ord,]
M = M[ord,]
Ge = Ge[ord,]

# Estimate pathway activity
pa = apply(M, 1, sum)
ep = rowSums(log(1+Ge))
fname = file.path(out.dir, "ep_pa_box.pdf")
ggplot(mapping = aes(x=ep, y=pa)) + 
  geom_jitter(alpha=.5) +
  xlab("Enzyme abundance (log)") + ylab("Phenylalanine metabolism activity") +
  stat_cor(method = "spearman") + geom_smooth(method = "lm") + 
  ggtitle(genus)
ggsave(fname, width = 3, height = 3)
message(sprintf("Written %s", fname))


# Rename compounds to show and add enzyme names
mm["mb_0535", "Top.annotation.name"] = "Phenylacetylglutamine"
mm["mb_0120", "Top.annotation.name"] = "Phenylpropanoate"
mm["mb_0157", "Top.annotation.name"] = "Phenylpyruvate"


# Plot a combined figure or enzyme presence and pathway activity.
df = expand.grid(EnzymeID=colnames(Ge), Mb=colnames(M), stringsAsFactors = F)
for(i in 1:nrow(df)){
  e = df[i, "EnzymeID"]
  m = df[i, "Mb"]
  ct = cor.test(Ge[,e], M[,m], method = "spearman")
  df[i, "Spearman"] = ct$estimate
  df[i, "Pvalue"] = ct$p.value
}

# Rename enzymes
levels = c(
  "hisC K00817", "mhpE K01666", "paaH K00074",
  "yhdR K11358", "paaF K01692", "hipO K01451",
  "enr K10797", "paaI K02614", "HPD K00457",
  "paaK K01912", "aaaT K03825", "padE K18357",
  "aspB K00812")
dkk = dk[!duplicated(dk$EnzymeID),]
row.names(dkk) = dkk$EnzymeID
df$Enzyme = dkk[df$EnzymeID, "Enzyme"]
df$Label = sprintf("%s %s", df$Enzyme, df$EnzymeID)
df$Label = factor(as.character(df$Label), 
                  levels=rev(levels))

# Order of molecules
levels = c(
  "164.072 L-Phenylalanine",
  "180.066 L-Tyrosine",
  "163.040 Phenylpyruvate",
  "87.009 Pyruvate",
  
  "115.004 Fumarate",
  "117.019 Succinate",
  "119.050 Phenylacetaldehyde",
  "165.056 Phenyllactate",
  "206.082 N-Acetyl-L-phenylalanine",
  
  "151.040 2-Hydroxyphenylacetate",
  "121.066 Phenylethyl alcohol",
  "134.061 2-Phenylacetamide",
  "178.051 Hippurate",
  "263.104 Phenylacetylglutamine",
  "149.061 Phenylpropanoate"
)
df$Rename = ""
for(i in 1:nrow(df)){
  m = df[i, "Mb"]
  ids = gsub(" ", "", unlist(strsplit(mm[m,]$Top.annotation.ids, ";")))
  matches = intersect(ids, row.names(cps))
  if(length(matches)){
    df[i, "Rename"] = sprintf("%.3f %s", mm[m, ]$ionMz, cps[matches, "Compound"][1])
  }
}
df$Rename = gsub("192.067 Hippurate", "178.051 Hippurate", df$Rename)
df$Molecule = factor(as.character(df$Rename), levels = rev(levels))


# Filter and plot
dff = df[df$Pvalue < alpha,]
fname = file.path(out.dir, sprintf("correlations_cps_%s.pdf", genus))
ggplot(data=dff) + geom_tile(mapping = aes(x=Molecule, y=Label, fill=Spearman)) + 
  theme(axis.text.x = element_text(angle = 90, vjust = .5, hjust = 1)) + xlab("") + ylab("") +
  # scale_fill_gradient(low = "blue", high = "firebrick", na.value = NA) + 
  theme(text= element_text(size = 7)) + 
  theme(legend.margin = ggplot2::margin(0,0,0,0)) + 
  theme(legend.position = "none") + 
  scale_y_discrete(drop=FALSE)
ggsave(fname, width = 2., height = 2.7) 
message(sprintf("Written %s", fname))  
fname = file.path(out.dir, sprintf("correlations_cps_%s_legend.pdf", genus))
ggplot(data=dff) + geom_tile(mapping = aes(x=Molecule, y=Label, fill=Spearman)) + 
  theme(axis.text.x = element_text(angle = 90, vjust = .5, hjust = 1)) + xlab("") + ylab("") +
  theme(text= element_text(size = 7)) + 
  theme(legend.margin = ggplot2::margin(0,0,0,0)) + 
  theme(legend.position = "right") + 
  scale_y_discrete(drop=FALSE)
ggsave(fname, width = 2.6, height = 2.8) 
message(sprintf("Written %s", fname))  
fname = file.path(out.dir, sprintf("correlations_cps_%s.csv", genus))
write.csv(x=dff, file = fname, row.names = F)
message(sprintf("Written %s", fname)) 



