################################################################################
# This script plots gene expression heatmaps of a node's HW genes for the input
# dataset.
#
# Usage:
#     Rscript make_HWG_expression_heatmaps.R networkFile dataFile dataName
#     out_folder selected_samples use_annotation [annotation_file]
#
#     networkFile: file path to the network file of an ADAGE/eADAGE model
#     dataFile: file path to the training expression compendium
#     dataName: name of the input data, will be appended at the end of each
#               output plot
#     out_folder: output folder for HWG expression heatmaps
#     selected_samples: if not "all", then only plot the selected samples using
#                       the provided sample names separated by ","
#     HW_cutoff: number of standard deviations from the mean to be counted as
#                high-weight
#     use_annotation: logical, whether use an annotation file to help label
#                     samples in the heatmap
#     annotation_file: (optional) file path to an annotation file that stores
#                      the medium type of each sample, used to label each sample
#                      with its medium type in heatmaps
################################################################################

pacman::p_load("gplots")
source("../data_collection/IDconverter.R")

comArgs <- commandArgs(trailingOnly = TRUE)
networkFile <- comArgs[1]
dataFile <- comArgs[2]
dataName <- comArgs[3]
out_folder <- comArgs[4]
selected_samples <- comArgs[5]
HW_cutoff <- as.numeric(comArgs[6])

use_annotation <- as.logical(comArgs[7])
if (use_annotation) {
  annotation_file <- comArgs[8]
}
dir.create(out_folder)

########## Load in data

data <- read.table(dataFile, header = T, row.names = 1, sep = "\t",
                   check.names = F)
gene_num <- nrow(data)
# convert PA numbers to gene names
rownames(data) <- IDconverter(IDFile, rownames(data))

weight <- read.table(networkFile, header = F, skip = 2, fill = T, sep = "\t",
                     colClasses = "character")
weight <- weight[1:gene_num, ]
weight <- data.matrix(weight)
net_size <- ncol(weight)

# preprocess the file to only contain the required samples
if (selected_samples != "all") {
  selected_samples <- as.integer((unlist(strsplit(selected_samples, ","))))
  data <- data[, selected_samples]
}
sample_count <- ncol(data)

# add medium annotations to sample names
if (use_annotation) {
  annotations <- read.table(annotation_file, sep = "\t", fill = TRUE, quote = "",
                            header = TRUE, stringsAsFactors = FALSE)
  annotations_noDup <- subset(annotations, !duplicated((annotations$cel_file)))
  medium_type <- annotations_noDup$medium[match(colnames(data),
                                                annotations_noDup$cel_file)]
  medium_celfile <- paste(colnames(data), medium_type, sep = "_")
  colnames(data) <- medium_celfile
}

############ Plot heatmaps of HW genes for each node side

# the gene expression heatmap color panel
mycol <- colorpanel(n = 50, low = "green", mid = "black", high = "red")

for (i in 1:net_size) {
  node_weight <- weight[c(1:gene_num), i]

  # positive side
  pos_HW_cutoff <- mean(node_weight) + HW_cutoff * sd(node_weight)
  high_weight_gene_pos <- data[node_weight >= pos_HW_cutoff, ]

  # only plot the heatmap if there are more than one HW genes
  if (nrow(high_weight_gene_pos) > 1) {

    # z-score the expression values for each gene across samples
    high_weight_gene_pos_nor <- t(apply(high_weight_gene_pos, 1,
                                        function(x) (x - mean(x))/sd(x)))
    # can generate NA if standard deviation is 0, convert NA to 0
    high_weight_gene_pos_nor[is.na(high_weight_gene_pos_nor)] <- 0

    # plot the heatmap and save to pdf
    outputFile <- file.path(out_folder, paste0("Node", i, "pos_", dataName,
                                               ".pdf"))
    pdf(outputFile, width = 30, height = sample_count/10 + 5)
    heatmap.2(t(high_weight_gene_pos_nor), col = mycol, trace = "none",
              margin = c(5, 60), cexRow = 0.6, cexCol = 1, symbreaks = T,
              symkey = T)
    dev.off()
  }

  # negative side
  neg_HW_cutoff <- mean(node_weight) - HW_cutoff * sd(node_weight)
  high_weight_gene_neg <- data[node_weight <= neg_HW_cutoff, ]

  # only plot the heatmap if there are more than one HW genes
  if (nrow(high_weight_gene_neg) > 1) {

    # z-score the expression values for each gene across samples
    high_weight_gene_neg_nor <- t(apply(high_weight_gene_neg, 1,
                                        function(x) (x - mean(x))/sd(x)))
    # can generate NA if standard deviation is 0, convert NA to 0
    high_weight_gene_neg_nor[is.na(high_weight_gene_neg_nor)] <- 0

    # plot the heatmap and save to pdf
    outputFile <- file.path(out_folder, paste0("Node", i, "neg_", dataName,
                                               ".pdf"))
    pdf(outputFile, width = 30, height = sample_count/10 + 5)
    heatmap.2(t(high_weight_gene_neg_nor), col = mycol, trace = "none",
              margin = c(5, 60), cexRow = 0.6, cexCol = 1, symbreaks = T,
              symkey = T)
    dev.off()
  }
}