library(CMDdemux)
library(Seurat)
library(scran)
library(scater)
library(deMULTIplex2)
library(demuxmix)
library(cellhashR)
library(DropletUtils)
library(stringr)

source(~/bench.R)
load("~/hu.hash.count.rdata")
load("~/hu.gex.count.rdata")
hu.demuxlet <- read.delim("~/experiment1_demuxlet.best")

# 1. Demuxlet
hu.donor.assign <- hu.demuxlet$BEST
amb.idx <- grepl("AMB", hu.donor.assign)
hu.donor.assign[amb.idx] <- "Negative"
dbl.idx <- grepl("DBL", hu.donor.assign)
hu.donor.assign[dbl.idx] <- "Doublet"
hu.donor.assign <- gsub("SNG-", "", hu.donor.assign)
names(hu.donor.assign) <- hu.demuxlet$BARCODE
hu.demux.result <- data.frame("Demuxlet" = hu.donor.assign[colnames(hu.hash.count)])

# 2. CMDdemux
hu.clr.norm <- LocalCLRNorm(hu.hash.count)
hu.kmed.cl <- KmedCluster(hu.clr.norm)
hu.cl.dist <- EuclideanClusterDist(hu.clr.norm, hu.kmed.cl)
hu.noncore <- DefineNonCore(hu.cl.dist, hu.kmed.cl, c(0.9, 0.9, 0.75, 0.92, 0.95, 0.94, 0.95, 0.92))
hu.cluster.assign <- LabelClusterHTO(hu.clr.norm, hu.kmed.cl, hu.noncore, "medoids")
hu.md.mat <- CalculateMD(hu.clr.norm, hu.noncore, hu.kmed.cl, hu.cluster.assign)
hu.outlier.assign <- AssignOutlierDrop(hu.md.mat, md_cut_q = 0.94)
hu.cmddemux.assign <- CMDdemuxClass(hu.md.mat, hu.hash.count, hu.outlier.assign, TRUE, hu.gex.count, 3, 2)
hu.demux.result$CMDdemux <- hu.cmddemux.assign$demux_global_class

# 3. HTOdemux
hu.obj <- CreateSeuratObject(counts = hu.hash.count)
hu.obj[["HTO"]] <- CreateAssayObject(counts = hu.hash.count)
hu.obj <- NormalizeData(hu.obj, assay = "HTO", normalization.method = "CLR")
hu.obj <- HTODemux(hu.obj, assay = "HTO", positive.quantile = 0.99)
hu.demux.result$HTODemux <- hu.obj$hash.ID[rownames(hu.demux.result)]

# 4. GMM-Demux
# Prepare input data
hu.gmm.input <- t(hu.hash.count)
write.csv(hu.gmm.input, "~/hu.gmm.input.csv", quote=F)
# Command: GMM-demux -c ~/hu.gmm.input.csv S1HuF,S2HuM,S3HuF,S4HuM,S5HuF,S6HuM,S7HuF,S8HuM -x S1HuF,S2HuM,S3HuF,S4HuM,S5HuF,S6HuM,S7HuF,S8HuM -f .
hu.gmm.output <- read.csv("~/GMM_full.csv")
hu.gmm.demux <- hu.gmm.output$Cluster_id
for(i in sort(unique(hu.gmm.output$Cluster_id))){
  if(i == 0){
    hu.gmm.demux[which(hu.gmm.demux %in% i)] <- "Negative"
  }else if(i %in% 1:nrow(hu.hash.count)){
    hu.gmm.demux[which(hu.gmm.demux %in% i)] <- rownames(hu.hash.count)[i]
  }else{
    hu.gmm.demux[which(hu.gmm.demux %in% i)] <- "Doublet"
  }
}  
names(hu.gmm.demux) <- hu.gmm.output$X
hu.demux.result$`GMM-Demux` <- hu.gmm.demux[rownames(hu.demux.result)]

# 5. deMULTIplex2
hu.demultiplex2.output <- demultiplexTags(hu.gmm.input, plot.diagnostics = FALSE, seed = 2024)
hu.demultiplex2.assign <- hu.demultiplex2.output$final_assign
hu.demultiplex2.assign[which(hu.demultiplex2.assign %in% "multiplet")] <- "Doublet"
hu.demultiplex2.assign[which(hu.demultiplex2.assign %in% "negative")] <- "Negative"
hu.demux.result$deMULTIplex2 <- hu.demultiplex2.assign[rownames(hu.demux.result)]

# 6. demuxEM
# demuxEM -p 8 --random-state 2024 ~/experiment1_human_st_raw_10x.h5 ~/experiment1_human_st_ADT.csv hu_demuxEM
# Python command
#import pegasusio as io
#data = io.read_input('~/hu_demuxEM.out.demuxEM.zarr.zip')
#df = pd.DataFrame(data.obs['demux_type'])
#df.to_csv('~/hu.demuxEM.demux.txt', sep='\t', index=True)
#df = pd.DataFrame(data.obs['assignment'])
#df.to_csv('~/hu.demuxEM.assign.txt', sep='\t', index=True)
hu.demuxEM.out1 <- read.table(file = "~/hu.demuxEM.demux.txt", header = TRUE)
hu.demuxEM.out2 <- read.table(file = "~/hu.demuxEM.assign.txt", sep = "\t", header = TRUE)
hu.demuxEM.assign <- hu.demuxEM.out1$demux_type
names(hu.demuxEM.assign) <- paste0(hu.demuxEM.out1$barcodekey, "-1")
# Treat "unknown" as "Negative"
hu.demuxEM.assign[which(hu.demuxEM.assign %in% "unknown")] <- "Negative"
hu.demuxEM.assign[which(hu.demuxEM.assign %in% "doublet")] <- "Doublet"
singlet.idx <- which(hu.demuxEM.out1$demux_type %in% "singlet")
hu.demuxEM.assign[singlet.idx] <- hu.demuxEM.out2$assignment[singlet.idx]
hu.demux.result$demuxEM <- hu.demuxEM.assign[rownames(hu.demux.result)]

# 7. demuxmix
hu.gex.genes <- colSums(hu.gex.count > 0)
hu.demuxmix.model <- demuxmix(hto = hu.hash.count, rna = hu.gex.genes)
hu.demuxmix.labels <- dmmClassify(hu.demuxmix.model)
hu.demuxmix.assign <- hu.demuxmix.labels$HTO
hu.demuxmix.assign[which(hu.demuxmix.assign %in% "uncertain")] <- "Uncertain"
hu.demuxmix.assign[which(hu.demuxmix.assign %in% "negative")] <- "Negative"
hu.demuxmix.assign[which(hu.demuxmix.labels$Type %in% "multiplet")] <- "Doublet"
names(hu.demuxmix.assign) <- rownames(hu.demuxmix.labels)
hu.demux.result$demuxmix <- hu.demuxmix.assign[rownames(hu.demux.result)]

# 8. hashedDrops
hu.hashedrops.output <- hashedDrops(hu.hash.count)
hu.hashedrops.assign <- rownames(hu.hash.count)[hu.hashedrops.output$Best]
hu.hashedrops.assign[!hu.hashedrops.output$Confident] <- "Uncertain"
hu.hashedrops.assign[which(hu.hashedrops.output$Doublet)] <- "Doublet"
names(hu.hashedrops.assign) <- colnames(hu.hash.count)
hu.demux.result$hashedDrops <- hu.hashedrops.assign[rownames(hu.demux.result)]

# 9. BFF
hu.bff.output <- GenerateCellHashingCalls(barcodeMatrix = hu.hash.count, methods = c("bff_raw", "bff_cluster"))
hu.demux.result$BFF_raw <- hu.bff.output$bff_raw
hu.demux.result$BFF_cluster <- hu.bff.output$bff_cluster

# Dimensional reduction for visualization
hu.sce <- SingleCellExperiment(assays = list(hto = hu.hash.count, clr = hu.clr.norm))
hu.sce <- runTSNE(hu.sce,exprs_values = "clr")
hu.sce <- runUMAP(hu.sce,exprs_values = "clr")

# Precision, recall, F, MCC
hu.label.levels <- c(rownames(hu.hash.count), "Doublet", "Negative", "Uncertain")
hu.evaluation <- GTStats(hu.demux.result, "Demuxlet", hu.label.levels[which(!hu.label.levels %in% c("Uncertain", "Negative"))])

# Average precision, recall, F, MCC
hu.evaluation.avg <- GTStatsAvg(hu.demux.result, "demuxlet", hu.label.levels[which(!hu.label.levels %in% "Uncertain")], "micro")

# Concordance of each method with the ground truth
hu.bench.concord <- ConfusionMat(hu.demux.result, "demuxlet", hu.label.levels) 

# Overall concordance
hu.ov.concord <- ConcordanceDF(hu.demux.result, "Demuxlet")

# Library size
hu.hto.lib <- log(colSums(hu.hash.count)+1)
hu.gex.lib <- log(colSums(hu.gex.count))
