
# compare models using cross validation

# --------------------------------------------------------------------------------------------------------------------------------
# functions

merge_all <- function(dflist, by.var) {
    Reduce(function(...) merge(..., by = by.var, all = TRUE), dflist)
}

"%!in%" <- Negate("%in%")

# --------------------------------------------------------------------------------------------------------------------------------
# options

options(stringsAsFactors = FALSE)

# --------------------------------------------------------------------------------------------------------------------------------
# packages

# installpackages("caret", dependencies = c("Imports", "Depends", "Suggests"))
packages <- c("caret", "e1071", "subselect", "ipred", "parallel", "doParallel", "corrplot",  "rowr", "ggplot2")
sapply(packages, require, character.only = TRUE)


# --------------------------------------------------------------------------------------------------------------------------------
# load data

project_dir <- "path-to-output-directory"
data_dir <- "path-to-input-directory"

# load pre-processed data
load(file.path(data_dir, "data_PP.Rdata"))


# --------------------------------------------------------------------------------------------------------------------------------
# tuning parameters

# https://machinelearningmastery.com/tune-machine-learning-algorithms-in-r/

modelMethods <- c("cubist", "rf", "xgbLinear", "rqnc", "gamSpline", "penalized", "BstLm",
                                "simpls", "widekernelpls", "glmnet", "gaussprPoly", "pcr", "lm")
tuningParams <- lapply(modelMethods, modelLookup)
names(tuningParams) <- modelMethods
tuningParams

# --------------------------------------------------------------------------------------------------------------------------------
control <- trainControl(method = "repeatedcv", number = 10, repeats = 10, allowParallel = TRUE)
metric <- "RMSE"

# --------------------------------------------------------------------------------------------------------------------------------
# data for MPD.s.ALL response
dat_MPD.s.ALL <- cv_data_PP$MPD.s.ALL_PP$dataset_PP[, colnames(cv_data_PP$MPD.s.ALL_PP$dataset_PP) %!in%
    c("fips", "region",  paste("totalOntree", c("n", "i", "e", "int"), sep = "."), responseColumns[responseColumns %!in% "MPD.s.ALL"])]

# --------------------------------------------------------------------------------------------------------------------------------
# parallel processing
cluster <- makeCluster(detectCores() - 1) # convention to leave 1 core for OS
registerDoParallel(cluster)

# --------------------------------------------------------------------------------------------------------------------------------

cat("starting model fitting", "\n")

set.seed(7)
cv_cubist_MPD.s.ALL <- train(MPD.s.ALL ~ . , method = "cubist", metric = metric, trControl = control, na.action = "na.exclude",
    tuneGrid = data.frame(committees = 100, neighbors = 9),
    data = dat_MPD.s.ALL)

set.seed(7)
mtry <- floor(sqrt(ncol(dat_MPD.s.ALL[, !(colnames(dat_MPD.s.ALL) %in% "MPD.s.ALL")])))
tunegrid <- expand.grid(.mtry=mtry) # grid or random search?
cv_rf_MPD.s.ALL <- train(MPD.s.ALL ~ . , method = "rf", metric = metric, trControl = control, na.action = "na.exclude",
    tuneGrid = tunegrid,
    data = dat_MPD.s.ALL)

set.seed(7)
cv_xgbLinear_MPD.s.ALL <- train(MPD.s.ALL ~ . , method = "xgbLinear", metric = metric, trControl = control, na.action = "na.exclude", 
    data = dat_MPD.s.ALL)
    
set.seed(7)
cv_rqnc_MPD.s.ALL <- train(MPD.s.ALL ~ . , method = "rqnc", metric = metric, trControl = control, na.action = "na.exclude", 
    data = dat_MPD.s.ALL)

set.seed(7)
cv_gamSpline_MPD.s.ALL <- train(MPD.s.ALL ~ . , method = "gamSpline", metric = metric, trControl = control, na.action = "na.exclude", 
    data = dat_MPD.s.ALL)

set.seed(7)
cv_penalized_MPD.s.ALL <- train(MPD.s.ALL ~ . , method = "penalized", metric = metric, trControl = control, na.action = "na.exclude", 
    tuneGrid = data.frame(lambda1 = 0, lambda2 = 1),
    data = dat_MPD.s.ALL)

set.seed(7)
cv_BstLm_MPD.s.ALL <- train(MPD.s.ALL ~ . , method = "BstLm", metric = metric, trControl = control, na.action = "na.exclude", 
    data = dat_MPD.s.ALL)

set.seed(7)
cv_simpls_MPD.s.ALL <- train(MPD.s.ALL ~ . , method = "simpls", metric = metric, trControl = control, na.action = "na.exclude", 
    data = dat_MPD.s.ALL)

set.seed(7)
cv_widekernelpls_MPD.s.ALL <- train(MPD.s.ALL ~ . , method = "widekernelpls", metric = metric, trControl = control, na.action = "na.exclude", 
    data = dat_MPD.s.ALL)

set.seed(7)
cv_glmnet_MPD.s.ALL <- train(MPD.s.ALL ~ . , method = "glmnet", metric = metric, trControl = control, na.action = "na.exclude", 
    data = dat_MPD.s.ALL)

set.seed(7)
cv_gaussprPoly_MPD.s.ALL <- train(MPD.s.ALL ~ . , method = "gaussprPoly", metric = metric, trControl = control, na.action = "na.exclude", 
    data = dat_MPD.s.ALL)

set.seed(7)
cv_pcr_MPD.s.ALL <- train(MPD.s.ALL ~ . , method = "pcr", metric = metric, trControl = control, na.action = "na.exclude", 
    data = dat_MPD.s.ALL)

set.seed(7)
cv_lm_MPD.s.ALL <- train(MPD.s.ALL ~ . , method = "lm", metric = metric, trControl = control, na.action = "na.exclude", 
    data = dat_MPD.s.ALL)
                
cat("finished model fitting", "\n")


# de-register parallel processing cluster
stopCluster(cluster)
registerDoSEQ()


# --------------------------------------------------------------------------------------------------------------------------------
# save output

cv_models_comp <- list(
    "cubist" = cv_cubist_MPD.s.ALL,
    "rf" = cv_rf_MPD.s.ALL,
    "xgbLinear" = cv_xgbLinear_MPD.s.ALL, 
    "rqnc" = cv_rqnc_MPD.s.ALL,
    "gamSpline" = cv_gamSpline_MPD.s.ALL,
    "penalized" = cv_penalized_MPD.s.ALL,
    "BstLm" = cv_BstLm_MPD.s.ALL,
    "simpls" = cv_simpls_MPD.s.ALL,
    "widekernelpls" = cv_widekernelpls_MPD.s.ALL,
    "glmnet" = cv_glmnet_MPD.s.ALL,
    "gaussprPoly" = cv_gaussprPoly_MPD.s.ALL,
    "pcr" = cv_pcr_MPD.s.ALL,
    "lm" = cv_lm_MPD.s.ALL
   )

save(cv_models_comp, file = file.path(project_dir, "cv_models_comp_MPD.s.ALL.Rdata"), compress = "gzip")

