if (!requireNamespace("caret", quietly = TRUE)) {
  install.packages("caret")
}
if (!requireNamespace("brms", quietly = TRUE)) {
  install.packages("brms")
}
if (!requireNamespace("data.table", quietly = TRUE)) {
  install.packages("data.table")
}
if (!requireNamespace("car", quietly = TRUE)) {
  install.packages("car")
}
if (!requireNamespace("hydroGOF", quietly = TRUE)) {
  install.packages("hydroGOF")
}

library(caret)
library(brms)

#### data ####
catchments <- data.table::fread(
  "catchments.csv",
  stringsAsFactors = TRUE,
  encoding = "UTF-8",
  data.table = FALSE
)

#### centering and scaling ####
cenSca <- preProcess(catchments[, -(1:5)], method = c("center", "scale"))
catchments_cenSca <- cbind(catchments[, 1:5], predict(cenSca, catchments[, -(1:5)]))

#### parameters ####
# please adjust according to available cores on your machine
nSamples     <- 6000 # number of iterations per chain
warmup       <- 3000 # number of warmup iterations
adaptDelta   <- 0.99 # parameter to control sampler's behaviour
maxTreedepth <- 15 # parameter to control sampler's behaviour
cores        <- 4 # number of cores

#### BaHSYM ####
best_fit_model <- brm(
  bf(
    log_SSY ~
      E + q + Q_p95 + Ret_Coeff - 1
   # + (E + q + Q_p95 + Ret_Coeff - 1 | Gauge  )                                           # Equation 3a
   + (E + q + Q_p95 + Ret_Coeff - 1 | Gauge  ) + (E + q + Q_p95 + Ret_Coeff - 1 | River) # Equation 3b
   # + (E + q + Q_p95 + Ret_Coeff - 1 | Gauge  ) + (E + q + Q_p95 + Ret_Coeff - 1 | Basin) # Equation 3c
   # + (E + q + Q_p95 + Ret_Coeff - 1 | Cluster)                                           # Equation 3d
  ),
  data = catchments_cenSca,
  iter = nSamples,
  warmup = warmup,
  cores = cores,
  control = list(
    adapt_delta = adaptDelta,
    max_treedepth = maxTreedepth
  ),
  seed = 1234567
)

summary(best_fit_model, priors = TRUE, prob = 0.99)

#### predictions ####
log_SSY_pred_cenSca <- predict(
  best_fit_model,
  nsamples = nSamples,
  summary = FALSE
)

#### undo centering and scaling ####
log_SSY_pred <- log_SSY_pred_cenSca * cenSca$std["log_SSY"] + cenSca$mean["log_SSY"]

#### residuals ####
log_SSY_fitted <- colMeans(log_SSY_pred)
log_SSY_residuals <- log_SSY_fitted - catchments$log_SSY
plot(log_SSY_fitted, log_SSY_residuals, as = 1)
car::qqPlot(log_SSY_residuals)

#### undo logarithm ###
SSY_fitted <- colMeans(exp(log_SSY_pred))

#### evaluation ####
cor(SSY_fitted, catchments$SSY)^2 # R2
hydroGOF::NSE(SSY_fitted, catchments$SSY)
hydroGOF::mNSE(SSY_fitted, catchments$SSY)
hydroGOF::rmse(SSY_fitted, catchments$SSY)
hydroGOF::pbias(SSY_fitted, catchments$SSY)
