library(ggplot2)
library(dplyr)
library(tidyr)
library(patchwork)
library(scales)
library(fmsb)
getwd()
setwd("/Users/fang/Documents/Projects/HF/scGEN/re")

##################Fig2###################
cols <- c("AttentionAE-sc"="#A6EBEB",'GraphSCC'='pink','k-means' = '#F2ECA2',
          "scDCCA"='#3E76AB',"scDEFR"='#718D8A',"scDeepCluster" = '#D59E71',
          "scTPC" = '#b96d62', "Spectral clustering" = '#467952', "SSNMDI" = '#B7B2D0',
          "scGEN"='#29274A')

df <- read.csv(file = './Github/Rcode and results/Performance/performance.csv', header = T)
#A
ACC <- subset(df, df$Metric == 'ACC')
rownames(ACC) <- ACC$Dataset
ACC <- ACC[,-c(1,2)]
ACC <- data.frame(t(ACC))
ACC <- rbind(rep(100,8), rep(0,8), ACC) 
rownames(ACC)[3] <- 'AttentionAE-sc'
rownames(ACC)[5] <- 'k-means'
rownames(ACC)[10] <- 'Spectral clustering'

radarchart(ACC, axistype = 1, axislabcol = "grey", 
           pcol = cols, plwd = 2, plty = 1,
           cglcol = "grey", cglwd = 1, cglty = 5, vlcex = 1.5,title = 'ACC')

ARI <- subset(df, df$Metric == 'ARI')
rownames(ARI) <- ARI$Dataset
ARI <- ARI[,-c(1,2)]
ARI <- data.frame(t(ARI))
ARI <- rbind(rep(100,8), rep(0,8), ARI) 
rownames(ARI)[3] <- 'AttentionAE-sc'
rownames(ARI)[5] <- 'k-means'
rownames(ARI)[10] <- 'Spectral clustering'
radarchart(ARI, axistype = 1, axislabcol = "grey", 
           pcol = cols, plwd = 2, plty = 1,
           cglcol = "grey", cglwd = 1, cglty = 5, vlcex = 1.5,title = 'ARI')

NMI <- subset(df, df$Metric == 'NMI')
rownames(NMI) <- NMI$Dataset
NMI <- NMI[,-c(1,2)]
NMI <- data.frame(t(NMI))
NMI <- rbind(rep(100,8), rep(0,8), NMI) 
rownames(NMI)[3] <- 'AttentionAE-sc'
rownames(NMI)[5] <- 'k-means'
rownames(NMI)[10] <- 'Spectral clustering'
radarchart(NMI, axistype = 1, axislabcol = "grey", 
           pcol = cols, plwd = 2, plty = 1,
           cglcol = "grey", cglwd = 1, cglty = 5, vlcex = 1.5,title = 'NMI')



#B
box_plot <- function(df,val){
  
  df <- df[-c(1,2),]
  df['methods'] <- rownames(df)
  df <- df[,c(9,1:8)]
  plot <- gather(df, key = 'datasets', value = val, -methods)
  plot$methods <- factor(plot$methods, levels = c('AttentionAE-sc','GraphSCC','k-means',
                                                  'scDCCA','scDEFR','scDeepCluster',
                                                  'scTPC','Spectral clustering',
                                                  'SSNMDI','scGEN'))
  ggplot(plot, aes(x = methods, y = val, fill = methods)) + 
    geom_boxplot(outlier.size = 0.5) + geom_point(size = 0.5) +
    theme_bw() + labs(x = '',y = val) +
    scale_fill_manual(values = cols) +
    theme(axis.text.x = element_text(angle = 45, hjust = 1, vjust = 1)) + 
    theme(legend.position = 'none')
  
}
box_plot(ACC, 'ACC')
box_plot(ARI, 'ARI')
box_plot(NMI, 'NMI')

##################Ablation################
#scGEN w/o ZINB
#ACC
((90.29-88.83)+(56.54-52.18)+(69.68-66.59)+(96.41-81.06)+(86.03-82.56)+(93.09-85.87)+(90.64-68.25)+(77.83-68.21))/8
#NMI
((77.81-74.65)+(48.07-44.94)+(62.76-60.15)+(96.22-87.92)+(79.31-77.66)+(93.23-89.19)+(85.74-76.35)+(60.79-56.67))/8
#ARI
((77.14-74.02)+(29.12-28.04)+(58.59-55.08)+(96.32-81.18)+(69.39-65.17)+(88.95-83.19)+(90.03-66.00)+(62.87-55.47))/8

#scGEN w/o HSL
#ACC
((90.29-89.32)+(56.54-55.82)+(69.68-58.77)+(96.41-88.14)+(86.03-78.28)+(93.09-85.82)+(90.64-75.59)+(77.83-67.62))/8
#NMI
((77.81-75.68)+(48.07-48.75)+(62.76-56.00)+(96.22-90.33)+(79.31-76.07)+(93.23-84.44)+(85.74-75.03)+(60.79-54.85))/8
#ARI
((77.14-75.09)+(29.12-29.20)+(58.59-42.66)+(96.32-87.78)+(69.39-64.64)+(88.95-77.92)+(90.03-67.59)+(62.87-54.50))/8


