############################################################################
# Description:
# This script iterates through a directory of ADAGE model weight matrices and
# determines the pathways significant in each model after crosstalk is removed.
#
# Usage:
#      Rscript pathway_coverage_no_crosstalk.R -n netfolder -o outfolder
#       -k keggfile -d datafile -r replace -c cores
#
#      netfolder: a directory containing ADAGE models
#      outfolder: a directory to store the output files
#      keggfile: file path to 'pseudomonas_KEGG_terms.txt'
#      datafile: the training compendium
#      replace_substring: the resulting file will keep the naming convention
#         of the input file, save for this substring (replaced by 'SigPathways')
#        [default: 'network']
#      cores_n: number of cores used to run crosstalk removal on models
#        in parallel.
#        [default: (# available cores) - 1]
#      std_HW: number of standard deviations from the mean a gene weight must
#        be to be considered high weight (HW)
#        [default: 2.5]
#
# Output:
# .txt files specifying the significant pathways found in the model.
# Each row will contain the following columns:
#   model size, node #, side (pos/neg), pathway, p-value, q-value
############################################################################

library(parallel)
library(optparse)

##############################
# GLOBAL VARIABLES
##############################

option_list <- list(
    make_option(c("-n", "--netfolder"), type = "character",
                help = "path to directory containing ADAGE models"),
    make_option(c("-o", "--outfolder"), type = "character",
                help = "path to the directory that will store output files"),
    make_option(c("-k", "--keggfile"), type = "character",
                help = paste("path to the KEGG pathway definitions",
                             "'pseudomonas_KEGG_terms.txt'")),
    make_option(c("-d", "--datafile"), type = "character",
                help = "path to the training compendium"),
    make_option(c("-r", "--replace"), type = "character", default = "network",
                help = paste("substring to replace in the input file when",
                             "generating the output file name")),
    make_option(c("-c", "--cores"), type = "integer",
                default = detectCores() - 1,
                help = paste("run crosstalk removal on models in parallel;",
                             "defaults to 1 less than the number of available",
                             "cores [default %default]")),
    make_option(c("-s", "--std"), type = "double", default = 2.5,
                help = paste("number of standard deviations from the mean a",
                             "gene weight must be to be considered high weight",
                             "(HW) [default %default]"))
    )

opt <- parse_args(OptionParser(option_list = option_list))

netfolder <- opt$netfolder
KEGG_file <- opt$keggfile
data_file <- opt$datafile
outfolder <- opt$outfolder
replace_substring <- opt$replace
cores_N <- opt$cores
std_HW <- opt$std


# read in the gene IDs from data file
col_n <- count.fields(data_file, sep = "\t")[1]
gene_ids <- read.table(data_file, sep = "\t", header = T,
  colClasses = c("character", rep("NULL", col_n - 1)))
gene_n <- nrow(gene_ids)

KEGG <- read.table(KEGG_file, sep = "\t", header = F,
  row.names = 1, stringsAsFactors = F)
krows <- nrow(KEGG)
KEGG_genes <- unique(unlist(strsplit(KEGG[, 2], ";")))
files <- list.files(path = netfolder, pattern = "*_network_ADAGE.txt",
  full.names = F, recursive = FALSE)

##############################
# HELPER FUNCTIONS
##############################

UpdateStep <- function(X, pi_old) {
  # Updates probability vector for each iteration of EM in MaxImpactEst.
  #
  # Args:
  #   X: observed gene-to-pathway membership matrix (n-by-k)
  #   pi_old: current vector of probabilities. an element at index j
  #           corresponds to the probability that, given a gene g_i, g_i gives
  #           the greatest fraction of its impact to a pathway j. (length k)
  # Returns:
  #   pi_new: updated probability vector (length k)

  x_wts_vec <- c()
  for (r in 1:length(KEGG_genes)) {
    # X[r,]: a gene r and its observed presence/absence in k pathways
    # collect the weighted presence/absence of gene r in each pathway,
    # where the weight is a function of the current probability vector. 
    x_wt <- X[r, ] %*% pi_old
    x_wts_vec <- c(x_wts_vec, x_wt)
  }
  new_numerator <- c()
  col_transform <- c()
  for (k in 1:krows) {
    # X[,k] corresponds to our observed gene annotations to pathway k.
    # multiply these by the current probabilities, normalize by the
    # weightings for gene presence/absence in pathway k, and then
    # take the sum to get a new expected value
    transform_col <- sum(X[, k] * pi_old[k]/x_wts_vec)
    new_numerator <- c(new_numerator, transform_col)
  }

  new_denominator <- sum(new_numerator)
  # divide by the sum of all k elements in the vector
  # so that the values in pi_new (the probabilities) add up to one.
  pi_new <- new_numerator/new_denominator
  return(pi_new)
}


MaxImpactEst <- function(gene_set) {
  # Determines the underlying pathway impact matrix: maps each gene
  # to the pathway it has the most impact in.
  #
  # Args:
  #   gene_set: vector of genes
  # Returns:
  #   out_map: a key/value list of lists where a pathway (key) is associated
  #   with a (value) list of genes as discovered by the EM algorithm.

  # Y is an indicator vector: 1 if KEGG gene g_i is in the gene_set,
  #                           0 otherwise.
  Y <- sapply(KEGG_genes, function(x) {
    as.numeric(x %in% gene_set)
  })

  n_kegg_genes <- length(KEGG_genes)
  X <- matrix(, nrow = n_kegg_genes, ncol = krows)  # observed matrix
  rownames(X) <- KEGG_genes
  colnames(X) <- rownames(KEGG)

  # should have the same dimensions and name, this is our
  # underlying pathway impact matrix--what we want to estimate.
  Z <- X

  # populate the membership matrix (n by k)
  for (pw in 1:krows) {
    pw_genes <- unlist(strsplit(KEGG[pw, 2], ";"))
    X[, pw] <- sapply(KEGG_genes, function(x) {
      as.numeric(x %in% pw_genes)
    })
  }

  # initialization step: weight vector of length k
  pi_0 <- colSums(X)/sum(X)
  pi_1 <- UpdateStep(X, pi_0)
  
  epsilon <- sqrt(sum((pi_1 - pi_0)^2))/100
  diff <- epsilon
  pi_old <- pi_1
  while (diff >= epsilon) {
    pi_new <- UpdateStep(X, pi_old)
    diff <- sqrt(sum((pi_new - pi_old)^2))
    pi_old <- pi_new
  }

  pi_final <- pi_new  # re-name variable for readability
  
  # populate our Z matrix
  for (gene in 1:n_kegg_genes) {
    zrow <- rep(0, krows)
    if (Y[gene] == 1) {
      X_i <- X[gene, ]
      denominator <- sum(X_i * pi_final)
      c_i <- (X_i * pi_final)/denominator  # conditional probabilities
      idx <- which.max(c_i)
      zrow[idx] <- 1
    }
    Z[gene, ] <- zrow
  }

  # for each row: find the col the 1 is located in
  find_pw_gene_map <- which(Z == 1, arr.ind = TRUE)
  fst_col_genes <- rownames(find_pw_gene_map)
  snd_col_pws <- colnames(Z)[find_pw_gene_map[, "col"]]
  names(snd_col_pws) <- fst_col_genes
  snd_col_pws <- sort(snd_col_pws)
  keys <- unique(snd_col_pws)
  out_map <- rep(list(list()), length(keys))
  names(out_map) <- keys
  for (i in 1:length(keys)) {
    j <- 1
    pw_genes <- c()
    while (j <= length(snd_col_pws)) {
      if (snd_col_pws[j] == keys[i]) {
        pw_genes <- c(pw_genes, names(snd_col_pws)[j])
      }
      j <- j + 1
    }
    idx_in_KEGG <- which(keys[i] == rownames(KEGG), arr.in = TRUE)
    all_pw_genes <- strsplit(KEGG[idx_in_KEGG, 2], ";")
    de_diff <- setdiff(gene_set, pw_genes)
    final_list <- setdiff(all_pw_genes[[1]], de_diff)
    out_map[keys[i]] <- list(final_list)
    
  }
  
  return(out_map)
}


ModelProcess <- function(netfile) {
  # Given an ADAGE model weight matrix, determine the pathways significant
  # in each of the nodes within the model. Write this information to a file.
  #
  # Args:
  #   netfile: filename for an ADAGE model weight matrix
  # Returns:
  #   NULL

  cat(netfile, "\n")
  path_to_netfile <- paste(netfolder, netfile, sep = "")
  weight <- read.table(path_to_netfile, header = F, skip = 2,
    sep = "\t", stringsAsFactors = F, fill = T)
  weight_matrix <- data.matrix(weight[1:gene_n, ])
  rownames(weight_matrix) <- gene_ids[, 1]
  netsize <- ncol(weight_matrix)
  
  pvalue_table <- matrix(, nrow = 0, ncol = 5)
  
  for (node in 1:netsize) {
    # get high-weight genes in a node
    gene_wts <- weight_matrix[, node]
    pos_cutoff <- mean(gene_wts) + std_HW * sd(gene_wts)
    neg_cutoff <- mean(gene_wts) - std_HW * sd(gene_wts)
    HW_genes_pos <- rownames(weight_matrix)[gene_wts >= pos_cutoff]
    HW_genes_neg <- rownames(weight_matrix)[gene_wts <= neg_cutoff]
    
    gene_set <- intersect(KEGG_genes, append(HW_genes_pos, HW_genes_neg))
    if (length(gene_set) > 0) {
      EM_map <- MaxImpactEst(gene_set)
      
      # KEGG enrichment analysis
      pvalue_list <- c()
      for (t in 1:length(EM_map)) {
        # build contingency table
        term_genes <- EM_map[[t]]
        high_in <- length(intersect(HW_genes_pos, term_genes))
        low_in <- length(term_genes) - high_in
        high_out <- length(HW_genes_pos) - high_in
        low_out <- gene_n - high_in - high_out - low_in
        contingency_table <- matrix(c(high_in, low_in, high_out, low_out), 
          nrow = 2)
        pvalue <- fisher.test(
          contingency_table, alternative = "greater")$p.value
        pvalue_list <- c(pvalue_list, pvalue)
      }
      this_table_pos <- cbind(paste("Node", rep(node, length(EM_map))),
        rep("pos", length(EM_map)), names(EM_map), pvalue_list)
      this_table_pos <- as.data.frame(this_table_pos)
      colnames(this_table_pos) <- c("Node", "side", "KEGG_terms", "pvalue")
      
      pvalue_list <- c()
      for (t in 1:length(EM_map)) {
        # build contingency table
        term_genes <- EM_map[[t]]
        high_in <- length(intersect(HW_genes_neg, term_genes))
        low_in <- length(term_genes) - high_in
        high_out <- length(HW_genes_neg) - high_in
        low_out <- gene_n - high_in - high_out - low_in
        contingency_table <- matrix(c(high_in, low_in, high_out, low_out), 
          nrow = 2)
        pvalue <- fisher.test(
          contingency_table, alternative = "greater")$p.value
        pvalue_list <- c(pvalue_list, pvalue)
      }
      
      this_table_neg <- cbind(paste("Node", rep(node, length(EM_map))),
        rep("neg", length(EM_map)), names(EM_map), pvalue_list)
      this_table_neg <- as.data.frame(this_table_neg)
      colnames(this_table_neg) <- c("Node", "side", "KEGG_terms", "pvalue")
      
      this_table <- as.data.frame(rbind(this_table_pos, this_table_neg))
      
      this_table$pvalue <- as.numeric(as.character(this_table$pvalue))
      this_table$padjust <- p.adjust(this_table$pvalue, method = "fdr")
      pvalue_table <- rbind(pvalue_table, this_table)
    }
  }

  pvalue_table <- as.data.frame(pvalue_table)
  colnames(pvalue_table) <- c("Node", "Side", "KEGG_terms", "pvalue", "padjust")
  
  # only keep significant pathways
  this_sig_pathways <- pvalue_table[which(pvalue_table$padjust <= 0.05), ]
  this_sig_pathways <- cbind(
    rep(netsize, nrow(this_sig_pathways)), this_sig_pathways)
  sig_pathways <- as.data.frame(this_sig_pathways, stringsAsFactors = F)
  colnames(sig_pathways) <- c(
    "netsize", "node", "side", "pathway", "pvalue", "qvalue")
  
  outfile <- gsub(replace_substring, "SigPathways", netfile)
  outfile_path <- paste(outfolder, outfile, sep = "")
  write.table(sig_pathways, outfile_path,
    row.names = F, col.names = T, quote = F, sep = "\t")
  sig_pathways$pathway <- as.character(sig_pathways$pathway)
  coverage <- length(table(sig_pathways$pathway))
  
  print(paste("This ADAGE model covers", coverage, "unique KEGG pathways."))
}

not_used <- mcmapply(ModelProcess, files, mc.cores=cores_N)
