library(CMDdemux)
library(Seurat)
library(scran)
library(scater)
library(deMULTIplex2)
library(demuxmix)
library(cellhashR)
library(DropletUtils)
library(stringr)

source(~/bench.R)
load("~/balb1.hash.count.rdata")
balb1.vireo1 <- read.csv("~/batch1_c1_donors.csv")
balb1.vireo2 <- read.csv("~/batch1_c2_donors.csv")

# 1. Vireo
balb1.vireo <- rbind(balb1.vireo1, balb1.vireo2)
balb1.vireo$genetic_donor <- gsub(" ", "-", balb1.vireo$genetic_donor)
balb1.vireo$Barcode <- gsub("-", ".", balb1.vireo$Barcode)
balb1.vireo$Barcode <- paste0("X",balb1.vireo$Barcode)
balb1.demux.result <- data.frame("Vireo" = rep(NA,ncol(balb1.hash.count)))
rownames(balb1.demux.result) <- colnames(balb1.hash.count)
balb1.demux.result$Vireo[match(balb1.vireo$Barcode, rownames(balb1.demux.result))] <- balb1.vireo[,"genetic_donor"]

# 2. CMDdemux
balb1.clr.norm <- LocalCLRNorm(balb1.hash.count)
balb1.kmed.cl <- KmedCluster(balb1.clr.norm)
balb1.cl.dist <- EuclideanClusterDist(balb1.clr.norm, balb1.kmed.cl)
balb1.noncore <- DefineNonCore(balb1.cl.dist, balb1.kmed.cl, c(0.79, 0.85, 0.81, 0.78, 0.83, 0.89, 0.84, 0.84))
balb1.cluster.assign <- LabelClusterHTO(balb1.clr.norm, balb1.kmed.cl, balb1.noncore, "medoids")
balb1.md.mat <- CalculateMD(balb1.clr.norm, balb1.noncore, balb1.kmed.cl, balb1.cluster.assign)
balb1.outlier.assign <- AssignOutlierDrop(balb1.md.mat, md_cut_q = 0.84)
balb1.cmddemux.assign <- CMDdemuxClass(balb1.md.mat, balb1.hash.count, balb1.outlier.assign, 4, 4)
balb1.demux.result$CMDdemux <- balb1.cmddemux.assign$demux_global_class

# 3. HTOdemux
balb1.obj <- CreateSeuratObject(counts = balb1.hash.count)
balb1.obj[["HTO"]] <- CreateAssayObject(counts = balb1.hash.count)
balb1.obj <- NormalizeData(balb1.obj, assay = "HTO", normalization.method = "CLR")
balb1.obj <- HTODemux(balb1.obj, assay = "HTO", positive.quantile = 0.99)
balb1.demux.result$HTODemux <- balb1.obj$hash.ID[rownames(balb1.demux.result)]

# 4. GMM-Demux
# Prepare input data
balb1.gmm.input <- t(balb1.hash.count)
write.csv(balb1.gmm.input, "~/balb1.gmm.input.csv", quote=F)
# Command: GMM-demux -c ~/balb1.gmm.input.csv BAL-A,BAL-B,BAL-C,BAL-D,BAL-E,BAL-F,BAL-G,BAL-H -x BAL-A,BAL-B,BAL-C,BAL-D,BAL-E,BAL-F,BAL-G,BAL-H -f .
balb1.gmm.output <- read.csv("~/GMM_full.csv")
balb1.gmm.demux <- balb1.gmm.output$Cluster_id
for(i in sort(unique(balb1.gmm.output$Cluster_id))){
  if(i == 0){
    balb1.gmm.demux[which(balb1.gmm.demux %in% i)] <- "Negative"
  }else if(i %in% 1:nrow(balb1.hash.count)){
    balb1.gmm.demux[which(balb1.gmm.demux %in% i)] <- rownames(balb1.hash.count)[i]
  }else{
    balb1.gmm.demux[which(balb1.gmm.demux %in% i)] <- "Doublet"
  }
}  
names(balb1.gmm.demux) <- balb1.gmm.output$X
balb1.demux.result$`GMM-Demux` <- balb1.gmm.demux[rownames(balb1.demux.result)]

# 5. deMULTIplex2
balb1.demultiplex2.output <- demultiplexTags(balb1.gmm.input, plot.diagnostics = FALSE, seed = 2025)
balb1.demultiplex2.assign <- balb1.demultiplex2.output$final_assign
balb1.demultiplex2.assign[which(balb1.demultiplex2.assign %in% "multiplet")] <- "Doublet"
balb1.demultiplex2.assign[which(balb1.demultiplex2.assign %in% "negative")] <- "Negative"
balb1.demux.result$deMULTIplex2 <- balb1.demultiplex2.assign[rownames(balb1.demux.result)]

# demuxEM cannot be applied to this data, because the gene expression data is not available. 

# 6. demuxmix
balb1.demuxmix.model <- demuxmix(hto = balb1.hash.count, model = "naive")
balb1.demuxmix.labels <- dmmClassify(balb1.demuxmix.model)
balb1.demuxmix.assign <- balb1.demuxmix.labels$HTO
balb1.demuxmix.assign[which(balb1.demuxmix.assign %in% "uncertain")] <- "Uncertain"
balb1.demuxmix.assign[which(balb1.demuxmix.assign %in% "negative")] <- "Negative"
balb1.demuxmix.assign[which(balb1.demuxmix.labels$Type %in% "multiplet")] <- "Doublet"
names(balb1.demuxmix.assign) <- rownames(balb1.demuxmix.labels)
balb1.demux.result$demuxmix <- balb1.demuxmix.assign[rownames(balb1.demux.result)]

# 7. hashedDrops
balb1.hashedrops.output <- hashedDrops(balb1.hash.count)
balb1.hashedrops.assign <- rownames(balb1.hash.count)[balb1.hashedrops.output$Best]
balb1.hashedrops.assign[!balb1.hashedrops.output$Confident] <- "Negative"
balb1.hashedrops.assign[which(balb1.hashedrops.output$Doublet)] <- "Doublet"
names(balb1.hashedrops.assign) <- colnames(balb1.hash.count)
balb1.demux.result$hashedDrops <- balb1.hashedrops.assign[rownames(balb1.demux.result)]

# 8. BFF
balb1.bff.output <- GenerateCellHashingCalls(barcodeMatrix = balb1.hash.count, methods = c("bff_raw", "bff_cluster"))
balb1.demux.result$BFF_raw <- balb1.bff.output$bff_raw
balb1.demux.result$BFF_cluster <- balb1.bff.output$bff_cluster

# Dimensional reduction for visualizaition
balb1.sce <- SingleCellExperiment(assays = list(hto = balb1.hash.count, clr = balb1.clr.norm))
balb1.sce <- runTSNE(balb1.sce,exprs_values = "clr")
balb1.sce <- runUMAP(balb1.sce,exprs_values = "clr")

# Precision, recall, F, MCC
balb1.label.levels <- c(rownames(balb1.hash.count), "Doublet", "Negative", "Uncertain")
balb1.evaluation <- GTStats(balb1.demux.result, "Vireo", balb1.label.levels[which(!balb1.label.levels %in% c("Uncertain", "Negative"))])

# Average precision, recall, F, MCC
balb1.evaluation.avg <- GTStatsAvg(balb1.demux.result, "Vireo", balb1.label.levels[which(!balb1.label.levels %in% "Uncertain")], "micro")

# Concordance of each method with the ground truth
balb1.bench.concord <- ConfusionMat(balb1.demux.result, "Vireo", balb1.label.levels)

# Overall concordance
balb1.ov.concord <- ConcordanceDF(balb1.demux.result, "Vireo")

# Library size
balb1.hto.lib <- log(colSums(balb1.hash.count)+1)
