#!/usr/bin/env Rscript

hlp = "Cluster metabolites."

# Paths
in.data = "./data/data.Robj"
out.dir = file.path("./output/metabolomics/clustering/")
dir.create(out.dir, recursive = TRUE, showWarnings = FALSE)

# Libs
require(ggplot2)
require(Rtsne)
require(mclust)
require(mltools)
require(data.table)
require(reshape2)

# Regression parameters
externals = c("super_class", "class", "sub_class")
tsne_cols = c("average_molecular_weight", "ionMz")
nc_range = c(seq(2, 50, 1)) 
min_levels = 5
replicates = 30
set.seed(41)

# Load data object
load(in.data)
obj = .GlobalEnv$obj
M = as.matrix(obj$assays$Metabolites$data)

# Load and scale metabolites
Mt = as.matrix(obj$assays$Metabolites$data)
M = t(scale(M))
df = obj$assays$Metabolites$metadata

# Compute t-SNE between metabolites
tsne = Rtsne(M)
df$tsne1 = tsne$Y[,1]
df$tsne2 = tsne$Y[,2]

# Compute number of clusters from external variable
for(col in externals){
  
  # Keep levels with at least minimum number of samples
  z = df[,col]
  keep = names(which(table(z) >= min_levels))
  z[!(z %in% keep)] = NA
  z[z == ""] = NA
  message(sprintf("Keeping %d/%d levels for %s in %d samples", 
                  length(keep), nlevels(z), col, sum(!is.na(z))))
  
  # Select number of clusters; store pbest solutions per nc
  msel = expand.grid(Clusters=nc_range, Replicate=1:replicates, ARI=0)
  solutions = list()
  keep = !is.na(z)
  for(i in 1:nrow(msel)){
    nc = msel[i, "Clusters"]
    r = msel[i, "Replicate"]
    km = kmeans(M, centers = nc)
    cli = as.factor(sprintf("C%02d", km$cluster))
    ari = adjustedRandIndex(cli[keep], z[keep])
    msel[i, "ARI"] = ari
    solutions[[i]] = cli
    message(sprintf("Clustering for k=%d/%d (%d/%d), score: %.2f",
                    nc, max(nc_range), r, replicates, ari))
  }
  
  # Select number of clusters
  agg = aggregate(msel$ARI, by=list(Clusters=msel$Clusters), mean)
  j = which(max(agg$x, na.rm = T) == agg$x)
  nc = agg[j, "Clusters"]
  message(sprintf("Setting %d clusters from purity", nc))
  
  # Select best solution for this number of clusters
  ds = msel[msel$Clusters == nc, c("Replicate", "ARI")]
  j = row.names(ds)[which(ds$ARI == max(ds$ARI))]
  sol = solutions[[as.numeric(j)]]
  message(sprintf("Setting %d clusters", nlevels(sol)))
  
  # Set in data frame
  cname = sprintf("Cluster_%s", col)
  df[[cname]] = sol
  
  # Determine majority cluster membership for this class
  mname = sprintf("Majority_%s", col)
  df[[mname]] = ""
  for(clu in sort(unique(df[[cname]]))){
    rows = df[[cname]] == clu
    vals = df[rows, col]
    vals = vals[!is.na(vals) & (vals != "")]
    cats = rev(sort(table(vals)))
    maj = names(cats)[1]
    df[rows, mname] = maj
    message(sprintf("Cluster %s, majority: %s", clu, maj))
  }
  
  # Select by purity
  fname = file.path(out.dir, sprintf("%s_clusters_purity.pdf", col, nc))
  pdf(fname, width = 5, height = 4)
  plot(agg$Clusters, agg$x, xlab="Num. clusters", 
       ylab="Mean cluster purity (ARI)", type="l")
  grid()
  for(r in 1:replicates){
    m = msel[msel$Replicate == r, ]
    lines(m$Clusters, m$ARI, col="gray")
  }
  lines(agg$Clusters, agg$x, col="black")
  dev.off()
  message(sprintf("Written %s", fname))
  
  # Plot class distribution
  rstz = rev(sort(table(z[keep])))
  rf = data.frame(Class=as.factor(names(rstz[rstz>0])),
                  Count=as.vector(rstz[rstz>0])) 
  fname = file.path(out.dir, sprintf("%s_distribution.pdf", col))
  ggplot(data=rf, aes(x=reorder(Class, -Count), y=Count)) + 
    geom_bar(stat="identity") + xlab("") + 
    theme(axis.text.x = element_text(angle = 90, hjust = 1, vjust=.5))
  ggsave(fname, width = 2 + sum(rstz>0) * 0.1, height = 3.5)
  message(sprintf("Written %s", fname))
    
  # Plot cluster t-SNE
  dfc = df[[cname]]
  tt = table(dfc)
  N = tt[dfc]
  df$ClusterLabel = sprintf("%s (%d)", dfc, N)
  for(tc in c("ClusterLabel", col, mname, tsne_cols)){
    cf = df[,tc]
    fname = file.path(out.dir, sprintf("%s_clusters_%s.pdf", col, tc))
    ggplot(data=df, aes(x=tsne1, y=tsne2, color=cf)) + 
      geom_point() + xlab("t-SNE 1") + ylab("t-SNE 2") + 
      ggtitle(sprintf("")) + 
      theme(legend.title = element_blank())
    if(is.numeric(cf))
      ggsave(fname, width = 7, height = 6)
    else if(length(unique(cf)) > 20)
      ggsave(fname, width = 9, height = 6)
    else 
      ggsave(fname, width = 8, height = 6)
    message(sprintf("Written %s", fname))
  }
  
  # Set a new assay based on averaging z-scores for this external
  # annotation ; metadata holds information about the cluster
  Mm = as.matrix(one_hot(as.data.table(df[,cname])))
  colnames(Mm) = levels(df[[cname]])
  row.names(Mm) = row.names(df)
  stopifnot(all(row.names(Mm) == colnames(data)))
  
  # Aggregate intensities
  data = obj$assays$Metabolites$data
  Mc = scale(data) %*% Mm
  Ms = sweep(Mc, 2, colSums(Mm), "/")

  # Compute t-SNE on aggregated matrix
  mx = obj$metadata
  stopifnot(all(row.names(mx)==row.names(Ms)))
  tname_x = sprintf("tsne_%s_x", cname)
  tname_y = sprintf("tsne_%s_y", cname)
  Mtsne = Rtsne(Ms)
  mx[[tname_x]] = Mtsne$Y[,1]
  mx[[tname_y]] = Mtsne$Y[,2]

  # Store metadata
  dfm = df[,c(cname, mname)]
  dfm = dfm[!duplicated(dfm),]
  dfm = dfm[order(dfm[,cname]),]
  row.names(dfm) = dfm[,cname]
  
  # Store cluster composition
  agg = aggregate(z, by=list(External=z, Cluster=df[,cname]), length)
  comp = dcast(agg, External~Cluster, fill = 0)
  cm = as.matrix(comp[,colnames(Mm)])
  cmn = sweep(cm, 2, colSums(cm), "/")
  row.names(cmn) = comp$External
  
  # Add assays to object
  aname = sprintf("Metabolites_%s", cname)
  obj$assays[[aname]] = list(data=Ms, metadata=dfm, composition=cmn)

  # Update metabolite metadata
  annotation = obj$assays$Metabolites$annotation
  obj$assays$Metabolites = list(data=data, 
                                metadata=df, 
                                annotation=annotation)
  # Update metadata
  obj$metadata = mx
}

# Save for use in Orange
fname = file.path(out.dir, "data.csv")
write.csv(M, fname)
fname = file.path(out.dir, "metadata.csv")
write.csv(df, fname)
fname = file.path(out.dir, "metadata-samples.csv")
mx = obj$metadata
write.csv(mx, fname)