################################################################################
# This script extracts high-weight genes from a model. It saves high-weight
# genes of all features into one file and also writes high-weight genes per node
# into its own file.
#
# Usage:
#     Rscript write_HWG.R networkFile sd_cutoff expressionFile HWFile HWG_folder
#
#     networkFile: file path to the network file of an ADAGE model
#     sd_cutoff: the standard deviation cutoff that defines HW genes
#     expressionFile: file path to the training expression compendium
#     HWFile: filename of the output file that contains HW genes for each node
#     HWG_folder: file folder to write HW gene list for each node
#
################################################################################

source("../data_collection/IDconverter.R")

############ load in arguments

networkFile <- commandArgs(trailingOnly = TRUE)[1]
sd_cutoff <- as.numeric(commandArgs(trailingOnly = TRUE)[2])
expressionFile <- commandArgs(trailingOnly = TRUE)[3]
HWFile <- commandArgs(trailingOnly = TRUE)[4]
HWG_folder <- commandArgs(trailingOnly = TRUE)[5]

############ load in data

express <- read.table(expressionFile, header = TRUE, row.names = 1, sep = "\t")
geneID <- rownames(express)
# convert PA numbers to gene names
geneName <- IDconverter(IDFile, geneID)
genesize <- nrow(express)
weight <- read.table(networkFile, header = FALSE, skip = 2, sep = "\t",
                     stringsAsFactors = FALSE, fill = TRUE)
weight <- data.matrix(weight[1:genesize, ])
rownames(weight) <- geneName
netsize <- ncol(weight)
dir.create(HWG_folder)

############ extract high-weight genes

HWGs_all <- c()
for (node in 1:netsize) {

  # positive side
  pos_cutoff <- mean(weight[, node]) + sd_cutoff * sd(weight[, node])
  # order positive HW genes from high to low weight
  weight_node <- weight[, node][order(weight[, node], decreasing = TRUE)]
  geneName_ordered <- geneName[order(weight[, node], decreasing = TRUE)]
  HWG_pos <- geneName_ordered[weight_node >= pos_cutoff]
  geneID_ordered <- geneID[order(weight[, node], decreasing = TRUE)]
  HWG_pos_PAID <- geneID_ordered[weight_node >= pos_cutoff]
  write.table(HWG_pos_PAID, file.path(HWG_folder,
                                      paste("Node", node, "pos.txt", sep = "")),
              quote = FALSE, row.names = FALSE, col.names = FALSE)
  HWGs_all <- c(HWGs_all, list(HWG_pos))

  # negative side
  neg_cutoff <- mean(weight[, node]) - sd_cutoff * sd(weight[, node])
  # order negative HW genes from low to high weight
  weight_node <- weight[, node][order(weight[, node], decreasing = FALSE)]
  geneName_ordered <- geneName[order(weight[, node], decreasing = FALSE)]
  HWG_neg <- geneName_ordered[weight_node <= neg_cutoff]
  geneID_ordered <- geneID[order(weight[, node], decreasing = FALSE)]
  HWG_neg_PAID <- geneID_ordered[weight_node <= neg_cutoff]
  write.table(HWG_neg_PAID, file.path(HWG_folder,
                                      paste("Node", node, "neg.txt", sep = "")),
              quote = FALSE, row.names = FALSE, col.names = FALSE)
  HWGs_all <- c(HWGs_all, list(HWG_neg))

}

len <- sapply(HWGs_all, length)  # get the length of each node's HWGs
max_len <- max(len)  # get the maximum length
len <- max_len - len
# build a data frame with nrow equal to the maximum length and fill empty cells
# with NA
HWGs_all <- as.data.frame(mapply(function(x, y) c(x, rep(NA, y)), HWGs_all, len))
colnames(HWGs_all) <- c(rbind(paste("Node", seq(1:netsize), "Pos"),
                              paste("Node", seq(1:netsize), "Neg")))
write.table(HWGs_all, HWFile, sep = "\t", quote = FALSE, col.names = TRUE,
            row.names = FALSE, na = "")