#!/usr/bin/env Rscript

hlp = " How do metadata variables associate with differential species abundance. 
        Variables where more species change more significantly affect the microbiome. "

# Paths
in.data = "./data/data.Robj"
out.dir = file.path("./output/associations/")
dir.create(out.dir, recursive = TRUE, showWarnings = FALSE)
img.dir = file.path(out.dir, "img")
dir.create(img.dir, recursive = TRUE, showWarnings = FALSE)

# Used assay to compute associations
assay = "MSPCore"

# Parameters
max.vals = 10         # Maximum values to deem categorical variables
min.species = 3       # Minimum species to display
min.samples = 200     # Minimum samples with known variable
min.level = 40        # Minimum samples per level
alpha = 0.25

# Colors
COLORS = c("#52A675", "#1982C4", "#FF595E", "#6A4C93", "#FFCA3A", "#45355A", "#8AC926")
names(COLORS) = c("Firmicutes", "Bacteroidetes", "Actinobacteria", "Protebacteria",
                  "Verrucomicrobia",  "Euryarchaeota", "Lentisphaerae")

# Libs
require(ggplot2)
require(Maaslin2)
require(pheatmap)

# Load data object
load(in.data)
obj = .GlobalEnv$obj
df = obj$metadata

# Load species abundance and keep common samples
Sx = obj$assays[[assay]]$data
sm = obj$assays[[assay]]$metadata
keep = intersect(row.names(Sx), row.names(df))
Sx = Sx[keep,]
df = df[keep,]
message(sprintf("Keeping %d samples", nrow(df)))

# Hardcode age as a variable
df$Age = cut(df$Age, breaks = round(quantile(df$Age, probs = c(0, 0.33, 0.66, 1))))
df$Age[is.na(df$Age)] = levels(df$Age)[2]
df$BMI = cut(df$BMI, breaks = round(quantile(df$BMI, probs = c(0, 0.33, 0.66, 1))))
df$BMI[is.na(df$BMI)] = levels(df$BMI)[2]

# Remove invalid variables
df$`(Shell)fish_allergy` = NULL

# Select variables to test
variables = setdiff(colnames(df), "Cluster")
unqs = unlist(lapply(variables, function(v)  length(unique(df[,v]))))
keep = unqs <= max.vals & unqs >= 2
variables = variables[keep]
unqs = unqs[keep]

# Create data frame
rf = data.frame(row.names = variables, 
                variable=gsub("_", " ", variables), values=unqs, species=0, 
                path="", stringsAsFactors = F)
message(sprintf("Keeping %d variables", nrow(rf)))

# Collect details
rfd = data.frame()

# Prepare matrix for heatmap
X = matrix(NA, nrow = nrow(rf), ncol=ncol(Sx))
row.names(X) = row.names(rf)
colnames(X) = colnames(Sx)

# Try all selected variables
for(v in row.names(rf)){
  
  # Convert variable to factor
  df[,v] = as.factor(df[,v])
  
  # Minimum levels to keep    
  keep.levels = names(which(table(df[,v]) > min.level))
  test = !is.na(df[,v]) & (df[,v] %in% keep.levels)
  rf[v, "samples"] = sum(test)
  
  # Minimum number of samples per assay
  if(sum(test) > min.samples){
    
    # Run Maaslin2 for selected assay
    maas.dir = file.path(img.dir, v)
    maas = Maaslin2(input_data = t(Sx[test,]), 
                    transform = "LOG",
                    input_metadata = df[test,], 
                    output = maas.dir,
                    random_effects = c(),
                    fixed_effects = c(v))
    res = maas$results
    
    # Save results
    res$qval[is.na(res$qval)] = 1
    inxs = res$qval < alpha
    total = length(unique(res[inxs, "feature"]))
  
    # Fill heatmap
    rf[v, "species"] = total
    rf[v, "path"] = maas.dir
    if(total > 0){
      rr = res[res$qval < alpha,]
      rr = rr[order(rr$qval),]
      rr = rr[!duplicated(rr$feature),]
      rr$variable = v
      rfd = rbind(rfd, rr)
      X[v, as.character(rr$feature)] = -log10(rr$qval)  
    }
  }
}

# Joint figure
fname = file.path(out.dir, "summary.pdf")
rff = rf[rf$species > min.species,]
qplot(data=rff, x=reorder(variable, species), y=species, geom = "col") +  
  coord_flip() + ylab("Affected microbial species") + xlab("")
ggsave(fname, height = 3 + 0.1 * nrow(rff), width=5.5)
message(sprintf("Written %s", fname))

# Filter
kr = rowSums(!is.na(X)) >= 1; X = X[kr,]
kc = colSums(!is.na(X)) >= 1; X = X[,kc]
message(sprintf("Keeping %d rows and %d columns", nrow(X), ncol(X)))

# Annotations
labels_row = rf[row.names(X), "variable"]
labels_col = colnames(X)
if(assay == "MSPCore"){
  ac = obj$assays[[assay]]$metadata[colnames(X),]
  labels_col = gsub("msp_", "", gsub(".core", "", ac$name, fixed = T))
}
prev = colSums(Sx > 0) / nrow(Sx)
ac = data.frame(row.names = colnames(X), Prevalence=prev[colnames(X)])

# Write results
rf = rf[rev(order(rf$species)),]
fname = file.path(out.dir, "summary.csv")
write.csv(rf, fname)
message(sprintf("Written %s", fname))

# Summarize - run only once
rfd$Phylum = factor(sm[rfd$feature, "phylum"],
                    levels = c("Firmicutes", "Bacteroidetes", "Actinobacteria",
                               "Proteobacteria", "Verrucomicrobia", "Euryarchaeota",
                               "Lentisphaerae"))
# rfd$species = rff[rfd$variable, "species"]
rfd$species = rf[rfd$variable, "species"]
rfd$variable = gsub("_", " ", rfd$variable)
agg = aggregate(rfd$species, by=list(variable=rfd$variable), length)
row.names(agg) = agg$variable
rfd$rank = agg[as.character(rfd$variable), "x"]

# Filter for plot
rfd = rfd[rfd$metadata != "Gender",]
rfdf = rfd[!is.na(rfd$species),]
rfdf = rfdf[rfdf$species > min.species,]

# Plot summary
fname = file.path(out.dir, "summary_phylum.pdf")
qplot(data=rfdf, x=reorder(variable, rank), fill=Phylum, geom = "bar") +
  coord_flip() + ylab("Affected MSPs") + xlab("") + 
  scale_fill_discrete(drop=FALSE) +
  scale_fill_manual(values = as.vector(COLORS)) + 
  theme(text = element_text(size=7)) +
  theme(legend.position = "none")
ggsave(fname, height = 2.4, width=2.4)
message(sprintf("Written %s", fname))
fname = file.path(out.dir, "summary_details.csv")
write.csv(rfd, fname)
message(sprintf("Written %s", fname))
