#!/usr/bin/env Rscript

# =============================================================================
# Parallel nMDS + bootstrap clustering across latitudinal bins
# =============================================================================
# - English code/comments (Spanish column names preserved)
# - Reads SERNAPESCA landing records, filters municipalities,
#   builds lat_bin × species community matrices (adjusted landings),
#   runs metaMDS (Bray-Curtis), clusters ordination space, and bootstraps support.
#
# Output:
# - mean silhouette vs k plot (optional)
# - average distance dendrogram
# - co-clustering support matrix (pairwise)
# - exported unique lat/lon/lat_bin table for mapping
#
# Notes:
# - This is computationally demanding (we used an HPC server of the DATACENTER-Secos). 
#. Start with n_bootstrap ~ 200–2000, scale up.
# - metaMDS is stochastic: we set seed per iteration for reproducibility.
# =============================================================================

suppressPackageStartupMessages({
  library(vegan)
  library(dplyr)
  library(tidyr)
  library(cluster)
  library(parallel)
  library(stringi)
  library(ggplot2)
  library(readr)
})

# -----------------------------
# Configuration
# -----------------------------
WORKDIR <- "~/R-SCRIPTS/FISHERIES_BIOGEOGRAPHY"
setwd(WORKDIR)
rm(list = ls())

IN_DATA   <- "SERNAPESCA_benthic_landings_Records_2015_2022_cleaned_18may24.csv"
IN_COMUNAS <- "comunas_seleccionadas.csv"

OUT_LATLON <- "latlon_latbin_unique.csv"
OUT_SUPPORT <- "latbin_coclustering_support.csv"
OUT_MEAN_DIST <- "latbin_mean_distance.csv"

# Bootstrap and clustering parameters
n_bootstrap <- 2000          # start small; scale up
k_mds       <- 5             # NMDS dimensions (your "optimized" choice)
k_values    <- 2:10          # silhouette evaluation across k
k_cutree    <- 5             # clustering k used to build co-clustering support
n_cores     <- max(1, detectCores() - 1)

# Exclusion filter (as in your original workflow)
exclude_lat_bin <- -24

# metaMDS controls (tune if needed)
trymax_val  <- 20
maxit_val   <- 200
autotransform_val <- FALSE   # your data are already processed (adjusted landings)

# -----------------------------
# Helpers
# -----------------------------
normalize_text <- function(x) {
  # Standardize municipality names: remove accents, keep ASCII, uppercase
  x <- as.character(x)
  x <- stringi::stri_trans_general(x, "Latin-ASCII")
  x <- toupper(x)
  x
}

# Build community matrix lat_bin × spp_scname from a sampled dataset
build_community_matrix <- function(df, spp_list) {
  grouped <- df %>%
    group_by(lat_bin, spp_scname) %>%
    summarise(Desembarques.Ajustados = sum(Desembarques.Ajustados, na.rm = TRUE), .groups = "drop")
  
  mat <- grouped %>%
    pivot_wider(
      names_from = spp_scname,
      values_from = Desembarques.Ajustados,
      values_fill = list(Desembarques.Ajustados = 0)
    )
  
  # Preserve lat_bin ordering for labels
  lat_bins <- mat$lat_bin
  mat <- mat %>% select(-lat_bin)
  
  # Ensure same species columns across bootstraps
  missing_species <- setdiff(spp_list, names(mat))
  if (length(missing_species) > 0) {
    for (s in missing_species) mat[[s]] <- 0
  }
  mat <- mat[, spp_list, drop = FALSE]
  
  list(mat = as.matrix(mat), lat_bins = lat_bins)
}

# One bootstrap iteration
bootstrap_iteration <- function(iter_id, datos, spp_list, k_values, seed_base = 1234) {
  set.seed(seed_base + iter_id)
  
  # Resample rows with replacement
  sample_data <- datos[sample.int(nrow(datos), replace = TRUE), , drop = FALSE]
  
  cm <- build_community_matrix(sample_data, spp_list)
  community_matrix <- cm$mat
  lat_bins <- cm$lat_bins
  
  # Need at least 2 lat_bins and 1 species column
  if (nrow(community_matrix) < 2 || ncol(community_matrix) < 1) {
    return(list(dist = NULL, silhouettes = rep(NA_real_, length(k_values)), cl = NULL, lat_bins = lat_bins))
  }
  
  # NMDS on lat_bin community matrix
  mds <- tryCatch(
    metaMDS(
      community_matrix,
      distance = "bray",
      k = k_mds,
      trymax = trymax_val,
      maxit = maxit_val,
      autotransform = autotransform_val,
      trace = FALSE
    ),
    error = function(e) NULL
  )
  
  if (is.null(mds) || is.null(mds$points)) {
    return(list(dist = NULL, silhouettes = rep(NA_real_, length(k_values)), cl = NULL, lat_bins = lat_bins))
  }
  
  # Distance among lat_bins in ordination space
  dist_matrix <- dist(mds$points)
  
  # Hierarchical clustering on ordination distances
  hc <- hclust(dist_matrix, method = "average")
  
  # Silhouette across k_values
  sil_scores <- sapply(k_values, function(k) {
    if (k >= nrow(mds$points)) return(NA_real_)
    cl_k <- cutree(hc, k = k)
    mean(silhouette(cl_k, dist_matrix)[, "sil_width"])
  })
  
  # Cluster labels for the chosen k (used in co-clustering support)
  cl_fixed <- if (k_cutree < nrow(mds$points)) cutree(hc, k = k_cutree) else NULL
  
  list(dist = dist_matrix, silhouettes = sil_scores, cl = cl_fixed, lat_bins = lat_bins)
}

# -----------------------------
# 1) Load data (robust numeric parsing)
# -----------------------------
# Your file uses ";" delimiter (based on your Python export); if it's actually CSV2 (;) keep this.
datos <- readr::read_delim(
  IN_DATA,
  delim = ";",
  locale = locale(encoding = "UTF-8", decimal_mark = ","),
  show_col_types = FALSE,
  progress = FALSE
)

# Standardize types and derived columns
datos <- datos %>%
  mutate(
    Lat = as.numeric(Lat),
    Lon = as.numeric(Lon),
    Desembarques.Ajustados = as.numeric(Desembarques.Ajustados),
    lat_bin = trunc(Lat),
    Comuna = normalize_text(Comuna)
  )

# Load selected municipalities
seleccionadas <- readr::read_delim(
  IN_COMUNAS,
  delim = ",",
  locale = locale(encoding = "UTF-8"),
  show_col_types = FALSE
) %>%
  mutate(Comuna = normalize_text(Comuna))

# Filter
datos <- datos %>%
  filter(Comuna %in% seleccionadas$Comuna) %>%
  filter(lat_bin != exclude_lat_bin) %>%
  filter(!is.na(Desembarques.Ajustados)) %>%
  filter(!is.na(lat_bin)) %>%
  filter(!is.na(spp_scname))

all_lat_bins <- sort(unique(datos$lat_bin))
spp_list <- sort(unique(datos$spp_scname))

cat("Rows:", nrow(datos), "\n")
cat("Lat bins:", length(all_lat_bins), "\n")
cat("Species:", length(spp_list), "\n")

# Export unique Lat/Lon/lat_bin for mapping (optional)
datos_unique <- datos %>%
  select(Lat, Lon, lat_bin) %>%
  distinct()
write_csv(datos_unique, OUT_LATLON)

# -----------------------------
# 2) Parallel bootstrap
# -----------------------------
cat("Using cores:", n_cores, "\n")
cl <- makeCluster(n_cores)

clusterExport(
  cl,
  varlist = c(
    "datos", "spp_list", "k_values",
    "bootstrap_iteration",
    "build_community_matrix",
    "k_mds", "trymax_val", "maxit_val",
    "autotransform_val", "k_cutree"
  ),
  envir = environment()
)

clusterEvalQ(cl, {
  library(vegan)
  library(dplyr)
  library(tidyr)
  library(cluster)
})

results <- parLapply(
  cl,
  X = seq_len(n_bootstrap),
  fun = function(i) bootstrap_iteration(i, datos = datos, spp_list = spp_list, k_values = k_values)
)

stopCluster(cl)

# -----------------------------
# 3) Summaries: silhouette and mean distance
# -----------------------------
sil_mat <- do.call(rbind, lapply(results, `[[`, "silhouettes"))
mean_sil <- colMeans(sil_mat, na.rm = TRUE)

sil_df <- data.frame(k = k_values, mean_silhouette = mean_sil)
p_sil <- ggplot(sil_df, aes(x = k, y = mean_silhouette)) +
  geom_line() +
  geom_point() +
  labs(x = "Number of clusters (k)", y = "Mean silhouette", title = "Mean silhouette vs. k (bootstrap)") +
  theme_minimal()
print(p_sil)

# Average distance matrix across bootstraps (aligned by lat_bin order)
# We rebuild a full matrix each time using lat_bin labels to align.
lat_levels <- all_lat_bins
n_lat <- length(lat_levels)
sum_mat <- matrix(0, n_lat, n_lat, dimnames = list(lat_levels, lat_levels))
count_mat <- matrix(0, n_lat, n_lat)

for (res in results) {
  d <- res$dist
  lb <- res$lat_bins
  if (is.null(d) || is.null(lb)) next
  
  dmat <- as.matrix(d)
  rownames(dmat) <- lb
  colnames(dmat) <- lb
  
  # Align into full matrix
  idx <- match(lb, lat_levels)
  sum_mat[idx, idx] <- sum_mat[idx, idx] + dmat
  count_mat[idx, idx] <- count_mat[idx, idx] + 1
}

mean_mat <- sum_mat / pmax(count_mat, 1)  # avoid division by zero
diag(mean_mat) <- 0

# Save mean distance matrix
mean_dist_df <- as.data.frame(as.table(mean_mat)) %>%
  rename(lat_bin_1 = Var1, lat_bin_2 = Var2, mean_distance = Freq)
write_csv(mean_dist_df, OUT_MEAN_DIST)

# Dendrogram from mean distances
mean_dist <- as.dist(mean_mat)
average_hc <- hclust(mean_dist, method = "average")

plot(average_hc, labels = lat_levels, main = "Hierarchical clustering (mean NMDS distances)")
rect.hclust(average_hc, k = k_cutree, border = "blue")

# -----------------------------
# 4) Co-clustering support (bootstrap)
# -----------------------------
# Support(i,j) = proportion of bootstraps where i and j are in the same cluster (k = k_cutree)
support <- matrix(0, n_lat, n_lat, dimnames = list(lat_levels, lat_levels))
valid_iters <- 0

for (res in results) {
  cl_k <- res$cl
  lb <- res$lat_bins
  if (is.null(cl_k) || is.null(lb)) next
  if (length(cl_k) != length(lb)) next
  
  valid_iters <- valid_iters + 1
  
  # cl_k is named by position; assign names = lat_bins
  names(cl_k) <- lb
  
  # Build co-membership matrix for this iteration
  # Align into full lat_levels
  for (i in seq_along(lb)) {
    for (j in seq_along(lb)) {
      if (cl_k[i] == cl_k[j]) {
        ii <- match(lb[i], lat_levels)
        jj <- match(lb[j], lat_levels)
        support[ii, jj] <- support[ii, jj] + 1
      }
    }
  }
}

support <- support / max(valid_iters, 1)
write_csv(as.data.frame(as.table(support)) %>% rename(lat_bin_1 = Var1, lat_bin_2 = Var2, support = Freq),
          OUT_SUPPORT)

cat("Valid bootstrap iterations for support:", valid_iters, " / ", n_bootstrap, "\n")
cat("Wrote:", OUT_SUPPORT, "\n")
cat("Done.\n")