
# compare models using cross validation


local <- "~/Dropbox/Work/Dan_Park/SDM"
setwd(local)

# --------------------------------------------------------------------------------------------------------------------------------
# functions

merge_all <- function(dflist, by.var) {
    Reduce(function(...) merge(..., by = by.var, all = TRUE), dflist)
}

"%!in%" <- Negate("%in%")

# --------------------------------------------------------------------------------------------------------------------------------
# options

options(stringsAsFactors = FALSE)

# --------------------------------------------------------------------------------------------------------------------------------
# packages

# installpackages("caret", dependencies = c("Imports", "Depends", "Suggests"))
packages <- c("caret", "e1071", "subselect", "ipred", "parallel", "doParallel", "corrplot",  "rowr", "ggplot2") 
sapply(packages, require, character.only = TRUE)


# --------------------------------------------------------------------------------------------------------------------------------
# load data

project_dir <- "path-to-output-directory"
data_dir <- "path-to-input-directory"

# load pre-processed data
load(file.path(data_dir, "data_PP.Rdata"))


# --------------------------------------------------------------------------------------------------------------------------------
# tuning parameters

# https://machinelearningmastery.com/tune-machine-learning-algorithms-in-r/

modelMethods <- c("cubist", "rf", "xgbLinear", "rqnc", "gamSpline", "penalized", "BstLm", 
                                "simpls", "widekernelpls", "glmnet", "gaussprPoly", "pcr", "lm")
tuningParams <- lapply(modelMethods, modelLookup)
names(tuningParams) <- modelMethods
tuningParams

# --------------------------------------------------------------------------------------------------------------------------------
control <- trainControl(method = "repeatedcv", number = 10, repeats = 10, allowParallel = TRUE)
metric <- "RMSE"

# --------------------------------------------------------------------------------------------------------------------------------
# data for species response
datSpecies <- cv_data_PP$species_PP$dataset_PP[, colnames(cv_data_PP$species_PP$dataset_PP) %!in% 
    c("fips", "region", "totalOntree", paste("totalOntree", c("n", "i", "e", "int"), sep = "."), responseColumns[responseColumns %!in% "species"])]

# --------------------------------------------------------------------------------------------------------------------------------
# Converting every categorical variable to numerical using dummy variables
# dmy <- dummyVars(" ~ state", data = datSpecies, fullRank = TRUE)
# datSpecies2 <- data.frame(datSpecies[, !(colnames(datSpecies) %in% "state")], data.frame(predict(dmy, newdata = datSpecies)))
   
# --------------------------------------------------------------------------------------------------------------------------------
# parallel processing
cluster <- makeCluster(detectCores() - 1) # convention to leave 1 core for OS
registerDoParallel(cluster)

# -------------------------------------------------------------------------------------------------------------------------------- 

cat("starting model fitting", "\n")
   
set.seed(7)
cv_cubist_species <- train(species ~ . , method = "cubist", metric = metric, trControl = control, na.action = "na.exclude", 
    tuneGrid = data.frame(committees = 100, neighbors = 9),
    data = datSpecies)

set.seed(7)
mtry <- floor(sqrt(ncol(datSpecies[, !(colnames(datSpecies) %in% "species")])))
tunegrid <- expand.grid(.mtry=mtry) # grid or random search?
cv_rf_species <- train(species ~ . , method = "rf", metric = metric, trControl = control, na.action = "na.exclude", 
    tuneGrid = tunegrid,
    data = datSpecies)
    
set.seed(7)
cv_xgbLinear_species <- train(species ~ . , method = "xgbLinear", metric = metric, trControl = control, na.action = "na.exclude", 
    data = datSpecies)
    
set.seed(7)
cv_rqnc_species <- train(species ~ . , method = "rqnc", metric = metric, trControl = control, na.action = "na.exclude", 
    data = datSpecies)
                
set.seed(7)
cv_gamSpline_species <- train(species ~ . , method = "gamSpline", metric = metric, trControl = control, na.action = "na.exclude", 
    data = datSpecies)

set.seed(7)
cv_penalized_species <- train(species ~ . , method = "penalized", metric = metric, trControl = control, na.action = "na.exclude", 
    tuneGrid = data.frame(lambda1 = 0, lambda2 = 1),
    data = datSpecies)

set.seed(7)
cv_BstLm_species <- train(species ~ . , method = "BstLm", metric = metric, trControl = control, na.action = "na.exclude", 
    data = datSpecies)

set.seed(7)
cv_simpls_species <- train(species ~ . , method = "simpls", metric = metric, trControl = control, na.action = "na.exclude", 
    data = datSpecies)

set.seed(7)
cv_widekernelpls_species <- train(species ~ . , method = "widekernelpls", metric = metric, trControl = control, na.action = "na.exclude", 
    data = datSpecies)

set.seed(7)
cv_glmnet_species <- train(species ~ . , method = "glmnet", metric = metric, trControl = control, na.action = "na.exclude", 
    data = datSpecies)

set.seed(7)
cv_gaussprPoly_species <- train(species ~ . , method = "gaussprPoly", metric = metric, trControl = control, na.action = "na.exclude", 
    data = datSpecies)

set.seed(7)
cv_pcr_species <- train(species ~ . , method = "pcr", metric = metric, trControl = control, na.action = "na.exclude", 
    data = datSpecies)

set.seed(7)
cv_lm_species <- train(species ~ . , method = "lm", metric = metric, trControl = control, na.action = "na.exclude", 
    data = datSpecies)


cat("finished model fitting", "\n")


# de-register parallel processing cluster
stopCluster(cluster)
registerDoSEQ()


# --------------------------------------------------------------------------------------------------------------------------------
# save output

cv_models_comp <- list(
    "cubist" = cv_cubist_species,
    "rf" = cv_rf_species,
    "xgbLinear" = cv_xgbLinear_species, 
    "rqnc" = cv_rqnc_species,
    "gamSpline" = cv_gamSpline_species,
    "penalized" = cv_penalized_species,
    "BstLm" = cv_BstLm_species,
    "simpls" = cv_simpls_species,
    "widekernelpls" = cv_widekernelpls_species,
    "glmnet" = cv_glmnet_species,
    "gaussprPoly" = cv_gaussprPoly_species,
    "pcr" = cv_pcr_species,
    "lm" = cv_lm_species
)

save(cv_models_comp, file = file.path(project_dir, "cv_models_comp_species.Rdata"), compress = "gzip")

    

