library(CMDdemux)
library(Seurat)
library(scran)
library(scater)
library(deMULTIplex2)
library(demuxmix)
library(cellhashR)
library(DropletUtils)
library(stringr)
library(tidyverse)
library(clValid)

source(~/bench.R)

# Mouse Vehicle
load("~/vehicle.hash.count.rdata")
load("~/vehicle.gex.count.rdata")

# 1. CMDdemux
vehicle.clr.norm <- LocalCLRNorm(vehicle.hash.count)
vehicle.kmed.cl <- KmedCluster(vehicle.clr.norm)
vehicle.cl.dist <- EuclideanClusterDist(vehicle.clr.norm, vehicle.kmed.cl)
vehicle.noncore <- DefineNonCore(vehicle.cl.dist, vehicle.kmed.cl, c(0.86, 0.91, 0.876))
vehicle.cluster.assign <- LabelClusterHTO(vehicle.clr.norm, vehicle.kmed.cl, vehicle.noncore, "medoids")
vehicle.md.mat <- CalculateMD(vehicle.clr.norm, vehicle.noncore, vehicle.kmed.cl, vehicle.cluster.assign)
vehicle.outlier.assign <- AssignOutlierDrop(vehicle.md.mat, 0.91)
vehicle.cmddemux.assign <- CMDdemuxClass(vehicle.md.mat, vehicle.hash.count, vehicle.outlier.assign, TRUE, vehicle.gex.count, 3, 2)
vehicle.demux.result <- data.frame("CMDdemux" = vehicle.cmddemux.assign$demux_global_class)
rownames(vehicle.demux.result) <- rownames(vehicle.cmddemux.assign)

# 2. HTOdemux
vehicle.obj <- CreateSeuratObject(counts = vehicle.hash.count)
vehicle.obj[["HTO"]] <- CreateAssayObject(counts = vehicle.hash.count)
vehicle.obj <- NormalizeData(vehicle.obj, assay = "HTO", normalization.method = "CLR")
vehicle.obj <- HTODemux(vehicle.obj, assay = "HTO", positive.quantile = 0.99)
vehicle.demux.result$HTODemux <- vehicle.obj$hash.ID[rownames(vehicle.demux.result)]

# 3. GMM-Demux
# Prepare input data
vehicle.gmm.input <- t(vehicle.hash.count)
write.csv(vehicle.gmm.input, "~/vehicle.gmm.input.csv", quote=F)
# Command: GMM-demux -c ~/vehicle.gmm.input.csv mouse1,mouse2,mouse3 -x mouse1,mouse2,mouse3 -f .
vehicle.gmm.output <- read.csv("~/GMM_full.csv")
vehicle.gmm.config <- read.table(file = "~/GMM_full.config", header = FALSE, sep = ",")
vehicle.gmm.demux <- GMM_demux_class(vehicle.gmm.output, vehicle.gmm.config, vehicle.hash.count)
vehicle.demux.result$`GMM-Demux` <- vehicle.gmm.demux 

# 4. deMULTIplex2
vehicle.demultiplex2.output <- demultiplexTags(vehicle.gmm.input, plot.diagnostics = FALSE, seed = 2024)
vehicle.demultiplex2.assign <- deMULTIplex2_class(vehicle.demultiplex2.output)
vehicle.demux.result$deMULTIplex2 <- vehicle.demultiplex2.assign

# 5. demuxEM
vehicle.hash.write <- as.data.frame(vehicle.hash.count) %>% rownames_to_column('Antibody')
vehicle.hash.write$Antibody <- 1:3
write.csv(vehicle.hash.write, "~/vehicle.demuxEM.input.csv", quote=F)
write10xCounts("~/vehicle.demuxEM.gex.h5", vehicle.gex.count, version='3')
#demuxEM -p 8 --random-state 2024 ~/vehicle.demuxEM.gex.h5 ~/vehicle.demuxEM.input.csv Vehicle
vehicle.demuxEM.out1 <- read.table(file = "~/vehicle.demuxEM.demux.txt", header = TRUE)
vehicle.demuxEM.out2 <- read.table(file = "~/vehicle.demuxEM.assign.txt", sep = "\t", header = TRUE)
vehicle.demuxEM.assign <- demuxEM_class(vehicle.demuxEM.out1, vehicle.demuxEM.out2, vehicle.hash.count, TRUE)
vehicle.demux.result$demuxEM <- vehicle.demuxEM.assign[rownames(vehicle.demux.result)]

# 6. demuxmix
vehicle.gex.genes <- colSums(vehicle.gex.count > 0)
vehicle.demuxmix.model <- demuxmix(hto = vehicle.hash.count, rna = vehicle.gex.genes)
vehicle.demuxmix.labels <- dmmClassify(vehicle.demuxmix.model)
vehicle.demuxmix.assign <- demuxmix_class(vehicle.demuxmix.labels)
vehicle.demux.result$demuxmix <- vehicle.demuxmix.assign[rownames(vehicle.demux.result)]

# 7. hashedDrops
vehicle.hasheddrops.output <- hashedDrops(vehicle.hash.count)
vehicle.hasheddrops.assign <- hashedDrops_class(vehicle.hasheddrops.output, vehicle.hash.count)
vehicle.demux.result$hashedDrops <- vehicle.hasheddrops.assign[rownames(vehicle.demux.result)]

# 8. BFF
vehicle.bff.output <- GenerateCellHashingCalls(barcodeMatrix = vehicle.hash.count, methods = c("bff_raw", "bff_cluster"))
vehicle.demux.result$BFF_raw <- vehicle.bff.output$bff_raw
# Cannot get results from "Cluster" based method
#vehicle.demux.result$BFF_cluster <- vehicle.bff.output$bff_cluster

# Dimensional reduction for visualization
vehicle.sce <- SingleCellExperiment(assays = list(hto = vehicle.hash.count, clr = vehicle.clr.norm))
vehicle.sce <- runTSNE(vehicle.sce,exprs_values = "clr")
vehicle.sce <- runUMAP(vehicle.sce,exprs_values = "clr")

# Silhouette score
vehicle.bench.metrics <- BenchMetricsNoGT(vehicle.hash.count, vehicle.demux.result)
vehicle.silhouette.score <- vehicle.bench.metrics$`Silhouette score`

# Library size
vehicle.hto.lib <- log(colSums(vehicle.hash.count)+1)
vehicle.gex.lib <- log(colSums(vehicle.gex.count))

# Library size ratios of doublets vs. singlets and singlets vs. negatives across different methods
vehicle.lib.ratio <- LibRatio(vehicle.gex.lib, vehicle.demux.result)

# Summary of library size ratio
vehicle.ratio.sum <- LibRatioSum(vehicle.gex.lib, vehicle.demux.result)


# Mouse treated
load("~/treated.hash.count.rdata")
load("~/treated.gex.count.rdata")

# 1. CMDdemux
treated.clr.norm <- LocalCLRNorm(treated.hash.count)
treated.kmed.cl <- KmedCluster(treated.clr.norm)
treated.cl.dist <- EuclideanClusterDist(treated.clr.norm, treated.kmed.cl)
treated.noncore <- DefineNonCore(treated.cl.dist, treated.kmed.cl, c(0.93, 0.94, 0.91))
treated.cluster.assign <- LabelClusterHTO(treated.clr.norm, treated.kmed.cl, treated.noncore, "medoids")
treated.md.mat <- CalculateMD(treated.clr.norm, treated.noncore, treated.kmed.cl, treated.cluster.assign)
treated.outlier.assign <- AssignOutlierDrop(treated.md.mat, 0.93)
treated.cmddemux.assign <- CMDdemuxClass(treated.md.mat, treated.hash.count, treated.outlier.assign, TRUE, treated.gex.count, 3, 2)
treated.demux.result <- data.frame("CMDdemux" = treated.cmddemux.assign$demux_global_class)
rownames(treated.demux.result) <- rownames(treated.cmddemux.assign)

# 2. HTOdemux
treated.obj <- CreateSeuratObject(counts = treated.hash.count)
treated.obj[["HTO"]] <- CreateAssayObject(counts = treated.hash.count)
treated.obj <- NormalizeData(treated.obj, assay = "HTO", normalization.method = "CLR")
treated.obj <- HTODemux(treated.obj, assay = "HTO", positive.quantile = 0.99)
treated.demux.result$HTODemux <- treated.obj$hash.ID[rownames(treated.demux.result)]

# 3. GMM-Demux
# Prepare input data
treated.gmm.input <- t(treated.hash.count)
write.csv(treated.gmm.input, "~/treated.gmm.input.csv", quote=F)
# Command: GMM-demux -c ~/treated.gmm.input.csv mouse1,mouse2,mouse3 -x mouse1,mouse2,mouse3 -f .
treated.gmm.output <- read.csv("~/GMM_full.csv")
treated.gmm.config <- read.table(file = "~/GMM_full.config", header = FALSE, sep = ",")
treated.gmm.demux <- GMM_demux_class(treated.gmm.output, treated.gmm.config, treated.hash.count)
treated.demux.result$`GMM-Demux` <- treated.gmm.demux

# 4. deMULTIplex2
# Cannot get results from deMULTIplex2
#treated.demultiplex2.output <- demultiplexTags(treated.gmm.input, plot.diagnostics = FALSE, seed = 2024)
#treated.demultiplex2.assign <- deMULTIplex2_class(treated.demultiplex2.output)
#treated.demux.result$deMULTIplex2 <- treated.demultiplex2.assign

# 5. demuxEM
treated.hash.write <- as.data.frame(treated.hash.count) %>% rownames_to_column('Antibody')
treated.hash.write$Antibody <- 1:3
write.csv(treated.hash.write, "~/treated.demuxEM.input.csv", quote=F)
write10xCounts("~/treated.demuxEM.gex.h5", treated.gex.count, version='3')
#demuxEM -p 8 --random-state 2024 ~/treated.demuxEM.gex.h5 ~/treated.demuxEM.input.csv Treated
treated.demuxEM.out1 <- read.table(file = "~/treated.demuxEM.demux.txt", header = TRUE)
treated.demuxEM.out2 <- read.table(file = "~/treated.demuxEM.assign.txt", sep = "\t", header = TRUE)
treated.demuxEM.assign <- demuxEM_class(treated.demuxEM.out1, treated.demuxEM.out2, treated.hash.count, TRUE)
treated.demux.result$demuxEM <- treated.demuxEM.assign[rownames(treated.demux.result)]

# 6. demuxmix
treated.gex.genes <- colSums(treated.gex.count > 0)
treated.demuxmix.model <- demuxmix(hto = treated.hash.count, rna = treated.gex.genes)
treated.demuxmix.labels <- dmmClassify(treated.demuxmix.model)
treated.demuxmix.assign <- demuxmix_class(treated.demuxmix.labels)
treated.demux.result$demuxmix <- treated.demuxmix.assign[rownames(treated.demux.result)]

# 7. hashedDrops
treated.hasheddrops.output <- hashedDrops(treated.hash.count)
treated.hasheddrops.assign <- hashedDrops_class(treated.hasheddrops.output, treated.hash.count)
treated.demux.result$hashedDrops <- treated.hasheddrops.assign[rownames(treated.demux.result)]

# 8. BFF
treated.bff.output <- GenerateCellHashingCalls(barcodeMatrix = treated.hash.count, methods = c("bff_raw", "bff_cluster"))
treated.demux.result$BFF_raw <- treated.bff.output$bff_raw
treated.demux.result$BFF_cluster <- treated.bff.output$bff_cluster

# Dimensional reduction for visualization
treated.sce <- SingleCellExperiment(assays = list(hto = treated.hash.count, clr = treated.clr.norm))
treated.sce <- runTSNE(treated.sce,exprs_values = "clr")
treated.sce <- runUMAP(treated.sce,exprs_values = "clr")

# Library size
treated.hto.lib <- log(colSums(treated.hash.count)+1)
treated.gex.lib <- log(colSums(treated.gex.count))

# Summary of library size ratio
treated.ratio.sum <- LibRatioSum(treated.gex.lib, treated.demux.result)

# MDS
common.genes <- intersect(rownames(treated.gex.count), rownames(vehicle.gex.count))
combine.mat <- cbind(vehicle.gex.count[common.genes,], treated.gex.count[common.genes,])
mouse.sce <- SingleCellExperiment(list(counts=combine.mat))
mouse.sce$sample <- NA
mouse.sce$sample[1:ncol(vehicle.gex.count)] <- "Vehicle"
mouse.sce$sample[(ncol(vehicle.gex.count)+1):ncol(mouse.sce)] <- "Treated"
mouse.sce$sample <- factor(mouse.sce$sample, levels = c("Vehicle", "Treated"))

cmddemux.vec <- rep(NA, ncol(mouse.sce))
names(cmddemux.vec) <- paste0(mouse.sce$sample, "_", colnames(mouse.sce))
treated.demux.result2 <- treated.demux.result
rownames(treated.demux.result2) <- paste0("Treated_",rownames(treated.demux.result))
vehicle.demux.result2 <- vehicle.demux.result
rownames(vehicle.demux.result2) <- paste0("Vehicle_",rownames(vehicle.demux.result))
cmddemux.vec[paste0("Treated_",rownames(treated.demux.result))] <- treated.demux.result[,"CMDdemux"]
cmddemux.vec[paste0("Vehicle_",rownames(vehicle.demux.result))] <- vehicle.demux.result[,"CMDdemux"]
mouse.sce$CMDdemux <- cmddemux.vec
mouse.sce <- mouse.sce[,which(!mouse.sce$CMDdemux %in% c("Negative", "Doublet"))]

integrate_fit <- function(sce){
  plk <- aggregateAcrossCells(sce, id=colData(sce)[,c("sample", "CMDdemux")])
  y <-DGEList(counts(plk), samples=colData(plk))
  discarded <- plk$ncells < 10
  y <- y[,!discarded]
  keep <- filterByExpr(y, group=plk$sample)
  y <- y[keep,,keep.lib.sizes = FALSE]
  y <- calcNormFactors(y)
  return(y)
}
mouse.y <- integrate_fit(mouse.sce)

mds.df <- function(object, top, dim.plot){
  x <- as.matrix(object)
  x <- edgeR::cpm(x, log = TRUE)
  nsamples <- ncol(x)
  cn <- colnames(x)
  nprobes <- nrow(x)
  top <- min(top,nprobes)
  ndim <- max(dim.plot)
  labels <- colnames(x)
  dd <- matrix(0,nrow=nsamples,ncol=nsamples,dimnames=list(cn,cn))
  topindex <- nprobes-top+1L
  for (i in 2L:(nsamples))
    for (j in 1L:(i-1L)) 
      dd[i,j]=sqrt(mean(sort.int((x[,i]-x[,j])^2,partial=topindex)[topindex:nprobes]))
  a1 <- suppressWarnings(cmdscale(as.dist(dd),k=ndim))
  mds.df <- data.frame("dim1" = a1[,dim.plot[1]], "dim2" = a1[,dim.plot[2]], "label" = object$samples$CMDdemux.1, "sample" = object$samples$sample.1)
  return(mds.df)
}
mouse.mds <- mds.df(mouse.y, 500, c(1,2))

# Proportion of doublets
treated.demux.result2 <- DemuxSingletClass(treated.demux.result, treated.hash.count)
treated.doublet.prop <- AssignProp(treated.demux.result2, "doublet")

# Proportion of negatives
treated.negative.prop <- AssignProp(treated.demux.result2, "negative")

# Silhouette score
treated.bench.metrics <- BenchMetricsNoGT(treated.hash.count, treated.demux.result)
treated.silhouette.score <- treated.bench.metrics$`Silhouette score`
