#!/usr/bin/env Rscript

## we need some custom functions we've written so have to figure out how to load those in rocker
files.sources <- list.files('~/work/RCode/NASA-FaCeT/R/', full.names = TRUE)
invisible(sapply(files.sources, source))

library(argparser, quietly=TRUE)
suppressMessages(library(dplyr, quietly=TRUE))
suppressMessages(library(data.table, quietly=TRUE))
suppressMessages(library(dismo, quietly=TRUE))
library(logger, quietly=TRUE)

## Build up our command line argument parser
p <- arg_parser("Fit BRT model")
p <- add_argument(p, "input_csv", help="input CSV file of presence-(pseudo)absence data. Minimum req'd columns correspond to response variable (e.g. observation type) and predictor variable(s) in model_config")
p <- add_argument(p, "config_file", help="path to csv file containing model config as csv")
p <- add_argument(p, "--gbm_step", default = FALSE, help="logical, indicating whether to use gbm.step from dismo package to fit the BRT. Default is FALSE and gbm.fixed is used.")
p <- add_argument(p, "output_model", help="path to output model results to as R native .RDS")
#p <- add_argument(p, "output_eval", help="path to output model evaluation results to as csv")

## parse the arguments & log them
args <- parse_args(p)
for (i in 1:length(args)){
  logger::log_info(paste0('Args include ', names(args)[i],': ', args[[i]]))
}
args$gbm_step <- as.logical(args$gbm_step)

## set seed to facilitate consistency across runs
set.seed(311)

## read in data
df <- data.table::fread(args$input_csv, sep=',', header=T)

## read config
config <- data.table::fread(args$config_file, sep=',', header=T)
for (i in 1:nrow(config)){
  logger::log_info(paste0('Config include ', config[i,1],': ', config[i,2]))
}
predictors <- c(stringr::str_split(config$config_value[which(config$config_variable == 'gbm.x')], ';', simplify = TRUE))
response <- config$config_value[which(config$config_variable == 'gbm.y')]

## get rid of NA values in predictors (env vars)
df <- data.frame(na.omit(df, cols = predictors))

if (any(!(predictors %in% names(df)))) stop('Not all variables specified for gbm.x (predictors) in config file are present in input data.')

if (!(response %in% names(df))) stop('Variable specified for gbm.y (response) in config file is not present in input data.')

## sort out which config variables we have or dont
family <- ifelse(any(config$config_variable == 'family'), 
                  config$config_value[which(config$config_variable == 'family')],
                  'bernoulli')

tc <- ifelse(any(config$config_variable == 'tree.complexity'), 
                  as.numeric(config$config_value[which(config$config_variable == 'tree.complexity')]),
                  1)

lr <- ifelse(any(config$config_variable == 'learning.rate'), 
              as.numeric(config$config_value[which(config$config_variable == 'learning.rate')]),
              0.01)

bf <- ifelse(any(config$config_variable == 'bag.fraction'), 
              as.numeric(config$config_value[which(config$config_variable == 'bag.fraction')]),
              0.75)

n.trees <- ifelse(any(config$config_variable == 'n.trees'), 
                  as.numeric(config$config_value[which(config$config_variable == 'n.trees')]),
                  2000)

k.folds <- ifelse(any(config$config_variable == 'k.folds'), 
              as.numeric(config$config_value[which(config$config_variable == 'k.folds')]),
              10)

if (args$gbm_step){
  ## step through model fit
  brt <- dismo::gbm.step(data = df, 
                         gbm.x = predictors, 
                         gbm.y = response, ### response variable
                         family = family,
                         tree.complexity = tc, ### complexity of the interactions that the model will fit
                         learning.rate = lr,  ### optimized to end up with >1000 trees
                         bag.fraction = bf) ### recommended by Elith, amount of input data used each time
  
} else{
  ## use fixed # of trees
  brt <- dismo::gbm.fixed(data = df, 
                          gbm.x = predictors, 
                          gbm.y = response, ### response variable
                          family = family,
                          tree.complexity = tc, ### complexity of the interactions that the model will fit
                          learning.rate = lr, 
                          bag.fraction = bf, ### recommended by Elith, amount of input data used each time
                          n.trees = n.trees)
  
}

## write out model results
saveRDS(brt, file = args$output_model)
log_info(paste0('Model written to ', args$output_model, '.'))


