if (!requireNamespace("caret", quietly = TRUE)) {
  install.packages("caret")
}
if (!requireNamespace("brms", quietly = TRUE)) {
  install.packages("brms")
}
if (!requireNamespace("data.table", quietly = TRUE)) {
  install.packages("data.table")
}
if (!requireNamespace("hydroGOF", quietly = TRUE)) {
  install.packages("hydroGOF")
}

library(caret)
library(brms)

#### data ####
catchments <- data.table::fread(
  "catchments.csv",
  stringsAsFactors = TRUE,
  encoding = "UTF-8",
  data.table = FALSE
)

#### cross-validation ####
set.seed(1234567)
k <- 10 # number of folds

#### folds leaving out gauges stratified by cluster ####
cluster <- tapply(as.integer(catchments$Gauge), catchments$Cluster, unique)
createClusterFolds <- function(cluster, k) {
  folds <- createFolds(cluster, k, returnTrain = TRUE)
  lapply(folds, function(x) {cluster[x]})
}
clusterFolds <- lapply(cluster, createClusterFolds, k = k)
clusterFolds <- do.call(Map, c(c, clusterFolds))
folds <- lapply(clusterFolds, function(x) {which(as.integer(catchments$Gauge) %in% x)})

#### vector for predictions ####
SSY_fitted <- numeric(nrow(catchments))

for (fold in folds) {
  #### centering and scaling of training data ####
  catchments_train <- catchments[fold, ]
  cenSca <- preProcess(catchments_train[, -(1:5)], method = c("center", "scale"))
  catchments_train_cenSca <- cbind(
    catchments_train[, 1:5],
    predict(cenSca, catchments_train[, -(1:5)])
  )

  #### centering and scaling of test data ####
  catchments_test <- catchments[-fold, ]
  catchments_test_cenSca <- cbind(
    catchments_test[, 1:5],
    predict(cenSca, catchments_test[, -(1:5)])
  )

  #### parameters ####
  # please adjust according to available cores on your machine
  nSamples     <- 6000 # number of iterations per chain
  warmup       <- 3000 # number of warmup iterations
  adaptDelta   <- 0.99 # parameter to control sampler's behaviour
  maxTreedepth <- 15 # parameter to control sampler's behaviour
  cores        <- 4 # number of cores

  #### BaHSYM ####
  model <- brm(
    bf(
      log_SSY ~
        E + q + Q_p95 + Ret_Coeff - 1 # using fixed-effects only additionaly requires disabling the prior on "sd"
     + (E + q + Q_p95 + Ret_Coeff - 1 | Cluster) # Equation 3d
    ),
    data = catchments_train_cenSca,
    prior = c(
      set_prior("normal(0, 0.5)", class = "b"),
      set_prior("exponential(1)", class = "sd"),
      set_prior("exponential(1)", class = "sigma")
    ),
    iter = nSamples,
    warmup = warmup,
    cores = cores,
    control = list(
      adapt_delta = adaptDelta,
      max_treedepth = maxTreedepth
    ),
    seed = 1234567
  )

  #### predictions of test data ####
  log_SSY_pred_cenSca <- predict(
    model,
    newdata = catchments_test_cenSca,
    nsamples = nSamples,
    summary = FALSE
  )

  #### undo centering and scaling ####
  log_SSY_pred <- log_SSY_pred_cenSca * cenSca$std["log_SSY"] + cenSca$mean["log_SSY"]

  #### undo logarithm ###
  SSY_fitted[-fold] <- colMeans(exp(log_SSY_pred))
}

#### evaluation ####
cor(SSY_fitted, catchments$SSY)^2 # R2
hydroGOF::NSE(SSY_fitted, catchments$SSY)
hydroGOF::mNSE(SSY_fitted, catchments$SSY)
hydroGOF::rmse(SSY_fitted, catchments$SSY)
hydroGOF::pbias(SSY_fitted, catchments$SSY)
