#!/usr/bin/env Rscript

hlp = " Compute the effect of species on overall cytokine response for each specie on cytokines
        multiple times. Is there a systematic trend? "

# Read pipeline arguments
args = commandArgs(trailingOnly = TRUE)
mode = args[1]

# Paths
in.data = "./data/data.Robj"

out.dir = file.path("./output/response/scfas/", mode)
dir.create(out.dir, recursive = TRUE, showWarnings = FALSE)
img.dir = file.path(out.dir, "img")
dir.create(img.dir, recursive = TRUE, showWarnings = FALSE)


# Parameters
assay = "Metabolites"
rank.data = F
log.data = T
alpha = 0.2
min.prev = 0.2

# Colors
NEGATIVE = "#51BAFF"
NEUTRAL = "#FFFFFF"
POSITIVE = "#EE6845"

# Parameters
template = "Value ~ Cytokine + Age + Gender"
if(mode == "trained"){
  template = "Value ~ Age + Gender"
  cytokines = c("IL1b", "IL6", "TNF")
  stimuli = c("S. aureus")  
} else if (mode == "specific") {
  template = "Value ~ Age + Gender"
  cytokines = c("IFNg")
  stimuli = c("M. tuberculosis")  
} else {
  message(sprintf("Uknown mode"))
  quit(save="n", 1)
}

# Libs
require(pheatmap)
require(reshape2)
require(ggplot2)
require(ggpubr)
require(ggrepel)

# Load data object
load(in.data)
obj = .GlobalEnv$obj

# Keep morning vaccinations only
meta = obj$metadata
meta = meta[meta$Vacc_time == "Morning",]
message(sprintf("Keeping %d morning vaccinations", nrow(meta)))

# Combine responses
C2 = obj$assays$CytokineFold2$data
cm2 = obj$assays$CytokineFold2$metadata
C3 = obj$assays$CytokineFold3$data
cm3 = obj$assays$CytokineFold3$metadata
cm2$Time = 2
cm3$Time = 3
C = as.matrix(cbind(C2, C3))
cm = rbind(cm2, cm3)

# Scale values
C = as.matrix(scale(C))

# Load assay
A = obj$assays[[assay]]$data
am = obj$assays[[assay]]$metadata
colnames(A) = gsub(" ", "_", colnames(A))
if(rank.data){
  message(sprintf("Using ranked values"))
  A = apply(A, 2, rank, ties.method = "min") - 1  
}
if(log.data){
  message(sprintf("Using log(1+X) scale"))
  A = log10(1+A)
}

# Merge data frame
data = melt(C, na.rm = T)
colnames(data) = c("Sample", "Response", "Value")
data$Sample = as.character(data$Sample)
data$Response = as.character(data$Response)

# Intersect data frame
keep = intersect(intersect(data$Sample, row.names(meta)), row.names(C))
keep = intersect(keep, row.names(A))
data = data[data$Sample %in% keep,]
message(sprintf("Keeping %d samples", length(keep)))

# Merge data frame
data[,colnames(cm)] = cm[data$Response, ]
data[,colnames(A)] = A[data$Sample, ]
data[,colnames(meta)] = meta[data$Sample,]
stopifnot(length(unique(data$Vacc_time)) == 1)
data$Timepoint = as.factor(data$Time)
levels(data$Timepoint) = c("2 weeks", "3 months")

# Remove stimulus
data = data[data$Cytokine %in% cytokines,]
data = data[data$Stimulus %in% stimuli,]
message(sprintf("Keeping cytokines"))
message(paste0(unique(data$Cytokine), collapse = " "))
message(sprintf("Keeping stimuli"))
message(paste0(unique(data$Stimulus), collapse = " "))

# Create results data frames
rf = data.frame(Feature=colnames(A), N=0, Prev=0,
                Estimate=0, Std=0, Tvalue=0, Pvalue=1, Padj=1, Path="", stringsAsFactors = F)
rf[,colnames(am)] = am[rf$Feature,]
ff = data.frame()

# Add label
rf$Label = rf$Feature
if(assay == "MSPCore"){
  labels = rf$species
  labels[is.na(labels)] = rf$genus[is.na(labels)]
  labels[is.na(labels)] = rf$phylum[is.na(labels)]
  rf$Label = sprintf("%s / %s", rf$Feature, labels)
} 

# Test each specie
for(i in 1:nrow(rf)){
  f = rf$Feature[i]
  frm = as.formula(sprintf("%s + %s", template, f))
  frm.null = as.formula(template)
  
  # Fit model
  model = lm(frm, data)
  coef = summary(model)$coefficients
  colnames(coef) = c("Estimate", "Std", "Tvalue", "Pvalue")
  rf[i, colnames(coef)] = coef[f,]
  rf[i, "N"] = sum(A[,f] > 0)
  rf[i, "Prev"] = round(mean(A[,f] > 0), 2)
  
  # Save effect of each species on batch-corrected data
  model.null = lm(frm.null, data)
  data$Residual = model.null$residuals
  for(cy in unique(data$Cytokine)){
    for(st in unique(data$Stimulus)){
      da = data[data$Cytokine == cy & data$Stimulus == st,]
      if(nrow(da) < 2) next;
      ct = cor.test(da[,f], da[,"Residual"])    
      dv = data.frame(Cytokine=cy, Stimulus=st, Feature=f, Correlation=ct$estimate, Pvalue=ct$p.value)
      ff = rbind(ff, dv)
    }
  }

  # Plot a figure
  b = as.factor(data[,f] > 0)
  levels(b) = c("-", "+")
  if(i %in% c(2, 7)){
    width=1.3; height=2.2;
    if(mode == "trained") width=2.5; height=2.2;
    if(mode == "specific") width=1.3; height=2.2;
    rf[i, "Path"] = fname = file.path(img.dir, sprintf("effects_%s.pdf", f))
    ggplot(data=data, aes(x=data[,f], y=Value)) + 
      geom_point(show.legend = TRUE, col=alpha("#777777", 0.3), size=.5) + 
      facet_grid(Timepoint ~ Cytokine) + 
      xlab("Log10 Intensity") + 
      ylab("Fold change vs. baseline") +
      theme(axis.text.x = element_text(angle = 90, hjust = 1, vjust = .5)) +
      theme(legend.title = element_blank(),
            text = element_text(size=7)) + 
      stat_cor(aes(label = ..r.label..), method="pearson", size=2) + 
      geom_smooth(method="lm", col="black") + 
      ggtitle(sprintf("%s P=%.4f", rf[i, "Top.annotation.name"], rf[i, "Pvalue"]))
    ggsave(fname, width = width , height = height)
    message(sprintf("Written %s", fname)) 
  }
}

# Correct p-values
rf[rf$Prev > min.prev,]$Padj = p.adjust(rf[rf$Prev > min.prev,]$Pvalue, method = "fdr")
rf = rf[order(rf$Pvalue),]
rf$Keep = FALSE
if(mode == "trained"){
  pass = rf$Padj < alpha & rf$Prev > min.prev
} else {
  pass = rf$Padj < alpha & rf$Prev > min.prev
  # pass = rf$Padj < 1 & rf$Prev > min.prev
}
rf[pass, "Keep"] = TRUE
keep = rf[pass, "Feature"]

# Name species effects
rf$Sign = factor("Neutral", levels = c("Negative", "Neutral", "Positive"))
inxs.pos = pass & rf$Estimate > 0
inxs.neg = pass & rf$Estimate < 0
rf[inxs.pos, "Sign"] = "Positive"
rf[inxs.neg, "Sign"] = "Negative"

# Filter
rff = rf[rf$Keep,]
message(sprintf("Keeping %d features", nrow(rff)))

# Draw a volcano
t = 3
rn = rff[order(sign(rff$Estimate) * -log(rff$Padj)), ][1:t,]
rp = rff[rev(order(sign(rff$Estimate) * -log(rff$Padj))), ][1:t,]
rf$Prevalence = factor(rf$Prev > min.prev)
rf$LabelTmp = rf$label
rf[!(rf$Feature %in% c(rp$Feature, rn$Feature)), "LabelTmp"] = ""
levels(rf$Prevalence) = c("", sprintf(">%.0f %%", 100*min.prev))
fname = file.path(out.dir, "results_volcano.pdf")
qplot(data=rf, x=Estimate, y=-log10(Padj), geom="point", col=Sign, shape=Prevalence) + 
  geom_hline(yintercept = -log10(alpha), linetype="dashed", col="gray") + 
  ylab("-Log10 (FDR)") + xlab("Effect size") + 
  geom_text_repel(x=rf$Estimate, y=-log10(rf$Padj), label=rf$LabelTmp, show.legend = F, force = 5, size=2) + 
  scale_shape_manual(values = c(1, 16)) + ylim(c(-0.2, 4.2)) +
  scale_fill_manual(values=c(NEGATIVE, NEUTRAL, POSITIVE)) + 
  scale_color_manual(values=c(NEGATIVE, NEUTRAL, POSITIVE)) + 
  theme(legend.margin = margin(1,1,1,1)) + 
  theme(text = element_text(size=7))
ggsave(fname, width = 2.8, height = 2.2)
message(sprintf("Written %s", fname))
rf$LabelTmp = NULL

# Write results
fname = file.path(out.dir, "results.csv")
write.csv(rf, fname)
message(sprintf("Written %s", fname))

# Histogram of pvalues
fname = file.path(out.dir, "pvalues.pdf")
ggplot(data=rf, mapping = aes(x=Pvalue)) + xlab("P-value") + geom_histogram() + ylab("Count") +
  theme(text = element_text(size=7)) + 
  theme(panel.background = element_blank(), axis.line = element_line(colour = "black", size = .1))
ggsave(fname, width = 2., height = 1.1)
message(sprintf("Written %s", fname))
