#!/usr/bin/env Rscript

hlp = " Compute pairwise correlations between assays and cytokines. "


# Paths
in.data = "./data/data.Robj"
out.dir = file.path("./output/metabolomics/correlations/")
dir.create(out.dir, recursive = TRUE, showWarnings = FALSE)
fig.dir = file.path(out.dir, "img")
dir.create(fig.dir, recursive = TRUE, showWarnings = FALSE)


# Parameters
assay1 = "Metabolites"  # Assay 1
assay2 = "MSPCore"      # Assay 2
padj = 0.05             # Adjusted p-value

# Parameters
min.samples = 15
max.samples = 320
min.species = 2

# Libs
# source("~/Dev/bcg/experiments/fg300tanzania/scripts/utils.R")
require(pheatmap)
require(reshape2)
require(ggplot2)
require(ggpubr)
require(ubiomeTools)



# Load data object
load(in.data)
obj = .GlobalEnv$obj
meta = obj$metadata

# Filter pathways
R = NULL     # Row assays
C = NULL     # Column assay
ar = NULL    # Row annotation
ac = NULL    # Column annotation
lr = NULL    # Row labels
lc = NULL    # Column labels
lr.cols = NULL # Annotation columns
lc.cols = NULL # Annotation columns

# Load and filter
for(assay in c(assay1, assay2)){
  if(startsWith(assay, "Cytokine")){
    M = obj$assays[[assay]]$data
    pf = obj$assays[[assay]]$metadata
    if(assay == "Cytokines"){
      keep = pf$Time == 1
      M = M[,keep]
      pf = pf[keep,]
    }
    lab = pf$Cytokine
    anc = c("Associations", "Stimulus", "Serum")
  } else if (startsWith(assay, "Metabolites")){
      M = obj$assays[[assay]]$data
      pf = obj$assays[[assay]]$metadata
      # lab = pf$name
      # lab = unlist(lapply(lab, function(s) substr(s, 1, max(length(s), 15))))
      # lab = sprintf("%s: %s", pf$class, lab)
      lab = sprintf("%s: %s", pf$Top.HMDB, pf$name)
      anc = c("Associations", "super_class", "average_molecular_weight")
  }  else if (startsWith(assay, "MSPCore")){
    M = obj$assays[[assay]]$data
    pf = obj$assays[[assay]]$metadata
    if(is.null(pf)){
      pf = data.frame(row.names = colnames(M), Feature=colnames(M))
      pf$Num.Species = 1
    }
    pf$Num.Samples = colSums(M != 0, na.rm = T)
    keep = pf$Num.Samples > min.samples
    M = M[,keep]
    pf = pf[keep,]
    lab = pf$name
    anc = c("Associations") 
  } else {
      M = obj$assays[[assay]]$data
      pf = obj$assays[[assay]]$metadata
      if(is.null(pf)){
        pf = data.frame(row.names = colnames(M), Feature=colnames(M))
        pf$Num.Species = 1
      }
      pf$Num.Samples = colSums(M != 0, na.rm = T)
      keep = pf$Num.Samples > min.samples
      M = M[,keep]
      pf = pf[keep,]
      lab = row.names(pf)
      lab = unlist(lapply(lab, function(s) substr(s, 1, max(length(s), 25))))
      anc = c("Associations")
  }

  # Set rows and columns
  if(assay == assay1 && is.null(R)){
    R =  as.matrix(M)
    ar = data.frame(pf)
    lr = lab
    lr.cols = anc
    message(sprintf("Keeping %d rows", ncol(M)))
  } else {
    C = as.matrix(M)
    ac = data.frame(pf)
    lc = lab
    lc.cols = anc
    message(sprintf("Keeping %d columns", ncol(C)))
  }
}

# Merge 
keep = intersect(row.names(R), row.names(C))
message(sprintf("Keeping %d samples in both assays", length(keep)))
C = C[keep,]
R = R[keep,]

# Shuffle data
# ("Shuffling rows ...")
# R[sample(row.names(R)),] = R[,] + 0

# Construct matrices
S = matrix(0, ncol = ncol(C), nrow = ncol(R))
P = matrix(1, ncol = ncol(C), nrow = ncol(R))
A = matrix(1, ncol = ncol(C), nrow = ncol(R))
colnames(S) = colnames(P) = colnames(A) = colnames(C)
row.names(S) = row.names(P) = row.names(S) = colnames(R)

# Compute pearson correlation
g = expand.grid(row.names(S), colnames(S), stringsAsFactors = F)
for(i in 1:nrow(g)){
  fi = as.character(g[i,1])
  ci = as.character(g[i,2])
  x = R[,fi]
  y = C[,ci]
  inxs = !is.na(x) & !is.na(y)
  ct = cor.test(x[inxs], y[inxs], method = "spearman")
  S[fi, ci] = ct$estimate
  P[fi, ci] = ct$p.value
}
A[,] = matrix(p.adjust(P), ncol=ncol(A), nrow=nrow(A))

# Augment data frame
ar$Associations = rowSums(A < padj)
ac$Associations = colSums(A < padj)

# Complete data frame ; adjust p-values based on filtered data
df = melt(S)
colnames(df) = c("Row", "Column", "spearman")
df$pvalue = melt(P)$value
df$padj = melt(A)$value
df$Row = as.character(df$Row)
df$Column = as.character(df$Column)
df = cbind(df, ar[df$Row,])
df = cbind(df, ac[df$Column,])
df$mark = ""
df[df$padj < padj, "mark"] = "*"

# Keep rows
kra = ar$Associations >= 1
kca = ac$Associations >= 1
na  = sum(kra)
message(sprintf("Keeping %d rows and %d columns", sum(kca), sum(kra)))

# If enough associations found
if(na > 1){

  # Cluster rows
  Sa = S[kra,][,kca]
  Aa = A[kra,][,kca]
  if(!is.null(ar)) ar = ar[kra,]
  if(!is.null(ac)) ac = ac[kca,]
  P = matrix("", ncol = ncol(Sa), nrow=nrow(Sa)); P[Aa < padj] = "o";

  # Store heatmap
  amax = max(abs(Sa))
  fname = file.path(out.dir, "heatmap_assoc.pdf")
  pdf(fname, width = 12.5 + ncol(Sa) * .1, height = 6.5 + nrow(Sa) * .1)
  pheatmap(Sa, 
           na_col = "white",
           annotation_row = ar[,lr.cols, drop=F],
           annotation_col = ac[,lc.cols, drop=F],
           show_rownames = TRUE, 
           show_colnames = TRUE,
           cluster_rows = TRUE,
           cluster_cols = TRUE,
           labels_col = lc[kca],
           labels_row = lr[kra],
           display_numbers = P)
  dev.off()
  message(sprintf("Written %s", fname))
}

# Write results
df = df[order(df$pvalue),]
fname = file.path(out.dir, "results.csv")
write.csv(df, fname)
message(sprintf("Written %s", fname))

# Write results - filtered
dff = df[df$padj <= padj,]
dff$path = ""

# Save
fname = file.path(out.dir, "results_filtered.csv")
write.csv(dff, fname)
message(sprintf("Written %d rows to %s", nrow(dff), fname))