#!/usr/bin/env Rscript

hlp = "Plot figure with clusters and associations."

# Paths
in.data = "./data/data.Robj"
out.dir = file.path("./output/associations/")
dir.create(out.dir, recursive = TRUE, showWarnings = FALSE)
img.dir = file.path(out.dir, "img")
dir.create(img.dir, recursive = TRUE, showWarnings = FALSE)

# Parameters
max.vals = 10
alpha = 0.1

# Libs
require(ggplot2)

# Load data object
load(in.data)
obj = .GlobalEnv$obj
df = obj$metadata

# Hardcode age as a variable
df$Age = cut(df$Age, breaks = round(quantile(df$Age, probs = c(0, 0.33, 0.66, 1), na.rm = T)))
df$Age[is.na(df$Age)] = levels(df$Age)[2]
df$BMI = cut(df$BMI, breaks = round(quantile(df$BMI, probs = c(0, 0.33, 0.66, 1), na.rm = T)))
df$BMI[is.na(df$BMI)] = levels(df$BMI)[2]

# Select variables to test
variables = setdiff(colnames(df), "Cluster")
unqs = unlist(lapply(variables, function(v)  length(unique(df[,v]))))
keep = unqs <= max.vals & unqs >= 2
variables = variables[keep]
unqs = unqs[keep]
rf = data.frame(row.names = variables, values=unqs, pvalue=1, FDR=1, pathMDS="", pathDist="",  
                stringsAsFactors = F)
message(sprintf("Keeping %d values", nrow(rf)))

# Try all selected variables
for(v in row.names(rf)){
  fac = as.factor(df[,v])
  if(length(levels(fac)) < 2) next
  if(min(table(fac)) < 2) next
  cs = chisq.test(fac, df$Cluster)
  p = rf[v, "pvalue"] = cs$p.value
  if(p > 0.001) label=sprintf("P < %.3f", p)
  else label=sprintf("P < %.1e", p)
  if(p > alpha) next;
  fname = file.path(img.dir, sprintf("mds_clusters_%s.pdf", v))
  rf[v, "pathMDS"] = fname
  pl = ggplot(data=df) + 
    geom_point(aes(x=MDS1, y=MDS2, color=fac), show.legend = F) + 
    xlab("MDS 1") + ylab("MDS 2") +
    theme(legend.title = element_blank(),
          legend.margin=margin(),
          text = element_text(size=7)) + ggtitle(gsub("_", " ", v))
  ggsave(fname, width = 2.3, height = 2.2, plot=pl)
  message(sprintf("Written %s", fname))
  
  # Distribution plots
  keep = !is.na(fac) & !is.na(df$Cluster)
  dfn = df[keep,]
  facn = fac[keep]
  fname = file.path(img.dir, sprintf("stack_s_%s.pdf", v))
  rf[v, "pathDist"] = fname
  ggplot(data=dfn, aes(x=Cluster, fill=facn)) + geom_bar(position = "fill") + 
    theme(legend.title = element_blank(), legend.margin=margin(),
          text = element_text(size=7)) + ggtitle(label) + ylab("Count")
  ggsave(fname, width = 2.7, height = 2.2)
  message(sprintf("Written %s", fname))
}

# Write results
rf$FDR = p.adjust(rf$pvalue, method = "fdr")
rf = rf[order(rf$pvalue),]
fname = file.path(out.dir, "summary.csv")
write.csv(rf, fname)
message(sprintf("Written %s", fname))

# Compare vaccinations
tmp = colnames(df)[grep("vaccination", colnames(df), ignore.case = T)]
vacs = tmp[grep("year", tmp, invert = T)]
rf = data.frame(Vaccination=vacs, Pvalue=1, row.names = vacs)
for(v in vacs){
  ct = chisq.test(df[,v], df$Outside_Europe)
  rf[v, "Pvalue"] = ct$p.value
}
fname = file.path(out.dir, "supp_table_VACC.csv")
rf = rf[order(rf$Pvalue),]
write.csv(rf, fname)
message(sprintf("Written %s", fname))