# XGBoost Model of defense system classification using prophages
library(xgboost) 
library(pROC)
library(dplyr)
library(irr)          # to obtain kappa
library(caret)


rm(list=ls(all=T)) # clears workspace
set.seed(3) 

# Loading input data
df <- read.table(file="matrix_90_ab_ml_freqmlst8_wored100_onlyph.tsv",header=T,sep="\t", row.names = 2) 
df <- df[,-1]

type_list<-c("Cas","CBASS","Gabija","RosmerTA","R-M","Gao_Qat","PD-T4-5","PD-T7-5","Ssp") # Select types
roc_list <- list()
df_performance<- data.frame(matrix(nrow= 0, ncol=6))
colnames(df_performance) <- c("Types","Kappa","%CC","AUC","Sensibility","Especificity")

for (t in type_list){
  print(t)
  Pres <- df[grepl(t, df$types),]
  Pres$types = as.integer(1)
  Aus <- df[!grepl(t, df$types),]
  Aus$types = as.integer(0)
  data<- rbind(Pres,Aus)
  
  # Split data in training and testing data
  m <- sample.int(n=nrow(data), size=floor(.7*nrow(data)),replace = F)
  traindata <- data[m,]
  testdata <- data[-m,]
  # Only predictors (features) as matrix
  traindatax <- as.matrix(subset(traindata, select=-types))
  # Objective variable (labels)
  traindatay <- c(traindata$types); head(traindatay,10)
  # Same in testing data
  testdatax<-as.matrix(subset(testdata, select=-types))
  testdatay<-c(testdata$types)
  
  # # Grid search for hyperparameters optimization
  #   gridxgb <- expand.grid(eta = c(1e-3, 1e-2, 1e-1), # Learning rate
  #                          max_depth = c(1, 5, 10), # Max depth of the tree
  #                          min_child_weight = c(1, 5), # Minimal number of requiered samples in each terminal node
  #                          subsample = c(.5, 1, by=0.15), # Training subset percentage for each tree
  #                          colsample_bytree= c(.5, 1, by=0.15), # Percentage of features for each tree
  #                          gamma = c(1e-3, 1e-2, 1e-1,1, 5), # Regularization (loss reduction)
  #                          lambda = c(0, 1e-3, 1e-2, 1e-1, 1, 10), # Regularization L2
  #                          alpha = c( 0, 1e-3, 1e-2, 1e-1, 1, 10), # Regularization L1
  # #
  #                          CC = 0,
  #                          optntrees = 0, # Save results
  #                          minerror = 0) # Save results
  #   ncombs <- 10000 # No. random combinations to try
  #   nc <- sample(nrow(gridxgb), ncombs) # Random index
  #   gridxgb <- gridxgb[nc, ] # Random hyperparameters
  # #
  #   if (t == "CAS" || t == 'Gabija'){ # Unbalanced data
  #     ratio<- nrow(Aus)/nrow(Pres)
  #     for(i in 1:ncombs){
  #       fitxgb <- xgb.cv(data=traindatax, label=traindatay, nrounds = 2000, nthread = 10,
  #                        metrics = "error", early_stopping_rounds = 10, nfold = 5,
  #                        scale_pos_weight = ratio,
  #                        verbose = F,
  #                        objective = "binary:logistic",
  #                        prediction = T,
  #                        eta = gridxgb$eta[i],
  #                        max_depth = gridxgb$max_depth[i],
  #                        min_child_weight = gridxgb$min_child_weight[i],
  #                        subsample = gridxgb$subsample[i],
  #                        colsample_bytree = gridxgb$colsample_bytree[i],
  #                        gamma = gridxgb$gamma[i],
  #                        lambda = gridxgb$lambda[i],
  #                        alpha = gridxgb$alpha[i])
  # #
  #       # Saving results
  #       gridxgb$optntrees[i] <- fitxgb$best_iteration
  #       gridxgb$minerror[i] <- min(fitxgb$evaluation_log$test_error_mean)
  #       gridxgb$CC[i] <- 100 - round(min(fitxgb$evaluation_log$test_error_mean * 100))
  # #
  #       if (i %% 100 == 0){
  #         print(i) # Print iteration number for monitoring the process
  #       }
  #     }
  #    }else{
  #     for(i in 1:ncombs){
  #       fitxgb <- xgb.cv(data=traindatax, label=traindatay, nrounds = 2000, nthread = 10,
  #                        metrics = "error", early_stopping_rounds = 10, nfold = 5,
  #                        verbose = F,
  #                        objective = "binary:logistic",
  #                        prediction = T,
  #                        eta = gridxgb$eta[i],
  #                        max_depth = gridxgb$max_depth[i],
  #                        min_child_weight = gridxgb$min_child_weight[i],
  #                        subsample = gridxgb$subsample[i],
  #                        colsample_bytree = gridxgb$colsample_bytree[i],
  #                        gamma = gridxgb$gamma[i],
  #                        lambda = gridxgb$lambda[i],
  #                        alpha = gridxgb$alpha[i])
  # 
  #       # Saving results
  #       gridxgb$optntrees[i] <- fitxgb$best_iteration
  #       gridxgb$minerror[i] <- min(fitxgb$evaluation_log$test_error_mean)
  #       gridxgb$CC[i] <- 100 - round(min(fitxgb$evaluation_log$test_error_mean * 100))
  # 
  #       if (i %% 100 == 0){
  #         print(i) # Print iteration number for monitoring the process
  #       }
  # 
  #     }
  #   }
  # save(gridxgb, file = paste0("gridxgb_", t,".RData"))

  
  
  
  
  load(file=paste0("gridxgb_",t,".RData")) # Loading hyperparameter grids
  # Sorting the hyperparameter grids
  gridxgb <- gridxgb[order(gridxgb$minerror, decreasing = F),]
  gridxgb <- subset(gridxgb, minerror != 0)
  besthppsxgb <- gridxgb[1,]
  
  
  
  if (t=="CAS" || t== "Gabija"){
    #Weight
    ratio<- nrow(Aus)/nrow(Pres)
    bestmodelxgb <- xgboost(
      data = traindatax, label= traindatay,
      nrounds = besthppsxgb$optntrees,
      objective = "binary:logistic",
      scale_pos_weight = ratio,
      eta = besthppsxgb$eta,
      max_depth = besthppsxgb$max_depth,
      min_child_weight = besthppsxgb$min_child_weight,
      subsample = besthppsxgb$subsample,
      colsample_bytree = besthppsxgb$colsample_bytree,
      gamma = besthppsxgb$gamma,
      lambda = besthppsxgb$lambda,
      alpha = besthppsxgb$alpha,
      verbose = 1 
    )
    prob<- predict(bestmodelxgb, testdatax)
    pred <- ifelse(prob > 0.5, 1, 0)
  }else{
    bestmodelxgb <- xgboost(
      data = traindatax, label= traindatay,
      nrounds = besthppsxgb$optntrees,
      objective = "binary:logistic",
      eta = besthppsxgb$eta,
      max_depth = besthppsxgb$max_depth,
      min_child_weight = besthppsxgb$min_child_weight,
      subsample = besthppsxgb$subsample,
      colsample_bytree = besthppsxgb$colsample_bytree,
      gamma = besthppsxgb$gamma,
      lambda = besthppsxgb$lambda,
      alpha = besthppsxgb$alpha,
      verbose = 1 
    )
    
    prob<- predict(bestmodelxgb, testdatax)
    pred <- ifelse(prob > 0.5, 1, 0)
  }
  
  # Target variable
  defsys_type <- testdatay
  
  # Convert to factor with consistent levels
  pred_factor <- factor(pred, levels = c(0,1))
  obs_factor <- factor(defsys_type, levels = c(0,1))
  
  # Generate confusion matrix using caret
  cm <- confusionMatrix(pred_factor, obs_factor, positive = "1")
  
  # Extract metrics
  acc <- round(cm$overall['Accuracy'], 2)
  kappa <- round(cm$overall['Kappa'], 2)
  precision <- round(cm$byClass['Precision'], 2)
  recall <- round(cm$byClass['Recall'], 2)
  f1_score <- round(cm$byClass['F1'], 2)
  sensitivity <- round(cm$byClass['Sensitivity'], 2)
  specificity <- round(cm$byClass['Specificity'], 2)
  
  # Calculate ROC and AUC
  roc <- roc(response = defsys_type, predictor = prob, levels=c("0", "1"), quiet = TRUE)
  AUC <- round(auc(roc), 2); AUC
  
  # Roc performance curves
  # 2. Get ggplot version of the ROC curve
  roc_plot <- ggroc(roc, color = "black", size = 1.2) +
    ggtitle(paste("ROC Curve ", t, "(AUC = ",AUC, ")" )) +
    theme_minimal()
  
  roc_list <- append(roc_list, assign(t,roc_plot))
  
  new_performance <- c(t,kappa,acc,AUC,precision,recall, f1_score)
  df_performance <- rbind(df_performance,new_performance)
  
  model_data <- list(
    model = bestmodelxgb,
    pres = Pres
  )
  
  model_name <- paste0("xgb_model_defensome_",t,".rds") # RDS is a R-specific format that preserves all components of the model

  saveRDS(model_data, model_name)
}

# Put all barplots together
  pdf("ROC_curves.pdf", height= 12, width = 18)
  plot_grid(Cas, CBASS, Gao_Qat, `R-M`, `PD-T4-5`, `PD-T7-5`, Gabija, RosmerTA, Ssp,ncol=3, nrow=3)
  dev.off()

library(writexl)
write_xlsx(df_performance,  "Suppl. Table.xlsx")

