# Import libaries
library("data.table")
library("ggplot2")
library("magrittr")

# Data path to evaluation directory of MIScnn
path <- "/home/mudomini/projects/KITS_challenge2019/kits19.MIScnn.validation.101019/kits19.MIScnn.validation/evaluation/"


# List folds
folds <- list.files(path, full.names=TRUE, include.dirs=TRUE, recursive=FALSE)

# Load validation data and collect them in one file
load_val_data <- function(subdir_path){
  val_path <- file.path(subdir_path, "history.tsv")
  fold <- substr(subdir_path, nchar(subdir_path)-1+1, nchar(subdir_path))
  val_fold <- fread(val_path, sep="\t") %>%
    .[,fold:=fold]
  return(val_fold)
}
# Run data collection
validation <- lapply(folds, load_val_data) %>%
  rbindlist(.)
# Load kits19 scoring data and collect them in one file
load_score_data <- function(subdir_path){
  fold <- substr(subdir_path, nchar(subdir_path)-1+1, nchar(subdir_path))
  score_path <- file.path(subdir_path, "detailed_validation.tsv")
  score_fold <- fread(score_path, sep="\t") %>%
    .[,fold:=fold]
  return(score_fold)
}
# Run data collection
scoring <- lapply(folds, load_score_data) %>%
  rbindlist(.)








# Plot loss validation figure - Preprocessing
loss_df <- melt(validation, id.vars=c("epoch", "fold"), 
                measure.vars=c("loss","val_loss"),
                variable.name="data_set",
                value.name="tversky_loss")
loss_df[data_set=="loss"]$data_set <- "Train set"
loss_df[data_set=="val_loss"]$data_set <- "Test set"
# Plot loss validation figure - Plotting (variant1)
plot_loss1 <- ggplot(loss_df, aes(epoch, tversky_loss, group=interaction(fold, data_set),
                                  color=fold, linetype=data_set)) + 
  geom_line(size=0.75) +
  #geom_smooth(method="loess") + 
  #geom_point(size=0.75) + 
  #scale_y_continuous(breaks=seq(0,2.5, 0.1), limits=c(0, 2.5)) + 
  #theme_classic() + 
  theme_bw() +
  scale_color_brewer(palette="Dark2") +
  labs(linetype="Data Set", color="CV Fold", x="Epoch", y="Tversky Loss") + 
  ggtitle("Loss Function during Model Training")
png("loss.var1.png", width=700, height=500, res=120)
plot_loss1
dev.off()

# Plot loss validation figure - Plotting (variant2)
loss_df_var2 <- loss_df[, .(loss_mean=mean(tversky_loss), loss_min=min(tversky_loss),
                            loss_max=max(tversky_loss)), 
                        by=c("epoch", "data_set")]
plot_loss2 <- ggplot(loss_df_var2, aes(epoch, loss_mean, group=data_set, color=data_set)) + 
  #geom_errorbar(aes(ymin=loss_min, ymax=loss_max), width=0.1, linetype="longdash") +
  geom_line(size=0.6, alpha=0.65) +
  geom_point(size=0.5) + 
  #theme_classic() + 
  theme_bw() +
  scale_color_brewer(palette="Dark2") +
  labs(color="Data Set", x="Epoch", y="Tversky Loss") + 
  ggtitle("Loss Function during Model Training")
png("loss.var2.png", width=700, height=500, res=120)
plot_loss2
dev.off()










# Plot dice validation figure - Preprocessing
dice_df <- melt(validation, id.vars=c("epoch", "fold"), 
                measure.vars=c("dice_soft","val_dice_soft"),
                variable.name="data_set",
                value.name="dice_soft")
dice_df[data_set=="dice_soft"]$data_set <- "Train set"
dice_df[data_set=="val_dice_soft"]$data_set <- "Test set"
# Plot dice validation figure - Plotting (variant1)
plot_dice1 <- ggplot(dice_df, aes(epoch, dice_soft, group=interaction(fold, data_set),
                                  color=fold, linetype=data_set)) + 
  geom_line(size=0.5) +
  #geom_point(size=0.75) + 
  #theme_classic() + 
  theme_bw() +
  scale_color_brewer(palette="Dark2") +
  labs(linetype="Data Set", color="CV Fold", x="Epoch", y="Soft DSC") + 
  ggtitle("Dice Coefficient during Model Training")
png("dice.var1.png", width=700, height=500, res=120)
plot_dice1
dev.off()

# Plot dice validation figure - Plotting (variant2)
dice_df_var2 <- dice_df[, .(dice_mean=mean(dice_soft), dice_min=min(dice_soft),
                            dice_max=max(dice_soft)), 
                        by=c("epoch", "data_set")]
plot_dice2 <- ggplot(dice_df_var2, aes(epoch, dice_mean, group=data_set, color=data_set)) + 
  #geom_errorbar(aes(ymin=dice_min, ymax=dice_max), width=0.1, linetype="longdash") +
  geom_line(size=0.6, alpha=0.65) +
  geom_point(size=0.5) + 
  #theme_classic() + 
  theme_bw() +
  scale_color_brewer(palette="Dark2") +
  labs(color="Data Set", x="Epoch", y="Soft DSC") + 
  ggtitle("Dice Coefficient during Model Training")
png("dice.var2.png", width=700, height=500, res=120)
plot_dice2
dev.off()






# Plot scoring figure
val_df <- melt(scoring, id.vars=c("sample_id"), 
               measure.vars=c("dice_class-1","dice_class-2"),
               variable.name="class",
               value.name="dice",
               variable.factor=TRUE)
#val_df[class=="dice_class-0"]$class <- "Background"
val_df[class=="dice_class-1"]$class <- "Kidney"
val_df[class=="dice_class-2"]$class <- "Tumor"
# Plot scoring figure
plot_score <- ggplot(val_df, aes(class, dice, fill=class)) + 
  geom_boxplot() +
  scale_y_continuous(breaks=seq(0, 1, 0.1), limits=c(0, 1)) + 
  #theme_classic() + 
  theme_bw() +
  scale_fill_brewer(palette="Dark2") +
  theme(legend.position = "none") + 
  labs(x = "", y="Dice Similarity Coefficient") + 
  ggtitle("Dice Coefficients from Validation")
png("score.png", width=700, height=500, res=120)
plot_score
dev.off()



# Multiple plot function
#
# ggplot objects can be passed in ..., or to plotlist (as a list of ggplot objects)
# - cols:   Number of columns in layout
# - layout: A matrix specifying the layout. If present, 'cols' is ignored.
#
# If the layout is something like matrix(c(1,2,3,3), nrow=2, byrow=TRUE),
# then plot 1 will go in the upper left, 2 will go in the upper right, and
# 3 will go all the way across the bottom.
#
multiplot <- function(..., plotlist=NULL, file, cols=1, layout=NULL) {
  library(grid)
  
  # Make a list from the ... arguments and plotlist
  plots <- c(list(...), plotlist)
  
  numPlots = length(plots)
  
  # If layout is NULL, then use 'cols' to determine layout
  if (is.null(layout)) {
    # Make the panel
    # ncol: Number of columns of plots
    # nrow: Number of rows needed, calculated from # of cols
    layout <- matrix(seq(1, cols * ceiling(numPlots/cols)),
                     ncol = cols, nrow = ceiling(numPlots/cols))
  }
  
  if (numPlots==1) {
    print(plots[[1]])
    
  } else {
    # Set up the page
    grid.newpage()
    pushViewport(viewport(layout = grid.layout(nrow(layout), ncol(layout))))
    
    # Make each plot, in the correct location
    for (i in 1:numPlots) {
      # Get the i,j matrix positions of the regions that contain this subplot
      matchidx <- as.data.frame(which(layout == i, arr.ind = TRUE))
      
      print(plots[[i]], vp = viewport(layout.pos.row = matchidx$row,
                                      layout.pos.col = matchidx$col))
    }
  }
}

# Combine all figures into a single plot
png("all.png", width=1800, height=500, res=140)
layout <- matrix(c(1, 1, 1, 2, 2, 2, 3, 3), 1, 8, byrow=TRUE)
multiplot(plot_loss2, plot_dice2, plot_score, cols=3, layout=layout)
dev.off()

# Calculate scores
scoring[, .(background_mean=mean(scoring$"dice_class-0"), background_median=median(scoring$"dice_class-0"),
            kidney_mean=mean(scoring$"dice_class-1"), kidney_median=median(scoring$"dice_class-1"),
            tumor_mean=mean(scoring$"dice_class-2"), tumor_median=median(scoring$"dice_class-2"))]

val_scores_calc <- function(fold_numb){
  validation[fold==fold_numb & epoch==length(validation[fold==fold_numb]$epoch),]
}
rbindlist(lapply(seq(0,2), val_scores_calc)) %>%
  .[, .(loss=mean(loss), val_loss=mean(val_loss),
        dice_soft=mean(dice_soft), val_dice_soft=mean(val_dice_soft),
        dice_crossentropy=mean(dice_crossentropy), val_dice_crossentropy=mean(val_dice_crossentropy))]
