
#### 0. Directory and environment

cd /mnt/sdb/Colaboraciones/Roberto_Elizondo_2025/
conda activate RUVSeq_env

# =============================================================================
# UPDATED PIPELINE CODE (R + Python) - Heatmaps formatting + EnhancedVolcano
#
# This code updates your existing workflow to:
#   1) Heatmaps: add "blank separations" (group gaps), shrink dendrogram + key
#   2) Volcano: reproduce volcano plot using EnhancedVolcano (with/without sample2)
#   3) GO/Reactome export: robust CSV writing (fix list-column write.csv error)
#   4) Python TF/Immune heatmaps: add blank separations + shrink dendrogram/key
#   5) Ranked immune pathways table (GO:BP + REAC) from g:Profiler outputs
#
# Directory assumption:
#   cd /mnt/sdb/Colaboraciones/Roberto_Elizondo_2025/
#
# =============================================================================


###############################################################################
# A) CONDA ENVIRONMENT (UPDATED)
###############################################################################
# File: RUVSeq_env.yml
# -----------------------------------------------------------------------------
name: RUVSeq_env
channels:
  - conda-forge
  - bioconda
dependencies:
  - r-base=4.3.*
  - r-biocmanager
  - r-readr
  - r-ggplot2
  - r-dplyr
  - r-car
  - r-mass
  - r-rcolorbrewer
  - r-gplots
  - r-viridis
  - r-tidyverse
  - r-factominer
  - r-factoextra
  - r-reshape2
  - r-gprofiler2
  - r-ggrepel
  - bioconductor-ruvseq
  - bioconductor-edger
  - bioconductor-multimir
  - bioconductor-enhancedvolcano

# Commands:
#   conda env create -f RUVSeq_env.yml
#   conda activate RUVSeq_env


nano run_mirna_pipeline.R
'''
#!/usr/bin/env Rscript

# =============================================================================
# run_mirna_pipeline.R
#
# End-to-end miRNA pipeline:
#   - Load miRNA count matrix
#   - Optional sample dropping (e.g., Ctrl_2,Olig_2)
#   - RUVSeq (RUVr) normalization
#   - edgeR differential expression (non-paired + paired if possible)
#   - DEG heatmaps (blank group separations + shrunken dendrogram + shrunken key)
#   - Volcano plot via EnhancedVolcano (Up/Down/No diff legend) with custom labels
#   - multiMiR target retrieval
#   - g:Profiler enrichment per miRNA (GO:BP, GO:MF, TF, REAC) with robust CSV export
#
# Usage examples:
#   # All samples
#   Rscript run_mirna_pipeline.R --input miRNA-counts.csv --outdir results/all_samples
#
#   # Without sample2 (drops Ctrl_2 and Olig_2)
#   Rscript run_mirna_pipeline.R --input miRNA-counts.csv --outdir results/without_sample2 --drop_samples Ctrl_2,Olig_2
#
# Notes:
#   - Requires packages: RUVSeq, edgeR, EnhancedVolcano, multiMiR, gprofiler2, gplots, viridis
#   - Recommended: run inside your conda env (RUVSeq_env)
# =============================================================================

suppressPackageStartupMessages({
  library(RUVSeq)
  library(edgeR)
  library(gplots)
  library(viridis)
  library(dplyr)
  library(ggplot2)
  library(EnhancedVolcano)
  library(ggrepel)
  library(multiMiR)
  library(gprofiler2)
})

# -----------------------------
# Minimal CLI parser (no optparse dependency)
# -----------------------------
get_arg <- function(flag, default = NULL) {
  args <- commandArgs(trailingOnly = TRUE)
  idx <- match(flag, args)
  if (!is.na(idx) && idx < length(args)) {
    return(args[idx + 1])
  }
  return(default)
}

input_file <- get_arg("--input", default = NA_character_)
outdir     <- get_arg("--outdir", default = "results")
drop_str   <- get_arg("--drop_samples", default = "")
fdr_heatmap <- as.numeric(get_arg("--fdr_heatmap", default = "0.05"))
pCutoff     <- as.numeric(get_arg("--pCutoff", default = "0.05"))
FCcutoff    <- as.numeric(get_arg("--FCcutoff", default = "1.0"))
k_ruv       <- as.integer(get_arg("--k_ruv", default = "1"))
label_str   <- get_arg("--label_mirnas", default =
                         "miR-3681-5p,miR-365a-5p,miR-6081,miR-193a-3p,miR-4684-5p,miR-622,miR-335-5p,miR-6883-3p,miR-3064-5p,miR-4478")
# Volcano aesthetics
volcano_width  <- as.numeric(get_arg("--volcano_width", default = "9.2"))
volcano_height <- as.numeric(get_arg("--volcano_height", default = "8.6"))
label_top_n_each_side <- as.integer(get_arg("--label_top_n", default = "3"))
max_labels_total <- as.integer(get_arg("--max_labels_total", default = "14"))
# Heatmap aesthetics (defaults tuned to match the reference viridis heatmap)
# - viridis palette by default
# - modest text sizes so labels don't get clipped
heatmap_cex_row  <- as.numeric(get_arg("--heatmap_cex_row", default = "0.60"))
heatmap_cex_col  <- as.numeric(get_arg("--heatmap_cex_col", default = "1.40"))
# Backward-compatible CLI arg (separators are no longer drawn by default)
heatmap_sepwidth <- as.numeric(get_arg("--heatmap_sepwidth", default = "0"))
heatmap_color    <- get_arg("--color", default = "viridis")


# Optional paired-order heatmap (OFF by default; enable with --paired_heatmap true)
use_paired_heatmap <- tolower(get_arg("--paired_heatmap", default = "false")) %in% c("true","t","1","yes","y")

if (is.na(input_file) || input_file == "") {
  stop("ERROR: --input is required (path to miRNA-counts.csv).")
}

# -----------------------------
# Output folder structure
# -----------------------------
dir_qc       <- file.path(outdir, "00_qc")
dir_ruv      <- file.path(outdir, "01_ruvseq")
dir_edger    <- file.path(outdir, "02_edger")
dir_heatmap  <- file.path(outdir, "03_heatmaps")
dir_volcano  <- file.path(outdir, "04_volcano")
dir_targets  <- file.path(outdir, "05_targets")
dir_enrich   <- file.path(outdir, "06_enrichment")
dir_go_terms <- file.path(dir_enrich, "GO_Terms")
dir_seed     <- file.path(outdir, "07_seed_conservation")

for (d in c(dir_qc, dir_ruv, dir_edger, dir_heatmap, dir_volcano, dir_targets, dir_enrich, dir_go_terms, dir_seed)) {
  dir.create(d, showWarnings = FALSE, recursive = TRUE)
}

# -----------------------------
# Helpers
# -----------------------------
clean_mirna_id <- function(x) {
  x <- as.character(x)
  x <- gsub("\\s+", "", x)
  x <- gsub("^hsa-", "", x, ignore.case = TRUE)
  x <- gsub("^mmu-", "", x, ignore.case = TRUE)
  x <- gsub("^rno-", "", x, ignore.case = TRUE)
  x
}

extract_pair_id <- function(sample_names) {
  # Returns numeric vector if names end with _NUMBER, else NA
  nums <- suppressWarnings(as.numeric(sub(".*_(\\d+)$", "\\1", sample_names)))
  if (all(is.na(nums))) return(rep(NA_real_, length(sample_names)))
  nums
}

order_by_suffix_number <- function(x) {
  nums <- suppressWarnings(as.numeric(sub(".*_(\\d+)$", "\\1", x)))
  ord <- order(is.na(nums), nums, x)
  x[ord]
}

make_group_factor <- function(sample_names) {
  grp <- ifelse(grepl("Ctrl", sample_names, ignore.case = TRUE), "Ctrl",
                ifelse(grepl("Olig", sample_names, ignore.case = TRUE), "Olig", NA))
  if (any(is.na(grp))) {
    stop("Could not infer group for some samples. Ensure column names contain 'Ctrl' or 'Olig'.")
  }
  factor(grp, levels = c("Ctrl","Olig"))
}

plot_qq <- function(df_counts, out_pdf) {
  pdf(out_pdf, width = 25, height = 10)
  n <- ncol(df_counts)
  par(mfrow = c(2, max(2, ceiling(n/2))))
  for (i in 1:ncol(df_counts)) {
    qqnorm(df_counts[[i]], main = paste("Q-Q plot for", names(df_counts)[i]))
    qqline(df_counts[[i]], col = "steelblue")
  }
  dev.off()
}

plot_rle_pca <- function(seqset, group_factor, out_pdf, title_prefix="") {
  colors <- RColorBrewer::brewer.pal(6, "Set2")
  pdf(out_pdf, width = 6, height = 6)
  plotRLE(seqset, outline = FALSE, ylim = c(-4, 4), col = colors[group_factor],
          main = paste0(title_prefix, "RLE"))
  plotPCA(seqset, col = colors[group_factor], cex = 1.2,
          main = paste0(title_prefix, "PCA"))
  dev.off()
}

get_heatmap_colors <- function(palette_name, n = 256) {
  pal_raw <- ifelse(is.null(palette_name) || is.na(palette_name) || palette_name == "", "greenred", palette_name)
  pal <- tolower(pal_raw)

  # ------------------------------------------------------------
  # 0) User-defined ramps: comma-separated colors
  #    Example: --color "navy,white,firebrick3"
  #             --color "#440154,#21908C,#FDE725"
  # ------------------------------------------------------------
  if (grepl(",", pal_raw)) {
    cols <- trimws(unlist(strsplit(pal_raw, ",")))
    cols <- cols[cols != ""]
    if (length(cols) >= 2) {
      return(colorRampPalette(cols)(n))
    }
  }

  # Allow suffixes to reverse palettes: _r, -r, _rev, -rev
  reverse_flag <- FALSE
  if (grepl("(_r|-r|_rev|-rev)$", pal)) {
    reverse_flag <- TRUE
    pal <- gsub("(_r|-r|_rev|-rev)$", "", pal)
  }

  # ------------------------------------------------------------
  # 1) Common diverging ramps (high-contrast, publication-friendly)
  # ------------------------------------------------------------
  if (pal %in% c("greenred", "gbr", "green-black-red", "green_black_red")) {
    cols <- colorRampPalette(c("green", "black", "red"))(n)
    return(if (reverse_flag) rev(cols) else cols)
  }
  if (pal %in% c("redgreen", "rbg", "red-black-green", "red_black_green")) {
    cols <- colorRampPalette(c("red", "black", "green"))(n)
    return(if (reverse_flag) rev(cols) else cols)
  }
  if (pal %in% c("bluewhitered", "bwr", "blue-white-red")) {
    cols <- colorRampPalette(c("blue", "white", "red"))(n)
    return(if (reverse_flag) rev(cols) else cols)
  }
  if (pal %in% c("redwhiteblue", "rwb", "red-white-blue")) {
    cols <- colorRampPalette(c("red", "white", "blue"))(n)
    return(if (reverse_flag) rev(cols) else cols)
  }

  # ------------------------------------------------------------
  # 2) Base R palettes
  # ------------------------------------------------------------
  if (pal %in% c("heat", "heatcolors", "heat.colors")) {
    cols <- heat.colors(n)
    return(if (reverse_flag) rev(cols) else cols)
  }
  if (pal %in% c("terrain", "terraincolors", "terrain.colors")) {
    cols <- terrain.colors(n)
    return(if (reverse_flag) rev(cols) else cols)
  }
  if (pal %in% c("topo", "topocolors", "topo.colors")) {
    cols <- topo.colors(n)
    return(if (reverse_flag) rev(cols) else cols)
  }
  if (pal %in% c("cm", "cmcolors", "cm.colors")) {
    cols <- cm.colors(n)
    return(if (reverse_flag) rev(cols) else cols)
  }

  # ------------------------------------------------------------
  # 3) Viridis palettes (+ turbo via viridisLite)
  #    Options: viridis, magma, inferno, plasma, cividis, turbo
  # ------------------------------------------------------------
  if (pal %in% c("viridis", "magma", "inferno", "plasma", "cividis")) {
    cols <- get(pal, asNamespace("viridis"))(n)
    return(if (reverse_flag) rev(cols) else cols)
  }
  if (pal %in% c("turbo")) {
    # turbo is in viridisLite
    cols <- viridisLite::turbo(n)
    return(if (reverse_flag) rev(cols) else cols)
  }

  # ------------------------------------------------------------
  # 4) RColorBrewer palettes (any name in brewer.pal.info)
  #    Example: --color RdBu, Spectral, BrBG, PiYG, YlGnBu, ...
  #    Reverse with --color RdBu_r
  # ------------------------------------------------------------
  pal2 <- pal
  if (startsWith(pal2, "brewer:")) pal2 <- sub("^brewer:", "", pal2)

  suppressWarnings({
    brewer_info <- try(RColorBrewer::brewer.pal.info, silent = TRUE)
  })
  if (!inherits(brewer_info, "try-error")) {
    brewer_names <- rownames(brewer_info)
    brewer_match <- brewer_names[tolower(brewer_names) == pal2]
    if (length(brewer_match) == 1) {
      maxc <- brewer_info[brewer_match, "maxcolors"]
      base_cols <- RColorBrewer::brewer.pal(maxc, brewer_match)
      cols <- colorRampPalette(base_cols)(n)
      return(if (reverse_flag) rev(cols) else cols)
    }
  }

  warning("Unknown --color palette: ", palette_name,
          ". Using green-black-red (greenred).",
          "\nTip: you can pass comma-separated colors, e.g. --color 'navy,white,firebrick3'.")
  cols <- colorRampPalette(c("green", "black", "red"))(n)
  return(if (reverse_flag) rev(cols) else cols)
}

plot_deg_heatmap <- function(mat_log2, out_pdf, main_title, colsep_positions,
                             cex_row = 1.10, cex_col = 1.60,
                             sepwidth_val = 0.05,
                             color_mode = "greenred") {
  stopifnot(is.matrix(mat_log2))
  if (nrow(mat_log2) < 2 || ncol(mat_log2) < 2) {
    stop("Heatmap matrix must be at least 2x2. Current dim: ",
         paste(dim(mat_log2), collapse="x"))
  }

  # Layout to shrink dendrogram/key and keep key on top-right
  lmat <- rbind(c(0, 3, 4),
                c(2, 1, 0))

  # Make dendrogram & key thinner to reduce blank space
  lwid <- c(0.70, 6.6, 0.65)  # row dendro | heatmap | key
  lhei <- c(0.50, 6.6)        # top row (col dendro+key) | heatmap

  hm_cols <- get_heatmap_colors(color_mode, n = 256)

  pdf(out_pdf, width = 17, height = 16)
  heatmap.2(
    mat_log2,
    main = main_title,
    density.info = "none",
    trace = "none",
    margins = c(8, 10),
    cexRow = cex_row,
    cexCol = cex_col,
    col = hm_cols,

    # --- Minimal blank separations ---
    colsep = colsep_positions,
    sepcolor = "white",
    sepwidth = c(sepwidth_val, sepwidth_val),

    # --- Shrink key ---
    key = TRUE,
    keysize = 0.60,
    key.title = "Color key",
    key.xlab = "Value",
    key.par = list(mar = c(1.5, 1.5, 0.8, 0.8)),

    # --- Shrink dendrogram/key via layout ---
    lmat = lmat, lwid = lwid, lhei = lhei
  )
  dev.off()
}

plot_volcano_enhancedvolcano <- function(res_table,
                                         out_pdf,
                                         plot_title,
                                         pCutoff,
                                         FCcutoff,
                                         always_label,
                                         label_top_n_each_side = 3,
                                         max_labels_total = 14,
                                         pointSize = 2.1,
                                         labSize = 8.0,
                                         xlim_left = -5,
                                         xlim_right = NA_real_,
                                         width = 9.2,
                                         height = 8.6) {

  stopifnot(is.data.frame(res_table))
  stopifnot(all(c("logFC","FDR") %in% colnames(res_table)))

  res <- res_table
  res$miRNA <- rownames(res)

  # Standardize miRNA labels for plotting and matching
  res$miRNA <- clean_mirna_id(res$miRNA)
  res$FDR <- as.numeric(res$FDR)

  # Collapse duplicates (prevents repeated labels)
  res <- res %>%
    arrange(FDR, desc(abs(logFC))) %>%
    distinct(miRNA, .keep_all = TRUE)

  res <- res %>%
    mutate(
      FDR_safe = pmax(FDR, .Machine$double.xmin),
      negLog10FDR = -log10(FDR_safe),
      Regulation = case_when(
        FDR < pCutoff & logFC >  FCcutoff ~ "Up",
        FDR < pCutoff & logFC < -FCcutoff ~ "Down",
        TRUE                              ~ "No diff."
      )
    )
  res$Regulation <- factor(res$Regulation, levels = c("Up","No diff.","Down"))

  # sanitize always_label list
  always_label <- clean_mirna_id(always_label)
  always_label <- always_label[!is.na(always_label) & always_label != ""]
  always_label <- unique(always_label)

  top_up <- res %>%
    filter(Regulation == "Up") %>%
    arrange(FDR, desc(logFC)) %>%
    head(label_top_n_each_side) %>%
    pull(miRNA)

  top_down <- res %>%
    filter(Regulation == "Down") %>%
    arrange(FDR, logFC) %>%
    head(label_top_n_each_side) %>%
    pull(miRNA)

  always_present <- res %>%
    filter(miRNA %in% always_label, Regulation != "No diff.") %>%
    pull(miRNA)

  # Build label set and cap total labels to avoid clutter
  select_labels <- unique(c(always_present, top_up, top_down))
  if (length(select_labels) > max_labels_total) {
    remaining <- setdiff(select_labels, always_present)
    rem_ranked <- res %>%
      filter(miRNA %in% remaining) %>%
      arrange(FDR, desc(abs(logFC))) %>%
      pull(miRNA)
    select_labels <- unique(c(always_present, head(rem_ranked, max(0, max_labels_total - length(always_present)))))
  }

  label_df <- res %>% filter(miRNA %in% select_labels)

  # X/Y limits for nicer aesthetics
  max_x <- max(res$logFC, na.rm = TRUE)
  if (is.na(xlim_right)) {
    # extra room for right-side labels
    xlim_right <- min(12, max(6, ceiling(max_x + 2.0)))
  }
  max_y <- max(res$negLog10FDR, na.rm = TRUE)
  ylim <- c(0, max(3, max_y + 0.6))

  # Base EnhancedVolcano WITHOUT its own labels (we add ggrepel labels ourselves)
  p <- EnhancedVolcano(
    res,
    lab = rep("", nrow(res)),   # <- suppress EV labels
    x = "logFC",
    y = "FDR",
    title = plot_title,
    subtitle = "",
    xlab = "Log2(fold change)",
    ylab = "-Log10(FDR)",
    pCutoff = pCutoff,
    FCcutoff = FCcutoff,
    cutoffLineType = "dashed",
    cutoffLineCol = "black",
    xlim = c(xlim_left, xlim_right),
    ylim = ylim,
    pointSize = pointSize,
    labSize = labSize,
    col = c("grey70","grey70","grey70","grey70"),
    legendPosition = "none"
  )

  # Overlay colored points + legend Up/No diff/Down
  p <- p +
    geom_point(
      data = res,
      aes(x = logFC, y = negLog10FDR, color = Regulation),
      inherit.aes = FALSE,
      size = pointSize,
      alpha = 0.95
    ) +
    scale_color_manual(
      name = "Regulation",
      values = c("Up" = "red2", "No diff." = "grey70", "Down" = "royalblue")
    ) +
    guides(color = guide_legend(override.aes = list(size = 4))) +
    theme_classic(base_size = 18) +
    theme(
      legend.position = "right",
      legend.title = element_text(size = 16),
      legend.text  = element_text(size = 15),
      plot.title   = element_text(hjust = 0.5, size = 20),
      axis.title   = element_text(size = 18),
      axis.text    = element_text(size = 16),
      plot.margin  = margin(10, 45, 10, 10)  # extra right margin for labels
    ) +
    coord_cartesian(clip = "off")

  # Add repelled labels that do NOT cover red dots:
  # - move labels horizontally away from the volcano (left labels to the left, right labels to the right)
  # - constrain movement mostly along y to avoid drifting back onto the points
  if (nrow(label_df) > 0) {
    label_pos <- label_df %>% filter(logFC >= 0)
    label_neg <- label_df %>% filter(logFC <  0)

    if (nrow(label_pos) > 0) {
      p <- p + ggrepel::geom_label_repel(
        data = label_pos,
        aes(x = logFC, y = negLog10FDR, label = miRNA),
        inherit.aes = FALSE,
        nudge_x = 0.9,
        direction = "y",
        box.padding = 0.45,
        point.padding = 0.40,
        segment.size = 0.3,
        segment.alpha = 0.8,
        min.segment.length = 0,
        seed = 123,
        size = labSize/3,         # labSize was in points; convert to ggplot "size" scale
        label.size = 0.25,
        color = "black",
        fill = "white"
      )
    }

    if (nrow(label_neg) > 0) {
      p <- p + ggrepel::geom_label_repel(
        data = label_neg,
        aes(x = logFC, y = negLog10FDR, label = miRNA),
        inherit.aes = FALSE,
        nudge_x = -0.9,
        direction = "y",
        box.padding = 0.45,
        point.padding = 0.40,
        segment.size = 0.3,
        segment.alpha = 0.8,
        min.segment.length = 0,
        seed = 123,
        size = labSize/3,
        label.size = 0.25,
        color = "black",
        fill = "white"
      )
    }
  }

  ggsave(out_pdf, plot = p, width = width, height = height, device = "pdf")
}

# -----------------------------
# Seed (m8) conservation: human DE miRNAs vs mouse miRNAs
#   - Uses miRBaseConverter built-in miRBase sequences (if installed)
#   - Plots ONLY the positive/conserved ones (n_mouse_seed_matches > 0)
#     to keep the figure compact and readable.
# -----------------------------
seed_m8_from_seq <- function(seq) {
  if (is.null(seq) || is.na(seq)) return(NA_character_)
  seq <- toupper(as.character(seq))
  seq <- gsub("T", "U", seq)  # just in case
  if (nchar(seq) < 8) return(NA_character_)
  substr(seq, 2, 8)  # nt2-nt8 (7mer-m8)
}

plot_seed_m8_conservation <- function(res_table,
                                     out_pdf,
                                     selection = c("up", "down"),
                                     fdr_cut = 0.05,
                                     fc_cut = 1.0,
                                     targetVersion = "v22",
                                     positive_only = TRUE,
                                     width = 7.0,
                                     height = NA_real_) {

  selection <- match.arg(selection)

  if (!requireNamespace("miRBaseConverter", quietly = TRUE)) {
    message("Seed conservation plot skipped: miRBaseConverter not installed.")
    return(invisible(NULL))
  }

  stopifnot(is.data.frame(res_table))
  stopifnot(all(c("logFC", "FDR") %in% colnames(res_table)))

  df <- res_table
  df$miRNA <- clean_mirna_id(rownames(df))
  df$FDR <- as.numeric(df$FDR)
  df$logFC <- as.numeric(df$logFC)

  # Select DE miRNAs
  if (selection == "up") {
    sel <- df %>%
      filter(!is.na(FDR)) %>%
      filter(FDR < fdr_cut, logFC > fc_cut)
  } else {
    sel <- df %>%
      filter(!is.na(FDR)) %>%
      filter(FDR < fdr_cut, logFC < -fc_cut)
  }

  if (nrow(sel) == 0) {
    message("Seed conservation plot: no DE miRNAs for selection=", selection,
            " (FDR<", fdr_cut, ", |logFC|>", fc_cut, ")")
    return(invisible(NULL))
  }

  # Retrieve mature miRNAs for human and mouse (built-in miRBase)
  hsa_all <- miRBaseConverter::getAllMiRNAs(version = targetVersion, type = "mature", species = "hsa")
  mmu_all <- miRBaseConverter::getAllMiRNAs(version = targetVersion, type = "mature", species = "mmu")

  # Compute seeds (m8) for mouse, then count by seed
  mmu_all$seed_m8 <- vapply(mmu_all$Sequence, seed_m8_from_seq, FUN.VALUE = character(1))
  mmu_seed_counts <- mmu_all %>%
    filter(!is.na(seed_m8) & seed_m8 != "") %>%
    count(seed_m8, name = "n_mouse")

  # Map selected human miRNAs to their mature sequences to get the seed
  # NOTE: our pipeline stores miRNAs WITHOUT the 'hsa-' prefix.
  sel <- sel %>%
    mutate(hsa_full = paste0("hsa-", miRNA),
           hsa_full_lower = tolower(hsa_full))

  hsa_map <- hsa_all %>%
    mutate(Name_lower = tolower(Name),
           seed_m8 = vapply(Sequence, seed_m8_from_seq, FUN.VALUE = character(1))) %>%
    select(Name_lower, seed_m8)

  sel <- sel %>%
    left_join(hsa_map, by = c("hsa_full_lower" = "Name_lower")) %>%
    left_join(mmu_seed_counts, by = "seed_m8") %>%
    mutate(
      n_mouse = ifelse(is.na(n_mouse), 0L, as.integer(n_mouse)),
      seed_conserved = n_mouse > 0
    )

  # Keep only the positive/conserved ones if requested
  if (positive_only) {
    sel <- sel %>% filter(n_mouse > 0)
  }

  if (nrow(sel) == 0) {
    message("Seed conservation plot: no conserved seeds for selection=", selection)
    return(invisible(NULL))
  }

  # Order by significance (FDR) like the rest of the pipeline
  sel <- sel %>% arrange(FDR, desc(abs(logFC)))
  sel$miRNA <- factor(sel$miRNA, levels = rev(sel$miRNA))

  # Dynamic figure height so labels don't collide (smaller when fewer miRNAs)
  if (is.na(height)) {
    height <- max(2.8, min(9.5, 1.8 + 0.15 * nrow(sel)))
  }

  # Compact bar plot (only positives)
  p <- ggplot(sel, aes(x = miRNA, y = n_mouse)) +
    geom_col(fill = "forestgreen", color = "black", width = 0.85) +
    coord_flip() +
    scale_y_continuous(breaks = seq(0, max(sel$n_mouse), by = 1), limits = c(0, max(sel$n_mouse))) +
    labs(
      title = "Seed (m8) conservation: human DE miRNAs vs mouse miRNAs",
      subtitle = paste0("Selection: ", selection, " (FDR<", fdr_cut, ", |logFC|>", fc_cut, ")"),
      x = NULL,
      y = "# mouse miRNAs with identical seed"
    ) +
    theme_bw(base_size = 12) +
    theme(
      plot.title = element_text(face = "bold", size = 14),
      axis.text.y = element_text(size = 7),
      legend.position = "none"
    )

  # Save PDF + PNG + CSV table
  ggsave(out_pdf, plot = p, width = width, height = height, device = "pdf")
  ggsave(sub("\\.pdf$", ".png", out_pdf), plot = p, width = width, height = height, dpi = 300, bg = "white")
  write.csv(sel, file = sub("\\.pdf$", ".csv", out_pdf), row.names = FALSE)

  invisible(sel)
}

# -----------------------------
# Load data
# -----------------------------
counts_raw <- read.csv(input_file, row.names = 1, check.names = FALSE)

# Keep only Ctrl/Olig columns
counts <- counts_raw[, grepl("Ctrl|Olig", colnames(counts_raw), ignore.case = TRUE), drop = FALSE]

# Drop samples if requested
drop_samples <- character(0)
if (!is.null(drop_str) && drop_str != "") {
  drop_samples <- trimws(unlist(strsplit(drop_str, ",")))
  keep <- !colnames(counts) %in% drop_samples
  counts <- counts[, keep, drop = FALSE]
}

if (ncol(counts) < 4) {
  stop("Not enough samples after filtering/dropping. Found ", ncol(counts), " columns.")
}

# QC: QQ plots (raw)
plot_qq(counts, file.path(dir_qc, "QQ_plots_raw.pdf"))

# Infer group
group <- make_group_factor(colnames(counts))

# -----------------------------
# RUVSeq: build SeqExpressionSet
# -----------------------------
set <- newSeqExpressionSet(as.matrix(counts),
                           phenoData = data.frame(group = group, row.names = colnames(counts)))

plot_rle_pca(set, group, file.path(dir_qc, "RLE_PCA_no_normalization.pdf"), title_prefix = "")

# Estimate unwanted factors using residuals (RUVr)
design <- model.matrix(~group, data = pData(set))
y <- DGEList(counts = counts(set), group = group)
y <- calcNormFactors(y, method = "upperquartile")
y <- estimateGLMCommonDisp(y, design)
y <- estimateGLMTagwiseDisp(y, design)

fit <- glmFit(y, design)
residuals_fit <- residuals(fit, type="deviance")

genes <- rownames(counts)[grep("", rownames(counts))]
set_ruv <- RUVr(set, genes, k = k_ruv, residuals_fit)

plot_rle_pca(set_ruv, group, file.path(dir_qc, "RLE_PCA_RUVr.pdf"), title_prefix = "RUVr ")

# Save normalized counts
norm_file <- file.path(dir_ruv, "Ctrl_vs_Olig_empirical_norm.tsv")
write.table(normCounts(set_ruv), file = norm_file, quote = FALSE, sep = "\t")

# Load normalized counts
norm_counts <- read.table(norm_file, header = TRUE, sep = "\t", row.names = 1, check.names = FALSE)

# QC: QQ plots (normalized)
plot_qq(as.data.frame(norm_counts), file.path(dir_qc, "QQ_plots_RUVr.pdf"))

# -----------------------------
# edgeR DE: non-paired exactTest
# -----------------------------
dge <- DGEList(counts = norm_counts, group = group)
dge <- calcNormFactors(dge)
dge <- estimateCommonDisp(dge)
dge <- estimateTagwiseDisp(dge)

et <- exactTest(dge)
res_nonpaired <- topTags(et, n = nrow(dge))$table
write.csv(res_nonpaired, file.path(dir_edger, "edgeR_results_non_paired.csv"), row.names = TRUE)

# -----------------------------
# edgeR DE: paired GLM (if pair IDs can be inferred)
# -----------------------------
pair_ids <- extract_pair_id(colnames(norm_counts))
paired_ok <- !all(is.na(pair_ids)) && length(unique(pair_ids[!is.na(pair_ids)])) >= 2

if (paired_ok) {
  pair_factor <- factor(pair_ids)
  group_glm <- group

  y_glm <- DGEList(counts = norm_counts)
  y_glm <- calcNormFactors(y_glm)
  design_glm <- model.matrix(~ pair_factor + group_glm)
  y_glm <- estimateDisp(y_glm, design_glm)
  fit_glm <- glmFit(y_glm, design_glm)

  coef_name <- "group_glmOlig"
  if (!(coef_name %in% colnames(design_glm))) {
    # fallback: pick last column
    coef_idx <- ncol(design_glm)
  } else {
    coef_idx <- which(colnames(design_glm) == coef_name)
  }

  lrt <- glmLRT(fit_glm, coef = coef_idx)
  res_paired <- topTags(lrt, n = Inf)$table
  write.csv(res_paired, file.path(dir_edger, "edgeR_results_paired.csv"), row.names = TRUE)
} else {
  message("Paired analysis skipped: could not infer pair IDs from sample names.")
}

# -----------------------------
# Heatmap of DEGs (grouped order Ctrl then Olig) with blank group separation
# -----------------------------
res_for_heat <- read.csv(file.path(dir_edger, "edgeR_results_non_paired.csv"), row.names = 1, check.names = FALSE)
res_for_heat$miRNA <- clean_mirna_id(rownames(res_for_heat))
res_for_heat$FDR <- as.numeric(res_for_heat$FDR)

# Filter by FDR; fallback if <2
sig_mirs <- res_for_heat %>%
  filter(!is.na(FDR)) %>%
  filter(FDR < fdr_heatmap) %>%
  arrange(FDR) %>%
  pull(miRNA)

if (length(sig_mirs) < 2) {
  message("WARNING: <2 DE miRNAs with FDR < ", fdr_heatmap, ". Using TOP 50 by FDR for heatmap.")
  sig_mirs <- res_for_heat %>%
    filter(!is.na(FDR)) %>%
    arrange(FDR) %>%
    head(50) %>%
    pull(miRNA)
}

rownames(norm_counts) <- clean_mirna_id(rownames(norm_counts))

# Column order grouped by condition
samples <- colnames(norm_counts)
ctrl_samples <- order_by_suffix_number(samples[grepl("Ctrl", samples, ignore.case = TRUE)])
olig_samples <- order_by_suffix_number(samples[grepl("Olig", samples, ignore.case = TRUE)])
ordered_samples <- c(ctrl_samples, olig_samples)

n_ctrl <- length(ctrl_samples)
colsep_positions <- if (n_ctrl > 0 && n_ctrl < length(ordered_samples)) n_ctrl else NULL

mat <- norm_counts[rownames(norm_counts) %in% sig_mirs, ordered_samples, drop = FALSE]
mat_log <- log2(as.matrix(mat) + 1)

plot_deg_heatmap(
  mat_log2 = mat_log,
  out_pdf = file.path(dir_heatmap, "DEG_heatmap_grouped.pdf"),
  main_title = "Control vs Olig (Log Counts)",
  colsep_positions = colsep_positions,
  cex_row = heatmap_cex_row,
  cex_col = heatmap_cex_col,
  sepwidth_val = heatmap_sepwidth,
  color_mode = heatmap_color
)

# Optional paired-order heatmap: Ctrl_i next to Olig_i, with separators between pairs
if (use_paired_heatmap && paired_ok) {
  pair_ids2 <- extract_pair_id(colnames(norm_counts))
  pair_levels <- sort(unique(pair_ids2[!is.na(pair_ids2)]))
  paired_order <- c()
  for (pid in pair_levels) {
    c_s <- colnames(norm_counts)[grepl("Ctrl", colnames(norm_counts), ignore.case = TRUE) & pair_ids2 == pid]
    o_s <- colnames(norm_counts)[grepl("Olig", colnames(norm_counts), ignore.case = TRUE) & pair_ids2 == pid]
    paired_order <- c(paired_order, c_s, o_s)
  }
  paired_order <- paired_order[paired_order %in% colnames(norm_counts)]
  # separators after each pair block (every 2 columns)
  if (length(paired_order) >= 4) {
    colsep_pairs <- seq(2, length(paired_order) - 2, by = 2)
  } else {
    colsep_pairs <- NULL
  }

  mat2 <- norm_counts[rownames(norm_counts) %in% sig_mirs, paired_order, drop = FALSE]
  mat2_log <- log2(as.matrix(mat2) + 1)

  plot_deg_heatmap(
    mat_log2 = mat2_log,
    out_pdf = file.path(dir_heatmap, "DEG_heatmap_paired_order.pdf"),
    main_title = "Control vs Olig (Log Counts) - paired order",
    colsep_positions = colsep_pairs,
    cex_row = heatmap_cex_row,
    cex_col = heatmap_cex_col,
    sepwidth_val = heatmap_sepwidth,
    color_mode = heatmap_color
  )
}

# -----------------------------
# Volcano plots (EnhancedVolcano): non-paired AND paired (if available)
# -----------------------------
labels_to_show <- unique(trimws(unlist(strsplit(label_str, ","))))
labels_to_show <- labels_to_show[labels_to_show != ""]

# 1) Non-paired volcano
res_volcano_nonpaired <- read.csv(
  file.path(dir_edger, "edgeR_results_non_paired.csv"),
  row.names = 1,
  check.names = FALSE
)

plot_volcano_enhancedvolcano(
  res_table = res_volcano_nonpaired,
  out_pdf = file.path(dir_volcano, "Volcano_EnhancedVolcano_nonpaired.pdf"),
  plot_title = paste0("Control vs Oligomycin (", basename(outdir), ") - non-paired"),
  pCutoff = pCutoff,
  FCcutoff = FCcutoff,
  always_label = labels_to_show,
  label_top_n_each_side = label_top_n_each_side,
  max_labels_total = max_labels_total,
  width = volcano_width,
  height = volcano_height,
  xlim_left = -5
)

# 2) Paired volcano (only if the paired results were produced)
paired_csv <- file.path(dir_edger, "edgeR_results_paired.csv")
if (file.exists(paired_csv)) {

  res_volcano_paired <- read.csv(
    paired_csv,
    row.names = 1,
    check.names = FALSE
  )

  plot_volcano_enhancedvolcano(
    res_table = res_volcano_paired,
    out_pdf = file.path(dir_volcano, "Volcano_EnhancedVolcano_paired.pdf"),
    plot_title = paste0("Control vs Oligomycin (", basename(outdir), ") - paired GLM"),
    pCutoff = pCutoff,
    FCcutoff = FCcutoff,
    always_label = labels_to_show,
    label_top_n_each_side = label_top_n_each_side,
    max_labels_total = max_labels_total,
    width = volcano_width,
    height = volcano_height,
    xlim_left = -5
  )

} else {
  message("Paired volcano skipped: paired results file not found at: ", paired_csv)
}



# -----------------------------
# Seed (m8) conservation plots (human DE miRNAs vs mouse miRNAs)
#   - Plots only the positive / conserved ones (n_mouse > 0)
#   - Uses smaller, auto-scaled plot height
# -----------------------------
tryCatch({
  plot_seed_m8_conservation(
    res_table  = res_volcano_nonpaired,
    out_pdf    = file.path(dir_seed, "Seed_m8_conservation_up_positive.pdf"),
    selection  = "up",
    fdr_cut    = pCutoff,
    fc_cut     = FCcutoff,
    positive_only = TRUE,
    width      = 7.0
  )

  plot_seed_m8_conservation(
    res_table  = res_volcano_nonpaired,
    out_pdf    = file.path(dir_seed, "Seed_m8_conservation_down_positive.pdf"),
    selection  = "down",
    fdr_cut    = pCutoff,
    fc_cut     = FCcutoff,
    positive_only = TRUE,
    width      = 7.0
  )
}, error = function(e) {
  message("Seed conservation plotting failed: ", conditionMessage(e))
})
# -----------------------------
# Target retrieval (multiMiR) + g:Profiler enrichment (GO+REAC)
# -----------------------------
# NOTE: Update mirna list as desired. Here we keep your original list (requires 'hsa-' prefix).
mirnas <- paste0("hsa-", c(
  "miR-4434", "miR-7112-3p", "miR-4684-5p", "miR-3064-5p",
  "miR-4478", "miR-4534", "miR-8069", "miR-365a-5p",
  "miR-3180", "miR-6887-5p", "miR-4535", "miR-4257",
  "miR-6736-5p", "miR-758-5p", "miR-4470", "miR-5682",
  "miR-4302", "miR-2467-5p", "miR-6071", "miR-3166",
  "miR-298", "miR-3681-5p", "miR-4267", "miR-135b-3p",
  "miR-1178-5p", "miR-4303", "miR-6842-3p", "miR-6838-3p",
  "miR-367-5p", "miR-2114-5p", "miR-646", "miR-933",
  "miR-623", "miR-4697-5p", "miR-208a-3p", "miR-6762-3p",
  "miR-939-5p", "miR-218-1-3p", "miR-4296", "miR-6720-5p",
  "miR-193a-3p", "miR-6750-5p", "miR-33b-5p", "miR-3714",
  "miR-3677-5p", "miR-4519", "miR-6080", "miR-4746-5p",
  "miR-497-3p", "miR-648", "miR-4266", "miR-5008-5p",
  "miR-6848-3p", "miR-6770-3p", "miR-619-3p", "miR-572",
  "miR-6715b-3p", "miR-3169", "miR-6774-3p", "miR-662",
  "miR-7703", "miR-6816-3p", "miR-127-5p", "miR-4690-3p",
  "miR-1199-3p", "miR-6893-3p", "miR-6510-3p", "miR-4673",
  "miR-6081", "miR-6821-3p", "miR-3124-5p", "miR-6791-3p",
  "miR-3917", "miR-935", "miR-4707-5p", "miR-4515"
))

message("Querying multiMiR targets (this can take time)...")
mm <- get_multimir(mirna = mirnas, summary = TRUE)

miRNA_targets <- mm@data %>%
  dplyr::select(mature_mirna_id, target_symbol, database) %>%
  distinct()

write.csv(miRNA_targets, file.path(dir_targets, "miRNA_Target_Genes.csv"), row.names = FALSE)

message("Running g:Profiler enrichment per miRNA (GO:BP/GO:MF/TF/REAC)...")
unique_mirnas <- unique(miRNA_targets$mature_mirna_id)

for (mirna in unique_mirnas) {

  target_genes_vec <- miRNA_targets %>%
    filter(mature_mirna_id == mirna) %>%
    pull(target_symbol) %>%
    unique() %>%
    na.omit()

  if (length(target_genes_vec) == 0) next

  mirna_safe <- gsub("[^A-Za-z0-9_.-]", "_", mirna)

  go_results <- tryCatch(
    gost(
      query = target_genes_vec,
      organism = "hsapiens",
      correction_method = "fdr",
      sources = c("GO:BP", "GO:MF", "TF", "REAC")
    ),
    error = function(e) {
      message("gost() failed for ", mirna, " : ", conditionMessage(e))
      return(NULL)
    }
  )

  if (is.null(go_results) || is.null(go_results$result) || nrow(go_results$result) == 0) next

  go_results_df <- as.data.frame(go_results$result)

  # Convert list-columns to CSV-safe strings
  is_list_col <- vapply(go_results_df, is.list, logical(1))
  if (any(is_list_col)) {
    go_results_df[is_list_col] <- lapply(go_results_df[is_list_col], function(col) {
      vapply(col, function(x) {
        if (is.null(x) || length(x) == 0 || all(is.na(x))) return(NA_character_)
        paste(unlist(x), collapse = ",")
      }, FUN.VALUE = NA_character_)
    })
  }

  # Rename intersection -> intersection_genes if present
  if ("intersection" %in% colnames(go_results_df)) {
    colnames(go_results_df)[colnames(go_results_df) == "intersection"] <- "intersection_genes"
  }

  write.csv(go_results_df,
            file = file.path(dir_go_terms, paste0("GO_Terms_", mirna_safe, ".csv")),
            row.names = FALSE)

  write.csv(data.frame(miRNA = mirna, target_genes = paste(target_genes_vec, collapse=";")),
            file = file.path(dir_go_terms, paste0("Target_Genes_", mirna_safe, ".csv")),
            row.names = FALSE)
}

message("DONE. Outputs written under: ", outdir)
'''

nano tf_immune_analysis.py
'''
#!/usr/bin/env python3
# =============================================================================
# tf_immune_analysis.py
#
# Reads g:Profiler per-miRNA enrichment CSVs (GO_Terms_*.csv) and produces:
#   - TF presence/absence heatmap (blank separators, shrunken dendrogram & color key)
#   - Immune-term presence/absence heatmap (same aesthetics)
#   - Sankey TF -> Immune co-occurrence (HTML)
#   - Ranked, enriched immune pathways table (GO:BP + Reactome) in a CSV
#
# Usage:
#   python tf_immune_analysis.py --input <GO_Terms_dir> --outdir <outdir>
#
# Example:
#   python tf_immune_analysis.py \
#     --input results/without_sample2/06_enrichment/GO_Terms \
#     --outdir results/without_sample2/07_tf_immune
#
# =============================================================================

import argparse
import glob
import os
import re
from collections import defaultdict

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm

try:
    import plotly.graph_objects as go
    PLOTLY_OK = True
except Exception:
    PLOTLY_OK = False

# -----------------------------
# CLI
# -----------------------------
def parse_args():
    p = argparse.ArgumentParser(description="TF + Immune term analysis from g:Profiler GO_Terms_*.csv files")
    p.add_argument("--input", required=True, help="Directory containing GO_Terms_*.csv files")
    p.add_argument("--outdir", required=True, help="Output directory for plots/tables")
    p.add_argument("--min_presence", type=int, default=2, help="Min # miRNAs a term must appear in to keep (default: 2)")
    p.add_argument("--pval_cutoff", type=float, default=0.05, help="p_value cutoff for ranked immune table (default: 0.05)")

    # Sankey controls (for legibility)
    p.add_argument("--sankey_top_tfs", type=int, default=10,
                   help="Show top N TF nodes in Sankey by miRNA-count (default: 10). Use 0 for all.")
    p.add_argument("--sankey_top_terms", type=int, default=8,
                   help="Show top N Immune-term nodes in Sankey by miRNA-count (default: 8). Use 0 for all.")
    p.add_argument("--sankey_min_link", type=int, default=1,
                   help="Minimum co-occurrence count to draw a link (default: 1). Increase to declutter.")
    p.add_argument("--sankey_scale", type=float, default=1.0,
                   help="Multiply link widths by this constant (default: 1.0) for visibility.")
    p.add_argument("--sankey_palette", default="Set3",
                   help="Plotly qualitative palette for RIGHT nodes (e.g. Set3, Plotly, D3, Dark2, Pastel).")
    p.add_argument("--sankey_left_color", default="#d978c2",
                   help="Hex color for LEFT TF nodes (default: pink).")
    p.add_argument("--sankey_alpha", type=float, default=0.55,
                   help="Link transparency 0-1 (default: 0.55).")
    p.add_argument("--sankey_width", type=int, default=1500, help="Sankey figure width in px (default: 1500).")
    p.add_argument("--sankey_height", type=int, default=900, help="Sankey figure height in px (default: 900).")

    return p.parse_args()

# -----------------------------
# TF dictionary
# -----------------------------
TRANSCRIPTION_FACTORS = {
    "FOXP3":  ["FOXP3", "Foxp3", "FKH10", "scurfin"],
    "GATA3":  ["GATA3", "GATA-3"],
    "TBX21":  ["T-bet", "Tbet", "TBX21"],
    "RORC":   ["RORC", "RORγt", "RORgt", "ROR-gamma", "NR1F3"],
    "BCL6":   ["BCL6", "B-cell lymphoma 6"],
    "MAF":    ["c-Maf", "MAF"],
    "BACH2":  ["BACH2"],
    "PRDM1":  ["BLIMP1", "Blimp-1", "PRDM1"],
    "AHR":    ["AHR", "aryl hydrocarbon receptor", "AhR"],
    "EGR2":   ["EGR2", "EGR-2"],
    "EGR3":   ["EGR3", "EGR-3"],
    "IKZF2":  ["Helios", "IKZF2"],
    "IKZF1":  ["Ikaros", "IKZF1"],
    "BCL11B": ["BCL-11B", "BCL11B"],
    "NFIL3":  ["NFIL3", "E4BP4"],
    "FOXO1":  ["FOXO1", "Fkhr", "FKHRL1", "FOXO1A"],
    "IRF4":   ["IRF4"],
    "STAT5":  ["STAT5", "Stat5"],
    "NFAT":   ["NFAT", "Nuclear factor of activated T-cells"],
    "AP-1":   ["AP-1", "AP1", "FOS", "JUN"],
    "NF-κB":  ["NF-kB", "NF-kappaB", "RelA", "RelB", "c-Rel", "NFκB", "NFKB"]
}

TF_PATTERNS = {tf: re.compile("|".join(re.escape(s) for s in syns), re.IGNORECASE)
               for tf, syns in TRANSCRIPTION_FACTORS.items()}

def find_matched_tfs(term_name, term_id):
    text = f"{term_name or ''} {term_id or ''}"
    return [tf for tf, pat in TF_PATTERNS.items() if pat.search(text)]

def make_tf_label(matched_list):
    if not matched_list:
        return None
    return "TF: " + ", ".join(sorted(set(matched_list)))

# -----------------------------
# Immune terms dictionary (for presence/absence heatmap)
# -----------------------------
IMMUNE_SYNONYMS = [
    "positive regulation of innate immune response",
    "somatic diversification of immune receptors",
    "lymphocyte activation",
    "lymphocyte differentiation",
    "regulation of lymphocyte differentiation",
    "regulation of lymphocyte activation",
    "positive regulation of lymphocyte activation",
    "positive regulation of lymphocyte differentiation",
    "appendix; lymphoid tissue[≥Low]",
    "appendix; lymphoid tissue[≥Medium]",
    "appendix; lymphoid tissue[High]",
    "skin 1; lymphocytes[≥Low]",
    "skin 1; lymphocytes[≥Medium]",
    "skin 1; lymphocytes[≥High]",
    "skin 2; lymphocytes[≥Low]",
    "skin 2; lymphocytes[≥Medium]",
    "skin 2; lymphocytes[≥High]",
    "rectum; mucosal lymphoid cells[High]",
    "miR targeted genes in lymphocytes"
]
IMMUNE_PATTERN = re.compile("|".join(re.escape(s) for s in IMMUNE_SYNONYMS), re.IGNORECASE)

def row_matches_immune(term_name, term_id):
    text = f"{term_name or ''} {term_id or ''}"
    return IMMUNE_PATTERN.search(text) is not None

# Wide immune regex for ranked table
IMMUNE_REGEX_WIDE = re.compile(
    r"(Th17|Th1|Th2|T[\s-]?cell|B[\s-]?cell|Treg|regulatory T|"
    r"lymphocyte|leukocyte|immune|immun|inflamm|"
    r"cytokine|chemokine|interleukin|interferon|"
    r"NF[-\s]?kappaB|NF[-\s]?κB|NFκB|NFKB|"
    r"toll[-\s]?like|\bTLR\b|JAK[-\s]?STAT|\bTNF\b|"
    r"antigen|MHC|"
    r"IL[-\s]?\d+)",
    re.IGNORECASE
)

def row_matches_immune_wide(term_name, term_id):
    text = f"{term_name or ''} {term_id or ''}"
    return IMMUNE_REGEX_WIDE.search(text) is not None

# -----------------------------
# Plot helpers
# -----------------------------
def save_clustermap(df, out_pdf, title, cmap, norm=None, discrete=False):
    sns.set(style="white", font_scale=1.4)

    # blank separators: thick white grid lines
    g = sns.clustermap(
        df,
        cmap=cmap,
        norm=norm,
        linewidths=1.6,
        linecolor="white",
        figsize=(16, 8),
        dendrogram_ratio=(0.08, 0.12),       # shrink dendrogram
        cbar_pos=(0.02, 0.83, 0.03, 0.12),   # shrink color key
        cbar_kws={"label": "Presence (1) / Absence (0)"}
    )
    g.fig.suptitle(title, y=1.02)
    g.savefig(out_pdf, bbox_inches="tight")
    plt.close()

def main():
    args = parse_args()
    in_dir = args.input
    outdir = args.outdir

    # Output subfolders
    out_heatmaps = os.path.join(outdir, "heatmaps")
    out_tables = os.path.join(outdir, "tables")
    out_sankey = os.path.join(outdir, "sankey")
    for d in (out_heatmaps, out_tables, out_sankey):
        os.makedirs(d, exist_ok=True)

    files = sorted(glob.glob(os.path.join(in_dir, "GO_Terms_*.csv")))
    if not files:
        raise FileNotFoundError(f"No GO_Terms_*.csv files found in: {in_dir}")

    all_dfs = []
    for f in files:
        base = os.path.splitext(os.path.basename(f))[0]
        mirna = base.replace("GO_Terms_", "")
        df = pd.read_csv(f)
        df["miRNA"] = mirna
        all_dfs.append(df)

    combined = pd.concat(all_dfs, ignore_index=True)

    # Some g:Profiler exports can contain NA term_id/term_name
    combined["term_id"] = combined["term_id"].astype(str)
    combined["term_name"] = combined["term_name"].astype(str)

    VALID_SOURCES = {"GO:BP", "GO:MF", "TF", "REAC"}
    df = combined[combined["source"].isin(VALID_SOURCES)].copy()

    # ------------------------------------------------------------
    # TF heatmap (presence/absence)
    # ------------------------------------------------------------
    df["MatchedTFs"] = df.apply(lambda r: find_matched_tfs(r.get("term_name"), r.get("term_id")), axis=1)
    tf_rows = df[df["MatchedTFs"].apply(lambda x: len(x) > 0)].copy()
    tf_rows["TFlabel"] = tf_rows["MatchedTFs"].apply(make_tf_label)
    tf_rows["presence"] = 1
    tf_rows = tf_rows.drop_duplicates(subset=["TFlabel", "miRNA"])

    tf_pivot = tf_rows.pivot_table(index="TFlabel", columns="miRNA", values="presence", fill_value=0, aggfunc="max")
    tf_pivot["row_sum"] = tf_pivot.sum(axis=1)
    tf_subset = tf_pivot[tf_pivot["row_sum"] >= args.min_presence].drop(columns=["row_sum"])

    if tf_subset.empty:
        print("No TF rows found (>=min_presence). Skipping TF heatmap.")
    else:
        cmap_tf = sns.light_palette("navy", as_cmap=True, reverse=True)
        save_clustermap(
            tf_subset,
            out_pdf=os.path.join(out_heatmaps, "Detected_TFs_heatmap.pdf"),
            title="Detected Transcription Factors",
            cmap=cmap_tf
        )

    # ------------------------------------------------------------
    # Immune heatmap (presence/absence)
    # ------------------------------------------------------------
    df["ImmuneMatch"] = df.apply(lambda r: row_matches_immune(r.get("term_name"), r.get("term_id")), axis=1)
    immune_rows = df[df["ImmuneMatch"]].copy()
    immune_rows["presence"] = 1
    immune_rows = immune_rows.drop_duplicates(subset=["term_id", "miRNA"])

    immune_pivot = immune_rows.pivot_table(index="term_id", columns="miRNA", values="presence", fill_value=0, aggfunc="max")
    immune_pivot["row_sum"] = immune_pivot.sum(axis=1)
    immune_subset = immune_pivot[immune_pivot["row_sum"] >= args.min_presence].drop(columns=["row_sum"])

    # Map term_id -> "term_id term_name"
    term2name = {}
    for _, r in immune_rows.iterrows():
        tid = r.get("term_id")
        tname = r.get("term_name")
        if pd.notnull(tid) and pd.notnull(tname):
            term2name[tid] = f"{tid} {tname}"
        else:
            term2name[tid] = tid
    immune_subset.index = [term2name.get(i, i) for i in immune_subset.index]

    if immune_subset.empty:
        print("No immune rows found (>=min_presence). Skipping immune heatmap.")
    else:
        cmap_discrete = ListedColormap(["lightgrey", "darkred"])
        norm = BoundaryNorm([-0.5, 0.5, 1.5], cmap_discrete.N)
        save_clustermap(
            immune_subset,
            out_pdf=os.path.join(out_heatmaps, "Detected_ImmuneTerms_heatmap.pdf"),
            title="Detected Immune Terms",
            cmap=cmap_discrete,
            norm=norm
        )

    # ------------------------------------------------------------
    # Sankey TF -> Immune co-occurrence (optional)
    # ------------------------------------------------------------
    if PLOTLY_OK and (not tf_subset.empty) and (not immune_subset.empty):
        tf_labels = list(tf_subset.index)
        immune_labels = list(immune_subset.index)

        pair_counts = defaultdict(int)
        common_cols = set(tf_subset.columns).intersection(set(immune_subset.columns))

        for col in common_cols:
            tf_present = tf_subset[tf_subset[col] == 1].index
            immune_present = immune_subset[immune_subset[col] == 1].index
            for tf_lab in tf_present:
                for im_lab in immune_present:
                    pair_counts[(tf_lab, im_lab)] += 1

        nodes = tf_labels + immune_labels
        node_index = {lab: i for i, lab in enumerate(nodes)}
        links = [(node_index[a], node_index[b], v) for (a, b), v in pair_counts.items() if v > 0]

        if links:
                        # -----------------------------
            # Improved Sankey (more legible):
            #   - top-N TFs and top-N Immune terms
            #   - labels include counts on a new line (like the reference figure)
            #   - fixed left/right positioning, larger canvas
            #   - link colors match the RIGHT node colors (with alpha)
            # -----------------------------
            import plotly.express as px
            import plotly.io as pio

            def _evenly_spaced(n, start=0.05, end=0.95):
                if n <= 1:
                    return [0.5]
                step = (end - start) / (n - 1)
                return [start + i * step for i in range(n)]

            def _hex_to_rgba(hex_color, alpha=0.55):
                h = hex_color.lstrip("#")
                if len(h) != 6:
                    return f"rgba(120,120,120,{alpha})"
                r = int(h[0:2], 16)
                g = int(h[2:4], 16)
                b = int(h[4:6], 16)
                return f"rgba({r},{g},{b},{alpha})"

            # Node counts (miRNA presence counts)
            tf_counts = subset_df.sum(axis=1).sort_values(ascending=False)
            immune_counts = subset_immune.sum(axis=1).sort_values(ascending=False)

            tf_keep = list(tf_counts.index)
            immune_keep = list(immune_counts.index)

            if args.sankey_top_tfs and args.sankey_top_tfs > 0:
                tf_keep = tf_keep[:args.sankey_top_tfs]
            if args.sankey_top_terms and args.sankey_top_terms > 0:
                immune_keep = immune_keep[:args.sankey_top_terms]

            # Subset for Sankey
            sank_tf = subset_df.loc[tf_keep, subset_df.columns]
            sank_im = subset_immune.loc[immune_keep, subset_immune.columns]

            # Recompute link counts (TF, immune) by miRNA co-occurrence
            pair_counts2 = defaultdict(int)
            common_cols2 = set(sank_tf.columns).intersection(set(sank_im.columns))
            for col in common_cols2:
                tfs_present = sank_tf[sank_tf[col] == 1].index
                ims_present = sank_im[sank_im[col] == 1].index
                for tf in tfs_present:
                    for im in ims_present:
                        pair_counts2[(tf, im)] += 1

            # Filter links
            links2 = [(a, b, v) for (a, b), v in pair_counts2.items() if v >= args.sankey_min_link]
            if not links2:
                print("No TF->Immune co-occurrences after filtering. Skipping Sankey.")
            else:
                # Order nodes by counts (desc) for readability
                tf_keep = sorted(tf_keep, key=lambda x: tf_counts.get(x, 0), reverse=True)
                immune_keep = sorted(immune_keep, key=lambda x: immune_counts.get(x, 0), reverse=True)

                # Build labels: include count on a new line (like reference)
                left_labels = [f"{tf}\n{int(tf_counts.get(tf, 0))}" for tf in tf_keep]
                right_labels = [f"{im}\n{int(immune_counts.get(im, 0))}" for im in immune_keep]

                nodes = left_labels + right_labels
                n_tf = len(left_labels)
                n_im = len(right_labels)

                node_index = {label: i for i, label in enumerate(nodes)}
                # Map original names -> labeled names
                tf_label_map = {tf: left_labels[i] for i, tf in enumerate(tf_keep)}
                im_label_map = {im: right_labels[i] for i, im in enumerate(immune_keep)}

                # Palette for RIGHT nodes
                pal_name = str(args.sankey_palette).strip().lower()
                palette_dict = {
                    "set3": px.colors.qualitative.Set3,
                    "plotly": px.colors.qualitative.Plotly,
                    "d3": px.colors.qualitative.D3,
                    "dark2": px.colors.qualitative.Dark2,
                    "pastel": px.colors.qualitative.Pastel,
                    "pastel1": px.colors.qualitative.Pastel1,
                    "bold": px.colors.qualitative.Bold,
                    "prism": px.colors.qualitative.Prism,
                    "safe": px.colors.qualitative.Safe,
                    "vivid": px.colors.qualitative.Vivid
                }
                right_base = palette_dict.get(pal_name, px.colors.qualitative.Set3)
                right_colors = [right_base[i % len(right_base)] for i in range(n_im)]
                left_color = args.sankey_left_color

                node_colors = [left_color] * n_tf + right_colors

                # Build link arrays with colors matching the RIGHT node (target)
                src_idx, tgt_idx, val_arr, link_colors = [], [], [], []
                for tf, im, v in links2:
                    if tf not in tf_keep or im not in immune_keep:
                        continue
                    s_lab = tf_label_map[tf]
                    t_lab = im_label_map[im]
                    s = node_index[s_lab]
                    t = node_index[t_lab]
                    src_idx.append(s)
                    tgt_idx.append(t)
                    val_arr.append(float(v) * float(args.sankey_scale))

                    # Target node color
                    im_pos = immune_keep.index(im)
                    link_colors.append(_hex_to_rgba(right_colors[im_pos], alpha=args.sankey_alpha))

                # Fixed positions (left / right)
                x = [0.01] * n_tf + [0.99] * n_im
                y = _evenly_spaced(n_tf, 0.05, 0.95) + _evenly_spaced(n_im, 0.05, 0.95)

                fig = go.Figure(data=[go.Sankey(
                    arrangement="fixed",
                    node=dict(
                        pad=18,
                        thickness=22,
                        line=dict(color="black", width=0.6),
                        label=nodes,
                        color=node_colors,
                        x=x,
                        y=y
                    ),
                    link=dict(
                        source=src_idx,
                        target=tgt_idx,
                        value=val_arr,
                        color=link_colors
                    )
                )])

                fig.update_layout(
                    title_text="Sankey: TF (left) → Immune GO terms (right)",
                    font_size=18,
                    width=args.sankey_width,
                    height=args.sankey_height,
                    margin=dict(l=40, r=340, t=80, b=40)
                )

                out_html = os.path.join(out_sankey, "tf_immune_sankey.html")
                fig.write_html(out_html)
                print(f"Sankey saved to: {out_html}")

                # Optional static export (requires kaleido)
                try:
                    out_png = os.path.join(out_sankey, "tf_immune_sankey.png")
                    fig.write_image(out_png, scale=2)
                    print(f"Sankey PNG saved to: {out_png}")
                except Exception:
                    print("Static PNG export skipped (install kaleido: pip install -U kaleido).")
        else:
            print("No TF->Immune co-occurrences found. Skipping Sankey.")
    else:
        if not PLOTLY_OK:
            print("Plotly not available; skipping Sankey.")
        else:
            print("TF or Immune subset empty; skipping Sankey.")

    # ------------------------------------------------------------
    # Ranked immune pathways table (GO:BP + REAC)
    # ------------------------------------------------------------
    required_cols = {"source", "term_id", "term_name", "p_value", "miRNA"}
    missing = required_cols - set(combined.columns)
    if missing:
        raise ValueError(f"Missing columns in combined data: {missing}")

    combined["p_value"] = pd.to_numeric(combined["p_value"], errors="coerce")

    ranked_df = combined[
        combined["source"].isin(["GO:BP", "REAC"]) &
        (combined["p_value"] <= args.pval_cutoff)
    ].copy()

    ranked_df["ImmuneMatchWide"] = ranked_df.apply(lambda r: row_matches_immune_wide(r.get("term_name"), r.get("term_id")), axis=1)
    immune_hits = ranked_df[ranked_df["ImmuneMatchWide"]].copy()

    if immune_hits.empty:
        print("No immune pathways found for ranked table with current filters.")
    else:
        ranked = (
            immune_hits
            .groupby(["source", "term_id", "term_name"], as_index=False)
            .agg(
                best_p=("p_value", "min"),
                n_miRNAs=("miRNA", "nunique"),
                miRNAs=("miRNA", lambda x: "; ".join(sorted(set(x))))
            )
        )
        ranked["minus_log10_p"] = -np.log10(ranked["best_p"].astype(float))
        ranked = ranked.sort_values(["best_p", "n_miRNAs"], ascending=[True, False]).reset_index(drop=True)
        ranked.insert(0, "rank", ranked.index + 1)

        ranked.to_csv(os.path.join(out_tables, "immune_pathways_ranked_GO_REAC.csv"), index=False)
        print("Saved ranked immune pathways table:",
              os.path.join(out_tables, "immune_pathways_ranked_GO_REAC.csv"))

    print("DONE. Outputs written under:", outdir)

if __name__ == "__main__":
    main()
'''

nano tf_immune_heatmaps_ranked_table.py
'''
#!/usr/bin/env python3
# ===============================================================
# tf_immune_heatmaps_ranked_table.py  (additional/legacy utility)
#
# Post-process g:Profiler outputs (GO_Terms_*.csv)
#   - TF heatmap (>= N miRNA presence)
#   - Immune terms heatmap (>= N miRNA presence) using a STRICT synonym list
#   - Sankey: TF -> Immune terms (simple co-occurrence across miRNAs)
#   - Ranked immune pathways table (GO:BP + REAC) with wide regex filter
#
# Usage:
#   python tf_immune_heatmaps_ranked_table.py --input <GO_Terms_dir> --outdir <outdir>
#
# Example:
#   python tf_immune_heatmaps_ranked_table.py \
#     --input ./figures/without_sample2/enrichment/gprofiler/GO_Terms \
#     --outdir ./figures/without_sample2/enrichment/summary
#
# Note:
#   This script is included because you asked to "join this script too".
#   For a *more legible* Sankey (top-N nodes, fixed left/right layout, counts in labels),
#   prefer: tf_immune_analysis.py (also included in this package).
# ===============================================================

import os
import glob
import re
import argparse
from collections import defaultdict

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm

try:
    import plotly.graph_objects as go
    HAS_PLOTLY = True
except Exception:
    HAS_PLOTLY = False


# -----------------------------
# TF DICTIONARY OF SYNONYMS
# -----------------------------
TRANSCRIPTION_FACTORS = {
    "FOXP3":  ["FOXP3", "Foxp3", "FKH10", "scurfin"],
    "GATA3":  ["GATA3", "GATA-3"],
    "TBX21":  ["T-bet", "Tbet", "TBX21"],
    "RORC":   ["RORC", "RORγt", "RORgt", "ROR-gamma", "NR1F3"],
    "BCL6":   ["BCL6", "B-cell lymphoma 6"],
    "MAF":    ["c-Maf", "MAF"],
    "BACH2":  ["BACH2"],
    "PRDM1":  ["BLIMP1", "Blimp-1", "PRDM1"],
    "AHR":    ["AHR", "aryl hydrocarbon receptor", "AhR"],
    "EGR2":   ["EGR2", "EGR-2"],
    "EGR3":   ["EGR3", "EGR-3"],
    "IKZF2":  ["Helios", "IKZF2"],
    "IKZF1":  ["Ikaros", "IKZF1"],
    "BCL11B": ["BCL-11B", "BCL11B"],
    "NFIL3":  ["NFIL3", "E4BP4"],
    "FOXO1":  ["FOXO1", "Fkhr", "FKHRL1", "FOXO1A"],
    "IRF4":   ["IRF4"],
    "STAT5":  ["STAT5", "Stat5"],
    "NFAT":   ["NFAT", "Nuclear factor of activated T-cells"],
    "AP-1":   ["AP-1", "AP1", "FOS", "JUN"],
    "NF-κB":  ["NF-kB", "NF-kappaB", "RelA", "RelB", "c-Rel", "NFκB", "NFKB"],
}

TF_PATTERNS = {}
for tf, syns in TRANSCRIPTION_FACTORS.items():
    escaped = [re.escape(s) for s in syns]
    TF_PATTERNS[tf] = re.compile("|".join(escaped), re.IGNORECASE)


def find_matched_tfs(term_name, term_id):
    text = f"{term_name or ''} {term_id or ''}"
    matched = []
    for tf, pat in TF_PATTERNS.items():
        if pat.search(text):
            matched.append(tf)
    return matched


def make_tf_label(matched_list, original_term_id):
    if not matched_list:
        return original_term_id
    return "TF: " + ", ".join(matched_list)


# -----------------------------
# IMMUNE DICTIONARY (strict list)
# -----------------------------
IMMUNE_SYNONYMS_STRICT = [
    "positive regulation of innate immune response",
    "somatic diversification of immune receptors",
    "lymphocyte activation",
    "lymphocyte differentiation",
    "regulation of lymphocyte differentiation",
    "regulation of lymphocyte activation",
    "positive regulation of lymphocyte activation",
    "positive regulation of lymphocyte differentiation",
    "appendix; lymphoid tissue[≥Low]",
    "appendix; lymphoid tissue[≥Medium]",
    "appendix; lymphoid tissue[High]",
    "skin 1; lymphocytes[≥Low]",
    "skin 1; lymphocytes[≥Medium]",
    "skin 1; lymphocytes[≥High]",
    "skin 2; lymphocytes[≥Low]",
    "skin 2; lymphocytes[≥Medium]",
    "skin 2; lymphocytes[≥High]",
    "rectum; mucosal lymphoid cells[High]",
    "miR targeted genes in lymphocytes",
]
IMMUNE_PATTERN_STRICT = re.compile("|".join(re.escape(s) for s in IMMUNE_SYNONYMS_STRICT), re.IGNORECASE)


def row_matches_immune_strict(term_name, term_id):
    text = f"{term_name or ''} {term_id or ''}"
    return IMMUNE_PATTERN_STRICT.search(text) is not None


# -----------------------------
# IMMUNE REGEX (wide) for ranked table
# -----------------------------
IMMUNE_REGEX_WIDE = (
    r"(Th17|Th1|Th2|T[\s-]?cell|B[\s-]?cell|Treg|regulatory T|"
    r"lymphocyte|leukocyte|immune|immun|inflamm|"
    r"cytokine|chemokine|interleukin|interferon|"
    r"NF[-\s]?kappaB|NF[-\s]?κB|NFκB|NFKB|"
    r"toll[-\s]?like|\bTLR\b|JAK[-\s]?STAT|\bTNF\b|"
    r"antigen|MHC|"
    r"IL[-\s]?\d+)"
)
IMMUNE_PATTERN_WIDE = re.compile(IMMUNE_REGEX_WIDE, re.IGNORECASE)


def row_matches_immune_wide(term_name, term_id):
    text = f"{term_name or ''} {term_id or ''}"
    return IMMUNE_PATTERN_WIDE.search(text) is not None


def ensure_dir(path: str) -> str:
    os.makedirs(path, exist_ok=True)
    return path


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--input", required=True, help="Directory containing GO_Terms_*.csv")
    ap.add_argument("--outdir", required=True, help="Output directory for plots/tables")
    ap.add_argument("--min_mirnas", type=int, default=2, help="Min miRNA presence threshold (default: 2)")
    ap.add_argument("--p_cutoff", type=float, default=0.05, help="p_value cutoff for ranked immune table (default: 0.05)")
    ap.add_argument("--valid_sources", default="GO:BP,GO:MF,TF,REAC",
                    help="Comma-separated list of sources to keep for heatmaps (default: GO:BP,GO:MF,TF,REAC)")
    args = ap.parse_args()

    in_dir = args.input
    outdir = args.outdir
    min_mirnas = args.min_mirnas
    p_cutoff = args.p_cutoff
    valid_sources = [x.strip() for x in args.valid_sources.split(",") if x.strip()]

    ensure_dir(outdir)
    out_heatmaps = ensure_dir(os.path.join(outdir, "heatmaps"))
    out_sankey = ensure_dir(os.path.join(outdir, "sankey"))
    out_tables = ensure_dir(os.path.join(outdir, "tables"))
    out_combined = ensure_dir(os.path.join(outdir, "combined"))

    pattern = os.path.join(in_dir, "GO_Terms_*.csv")
    file_list = sorted(glob.glob(pattern))
    if not file_list:
        raise SystemExit(f"No files found matching: {pattern}")

    all_dfs = []
    for fname in file_list:
        base = os.path.splitext(os.path.basename(fname))[0]
        mirna = base.replace("GO_Terms_", "")
        df = pd.read_csv(fname)
        df["miRNA"] = mirna
        all_dfs.append(df)

    combined_df = pd.concat(all_dfs, ignore_index=True)
    combined_path = os.path.join(out_combined, "combined_gprofiler_terms.csv")
    combined_df.to_csv(combined_path, index=False)

    required_cols = {"source", "term_id", "term_name", "p_value", "miRNA"}
    missing = required_cols - set(combined_df.columns)
    if missing:
        raise ValueError(f"Missing columns in combined_df: {missing}.")

    filtered_df = combined_df[combined_df["source"].isin(valid_sources)].copy()

    # ---- Heatmap 1: TF
    filtered_df["MatchedTFs"] = filtered_df.apply(
        lambda r: find_matched_tfs(r.get("term_name"), r.get("term_id")),
        axis=1
    )
    focused_df = filtered_df[filtered_df["MatchedTFs"].apply(lambda x: len(x) > 0)].copy()
    focused_df["TFlabel"] = focused_df.apply(
        lambda r: make_tf_label(r["MatchedTFs"], r.get("term_id")),
        axis=1
    )
    focused_df["presence"] = 1
    focused_df.drop_duplicates(subset=["TFlabel", "miRNA"], inplace=True)

    pivot_tf = focused_df.pivot_table(index="TFlabel", columns="miRNA",
                                      values="presence", fill_value=0, aggfunc="max")
    pivot_tf["row_sum"] = pivot_tf.sum(axis=1)
    subset_tf = pivot_tf[pivot_tf["row_sum"] >= min_mirnas].drop(columns=["row_sum"])

    tf_heatmap_path = os.path.join(out_heatmaps, "heatmap_with_TF_names.pdf")
    if subset_tf.empty:
        print(f"No TF terms found in >= {min_mirnas} miRNAs. Skipping TF heatmap.")
    else:
        sns.set(style="white", font_scale=1.3)
        g = sns.clustermap(subset_tf, linewidths=0.5, linecolor="gray",
                           figsize=(16, 8), cbar_kws={"label": "Presence (1) / Absence (0)"})
        g.fig.suptitle("Detected TFs (presence/absence)", y=1.02)
        g.savefig(tf_heatmap_path, format="pdf", bbox_inches="tight")
        plt.close()

    # ---- Heatmap 2: immune strict
    filtered_df["ImmuneMatchStrict"] = filtered_df.apply(
        lambda r: row_matches_immune_strict(r.get("term_name"), r.get("term_id")),
        axis=1
    )
    immune_df = filtered_df[filtered_df["ImmuneMatchStrict"]].copy()
    immune_df["presence"] = 1
    immune_df.drop_duplicates(subset=["term_id", "miRNA"], inplace=True)

    pivot_immune = immune_df.pivot_table(index="term_id", columns="miRNA",
                                         values="presence", fill_value=0, aggfunc="max")
    pivot_immune["row_sum"] = pivot_immune.sum(axis=1)
    subset_immune = pivot_immune[pivot_immune["row_sum"] >= min_mirnas].drop(columns=["row_sum"])

    term2name = {}
    for _, row in immune_df.iterrows():
        tid = row.get("term_id")
        tname = row.get("term_name")
        term2name[tid] = f"{tid} {tname}" if pd.notnull(tid) and pd.notnull(tname) else tid
    subset_immune.index = [term2name.get(t, t) for t in subset_immune.index]

    immune_heatmap_path = os.path.join(out_heatmaps, "heatmap_immune_terms.pdf")
    if subset_immune.empty:
        print(f"No strict immune terms found in >= {min_mirnas} miRNAs. Skipping immune heatmap.")
    else:
        sns.set(style="white", font_scale=1.3)
        cmap_discrete = ListedColormap(["lightgrey", "darkred"])
        norm = BoundaryNorm([-0.5, 0.5, 1.5], cmap_discrete.N)
        g2 = sns.clustermap(subset_immune, cmap=cmap_discrete, norm=norm,
                            linewidths=0.5, linecolor="gray",
                            figsize=(16, 8),
                            cbar_kws={"label": "Presence (1) / Absence (0)"})
        g2.fig.suptitle("Detected Immune Terms (strict list)", y=1.02)
        g2.savefig(immune_heatmap_path, format="pdf", bbox_inches="tight")
        plt.close()

    # ---- Sankey (simple)
    sankey_path = os.path.join(out_sankey, "tf_immune_sankey.html")
    if not HAS_PLOTLY:
        print("Plotly not available; skipping Sankey.")
    elif subset_tf.empty or subset_immune.empty:
        print("Cannot build Sankey; no TF or immune data found.")
    else:
        tf_labels = list(subset_tf.index)
        immune_labels = list(subset_immune.index)
        pair_counts = defaultdict(int)
        common_cols = set(subset_tf.columns).intersection(set(subset_immune.columns))
        for col in common_cols:
            tf_present = subset_tf[subset_tf[col] == 1].index
            immune_present = subset_immune[subset_immune[col] == 1].index
            for tf in tf_present:
                for im in immune_present:
                    pair_counts[(tf, im)] += 1

        all_nodes = tf_labels + immune_labels
        node_index = {label: i for i, label in enumerate(all_nodes)}
        links = [(node_index[tf], node_index[im], v) for (tf, im), v in pair_counts.items() if v > 0]

        if links:
            fig = go.Figure(data=[go.Sankey(
                node=dict(pad=15, thickness=20, line=dict(color="black", width=0.5), label=all_nodes),
                link=dict(source=[s for s, _, _ in links],
                          target=[t for _, t, _ in links],
                          value=[v for _, _, v in links])
            )])
            fig.update_layout(title_text="Sankey: TF (left) -> Immune Terms (right)", font_size=14)
            fig.write_html(sankey_path)
            print(f"Sankey saved: {sankey_path}")
        else:
            print("No TF -> Immune co-occurrences found. Skipping Sankey.")

    # ---- Ranked immune pathways table
    enrich_sources_for_table = {"GO:BP", "REAC"}
    df_sig = combined_df[
        (combined_df["source"].isin(enrich_sources_for_table)) &
        (pd.to_numeric(combined_df["p_value"], errors="coerce") <= p_cutoff)
    ].copy()

    df_sig["ImmuneMatchWide"] = df_sig.apply(
        lambda r: row_matches_immune_wide(r.get("term_name"), r.get("term_id")),
        axis=1
    )
    immune_hits = df_sig[df_sig["ImmuneMatchWide"]].copy()
    ranked_path = os.path.join(out_tables, "immune_pathways_ranked_GO_REAC.csv")

    if not immune_hits.empty:
        ranked = (immune_hits.groupby(["source", "term_id", "term_name"], as_index=False)
                  .agg(best_p_value=("p_value", "min"),
                       n_miRNAs=("miRNA", "nunique"),
                       miRNAs=("miRNA", lambda x: "; ".join(sorted(set(map(str, x)))))))
        best = pd.to_numeric(ranked["best_p_value"], errors="coerce").fillna(1.0)
        best = np.maximum(best, np.finfo(float).tiny)
        ranked["minus_log10_p"] = -np.log10(best)
        ranked = ranked.sort_values(["best_p_value", "n_miRNAs"], ascending=[True, False]).reset_index(drop=True)
        ranked.insert(0, "rank", ranked.index + 1)
        ranked.to_csv(ranked_path, index=False)
        print(f"Ranked immune pathway table saved to: {ranked_path}")
    else:
        print(f"No immune pathways found with p_value <= {p_cutoff} in GO:BP/REAC.")

    print("DONE.")


if __name__ == "__main__":
    main()
'''

nano cross_species_mouse_seed_enrichment.R
'''
#!/usr/bin/env Rscript
# ===============================================================
# Cross-species validation: human EV miRNAs -> mouse targeting
#   (Seed-scan mode using TargetScan-style miR family table)
#
# FIX v2:
#   - Robust parser for miR_Family_Info.txt (plain/zip) WITHOUT assuming exact column names.
#   - Handles BOM, extra header/comment lines, and name variations ("miR family" vs "miR.family" etc.)
#
# What this script does (for ONE DE contrast/run):
#   1) Reads an edgeR DE table from an existing pipeline /results folder
#   2) Selects a miRNA subset (default = significantly UP miRNAs)
#   3) Fetches HUMAN mature sequences + Seed+m8 from a miR family table (e.g. TargetScan miR_Family_Info)
#      - very relaxed name matching (ignores hsa- prefixes, case, underscores, dot-versions, etc.)
#   4) Computes seed conservation vs MOUSE miRNAs using the SAME miR family table (species_id=10090)
#   5) Option A (recommended): scans mouse 3'UTRs for canonical seed sites (8mer, 7mer-m8, optional)
#   6) Runs functional enrichment in mouse (GO:BP + Reactome) via g:Profiler (organism = mmusculus)
#   7) Writes tables + PDF + summary
#
# Requirements (R packages):
#   - dplyr, readr, stringr, ggplot2, tibble
#   - gprofiler2
#   - AnnotationDbi, org.Mm.eg.db
#   - Biostrings, GenomicFeatures
#   - Either TxDb+BSgenome for mm10/mm39 OR --mouse_utr_fasta + --mouse_tx2gene
# ===============================================================

suppressPackageStartupMessages({
  library(dplyr)
  library(readr)
  library(stringr)
  library(ggplot2)
  library(tibble)
  library(gprofiler2)
  library(AnnotationDbi)
  library(org.Mm.eg.db)
  library(Biostrings)
  library(GenomicFeatures)
})

# ---------------------------
# Arg parsing (no optparse)
# ---------------------------
args <- commandArgs(trailingOnly = TRUE)

get_arg <- function(flag, default = NULL) {
  hit <- which(args == flag)
  if (length(hit) == 0) return(default)
  if (hit == length(args)) return(default)
  args[hit + 1]
}
has_flag <- function(flag) any(args == flag)

if (has_flag("--help") || length(args) == 0) {
  cat("
cross_species_mouse_seed_enrichment.R

Required:
  --results_dir          Pipeline run folder containing 02_edger/edgeR_results_*.csv
  --mir_family_table     miR family table (TSV or ZIP containing TSV) with columns like:
                         miR family, Seed+m8, Species ID, MiRBase ID, Mature sequence, ...

Optional:
  --design               paired | non_paired (default: non_paired)
  --de_table             Path to edgeR table (overrides --results_dir/--design)
  --direction            up | down | all_sig | all (default: up)
  --fdr                  FDR cutoff (default: 0.05)
  --logfc                logFC cutoff (default: 1)
  --outdir               output directory (default: <results_dir>/08_cross_species_mouse)
  --immune_regex         custom regex for immune term filtering (optional)

Targeting:
  --target_mode          seedscan_utr (default) | none
  --mouse_genome         mm10 | mm39 (default: mm10)
  --mouse_utr_fasta      Precomputed mouse 3'UTR FASTA (optional alternative to TxDb/BSgenome)
  --mouse_tx2gene        Transcript->Entrez mapping table (TSV/CSV) if using --mouse_utr_fasta
  --site_types           comma-separated: 8mer,7mer-m8,7mer-A1,6mer (default: 8mer,7mer-m8)
  --min_sites            minimum number of sites per transcript (default: 1)

Example:
  Rscript cross_species_mouse_seed_enrichment.R \\
    --results_dir results/without_sample2 \\
    --design non_paired \\
    --direction up --fdr 0.05 --logfc 1 \\
    --mir_family_table mirna_pipeline_package_v6/miR_Family_Info.txt.zip \\
    --target_mode seedscan_utr --mouse_genome mm10 \\
    --outdir results/without_sample2/08_cross_species_mouse_seedscan
\n")
  quit(status = 0)
}

results_dir <- get_arg("--results_dir", default = NULL)
design      <- get_arg("--design", default = "non_paired")
de_table    <- get_arg("--de_table", default = NULL)

direction   <- get_arg("--direction", default = "up")
fdr_cut     <- as.numeric(get_arg("--fdr", default = "0.05"))
logfc_cut   <- as.numeric(get_arg("--logfc", default = "1"))
outdir      <- get_arg("--outdir", default = NULL)

mir_family_table <- get_arg("--mir_family_table", default = NULL)

target_mode <- get_arg("--target_mode", default = "seedscan_utr")
mouse_genome <- get_arg("--mouse_genome", default = "mm10")
mouse_utr_fasta <- get_arg("--mouse_utr_fasta", default = NULL)
mouse_tx2gene   <- get_arg("--mouse_tx2gene", default = NULL)
site_types <- get_arg("--site_types", default = "8mer,7mer-m8")
min_sites  <- as.integer(get_arg("--min_sites", default = "1"))

species_human <- as.integer(get_arg("--species_human", default = "9606"))
species_mouse <- as.integer(get_arg("--species_mouse", default = "10090"))

immune_regex <- get_arg("--immune_regex", default =
  paste0(
    "(Th17|Th1|Th2|T[ -]?cell|B[ -]?cell|Treg|regulatory T|",
    "lymphocyte|leukocyte|immune|immun|inflamm|",
    "cytokine|chemokine|interleukin|interferon|",
    "NF[- ]?kappaB|NF[- ]?κB|NFκB|NFKB|",
    "toll[- ]?like|\\bTLR\\b|JAK[- ]?STAT|\\bTNF\\b|",
    "antigen|MHC|IL[- ]?\\d+)"
  )
)

if (is.null(outdir)) {
  if (is.null(results_dir)) stop("ERROR: Provide --results_dir or --outdir + --de_table")
  outdir <- file.path(results_dir, "08_cross_species_mouse")
}

dir.create(outdir, showWarnings = FALSE, recursive = TRUE)
fig_dir <- file.path(outdir, "figures"); dir.create(fig_dir, showWarnings = FALSE, recursive = TRUE)
tab_dir <- file.path(outdir, "tables");  dir.create(tab_dir, showWarnings = FALSE, recursive = TRUE)

# ---------------------------
# Helpers
# ---------------------------
fix_empty_colnames <- function(df) {
  cn <- colnames(df)
  bad <- is.na(cn) | cn == ""
  if (any(bad)) {
    cn[bad] <- paste0("V", which(bad))
    colnames(df) <- cn
  }
  df
}

normalize_mir_key <- function(x) {
  x <- as.character(x)
  x <- str_trim(x)
  x <- str_replace_all(x, "\\s+", "")
  x <- str_replace_all(x, "_", "-")
  x <- tolower(x)
  x <- sub("\\s.*$", "", x)
  # Strip 3-letter miRBase species prefixes (e.g. "hsa-", "mmu-")
  # while avoiding the common bug of stripping the core "mir-" prefix
  # from names like "miR-365a-5p". We only strip a 3-letter prefix
  # when the string does NOT already start with "mir-" or "let-".
  x <- sub("^(?!mir-|let-)([a-z]{3})-", "", x, perl = TRUE)
  # Be forgiving about missing hyphen after 'mir'/'let' (e.g. "mir21")
  x <- sub("^(mir|let)(?=[0-9])", "\\1-", x, perl = TRUE)
  # If an upstream preprocessing step already removed the leading "mir-"
  # (e.g. "365a-5p"), re-add it so it can still match the family table.
  x <- ifelse(grepl("^[0-9]", x), paste0("mir-", x), x)
  x
}

strip_dot_version <- function(x) sub("\\.[0-9]+$", "", x)
has_arm <- function(x) grepl("-(5p|3p)$", x, ignore.case = TRUE)

seed_to_site_patterns <- function(seed_m8_rna) {
  seed <- toupper(seed_m8_rna)
  seed_dna <- chartr("U", "T", seed)
  rc7 <- as.character(Biostrings::reverseComplement(Biostrings::DNAString(seed_dna)))
  seed6 <- substr(seed_dna, 1, 6)
  rc6 <- as.character(Biostrings::reverseComplement(Biostrings::DNAString(seed6)))
  list(
    `7mer-m8` = rc7,
    `8mer`    = paste0(rc7, "A"),
    `7mer-A1` = paste0(rc6, "A"),
    `6mer`    = rc6
  )
}

flatten_list_cols <- function(df) {
  if (!is.data.frame(df)) return(df)
  for (nm in names(df)) {
    if (is.list(df[[nm]])) {
      df[[nm]] <- vapply(df[[nm]], function(x) {
        if (is.null(x) || length(x) == 0) return("")
        paste(as.character(x), collapse = ";")
      }, character(1))
    }
  }
  df
}

# ---------------------------
# Locate DE table
# ---------------------------
if (is.null(de_table)) {
  if (is.null(results_dir)) stop("ERROR: Provide --de_table or --results_dir")
  edger_dir <- file.path(results_dir, "02_edger")
  if (!dir.exists(edger_dir)) stop("ERROR: expected folder not found: ", edger_dir)
  de_table <- if (design == "paired") {
    file.path(edger_dir, "edgeR_results_paired.csv")
  } else {
    file.path(edger_dir, "edgeR_results_non_paired.csv")
  }
}
if (!file.exists(de_table)) stop("ERROR: DE table not found: ", de_table)

message("Reading DE table: ", de_table)
de <- read.csv(de_table, stringsAsFactors = FALSE, check.names = FALSE)
de <- fix_empty_colnames(de)

first_col <- colnames(de)[1]
de$miRNA_raw <- de[[first_col]]
if (all(grepl("^[0-9]+$", as.character(de$miRNA_raw))) && ("X" %in% colnames(de))) {
  de$miRNA_raw <- de$X
}
stopifnot(all(c("logFC", "FDR") %in% colnames(de)))

de <- de %>%
  mutate(
    miRNA_raw   = as.character(miRNA_raw),
    miRNA_label = {
      x <- str_trim(miRNA_raw)
      x <- str_replace_all(x, "\\s+", "")
      x
    },
    miRNA_key = normalize_mir_key(miRNA_label),
    miRNA_key_nodot = strip_dot_version(miRNA_key)
  )

# ---------------------------
# Subset selection
# ---------------------------
sel <- de
if (direction == "up") {
  sel <- de %>% filter(FDR <= fdr_cut, logFC >= logfc_cut)
} else if (direction == "down") {
  sel <- de %>% filter(FDR <= fdr_cut, logFC <= -logfc_cut)
} else if (direction == "all_sig") {
  sel <- de %>% filter(FDR <= fdr_cut, abs(logFC) >= logfc_cut)
} else if (direction == "all") {
  sel <- de
} else {
  stop("Unknown --direction: ", direction, " (use up|down|all_sig|all)")
}
message("Selected miRNAs (", direction, "): ", nrow(sel))

# ---------------------------
# Read miR family table (TSV or ZIP), ROBUSTLY
# ---------------------------
if (is.null(mir_family_table)) stop("ERROR: Provide --mir_family_table (TSV or ZIP containing TSV).")
if (!file.exists(mir_family_table)) stop("ERROR: miR family table not found: ", mir_family_table)

read_mir_family_table <- function(path) {
  real_path <- path
  if (grepl("\\.zip$", path, ignore.case = TRUE)) {
    tmpdir <- tempfile("mirfam_unzip_")
    dir.create(tmpdir)
    utils::unzip(path, exdir = tmpdir)
    cand <- list.files(tmpdir, recursive = TRUE, full.names = TRUE)
    cand <- cand[grepl("\\.(txt|tsv|tab|csv)$", cand, ignore.case = TRUE)]
    if (length(cand) == 0) stop("ERROR: ZIP contains no .txt/.tsv/.tab/.csv file: ", path)
    real_path <- cand[1]
  }

  # Detect header line index (some files may have preamble/comments)
  lines <- readLines(real_path, warn = FALSE)
  hdr_idx <- which(
    grepl("species\\s*id", lines, ignore.case = TRUE) &
      grepl("mirbase", lines, ignore.case = TRUE) &
      grepl("seed", lines, ignore.case = TRUE)
  )[1]
  skip <- if (!is.na(hdr_idx) && hdr_idx > 1) hdr_idx - 1 else 0

  # First try: tab-delimited
  fam <- tryCatch({
    read.delim(real_path,
               sep = "\t", header = TRUE, skip = skip,
               stringsAsFactors = FALSE,
               quote = "", comment.char = "",
               check.names = FALSE, fill = TRUE)
  }, error = function(e) NULL)

  # Fallback: whitespace-delimited if tab parse failed
  if (is.null(fam) || ncol(fam) < 5) {
    fam <- tryCatch({
      read.table(real_path,
                 sep = "", header = TRUE, skip = skip,
                 stringsAsFactors = FALSE,
                 quote = "", comment.char = "",
                 check.names = FALSE, fill = TRUE)
    }, error = function(e) NULL)
  }

  if (is.null(fam) || nrow(fam) == 0) stop("ERROR: Failed to parse miR family table: ", real_path)

  # Clean colnames: remove BOM, normalize spaces
  cn <- colnames(fam)
  cn <- sub("^\ufeff", "", cn)           # remove BOM at start
  cn <- gsub("\ufeff", "", cn, fixed = TRUE)
  cn <- gsub("\\s+", " ", cn)
  cn <- trimws(cn)
  colnames(fam) <- cn

  cn_norm <- tolower(cn)
  cn_norm <- gsub("\\s+", " ", cn_norm)
  cn_norm <- trimws(cn_norm)

  find_col <- function(patterns) {
    idx <- integer(0)
    for (p in patterns) {
      hit <- which(grepl(p, cn_norm, ignore.case = TRUE))
      idx <- c(idx, hit)
    }
    idx <- unique(idx)
    if (length(idx) == 0) return(NA_integer_)
    idx[1]
  }

  idx_family <- find_col(c("^mir\\s*family$", "^mir_family$", "^mir\\.family$"))
  idx_seed   <- find_col(c("^seed\\+m8$", "^seed\\s*\\+\\s*m8$", "^seed\\.?\\+?m8$", "seed\\+m8"))
  idx_species<- find_col(c("^species\\s*id$", "^species_id$", "^taxon\\s*id$"))
  idx_mirbase<- find_col(c("^mirbase\\s*id$", "^mirbase_id$", "^mirbase\\.id$"))
  idx_mature <- find_col(c("^mature\\s*sequence$", "^mature_sequence$", "^mature\\.sequence$"))

  idx_fc     <- find_col(c("^family\\s*conservation\\??$", "family\\s*conservation"))
  idx_acc    <- find_col(c("^mirbase\\s*accession$", "^mirbase_accession$"))

  need_idx <- c(idx_family, idx_seed, idx_species, idx_mirbase, idx_mature)
  if (any(is.na(need_idx))) {
    msg <- paste0(
      "ERROR: Could not identify required columns in miR family table.\n",
      "Parsed columns were:\n  - ", paste(cn, collapse = "\n  - "), "\n\n",
      "Required (approx): miR family | Seed+m8 | Species ID | MiRBase ID | Mature sequence\n",
      "Tip: ensure you are using TargetScan miR_Family_Info.txt (tab-delimited) or adjust parsing.\n"
    )
    stop(msg)
  }

  out <- tibble::tibble(
    mir_family        = as.character(fam[[idx_family]]),
    seed_m8           = as.character(fam[[idx_seed]]),
    species_id        = suppressWarnings(as.integer(fam[[idx_species]])),
    mirbase_id        = as.character(fam[[idx_mirbase]]),
    mature_sequence   = as.character(fam[[idx_mature]]),
    family_conservation = if (!is.na(idx_fc)) as.character(fam[[idx_fc]]) else NA_character_,
    mirbase_accession   = if (!is.na(idx_acc)) as.character(fam[[idx_acc]]) else NA_character_
  ) %>%
    mutate(
      seed_m8 = str_trim(seed_m8),
      mirbase_id = str_trim(mirbase_id),
      mature_sequence = str_trim(mature_sequence),
      mirbase_id_key = normalize_mir_key(mirbase_id),
      mirbase_id_key_nodot = strip_dot_version(mirbase_id_key)
    )

  out
}

message("Reading miR family table: ", mir_family_table)
fam <- read_mir_family_table(mir_family_table)

fam_hsa <- fam %>% filter(species_id == species_human)
fam_mmu <- fam %>% filter(species_id == species_mouse)

message("miR family records: human=", nrow(fam_hsa), " mouse=", nrow(fam_mmu))

# Build lookup maps for human
hsa_lookup <- fam_hsa %>%
  arrange(mirbase_id) %>%
  distinct(mirbase_id_key, .keep_all = TRUE)

hsa_lookup_nodot <- fam_hsa %>%
  arrange(mirbase_id) %>%
  distinct(mirbase_id_key_nodot, .keep_all = TRUE)

# Mouse seed -> list of mouse miRNAs
mmu_seed2ids <- fam_mmu %>%
  filter(!is.na(seed_m8), seed_m8 != "") %>%
  group_by(seed_m8) %>%
  summarise(mouse_mirbase_ids = paste(unique(mirbase_id), collapse = ";"),
            n_mouse = n_distinct(mirbase_id),
            .groups = "drop") %>%
  distinct(seed_m8, .keep_all = TRUE)

mmu_seed_map <- setNames(mmu_seed2ids$mouse_mirbase_ids, mmu_seed2ids$seed_m8)
mmu_seed_n   <- setNames(mmu_seed2ids$n_mouse, mmu_seed2ids$seed_m8)

# Map selected DE miRNAs to human sequences/seeds (very relaxed)
lookup_human_mir <- function(mir_key, mir_key_nodot) {
  cand <- c(mir_key, mir_key_nodot)
  if (!has_arm(mir_key) && !has_arm(mir_key_nodot)) {
    cand <- c(cand,
              paste0(mir_key, "-5p"),
              paste0(mir_key, "-3p"),
              paste0(mir_key_nodot, "-5p"),
              paste0(mir_key_nodot, "-3p"))
  }
  cand <- unique(cand[!is.na(cand) & cand != ""])
  for (k in cand) {
    hit <- hsa_lookup %>% filter(mirbase_id_key == k)
    if (nrow(hit) == 0) hit <- hsa_lookup_nodot %>% filter(mirbase_id_key_nodot == k)
    if (nrow(hit) > 0) {
      hit <- hit[1, , drop = FALSE]
      return(list(
        matched = TRUE,
        mirbase_id = as.character(hit$mirbase_id),
        seed_m8 = as.character(hit$seed_m8),
        mature_sequence = as.character(hit$mature_sequence),
        mir_family = as.character(hit$mir_family),
        mirbase_accession = as.character(hit$mirbase_accession),
        family_conservation = as.character(hit$family_conservation)
      ))
    }
  }
  list(matched = FALSE,
       mirbase_id = NA_character_,
       seed_m8 = NA_character_,
       mature_sequence = NA_character_,
       mir_family = NA_character_,
       mirbase_accession = NA_character_,
       family_conservation = NA_character_)
}

mapped <- lapply(seq_len(nrow(sel)), function(i) {
  lookup_human_mir(sel$miRNA_key[i], sel$miRNA_key_nodot[i])
})


get_chr1 <- function(x, field) {
  v <- x[[field]]
  if (is.null(v) || length(v) == 0) return(NA_character_)
  v <- as.character(v[1])
  if (is.na(v) || v == "") return(NA_character_)
  v
}

sel$mirbase_id_matched <- vapply(mapped, get_chr1, character(1), field = "mirbase_id")
sel$seed_m8 <- vapply(mapped, get_chr1, character(1), field = "seed_m8")
sel$mature_sequence <- vapply(mapped, get_chr1, character(1), field = "mature_sequence")
sel$mir_family <- vapply(mapped, get_chr1, character(1), field = "mir_family")
sel$mirbase_accession <- vapply(mapped, get_chr1, character(1), field = "mirbase_accession")
sel$family_conservation <- vapply(mapped, get_chr1, character(1), field = "family_conservation")
sel$mirbase_match_ok <- !is.na(sel$mirbase_id_matched) & sel$mirbase_id_matched != ""

# Seed conservation (human seed -> mouse miRNAs with same seed)
sel$mouse_miRNAs_same_seed <- vapply(sel$seed_m8, function(seed) {
  if (!is.na(seed) && seed != "" && seed %in% names(mmu_seed_map)) mmu_seed_map[[seed]] else ""
}, character(1))

sel$n_mouse_miRNAs_same_seed <- vapply(sel$seed_m8, function(seed) {
  if (!is.na(seed) && seed != "" && seed %in% names(mmu_seed_n)) as.integer(mmu_seed_n[[seed]]) else 0L
}, integer(1))

sel$has_mouse_seed_match <- sel$n_mouse_miRNAs_same_seed > 0

# Export mapping diagnostics
map_out <- file.path(tab_dir, "human_miRNA_familytable_mapping.tsv")
write_tsv(
  sel %>%
    dplyr::select(miRNA_label, miRNA_key, miRNA_key_nodot, mirbase_id_matched, seed_m8, mature_sequence, mir_family, mirbase_accession, family_conservation),
  map_out
)
message("Wrote: ", map_out)

not_found_out <- file.path(tab_dir, "human_miRNAs_not_found_in_familytable.txt")
writeLines(sel$miRNA_label[!sel$mirbase_match_ok], con = not_found_out)
message("Wrote: ", not_found_out)

# ---------------------------
# Mouse targeting: seedscan in 3'UTRs
# ---------------------------
load_mouse_utrs <- function(mouse_genome, mouse_utr_fasta = NULL, mouse_tx2gene = NULL) {
  if (!is.null(mouse_utr_fasta)) {
    if (!file.exists(mouse_utr_fasta)) stop("ERROR: --mouse_utr_fasta not found: ", mouse_utr_fasta)
    message("Loading mouse 3'UTRs from FASTA: ", mouse_utr_fasta)
    utr_seqs <- Biostrings::readDNAStringSet(mouse_utr_fasta)
    tx_ids <- names(utr_seqs)

    tx2gene <- NULL
    if (!is.null(mouse_tx2gene)) {
      if (!file.exists(mouse_tx2gene)) stop("ERROR: --mouse_tx2gene not found: ", mouse_tx2gene)
      message("Loading transcript->gene mapping: ", mouse_tx2gene)
      tx2gene_df <- tryCatch({
        readr::read_tsv(mouse_tx2gene, show_col_types = FALSE, progress = FALSE)
      }, error = function(e) {
        readr::read_csv(mouse_tx2gene, show_col_types = FALSE, progress = FALSE)
      })
      cn <- colnames(tx2gene_df)
      tx_col <- cn[which(tolower(cn) %in% c("tx","tx_name","transcript","transcript_id","transcriptid","txid","name"))[1]]
      gene_col <- cn[which(tolower(cn) %in% c("gene","gene_id","geneid","entrez","entrezid","entrez_id","symbol"))[1]]
      if (is.na(tx_col) || is.na(gene_col)) {
        stop("ERROR: --mouse_tx2gene must contain transcript and gene columns. Found: ", paste(cn, collapse = ", "))
      }
      tx2gene <- tx2gene_df %>%
        transmute(tx_name = as.character(.data[[tx_col]]),
                  gene_id = as.character(.data[[gene_col]])) %>%
        filter(!is.na(tx_name), tx_name != "") %>%
        distinct(tx_name, gene_id)
    } else {
      warning("No --mouse_tx2gene provided; treating FASTA names as gene IDs (NOT recommended).")
      tx2gene <- tibble(tx_name = tx_ids, gene_id = tx_ids)
    }
    return(list(utr_seqs = utr_seqs, tx2gene = tx2gene))
  }

  txdb_pkg <- if (mouse_genome == "mm10") "TxDb.Mmusculus.UCSC.mm10.knownGene" else if (mouse_genome == "mm39") "TxDb.Mmusculus.UCSC.mm39.knownGene" else NA
  if (is.na(txdb_pkg)) stop("ERROR: Unsupported --mouse_genome: ", mouse_genome, " (use mm10 or mm39)")

  bs_pkg <- if (mouse_genome == "mm10") "BSgenome.Mmusculus.UCSC.mm10" else "BSgenome.Mmusculus.UCSC.mm39"

  if (!requireNamespace(txdb_pkg, quietly = TRUE)) {
    stop("ERROR: Missing package ", txdb_pkg, ". Install it OR provide --mouse_utr_fasta + --mouse_tx2gene.")
  }
  if (!requireNamespace("BSgenome", quietly = TRUE)) {
    stop("ERROR: Missing package BSgenome. Install BSgenome + ", bs_pkg, " OR provide --mouse_utr_fasta.")
  }
  if (!requireNamespace(bs_pkg, quietly = TRUE)) {
    stop("ERROR: Missing package ", bs_pkg, ". Install it OR provide --mouse_utr_fasta.")
  }

  suppressPackageStartupMessages(library(txdb_pkg, character.only = TRUE))
  suppressPackageStartupMessages(library(bs_pkg, character.only = TRUE))

  txdb <- getExportedValue(txdb_pkg, txdb_pkg)
  genome <- get(bs_pkg)

  message("Extracting mouse 3'UTRs from ", txdb_pkg, " + ", bs_pkg, " ...")
  utr3 <- GenomicFeatures::threeUTRsByTranscript(txdb, use.names = TRUE)
  utr_seqs <- GenomicFeatures::extractTranscriptSeqs(genome, utr3)

  tx_gr <- GenomicFeatures::transcripts(txdb, columns = c("tx_name", "gene_id"))
  tx2gene <- tibble(tx_name = as.character(tx_gr$tx_name),
                    gene_id = as.character(tx_gr$gene_id)) %>%
    filter(!is.na(tx_name), tx_name != "", !is.na(gene_id), gene_id != "") %>%
    distinct(tx_name, gene_id)

  keep_tx <- intersect(names(utr_seqs), tx2gene$tx_name)
  utr_seqs <- utr_seqs[keep_tx]
  tx2gene <- tx2gene %>% filter(tx_name %in% keep_tx)

  message("Mouse UTRs loaded: transcripts=", length(utr_seqs), " mapped_tx2gene=", nrow(tx2gene))
  list(utr_seqs = utr_seqs, tx2gene = tx2gene)
}

seedscan_targets_for_seed <- function(seed_m8, utr_seqs, site_types_vec, min_sites) {
  if (is.na(seed_m8) || seed_m8 == "") return(character(0))
  pats <- seed_to_site_patterns(seed_m8)

  counts <- list()
  if ("7mer-m8" %in% site_types_vec || "8mer" %in% site_types_vec) {
    counts[["7mer-m8"]] <- Biostrings::vcountPattern(Biostrings::DNAString(pats[["7mer-m8"]]), utr_seqs, fixed = TRUE)
  }
  if ("8mer" %in% site_types_vec) {
    counts[["8mer"]] <- Biostrings::vcountPattern(Biostrings::DNAString(pats[["8mer"]]), utr_seqs, fixed = TRUE)
  }
  if ("7mer-A1" %in% site_types_vec) {
    counts[["7mer-A1"]] <- Biostrings::vcountPattern(Biostrings::DNAString(pats[["7mer-A1"]]), utr_seqs, fixed = TRUE)
  }
  if ("6mer" %in% site_types_vec) {
    counts[["6mer"]] <- Biostrings::vcountPattern(Biostrings::DNAString(pats[["6mer"]]), utr_seqs, fixed = TRUE)
  }

  score <- integer(length(utr_seqs))
  if ("8mer" %in% site_types_vec) score <- score + counts[["8mer"]]
  if ("7mer-m8" %in% site_types_vec) {
    if ("8mer" %in% site_types_vec) {
      score <- score + pmax(counts[["7mer-m8"]] - counts[["8mer"]], 0L)
    } else {
      score <- score + counts[["7mer-m8"]]
    }
  }
  if ("7mer-A1" %in% site_types_vec) score <- score + counts[["7mer-A1"]]
  if ("6mer" %in% site_types_vec) score <- score + counts[["6mer"]]

  names(utr_seqs)[which(score >= min_sites)]
}

site_types_vec <- str_split(site_types, ",", simplify = TRUE) %>% as.character()
site_types_vec <- str_trim(site_types_vec)
site_types_vec <- site_types_vec[site_types_vec != ""]
site_types_vec <- unique(site_types_vec)

sel$mouse_targets_entrez <- ""
sel$n_mouse_targets <- 0L
all_entrez <- character(0)

if (target_mode == "seedscan_utr") {
  usable <- sel %>% filter(!is.na(seed_m8), seed_m8 != "")
  if (nrow(usable) == 0) {
    warning("No selected miRNAs had Seed+m8 available from the miR family table. Skipping targeting/enrichment.")
  } else {
    utr_obj <- load_mouse_utrs(mouse_genome, mouse_utr_fasta, mouse_tx2gene)
    utr_seqs <- utr_obj$utr_seqs
    tx2gene <- utr_obj$tx2gene

    seeds <- unique(usable$seed_m8)
    message("Seed-scan targeting: unique seeds=", length(seeds),
            " site_types=", paste(site_types_vec, collapse = ","),
            " min_sites=", min_sites)

    seed2genes <- list()
    dbg_out <- file.path(tab_dir, "seedscan_site_counts_by_transcript_example.tsv")

    for (si in seq_along(seeds)) {
      seed <- seeds[si]
      tx_hits <- seedscan_targets_for_seed(seed, utr_seqs, site_types_vec, min_sites)
      if (length(tx_hits) == 0) {
        seed2genes[[seed]] <- character(0)
      } else {
        gene_ids <- tx2gene %>% filter(tx_name %in% tx_hits) %>% pull(gene_id) %>% unique()
        seed2genes[[seed]] <- as.character(gene_ids)
      }

      if (si == 1) {
        pats <- seed_to_site_patterns(seed)
        c7 <- Biostrings::vcountPattern(Biostrings::DNAString(pats[["7mer-m8"]]), utr_seqs, fixed = TRUE)
        c8 <- Biostrings::vcountPattern(Biostrings::DNAString(pats[["8mer"]]), utr_seqs, fixed = TRUE)
        dbg <- tibble(
          transcript = names(utr_seqs),
          count_7mer_m8 = as.integer(c7),
          count_8mer = as.integer(c8)
        ) %>%
          arrange(desc(count_8mer), desc(count_7mer_m8)) %>%
          slice_head(n = 200)
        write_tsv(dbg, dbg_out)
        message("Wrote: ", dbg_out)
      }
    }

    sel$mouse_targets_entrez <- vapply(sel$seed_m8, function(seed) {
      if (is.na(seed) || seed == "" || !(seed %in% names(seed2genes))) return("")
      genes <- seed2genes[[seed]]
      if (length(genes) == 0) return("")
      paste(unique(genes), collapse = ";")
    }, character(1))

    sel$n_mouse_targets <- vapply(sel$mouse_targets_entrez, function(x) {
      if (is.na(x) || x == "") return(0L)
      length(unique(strsplit(x, ";", fixed = TRUE)[[1]]))
    }, integer(1))

    all_entrez <- unique(unlist(strsplit(sel$mouse_targets_entrez[sel$mouse_targets_entrez != ""], ";", fixed = TRUE)))
    all_entrez <- unique(all_entrez[!is.na(all_entrez) & all_entrez != ""])
  }

} else if (target_mode == "none") {
  message("Target mode = none. Skipping targeting/enrichment.")
} else {
  stop("ERROR: Unknown --target_mode: ", target_mode, " (use seedscan_utr|none)")
}

# Convert Entrez -> Symbol
sym_map <- tibble(ENTREZID = character(0), SYMBOL = character(0))
mouse_symbols <- character(0)

if (length(all_entrez) > 0) {
  sym_map <- AnnotationDbi::select(org.Mm.eg.db, keys = all_entrez,
                                  columns = c("SYMBOL"), keytype = "ENTREZID") %>%
    as_tibble() %>%
    filter(!is.na(SYMBOL)) %>%
    distinct(ENTREZID, SYMBOL)
  mouse_symbols <- unique(sym_map$SYMBOL)
}

# ---------------------------
# Write summary tables
# ---------------------------
seed_out <- file.path(tab_dir, "seed_conservation_and_mouse_targets.csv")
sel_out <- sel %>%
  dplyr::select(
    miRNA_label, logFC, FDR,
    mirbase_id_matched, mir_family, mirbase_accession, family_conservation,
    seed_m8, mature_sequence,
    has_mouse_seed_match, n_mouse_miRNAs_same_seed, mouse_miRNAs_same_seed,
    n_mouse_targets, mouse_targets_entrez
  ) %>%
  arrange(FDR, desc(logFC))

write.csv(sel_out, seed_out, row.names = FALSE)
message("Wrote: ", seed_out)

targets_out <- file.path(tab_dir, "mouse_target_genes_used_for_enrichment.csv")
write.csv(sym_map %>% arrange(SYMBOL), targets_out, row.names = FALSE)
message("Wrote: ", targets_out)

# ---------------------------
# Enrichment
# ---------------------------
enrich_full_out <- file.path(tab_dir, "mouse_enrichment_GO_BP_REAC_full.csv")
enrich_immune_out <- file.path(tab_dir, "mouse_enrichment_GO_BP_REAC_immune_only.csv")

enrich_df <- data.frame()
enrich_immune <- data.frame()

if (target_mode != "none") {
  if (length(mouse_symbols) < 10) {
    warning("Too few mouse targets for enrichment (n=", length(mouse_symbols), "). Skipping gost().")
  } else {
    gp <- gost(
      query = mouse_symbols,
      organism = "mmusculus",
      correction_method = "fdr",
      sources = c("GO:BP", "REAC")
    )

    if (!is.null(gp) && "result" %in% names(gp) && !is.null(gp$result)) {
      enrich_df <- as.data.frame(gp$result)
      enrich_df <- flatten_list_cols(enrich_df)
      write.csv(enrich_df, enrich_full_out, row.names = FALSE)
      message("Wrote: ", enrich_full_out)

      enrich_immune <- enrich_df %>%
        filter(str_detect(term_name, regex(immune_regex, ignore_case = TRUE))) %>%
        arrange(p_value, desc(intersection_size))
      write.csv(enrich_immune, enrich_immune_out, row.names = FALSE)
      message("Wrote: ", enrich_immune_out)
    } else {
      warning("gost() returned no results.")
    }
  }
}

# ---------------------------
# Figure PDF
# ---------------------------
pdf_out <- file.path(fig_dir, "cross_species_seed_conservation_and_mouse_enrichment.pdf")

p1 <- sel %>%
  mutate(miRNA_label = factor(miRNA_label, levels = miRNA_label[order(n_mouse_miRNAs_same_seed, decreasing = TRUE)])) %>%
  ggplot(aes(x = miRNA_label, y = n_mouse_miRNAs_same_seed, fill = has_mouse_seed_match)) +
  geom_col(width = 0.85, color = "black", linewidth = 0.2) +
  coord_flip() +
  scale_fill_manual(values = c("TRUE" = "#2E7D32", "FALSE" = "grey70"),
                    name = "Seed conserved\n(human->mouse)") +
  labs(
    title = "Seed (m8) conservation: human DE miRNAs vs mouse miRNAs",
    subtitle = paste0("Selection: ", direction, " (FDR≤", fdr_cut, ", |logFC|≥", logfc_cut, ")"),
    x = NULL,
    y = "# mouse miRNAs with identical seed"
  ) +
  theme_bw(base_size = 12) +
  theme(
    legend.position = "right",
    plot.title = element_text(face = "bold")
  )

p2_data <- NULL
if (is.data.frame(enrich_immune) && nrow(enrich_immune) > 0) {
  p2_data <- enrich_immune
} else if (is.data.frame(enrich_df) && nrow(enrich_df) > 0) {
  p2_data <- enrich_df
}

if (!is.null(p2_data)) {
  top_n <- min(20, nrow(p2_data))
  p2_top <- p2_data %>%
    slice_head(n = top_n) %>%
    mutate(minus_log10_fdr = -log10(p_value),
           term_label = str_wrap(term_name, width = 45),
           term_label = factor(term_label, levels = rev(unique(term_label))))

  p2 <- ggplot(p2_top, aes(x = minus_log10_fdr, y = term_label, color = source, size = intersection_size)) +
    geom_point(alpha = 0.9) +
    labs(
      title = "Mouse functional enrichment (seed-scan targets in mouse 3'UTRs)",
      subtitle = "Top terms (immune-only if available); g:Profiler organism=mmusculus; sources=GO:BP,REAC",
      x = "-log10(FDR)",
      y = NULL,
      color = "Source",
      size = "Intersection\nsize"
    ) +
    theme_bw(base_size = 12) +
    theme(
      plot.title = element_text(face = "bold"),
      legend.position = "right"
    )
} else {
  p2 <- ggplot() + theme_void() + labs(title = "No enrichment results to plot.")
}

pdf(pdf_out, width = 11, height = 8.5, onefile = TRUE)
print(p1)
print(p2)
dev.off()
message("Wrote: ", pdf_out)

# ---------------------------
# Summary TXT
# ---------------------------
summary_out <- file.path(outdir, "selection_and_cross_species_summary.txt")

total_selected <- nrow(sel)
mapped_n <- sum(sel$mirbase_match_ok, na.rm = TRUE)
seed_known <- sum(!is.na(sel$seed_m8) & sel$seed_m8 != "", na.rm = TRUE)
seed_conserved <- sum(sel$has_mouse_seed_match, na.rm = TRUE)
usable_targets <- sum(sel$n_mouse_targets > 0, na.rm = TRUE)
target_gene_n <- length(mouse_symbols)

cat(
  "Cross-species validation summary\n",
  "================================\n\n",
  "DE table: ", de_table, "\n",
  "Selection: direction=", direction, ", FDR≤", fdr_cut, ", |logFC|≥", logfc_cut, "\n\n",
  "miR family table: ", mir_family_table, "\n",
  "Human species_id=", species_human, " Mouse species_id=", species_mouse, "\n\n",
  "Targeting mode: ", target_mode, "\n",
  "Mouse genome: ", mouse_genome, "\n",
  "Site types: ", paste(site_types_vec, collapse = ","), " ; min_sites=", min_sites, "\n\n",
  "Counts:\n",
  "  Total selected miRNAs: ", total_selected, "\n",
  "  Selected miRNAs mapped to family table (human): ", mapped_n, "\n",
  "  Selected miRNAs with Seed+m8 available: ", seed_known, "\n",
  "  Selected miRNAs with ≥1 mouse miRNA sharing the same seed: ", seed_conserved, "\n",
  "  Selected miRNAs with ≥1 mouse target gene (seedscan): ", usable_targets, "\n",
  "  Union of mouse target genes (SYMBOL) used for enrichment: ", target_gene_n, "\n\n",
  "Outputs:\n",
  "  - Mapping TSV: ", map_out, "\n",
  "  - Missing miRNAs TXT: ", not_found_out, "\n",
  "  - CSV seed/targets: ", seed_out, "\n",
  "  - CSV target genes: ", targets_out, "\n",
  "  - CSV enrichment full: ", enrich_full_out, "\n",
  "  - CSV enrichment immune-only: ", enrich_immune_out, "\n",
  "  - PDF figure: ", pdf_out, "\n",
  sep = "",
  file = summary_out
)

message("Wrote: ", summary_out)
message("DONE.")
'''

#####
##### Testing final package
#####


cd /mnt/sdb/Colaboraciones/Roberto_Elizondo_2025/
conda activate RUVSeq_env

chmod +x mirna_pipeline_package_v6.zip

# 1) Create + activate R env
conda env create -f RUVSeq_env.yml
conda activate RUVSeq_env

# Python post-processing (recommended improved Sankey)
pip install pandas seaborn matplotlib plotly kaleido numpy

# 2) Run ALL samples (outputs under results/all_samples)
Rscript mirna_pipeline_package_v6/run_mirna_pipeline.R --input miRNA-counts.csv --outdir results/all_samples --color greenred --heatmap_cex_row 1.6 --heatmap_cex_col 2.2

# 3) Run WITHOUT sample2 (Ctrl_2 + Olig_2 removed)
Rscript mirna_pipeline_package_v6/run_mirna_pipeline.R --input miRNA-counts.csv --outdir results/without_sample2 \
  --drop_samples Ctrl_2,Olig_2 --color greenred --heatmap_cex_row 1.6 --heatmap_cex_col 2.2

# 4) Sankey Plot
python3 mirna_pipeline_package_v6/tf_immune_analysis.py \
  --input results/without_sample2/06_enrichment/GO_Terms \
  --outdir results/without_sample2/07_tf_immune \
  --sankey_top_tfs 10 --sankey_top_terms 8 --sankey_min_link 2 --sankey_palette Set3

# 5) Optional: run legacy/simple sankey
python3 mirna_pipeline_package_v6/tf_immune_heatmaps_ranked_table.py \
  --input results/without_sample2/06_enrichment/GO_Terms \
  --outdir results/without_sample2/07_tf_immune_legacy

# 6) Cross-correlation Human vs Mouse

# https://www.targetscan.org/cgi-bin/targetscan/data_download.vert80.cgi

Rscript mirna_pipeline_package_v6/cross_species_mouse_seed_enrichment.R \
  --results_dir results/without_sample2 \
  --design non_paired \
  --direction up \
  --fdr 0.05 \
  --logfc 1 \
  --outdir results/without_sample2/08_cross_species_mouse_seedscan \
  --target_mode seedscan_utr \
  --mir_family_table mirna_pipeline_package_v6/miR_Family_Info.txt.zip \
  --mouse_genome mm10 \
  --site_types 8mer,7mer-m8 \
  --min_sites 1

