if (!requireNamespace("caret", quietly = TRUE)) {
  install.packages("caret")
}
if (!requireNamespace("brms", quietly = TRUE)) {
  install.packages("brms")
}
if (!requireNamespace("data.table", quietly = TRUE)) {
  install.packages("data.table")
}
if (!requireNamespace("car", quietly = TRUE)) {
  install.packages("car")
}
if (!requireNamespace("hydroGOF", quietly = TRUE)) {
  install.packages("hydroGOF")
}

library(caret)
library(brms)

#### data ####
catchments <- data.table::fread(
  "catchments.csv",
  stringsAsFactors = TRUE,
  encoding = "UTF-8",
  data.table = FALSE
)

#### centering and scaling ####
cenSca <- preProcess(catchments[, -(1:5)], method = c("center", "scale"))
catchments_cenSca <- cbind(catchments[, 1:5], predict(cenSca, catchments[, -(1:5)]))

#### parameters ####
# please adjust according to available cores on your machine
nSamples     <- 6000 # number of iterations per chain
warmup       <- 3000 # number of warmup iterations
adaptDelta   <- 0.99 # parameter to control sampler's behaviour
maxTreedepth <- 15 # parameter to control sampler's behaviour
cores        <- 4 # number of cores

#### BaHSYM ####
best_fit_model <- brm(
  bf(
    log_SSY ~
      E + q + Q_p95 + Ret_Coeff - 1
   # + (E + q + Q_p95 + Ret_Coeff - 1 | Gauge  )                                           # Equation 3a
   + (E + q + Q_p95 + Ret_Coeff - 1 | Gauge  ) + (E + q + Q_p95 + Ret_Coeff - 1 | River)   # Equation 3b
   # + (E + q + Q_p95 + Ret_Coeff - 1 | Gauge  ) + (E + q + Q_p95 + Ret_Coeff - 1 | Basin) # Equation 3c
   # + (E + q + Q_p95 + Ret_Coeff - 1 | Cluster)                                           # Equation 3d
  ),
  data = catchments_cenSca,
  prior = c(
    set_prior("normal(0, 0.5)", class = "b"),
    set_prior("exponential(1)", class = "sd"),
    set_prior("exponential(1)", class = "sigma")
  ),
  iter = nSamples,
  warmup = warmup,
  cores = cores,
  control = list(
    adapt_delta = adaptDelta,
    max_treedepth = maxTreedepth
  ),
  seed = 1234567
)

summary(best_fit_model, priors = TRUE, prob = 0.99)

#### predictions ####
log_SSY_pred_cenSca <- predict(
  best_fit_model,
  nsamples = nSamples,
  summary = FALSE
)

#### undo centering and scaling ####
log_SSY_pred <- log_SSY_pred_cenSca * cenSca$std["log_SSY"] + cenSca$mean["log_SSY"]

#### residuals ####
log_SSY_fitted <- colMeans(log_SSY_pred)
log_SSY_residuals <- log_SSY_fitted - catchments$log_SSY
png("P://STOBIMO-Spurenstoffe//Publikationen//Article_Sediment_yield//Environ_Model_Softw/Revision/BaHSYM_Manuscript_Rev/Figures/residuals_plot.png", width = 16, height = 9, units = "cm", res = 600)
plot(log_SSY_fitted, log_SSY_residuals, as = 1, ylim = c(-1.5, 1.5), type = "n", xlab = "Fitted logSSY", ylab = "Residuals")
abline(0, 0, lty = 2, col = "cornflowerblue")
points(log_SSY_fitted, log_SSY_residuals)
dev.off()
car::qqPlot(log_SSY_residuals)

#### undo logarithm ###
SSY_fitted <- colMeans(exp(log_SSY_pred))

#### evaluation ####
cor(SSY_fitted, catchments$SSY)^2 # R2
hydroGOF::NSE(SSY_fitted, catchments$SSY)
hydroGOF::mNSE(SSY_fitted, catchments$SSY)
hydroGOF::rmse(SSY_fitted, catchments$SSY)
hydroGOF::pbias(SSY_fitted, catchments$SSY)


# Plot figure for paper ###
Beob <- catchments[, c(3, 7, 9), drop = FALSE]
plot_file <- cbind(Beob, SSY_fitted)
plot_file$Gauge <- as.character(plot_file$Gauge)
plot_file$Gauge[!plot_file$Gauge %in% c("Kössen-Hütte", "Neumarkt")] <- "all other Gauges"
plot_file$Gauge <- factor(
  plot_file$Gauge,
  c("Kössen-Hütte", "Neumarkt", "all other Gauges"),
  c("Kössen-Hütte", "Neumarkt", "all other Gauges"),
)
plot_file <- plot_file[order(plot_file$Gauge, decreasing = TRUE), ]

# plot_model <- ggplot(plot_file, aes(x = SSY, y = SSY_fitted)) +
#   geom_point(size = 2.5) +
#   scale_y_continuous(limits=c(0, 1600)) +
#   scale_x_continuous(limits=c(0, 1600)) +
#   labs(x = expression(paste("Observed annual SSY [" ~ t ~ km^{-2} ~ y^{-1}~"]"))) +
#   theme(axis.title.x = element_text(face="bold", size=25), axis.text.x  = element_text(size=22)) +
#   labs(y = expression(paste("Modelled annual SSY [" ~ t ~ km^{-2} ~ y^{-1}~"]"))) +
#   theme(axis.title.y = element_text(face="bold", size=25), axis.text.y  = element_text(size=22)) +
#   geom_segment(aes(x = 0, xend = 1500, y = 0, yend = 1500), linetype="dashed", size = 1, color = "blue") +
#   coord_fixed(ratio = 1)


#plot_model
#ggsave("P://STOBIMO-Spurenstoffe//Publikationen//Article_Sediment_yield//Environ_Model_Softw/Revision/BaHSYM_Manuscript_Rev/Figures/Plot_Model.pdf", plot = plot_model, width = 16, height = 9, scale = 2, units = "cm", dpi = 600)

# Plot figure in log-log scale ###
plot_model <- ggplot(plot_file, aes(x = SSY, y = SSY_fitted)) +
  geom_abline(intercept = 0, slope = 1, linetype = "dashed", size = 0.3, color = "cornflowerblue") +
  geom_point(aes(x = SSY, y = SSY_fitted, color = Gauge), size = 1) +
  labs(x = expression(paste(Observed~annual~SSY~"[t"~km^{-2}~y^{-1}, "]"))) +
  labs(y = expression(paste(Modelled~annual~SSY~"[t"~km^{-2}~y^{-1}, "]"))) +
  scale_x_continuous(trans = "log10") +
  scale_y_continuous(trans = "log10") +
  scale_color_manual(values = c("#ffaa00", "#c300ff", "grey30")) +
  coord_fixed(ratio = 1) +
  theme_minimal() +
  theme(
    axis.title = element_text(face = "bold", size = 10),
    axis.text = element_text(size = 9)
  ) +
  theme(
    legend.title = element_text(size = 9),
    legend.text = element_text(size = 9)
  )

#ggsave("P://STOBIMO-Spurenstoffe//Publikationen//Article_Sediment_yield//Environ_Model_Softw/Revision/BaHSYM_Manuscript_Rev/Figures/Plot_Model_loglog.pdf", plot = plot_model, width = 16, height = 9, scale = 2, units = "cm", dpi = 600)

# Bubble plot showing catchment size

# plot_model <- ggplot(plot_file, aes(x = SSY, y = SSY_fitted, size = A)) +
#   geom_point() +
#   labs(x = expression(paste("Observed annual SSY [" ~ t ~ km^{-2} ~ y^{-1}~"]"))) +
#   theme(axis.title.x = element_text(face="bold", size=25), axis.text.x  = element_text(size=22)) +
#   labs(y = expression(paste("Modelled annual SSY [" ~ t ~ km^{-2} ~ y^{-1}~"]"))) +
#   theme(axis.title.y = element_text(face="bold", size=25), axis.text.y  = element_text(size=22)) +
#   geom_segment(aes(x = 0, xend = 1500, y = 0, yend = 1500), linetype="dashed", size = 1, color = "blue") +
#   scale_x_continuous(trans='log10') +
#   scale_y_continuous(trans='log10') +
#   coord_fixed(ratio = 1)

#plot_model

ggsave("P://STOBIMO-Spurenstoffe//Publikationen//Article_Sediment_yield//Environ_Model_Softw/Revision/BaHSYM_Manuscript_Rev/Figures/Plot_Model_loglog_col_leg.pdf", plot = plot_model, width = 16, height = 9, units = "cm", dpi = 600)

# catchments$A_class <- cut(catchments$A, c(-Inf, 2500, 5000, 7500, Inf))
# for (a in levels(catchments$A_class)) {
#   print(c(
#     hydroGOF::NSE(SSY_fitted[catchments$A_class == a], catchments$SSY[catchments$A_class == a]),
#     hydroGOF::mNSE(SSY_fitted[catchments$A_class == a], catchments$SSY[catchments$A_class == a])
#   ))
# }
