library(dplyr)
library(ggplot2)
library(tidyr)
library(purrr)
library(stringr)

# =======================
# 0) Unified theme (outer frame + horizontal grid only)
# =======================
theme_paper <- function(base_size = 18, base_family = "sans") {
  theme_classic(base_size = base_size, base_family = base_family) +
    theme(
      panel.border = element_rect(color = "black", fill = NA, linewidth = 0.9),
      
      # horizontal grid only
      panel.grid.major.y = element_line(color = "grey85", linewidth = 0.6),
      panel.grid.minor.y = element_line(color = "grey92", linewidth = 0.4),
      panel.grid.major.x = element_blank(),
      panel.grid.minor.x = element_blank(),
      
      axis.line = element_line(color = "black", linewidth = 0.6),
      axis.ticks = element_line(color = "black", linewidth = 0.6),
      axis.ticks.length = unit(2.5, "pt"),
      
      axis.title  = element_text(size = base_size + 4),
      axis.text   = element_text(size = base_size + 2),
      legend.text = element_text(size = base_size + 2),
      
      legend.title = element_blank(),
      legend.key = element_blank(),
      legend.background = element_blank(),
      legend.box.background = element_blank()
    )
}

# =======================
# 1) Load data
# =======================
df <- read.csv("MoCA_SPACE_nana_final0905_final_removed_outliers.csv") %>%
  drop_na(SPACE_error, Moca)   # keep Week even if NA (supervised may not have it)

# =======================
# 2) Build plotting groups (robust Week parsing)
# =======================
df_plot <- df %>%
  transmute(
    SPACE_error,
    Moca,
    Supervision_chr = str_squish(as.character(Supervision)),
    Week_num = suppressWarnings(as.integer(str_extract(str_squish(as.character(Week)), "[123]"))),
    Group = case_when(
      Supervision_chr == "1" & Week_num %in% c(1,2,3) ~ paste0("Unsupervised - Week ", Week_num),
      Supervision_chr == "0"                          ~ "Supervised",
      TRUE                                            ~ NA_character_
    )
  ) %>%
  filter(!is.na(Group)) %>%
  mutate(
    # safety: normalize long dashes and spaces (prevents hidden-character traps)
    Group = str_replace_all(Group, "[\u2013\u2014]", "-"),
    Group = str_squish(Group)
  )

# enforce intended order (and drop anything else)
intended_levels <- c("Unsupervised - Week 1",
                     "Unsupervised - Week 2",
                     "Unsupervised - Week 3",
                     "Supervised")
df_plot$Group <- factor(df_plot$Group, levels = intended_levels)
df_plot <- df_plot %>% filter(!is.na(Group))

print(df_plot %>% count(Group, name = "n_rows"))

# =======================
# 3) Spearman correlations for legend labels
# =======================
safe_spearman <- function(x, y) {
  cc <- complete.cases(x, y)
  if (sum(cc) < 3) return(c(r = NA_real_, p = NA_real_))
  out <- tryCatch({
    test <- cor.test(x[cc], y[cc], method = "spearman")
    c(r = unname(test$estimate), p = test$p.value)
  }, error = function(e) c(r = NA_real_, p = NA_real_))
  out
}

corr_tbl <- df_plot %>%
  group_by(Group) %>%
  summarise(tmp = list(safe_spearman(Moca, SPACE_error)), .groups = "drop") %>%
  mutate(
    rho = map_dbl(tmp, 1),
    p   = map_dbl(tmp, 2)
  ) %>%
  select(-tmp) %>%
  mutate(
    # Group name -> "Unsupervised: Week 1" etc.
    group_pretty = case_when(
      str_detect(Group, "^Unsupervised - Week") ~ str_replace(Group, "^Unsupervised - Week ", "Unsupervised: Week "),
      Group == "Supervised" ~ "Supervised",
      TRUE ~ Group
    ),
    r_lab = sprintf("ρ = %.2f", rho),
    p_lab = case_when(
      is.na(p) ~ "p = NA",
      p < .001 ~ "p < .001",
      TRUE     ~ sprintf("p = %.3f", p)
    ),
    legend_lab = paste0(group_pretty, " (", r_lab, ", ", p_lab, ")")
  )

legend_labs <- corr_tbl$legend_lab[match(levels(df_plot$Group), corr_tbl$Group)]

# =======================
# 4) Colors (always match levels -> NO warnings)
# =======================
PAL_SANEKY <- c(
  "Pink"   = "#EA6073",
  "Navy"   = "#46506B",
  "Yellow" = "#E4B74F",
  "Teal"   = "#2EB4B8"   # <- NEW (Fig6/7 teal)
)

# fixed mapping: W1 navy, W2 teal (new), W3 pink, supervised yellow
base_cols <- c(PAL_SANEKY["Yellow"],
               PAL_SANEKY["Teal"],
               PAL_SANEKY["Pink"],
               PAL_SANEKY["Navy"])

col_vals  <- setNames(base_cols[seq_along(grp_levels)], grp_levels)
fill_vals <- col_vals

# =======================
# 5) Plot
# =======================
p <- ggplot(df_plot, aes(x = Moca, y = SPACE_error, color = Group, fill = Group)) +
  geom_point(alpha = 0.65, size = 2.0) +
  geom_smooth(method = "lm", se = TRUE, alpha = 0.18, linewidth = 1.1) +
  scale_color_manual(values = col_vals,
                     breaks = grp_levels,
                     labels = legend_labs,
                     drop   = FALSE) +
  scale_fill_manual(values  = fill_vals,
                    breaks = grp_levels,
                    labels = legend_labs,
                    drop   = FALSE) +
  labs(x = "MoCA", y = "SPACE error", color = NULL, fill = NULL) +
  guides(
    fill  = "none",
    color = guide_legend(ncol = 1, byrow = TRUE)
  ) +
  coord_cartesian(ylim = c(-3, 3)) +   # ✅ give bottom space inside the frame
  theme_paper(base_size = 20) +
  theme(
    # ✅ legend inside panel, in the bottom empty area
    legend.position = c(0.50, 0.08),
    legend.justification = c("center", "bottom"),
    
    # ✅ legend should not show grid behind it
    legend.background = element_rect(fill = "white", color = "white"),
    legend.box.background = element_rect(fill = "white", color = "white"),
    
    legend.margin = margin(6, 8, 6, 8),
    legend.key.height = unit(1.0, "lines")
  )

print(p)
ggsave("Fig5_MoCA_SPACEerror_unified.png", p, width = 8.5, height = 6.2, dpi = 300)

# =======================
# 6) Export (optional)
# =======================
ggsave("Fig5_MoCA_SPACEerror_unified.png", p, width = 8.5, height = 6.2, dpi = 300)