
rm(list = ls())
gc()
# setwd(dir = "/Volumes/work/晚期胃癌免疫联合化疗疗效标志物探寻-olink")

# 加载依赖包
packages <- c("data.table", "dplyr", "tidyr", "stringr", "readxl", "writexl", 
              "ggplot2", "ggpubr", "pheatmap", "survminer", "survival", 
              "glmnet", "pROC", "dcurves", "reshape2", "patchwork", "tableone", "UpSetR")
lapply(packages, require, character.only = TRUE)
suppressMessages(library(clusterProfiler))

# 加载自定义函数库
source("~/GemhR.R")

get_custom_surv_report <- function(data, time_col, status_col, group_col, control.chr) {
  if (!control.chr %in% unique(data[[group_col]])) stop("对照组不存在")
  data[[group_col]] <- relevel(as.factor(data[[group_col]]), ref = control.chr)
  f_surv <- as.formula(paste0("Surv(", time_col, ", ", status_col, ") ~ ", group_col))
  
  # 中位生存时间
  km_tbl <- summary(survfit(f_surv, data = data))$table
  med_desc <- paste(sapply(1:nrow(km_tbl), function(i) {
    val <- km_tbl[i, "median"]
    paste0(rownames(km_tbl)[i], " mOS/PFS=", ifelse(is.na(val), "NR", round(val, 2)))
  }), collapse = "；")
  
  # Log-rank检验
  sdf <- survdiff(f_surv, data = data)
  p_val <- 1 - pchisq(sdf$chisq, length(sdf$n) - 1)
  
  # HR计算
  if (length(levels(data[[group_col]])) == 2) {
    O1 <- sdf$obs[1]; E1 <- sdf$exp[1]; O2 <- sdf$obs[2]; E2 <- sdf$exp[2]
    HR_val <- (O2 / E2) / (O1 / E1)
    HR_ci <- exp(log(HR_val) + c(-1, 1) * 1.96 * sqrt(1/E1 + 1/E2))
  } else {
    cox_sum <- summary(coxph(f_surv, data = data))
    HR_val <- cox_sum$conf.int[1,1]
    HR_ci <- cox_sum$conf.int[1, 3:4]
  }
  
  return(sprintf("%s。HR=%.2f (95%%CI: %.2f-%.2f), P=%.3f", med_desc, HR_val, HR_ci[1], HR_ci[2], p_val))
}


get_stats_report <- function(data, value_col = "NPX", group_cols = c("Stage", "Response"), measure = "SEM") {
  data %>% group_by(across(all_of(group_cols))) %>%
    summarise(Mean = mean(.data[[value_col]], na.rm=T), 
              SD = sd(.data[[value_col]], na.rm=T), N = n(), .groups = 'drop') %>%
    mutate(SEM = SD/sqrt(N), 
           Val = ifelse(measure=="SEM", paste0(round(Mean,2), "±", round(SEM,2)), paste0(round(Mean,2), "±", round(SD,2))))
}


pdata <- fread("data/clean_data/olink.V2/clinical_data.txt") %>% data.frame()
sample_data <- fread("data/clean_data/olink.V2/sample_data.txt") %>% data.frame()
expr_data <- fread("data/clean_data/olink.V2/expr_data_clean.txt") %>% data.frame()
compare_data <- fread("data/clean_data/olink.V2/compare_group.txt") %>% data.frame()
group_data <- fread("data/clean_data/olink.V2/compare_sample.txt") %>% data.frame()

# 基因名称校正
Gene_change <- data.frame(
  raw = c("HO-1", "IFN-gamma", "PD-L2", "IL12", "Gal-9", "MUC-16", "PD-L1", "CASP-8", "MIC-A/B", "CSF-1", "Gal-1", "CAIX", "TIE2", "TRAIL", "VEGFR-2", "MCP-1", "IL-1 alpha", "MCP-3", "LAP TGF-beta-1", "TWEAK", "MCP-2", "PDGF subunit B", "IL8", "MCP-4", "CD40-L"),
  new = c("HMOX1", "IFNG", "PDCD1LG2", "IL12B", "LGALS9", "MUC16", "CD274", "CASP8", "MICA", "CSF1", "LGALS1", "CA9", "TEK", "TNFSF10", "KDR", "CCL2", "IL1A", "CCL7", "TGFB1", "TNFSF12", "CCL8", "PDGFB", "CXCL8", "CCL13", "CD40LG")
)

# 阶段信息提取
sample_data <- mutate(sample_data, Stage = str_split(SampleID, "_", simplify = T)[, 2]) %>%
  mutate(Stage = ifelse(Stage == "Q", "T1", paste0("T", Stage)))


run_diff_analysis <- function(method = "t.test", out_path) {
  dir.create(out_path, recursive = T)
  Compalre_Number <- data.frame()
  
  for (i in unique(compare_data$Compare.Group.Num)) {
    sub_cmp <- filter(compare_data, Compare.Group.Num == i)
    sub_grp <- filter(group_data, Compare.Group.Num == i)
    
    for (j in 1:nrow(sub_cmp)) {
      case <- sub_cmp$Case[j]; control <- sub_cmp$Control[j]
      # 样本筛选
      sams <- filter(sub_grp, Group %in% c(case, control))
      if(nrow(sams) == 0) next
      
      # 记录比较数量
      Compalre_Number <- rbind(data.frame(table(sams$Group)) %>% mutate(Code=i, Cycle=j), Compalre_Number)
      
      # 准备表达矩阵
      tmp_expr <- filter(expr_data, DataID %in% sub_grp$DataID) %>% 
        inner_join(dplyr::select(sams, SampleID=Analysis.Name, Group), by="SampleID") %>%
        filter(Group %in% c(case, control)) %>%
        mutate(Group = ifelse(Group == case, "Case", "Control")) %>%
        filter(!grepl("Ctrl", Gene))
      
      # 执行差异分析
      res <- GemhWilcoxDE(inputData = tmp_expr, method = method, group = c("Case", "Control"), dif_method = "diff")
      res <- mutate(res, p.adjust = p.adjust(pvalue, method = "BH"),
                    sig = case_when(pvalue < 0.05 & log2FC > 0 ~ "sig_up",
                                    pvalue < 0.05 & log2FC < 0 ~ "sig_down", T ~ "no_sig"))
      
      # 筛选 Top 10 + Target Genes
      top_genes <- if(any(res$sig != "no_sig")) {
        arrange(filter(res, sig != "no_sig"), sig, desc(abs(log2FC))) %>% group_by(sig) %>% slice_head(n=10) %>% pull(Gene)
      } else NULL
      label_genes <- unique(c(top_genes, intersect(c("MMP12","MUC-16","IL15"), res$Gene)))
      
      # 1. 火山图
      p_vol <- GemhPlot_volplot(res[,c("Gene","log2FC","pvalue")], p_value=0.05, FoldChange=1, LableGene=label_genes) +
        labs(title = paste0(sub_cmp$Compare.Group.Name[j], ": ", case, " vs ", control))
      ggsave(paste0(out_path, "/", i, "_", case, "_vs_", control, "_Volcano.pdf"), p_vol, width=4, height=4)
      
      # 2. 箱线图 (针对显著基因)
      if(length(label_genes) > 0) {
        sig_data <- filter(tmp_expr, Gene %in% label_genes) %>% left_join(res[,c("Gene","pvalue")]) %>%
          mutate(sig_label = case_when(pvalue<=0.01 ~ "**", pvalue<=0.05 ~ "*", T ~ paste0("p=",round(pvalue,3))))
        
        p_box <- ggplot(sig_data, aes(x=Gene, y=NPX)) +
          geom_boxplot(aes(fill=Group)) +
          geom_text(aes(x=Gene, y=max(NPX)*1.1, label=sig_label)) +
          scale_fill_manual(values = c("Control"="#377EB8", "Case"="#E41A1C")) +
          theme_classic() + theme(axis.text.x = element_text(angle=90, hjust=1))
        ggsave(paste0(out_path, "/", i, "_", case, "_vs_", control, "_Boxplot.pdf"), p_box, width=min(15, 3+length(label_genes)*0.3), height=4)
      }
      
      # 3. 结果保存 (仅T检验时保存Excel，便于后续统计)
      if(method == "t.test") writexl::write_xlsx(res, paste0(out_path, "/", i, "_", case, "_vs_", control, "_Result.xlsx"))
    }
  }
}

# 执行两次差异分析
run_diff_analysis("t.test", "res/2.Compare_Analysis-T.test")
run_diff_analysis("wilcox.test", "res/2.Compare_Analysis-Wilcox")

target_cmp <- filter(compare_data, Compare.Group.Num == 4) # 假设第四组是响应分组
if(nrow(target_cmp) > 0) {
  case <- target_cmp$Case[1]; ctrl <- target_cmp$Control[1]
  sams <- filter(group_data, Compare.Group.Num == 4 & Group %in% c(case, ctrl))
  
  roc_data <- filter(expr_data, Gene == "IL15" & SampleID %in% sams$Analysis.Name) %>%
    inner_join(dplyr::select(sams, SampleID=Analysis.Name, Group))
  
  roc_obj <- roc(roc_data$Group, roc_data$NPX, levels=c(ctrl, case))
  pdf("res/IL15_ROC_Curve.pdf", width=5, height=5)
  plot(roc_obj, main="IL15 ROC", col="blue", print.auc=TRUE)
  dev.off()
}

tryCatch({
  swim_df <- read_excel("doc/吴昊主任晚期胃癌一线免化项目泳道图需求/泳道图所需数据.xlsx", sheet = 2) %>%
    filter(!is.na(pid)) %>% dplyr::select(1:8)
  colnames(swim_df)[2] <- "Time"
  swim_df <- arrange(swim_df, Time) %>% mutate(pid = factor(pid, levels = pid))
  
  # 转换为长格式用于画点
  swim_long <- pivot_longer(swim_df, cols=c(CR,PR,SD,PD,death,surgery), names_to="Type", values_to="PointTime") %>% 
    drop_na() %>% mutate(Type = factor(Type, levels=c("CR","PR","SD","PD","surgery","death")))
  
  cols <- c("CR"="red","PR"="red","SD"="darkblue","PD"="darkblue","death"="black","surgery"="darkgreen")
  shps <- c("CR"=2, "PR"=17, "SD"=1, "PD"=19, "death"=15, "surgery"=3)
  
  p_swim <- ggplot(swim_df) +
    geom_bar(aes(x=pid, y=Time), stat="identity", fill="grey", width=0.7) +
    geom_point(data=swim_long, aes(x=pid, y=PointTime, color=Type, shape=Type), size=2) +
    scale_color_manual(values=cols) + scale_shape_manual(values=shps) +
    coord_flip() + theme_classic() + labs(x="Patient ID", y="Time (weeks)")
  
  ggsave("res/Swimmer_Plot.pdf", p_swim, width=6, height=4.5)
}, error = function(e) message("泳道图数据未找到，跳过"))


# 通用 KM 作图函数
run_km_plot <- function(data, time_col, status_col, group_col, title_text, out_path) {
  # 确保只有非NA数据
  df <- data[!is.na(data[[group_col]]) & !is.na(data[[time_col]]) & !is.na(data[[status_col]]), ]
  
  if(length(unique(df[[group_col]])) < 2) return(NULL)
  
  # 拟合
  f <- as.formula(paste0("Surv(", time_col, ", ", status_col, ") ~ ", group_col))
  fit <- survfit(f, data = df)
  
  # 绘图
  p <- ggsurvplot(
    fit, data = df,
    pval = TRUE, conf.int = FALSE, risk.table = TRUE,
    palette = "npg", ggtheme = theme_bw(),
    title = title_text, xlab = "Time (Months)"
  )
  
  # 保存
  pdf(out_path, width = 5, height = 5, onefile = FALSE)
  print(p)
  dev.off()
  return(p)
}

# 通用单因素 Cox 函数
run_unicox <- function(data, time_col, status_col, var_cols) {
  res_list <- list()
  for(v in var_cols) {
    tryCatch({
      f <- as.formula(paste0("Surv(", time_col, ", ", status_col, ") ~ ", v))
      fit <- coxph(f, data = data)
      s <- summary(fit)
      res_list[[v]] <- data.frame(
        Variable = v,
        HR = s$coefficients[1, 2],
        Lower = s$conf.int[1, 3],
        Upper = s$conf.int[1, 4],
        P_val = s$coefficients[1, 5]
      )
    }, error = function(e) return(NULL))
  }
  return(do.call(rbind, res_list))
}

# 1. 临床特征与 PFS 的关系 -----------------------------------------------------
dir.create("res/3.PFS相关分析/1.临床指标关系", recursive = TRUE)

# 临床变量列表
clin_vars <- c("Age", "BMI", "ECOG", "M.Site.Num", "Sex", "Smoke.His", "Alcohol.His", 
               "Family.His", "Liver.M", "Lymph.node.M", "Peritoneal.M", "PD.L1")

# 1.1 KM 曲线 (针对离散变量或二值化连续变量)
for (var in clin_vars) {
  if(var %in% names(pdata)) {
    plot_data <- pdata
    # 如果是连续变量且唯一值多，按中位数分组
    if(is.numeric(plot_data[[var]]) && length(unique(plot_data[[var]])) > 5) {
      med <- median(plot_data[[var]], na.rm=TRUE)
      plot_data[[var]] <- ifelse(plot_data[[var]] > med, "High", "Low")
    }
    
    run_km_plot(plot_data, "pfs.time", "pfs", var, 
                title_text = paste0("PFS by ", var),
                out_path = paste0("res/3.PFS相关分析/1.临床指标关系/", var, "_KM.pdf"))
  }
}


cox_clin <- run_unicox(pdata, "pfs.time", "pfs", clin_vars)
writexl::write_xlsx(cox_clin, "res/3.PFS相关分析/1.临床指标关系/Clinical_Cox_Result.xlsx")

stages <- c("T1", "T2", "T3")

for (st in stages) {
  out_dir <- paste0("res/3.PFS相关分析/", st, "_表达与PFS关系")
  dir.create(out_dir, recursive = TRUE)
  
  # 提取该阶段数据
  st_data <- sample_data %>% filter(Stage == st) %>%
    inner_join(expr_data, by = c("PID", "Pname", "SampleID", "DataID")) %>%
    inner_join(dplyr::select(pdata, PID, pfs.time, pfs), by = "PID")
  
  # 2.1 批量 Cox 筛选 (所有蛋白)
  all_genes <- unique(st_data$Gene)
  # 将数据转宽以加速 Cox
  expr_wide <- st_data %>% dplyr::select(PID, pfs.time, pfs, Gene, NPX) %>%
    pivot_wider(names_from = Gene, values_from = NPX)
  
  # 运行 Cox
  cox_res <- run_unicox(expr_wide, "pfs.time", "pfs", all_genes)
  
  if(!is.null(cox_res)) {
    writexl::write_xlsx(cox_res, paste0(out_dir, "/All_Proteins_Cox.xlsx"))
    
    # 2.2 对显著基因 (P<0.05) 或 目标基因 (IL15, MUC16, MMP12) 绘制 KM
    sig_genes <- cox_res %>% filter(P_val < 0.05) %>% pull(Variable)
    plot_genes <- unique(c(sig_genes, intersect(all_genes, c("IL15", "MUC16", "MMP12", "MUC-16"))))
    
    for (g in plot_genes) {
      sub_dat <- st_data %>% filter(Gene == g) %>%
        mutate(Group = ifelse(NPX > median(NPX, na.rm=T), "High", "Low"))
      
      run_km_plot(sub_dat, "pfs.time", "pfs", "Group",
                  title_text = paste0(g, " (", st, ") PFS"),
                  out_path = paste0(out_dir, "/", gsub("[/]", "_", g), "_KM.pdf"))
    }
  }
}


comparisons <- list(c("T1", "T2"), c("T1", "T3"))

for (comp in comparisons) {
  t_pre <- comp[1]; t_post <- comp[2]
  dir_name <- paste0(t_post, "-", t_pre, "_变化与PFS关系")
  out_path <- paste0("res/3.PFS相关分析/", dir_name)
  dir.create(out_path, recursive = TRUE)
  
  # 筛选配对样本
  pair_pids <- sample_data %>% filter(Stage %in% comp) %>%
    group_by(PID, Gene) %>% summarise(n = n_distinct(Stage), .groups="drop") %>%
    filter(n == 2) %>% pull(PID) %>% unique()
  
  delta_data <- sample_data %>% filter(PID %in% pair_pids, Stage %in% comp) %>%
    inner_join(expr_data, by = c("PID", "Pname", "SampleID", "DataID")) %>%
    dplyr::select(PID, Stage, Gene, NPX) %>%
    pivot_wider(names_from = Stage, values_from = NPX) 
  
  # 计算差值 (Dif)
  delta_data$Dif <- delta_data[[t_post]] - delta_data[[t_pre]]
  
  # 合并生存数据
  surv_delta <- delta_data %>% inner_join(dplyr::select(pdata, PID, pfs.time, pfs), by="PID")
  
  # 3.1 批量 Cox (基于差值)
  cox_delta <- run_unicox(surv_delta, "pfs.time", "pfs", "Dif") # 这里需循环每个基因
  
  # 这种宽表结构不适合直接用 run_unicox 跑所有基因，需转置或循环
  res_list <- list()
  for(g in unique(surv_delta$Gene)) {
    sub_d <- surv_delta %>% filter(Gene == g)
    res <- run_unicox(sub_d, "pfs.time", "pfs", "Dif")
    if(!is.null(res)) { res$Gene <- g; res_list[[g]] <- res }
  }
  final_delta_cox <- do.call(rbind, res_list)
  writexl::write_xlsx(final_delta_cox, paste0(out_path, "/Delta_Cox_Result.xlsx"))
  
  # 3.2 绘制显著基因 KM (按差值中位数分组)
  sig_genes <- final_delta_cox %>% filter(P_val < 0.05) %>% pull(Gene)
  target_list <- c("IL15", "MUC16", "MMP12", "MUC-16")
  plot_genes <- unique(c(sig_genes, intersect(unique(surv_delta$Gene), target_list)))
  
  for (g in plot_genes) {
    sub_dat <- surv_delta %>% filter(Gene == g) %>%
      mutate(Group = ifelse(Dif > median(Dif, na.rm=T), "High Delta", "Low Delta"))
    
    run_km_plot(sub_dat, "pfs.time", "pfs", "Group",
                title_text = paste0(g, " Delta (", t_post, "-", t_pre, ")"),
                out_path = paste0(out_path, "/", gsub("[/]", "_", g), "_KM.pdf"))
  }
}

# 重点关注：MUC-16, MMP12
explore_genes <- c("MUC-16", "MMP12", "MUC16", "IL15")
out_cutoff <- "res/4.探索变化的最佳Cutoff"
dir.create(out_cutoff, recursive = TRUE)

# 准备 T2-T1 变化率数据
t1_t2_data <- sample_data %>% filter(Stage %in% c("T1", "T2")) %>%
  inner_join(expr_data, by = c("PID", "Pname", "SampleID", "DataID")) %>%
  filter(Gene %in% explore_genes) %>%
  dplyr::select(PID, Stage, Gene, NPX) %>%
  pivot_wider(names_from = Stage, values_from = NPX) %>%
  filter(!is.na(T1) & !is.na(T2)) %>%
  mutate(Pct_Change = (T2 - T1) / abs(T1) * 100) %>% # 变化百分比
  inner_join(dplyr::select(pdata, PID, pfs.time, pfs), by="PID")

cutoff_res <- list()

for (g in unique(t1_t2_data$Gene)) {
  sub_dat <- t1_t2_data %>% filter(Gene == g)
  
  # 寻找最佳截断点
  res.cut <- surv_cutpoint(sub_dat, time = "pfs.time", event = "pfs", variables = "Pct_Change", minprop = 0.3)
  best_cut <- res.cut$cutpoint$cutpoint
  
  # 分组并绘图
  sub_dat$Group <- ifelse(sub_dat$Pct_Change > best_cut, "High Change", "Low Change")
  
  p <- run_km_plot(sub_dat, "pfs.time", "pfs", "Group",
                   title_text = paste0(g, " Optimal Cutoff: ", round(best_cut, 2), "%"),
                   out_path = paste0(out_cutoff, "/", g, "_Optimal_KM.pdf"))
  
  cutoff_res[[g]] <- data.frame(Gene = g, Best_Cutoff_Pct = best_cut)
}

writexl::write_xlsx(do.call(rbind, cutoff_res), paste0(out_cutoff, "/Cutoff_Summary.xlsx"))

# 重点关注：MUC16, MMP12 按 Response 分组
dir.create("res/5.趋势图", recursive = TRUE)

plot_trend <- function(gene_name) {
  # 准备数据
  plot_dat <- expr_data %>% filter(Gene == gene_name) %>%
    inner_join(sample_data, by = c("PID", "Pname", "SampleID", "DataID")) %>%
    left_join(dplyr::select(pdata, PID, Response), by="PID") %>%
    filter(!is.na(Response)) %>%
    mutate(Stage = factor(Stage, levels = c("T1", "T2", "T3", "T4")))
  
  # 汇总均值用于画粗线
  mean_dat <- plot_dat %>% group_by(Stage, Response) %>%
    summarise(Mean_NPX = mean(NPX, na.rm=T), SD = sd(NPX, na.rm=T), .groups="drop")
  
  # 绘图
  p <- ggplot() +
    # 个体轨迹 (细线)
    geom_line(data = plot_dat, aes(x = Stage, y = NPX, group = PID, color = Response), alpha = 0.3, size = 0.5) +
    # 均值轨迹 (粗线)
    geom_line(data = mean_dat, aes(x = Stage, y = Mean_NPX, group = Response, color = Response), size = 1.5) +
    geom_point(data = mean_dat, aes(x = Stage, y = Mean_NPX, color = Response), size = 3) +
    # 误差带
    geom_ribbon(data = mean_dat, aes(x = Stage, ymin = Mean_NPX - SD, ymax = Mean_NPX + SD, fill = Response, group = Response), alpha = 0.1) +
    theme_bw() +
    scale_color_manual(values = c("PR/CR"="#E41A1C", "PD/SD"="#377EB8", "Responder"="#E41A1C", "Non-Responder"="#377EB8")) +
    scale_fill_manual(values = c("PR/CR"="#E41A1C", "PD/SD"="#377EB8", "Responder"="#E41A1C", "Non-Responder"="#377EB8")) +
    labs(title = paste0(gene_name, " Dynamic Trend by Response"), y = "NPX Value")
  
  ggsave(paste0("res/5.趋势图/", gene_name, "_Trend.pdf"), p, width = 5, height = 4)
}

# 绘制目标基因
for(g in explore_genes) {
  tryCatch(plot_trend(g), error = function(e) message(paste("Skip trend plot for", g)))
}


if(!exists("expr_data")) stop("请先运行 Part 1 代码加载数据！")

# 创建输出目录
dir_rev <- "res/审稿返修_补充分析"
dir.create(dir_rev, recursive = TRUE)

# 目标基因列表
target_genes <- c("IL15", "MUC-16", "MMP12", "IL2") 
# 注意：MUC-16 在部分数据中可能被称为 MUC16，需注意兼容

# 提取基线期 (T1/Q) 表达数据
base_expr <- expr_data %>% 
  filter(grepl("_1|_Q", SampleID)) %>%
  filter(Gene %in% target_genes) %>%
  dplyr::select(PID, Gene, NPX) %>%
  pivot_wider(names_from = Gene, values_from = NPX)

# 合并临床数据
merged_df <- pdata %>%
  dplyr::select(PID, Sex, Age, BMI, ECOG, Smoke.His, Alcohol.His, 
                Family.His, Liver.M, Lymph.node.M, Peritoneal.M, PD.L1, pfs.time, pfs) %>%
  inner_join(base_expr, by = "PID")

clinical_vars <- c("Sex", "Age", "BMI", "ECOG", "PD.L1", "Smoke.His", 
                   "Alcohol.His", "Family.His", "Liver.M", "Lymph.node.M", "Peritoneal.M")

result_table <- data.frame()

for (var in clinical_vars) {
  if(!var %in% names(merged_df)) next
  
  vals <- merged_df[[var]]
  # 判断是否为连续变量 (数值型且取值>5)
  is_cont <- is.numeric(vals) && length(unique(na.omit(vals))) > 5
  
  # A. 描述统计
  if (is_cont) {
    stats_str <- sprintf("%.2f (%.2f - %.2f)", median(vals, na.rm=T), 
                         quantile(vals, 0.25, na.rm=T), quantile(vals, 0.75, na.rm=T))
    row_base <- data.frame(Variable = var, Group = "Median (IQR)", Stats = stats_str)
  } else {
    tbl <- table(vals)
    props <- prop.table(tbl) * 100
    row_base <- data.frame(Variable = rep(var, length(tbl)), 
                           Group = names(tbl), 
                           Stats = sprintf("%d (%.1f%%)", tbl, props))
  }
  
  # B. 计算与目标基因的相关性 P 值
  for (gene in intersect(names(merged_df), target_genes)) {
    gene_vec <- merged_df[[gene]]
    p_val <- NA
    
    if (is_cont) {
      # Spearman 相关
      p_val <- cor.test(vals, gene_vec, method = "spearman")$p.value
    } else {
      # 差异检验
      groups <- unique(na.omit(vals))
      if (length(groups) == 2) {
        p_val <- wilcox.test(gene_vec ~ vals, data = merged_df)$p.value
      } else if (length(groups) > 2) {
        p_val <- kruskal.test(gene_vec ~ vals, data = merged_df)$p.value
      }
    }
    
    # 将 P 值填入 (如果是分类变量，重复填入每行)
    row_base[[paste0("P_", gene)]] <- p_val
  }
  
  result_table <- bind_rows(result_table, row_base)
}

# 格式化 P 值并保存
format_p <- function(x) ifelse(x < 0.001, "<0.001", sprintf("%.3f", x))
p_cols <- grep("P_", names(result_table))
result_table[p_cols] <- lapply(result_table[p_cols], format_p)

writexl::write_xlsx(result_table, paste0(dir_rev, "/Clinical_Association_Table.xlsx"))


inf_file <- "CA125 炎症感染指标.xlsx"


if(file.exists(inf_file)) {
  # 2.1 读取并清洗数据
  inf_raw <- read_excel(inf_file, sheet = 1) %>% data.frame(stringsAsFactors = F)
  inf_raw <- inf_raw[, 1:17] # 取前17列
  
  # 转数值
  for (i in 2:ncol(inf_raw)) inf_raw[,i] <- as.numeric(inf_raw[,i])
  
  # 转长格式并解析 Stage
  inf_long <- pivot_longer(inf_raw, -c(`序号`)) %>%
    mutate(Stage = str_extract(name, "T\\d+"),
           var = str_split(name, "[.]", simplify = T)[,1]) %>%
    filter(!is.na(value)) %>%
    dplyr::rename(ID = `序号`)
  
  # 匹配 SampleID
  xuyao_id <- read_excel("data/【钉钉原始文件】2024.10.28-生信分析.xlsx", sheet = 1)
  id_map <- xuyao_id %>% 
    dplyr::select(ID = `序号`, Sample.ID) %>%
    filter(!is.na(ID)) %>%
    mutate(Pname = str_split(Sample.ID, "_", simplify = T)[,1], ID = as.character(ID))
  
  inf_data <- inf_long %>%
    mutate(ID = as.character(ID)) %>%
    left_join(dplyr::select(id_map, ID, Pname), by = "ID") %>%
    mutate(SampleID = case_when(
      Pname == "QZA" & Stage == "T0" ~ "QZA_Q",
      T ~ paste0(Pname, "_", as.numeric(str_extract(Stage, "\\d+")) + 1) # T0->_1, T1->_2 logic
    )) %>%
    filter(SampleID %in% expr_data$SampleID) %>%
    dplyr::select(SampleID, var, value)
  
  # 2.2 关联分析 (Spearman Correlation Heatmap)
  # 准备数据：行=样本，列=变量
  inf_wide <- inf_data %>% pivot_wider(names_from = var, values_from = value)
  olink_wide <- expr_data %>% 
    filter(Gene %in% target_genes) %>%
    dplyr::select(SampleID, Gene, NPX) %>%
    pivot_wider(names_from = Gene, values_from = NPX)
  
  merged_inf <- inner_join(olink_wide, inf_wide, by = "SampleID")
  
  inf_vars <- c("WBC", "NE", "PCT", "CA125")
  cor_res <- data.frame()
  
  for (gene in intersect(names(merged_inf), target_genes)) {
    for (inf in intersect(names(merged_inf), inf_vars)) {
      test <- cor.test(merged_inf[[gene]], merged_inf[[inf]], method = "spearman")
      cor_res <- rbind(cor_res, data.frame(
        Target_Gene = gene,
        Confounder = inf,
        Correlation_r = test$estimate,
        P_value = test$p.value
      ))
    }
  }
  
  # 2.3 绘图
  cor_res <- cor_res %>%
    mutate(Label = paste0(round(Correlation_r, 2), "\n", 
                          ifelse(P_value < 0.05, "*", "")),
           P_val_cat = cut(P_value, breaks = c(0, 0.001, 0.01, 0.05, 1), labels = c("***", "**", "*", "ns")))
  
  p_heat <- ggplot(cor_res, aes(x = Confounder, y = Target_Gene, fill = Correlation_r)) +
    geom_tile(color = "white") +
    scale_fill_gradient2(low = "blue", high = "red", mid = "white", midpoint = 0, limit = c(-1, 1)) +
    geom_text(aes(label = Label), size = 3) +
    theme_minimal() +
    labs(title = "Correlation: Olink Proteins vs Inflammatory Markers", x = "", y = "")
  
  ggsave(paste0(dir_rev, "/Inflammation_Correlation_Heatmap.pdf"), p_heat, width = 5, height = 4)
  
} else {
  message("警告：未找到炎症指标文件 'CA125 炎症感染指标.xlsx'，跳过此模块。")
}

forest_df <- data.frame()
subgroups <- c("Sex", "ECOG", "Liver.M", "Lymph.node.M", "PD.L1") # 关键亚组

for (gene in intersect(names(merged_df), target_genes)) {
  # 按中位数分组
  med <- median(merged_df[[gene]], na.rm=T)
  merged_df$Group <- ifelse(merged_df[[gene]] > med, "High", "Low")
  merged_df$Group <- factor(merged_df$Group, levels = c("Low", "High"))
  
  # 1. 总人群
  fit_all <- coxph(Surv(pfs.time, pfs) ~ Group, data = merged_df)
  s_all <- summary(fit_all)
  forest_df <- rbind(forest_df, data.frame(
    Gene = gene, Subgroup = "Overall", Level = "All", 
    N = s_all$n, HR = s_all$coefficients[2], P = s_all$coefficients[5], Interact_P = NA
  ))
  
  # 2. 亚组循环
  for (var in subgroups) {
    if(!var %in% names(merged_df)) next
    
    # 交互作用检验
    f_int <- as.formula(paste0("Surv(pfs.time, pfs) ~ Group * ", var))
    fit_int <- tryCatch(coxph(f_int, data = merged_df), error = function(e) NULL)
    p_int <- NA
    if(!is.null(fit_int)) {
      coefs <- summary(fit_int)$coefficients
      idx <- grep(":", rownames(coefs))
      if(length(idx)>0) p_int <- coefs[idx, 5]
    }
    
    # 亚组内分析
    for (lev in unique(na.omit(merged_df[[var]]))) {
      sub_dat <- merged_df[merged_df[[var]] == lev, ]
      if(nrow(sub_dat) < 5) next
      
      fit_sub <- coxph(Surv(pfs.time, pfs) ~ Group, data = sub_dat)
      s_sub <- summary(fit_sub)
      
      forest_df <- rbind(forest_df, data.frame(
        Gene = gene, Subgroup = var, Level = as.character(lev), 
        N = s_sub$n, HR = s_sub$coefficients[2], P = s_sub$coefficients[5], Interact_P = p_int
      ))
    }
  }
}

writexl::write_xlsx(forest_df, paste0(dir_rev, "/Subgroup_Forest_Data.xlsx"))

tcga_file <- "TCGA-STAD_tpm-clean.txt"
# 注意：若本地无此文件，代码会报错。此处添加 tryCatch

tryCatch({
  if(!file.exists(tcga_file)) stop("TCGA文件不存在")
  
  tcga_tpm <- fread(tcga_file)
  
  # 筛选胃癌样本 (STAD) & 目标基因
  # 假设 SampleType <= 10 为肿瘤
  tcga_sub <- tcga_tpm %>%
    filter(Gene %in% c("IL15", "IL2", "MUC16", "MMP12")) %>%
    filter(SampleType <= 10) %>%
    dplyr::select(SampleId, Gene, tpm) %>%
    pivot_wider(names_from = Gene, values_from = tpm)
  
  # 4.1 相关性分析 (IL15 vs IL2, IL15 vs MUC16, etc.)
  pairs <- list(c("IL15", "IL2"), c("IL15", "MUC16"), c("IL15", "MMP12"), c("MMP12", "MUC16"))
  
  plot_list <- list()
  for (p in pairs) {
    g1 <- p[1]; g2 <- p[2]
    if(all(c(g1, g2) %in% names(tcga_sub))) {
      sp <- ggscatter(tcga_sub, x = g1, y = g2, 
                      add = "reg.line", conf.int = TRUE, 
                      cor.coef = TRUE, cor.method = "pearson",
                      xlab = paste(g1, "TPM"), ylab = paste(g2, "TPM"),
                      title = "TCGA-STAD Cohort")
      plot_list[[paste(g1, g2, sep="_")]] <- sp
    }
  }
  
  # 拼图并保存
  p_tcga <- wrap_plots(plot_list, ncol = 2)
  ggsave(paste0(dir_rev, "/TCGA_Correlation_Validation.pdf"), p_tcga, width = 8, height = 8)
  
  message("TCGA 验证完成。")
  
}, error = function(e) {
  message(paste("跳过 TCGA 分析：", e$message))
  message("提示：请确认 'TCGA-STAD_tpm-clean.txt' 路径是否正确。")
})

# 5.1 基线 IL15 vs Response
tmp_resp <- pdata %>% 
  dplyr::select(PID, Response) %>%
  mutate(Group = ifelse(Response %in% c("PR/CR", "PR", "CR"), "Responder", "Non-Responder"))

il15_dat <- base_expr %>% 
  filter(Gene == "IL15") %>% 
  inner_join(tmp_resp, by="PID")

if(nrow(il15_dat) > 0) {
  # 所有基因的差异分析 (为了计算 FDR，需要背景基因)
  all_base <- expr_data %>% filter(grepl("_1|_Q", SampleID)) %>%
    inner_join(tmp_resp, by="PID")
  
  diff_res <- all_base %>%
    group_by(Gene) %>%
    summarise(P_val = wilcox.test(NPX ~ Group)$p.value, .groups="drop") %>%
    mutate(FDR = p.adjust(P_val, method = "BH")) %>%
    arrange(P_val)
  
  writexl::write_xlsx(diff_res, paste0(dir_rev, "/Baseline_Response_FDR.xlsx"))
  print(head(diff_res))
}

# 合并化疗信息
cox_df <- merged_df %>%
  dplyr::select(PID, pfs.time, pfs, Chemotherapy) %>%
  inner_join(base_expr, by="PID")

if("Chemotherapy" %in% names(cox_df)) {
  res_multi <- data.frame()
  
  for(g in intersect(names(cox_df), target_genes)) {
    # 公式: Surv ~ Gene + Chemotherapy
    f <- as.formula(paste0("Surv(pfs.time, pfs) ~ ", g, " + Chemotherapy"))
    fit <- coxph(f, data = cox_df)
    s <- summary(fit)
    
    res_multi <- rbind(res_multi, data.frame(
      Gene = g,
      Gene_HR = s$coefficients[1, 2],
      Gene_P = s$coefficients[1, 5],
      Chemo_P = s$coefficients[2, 5]
    ))
  }
  writexl::write_xlsx(res_multi, paste0(dir_rev, "/Multivariate_Cox_Chemo.xlsx"))
}


if(!exists("pdata") | !exists("expr_data")) stop("请先运行 Part 1 代码加载数据！")

# 创建输出目录
dir_model <- "res/模型构建与评估"
dir.create(dir_model, recursive = TRUE)

# 1.1 提取 Baseline IL15
base_il15 <- expr_data %>%
  filter(Gene == "IL15" & grepl("_1|_Q", SampleID)) %>%
  dplyr::select(PID, Baseline_IL15 = NPX)

# 1.2 计算 Delta MUC16 和 Delta MMP12 (T2 - T1)
# 注意：需确保基因名匹配 (MUC-16 vs MUC16)
target_deltas <- c("MUC-16", "MMP12", "MUC16")

delta_df <- sample_data %>% 
  filter(Stage %in% c("T1", "T2")) %>%
  inner_join(expr_data, by = c("PID", "Pname", "SampleID", "DataID")) %>%
  filter(Gene %in% target_deltas) %>%
  dplyr::select(PID, Stage, Gene, NPX) %>%
  # 统一基因名
  mutate(Gene = ifelse(Gene == "MUC16", "MUC-16", Gene)) %>%
  pivot_wider(names_from = Stage, values_from = NPX) %>%
  filter(!is.na(T1) & !is.na(T2)) %>%
  mutate(Delta = T2 - T1) %>%
  dplyr::select(PID, Gene, Delta) %>%
  pivot_wider(names_from = Gene, values_from = Delta, names_prefix = "Delta_")

# 1.3 合并临床数据与特征
model_data <- pdata %>%
  dplyr::select(PID, pfs.time, pfs, Response) %>%
  inner_join(base_il15, by = "PID") %>%
  inner_join(delta_df, by = "PID") %>%
  mutate(
    # 定义二分类响应：PR/CR 为 Responder (1), 其他为 Non-Responder (0)
    Response_Binary = ifelse(Response %in% c("PR/CR", "PR", "CR"), 1, 0),
    Response_Factor = factor(ifelse(Response_Binary == 1, "Responder", "Non-Responder"), 
                             levels = c("Non-Responder", "Responder"))
  ) %>%
  na.omit() # 移除缺失值

# 检查列名 (防止 MUC-16 命名问题)
if("Delta_MUC-16" %in% names(model_data)) {
  names(model_data)[names(model_data) == "Delta_MUC-16"] <- "Delta_MUC16"
}

message(paste("建模纳入样本量:", nrow(model_data)))
writexl::write_xlsx(model_data, paste0(dir_model, "/Modeling_Data_Clean.xlsx"))

features <- c("Baseline_IL15", "Delta_MUC16", "Delta_MMP12")

# 2.1 PFS Cox 模型
f_cox <- as.formula(paste("Surv(pfs.time, pfs) ~", paste(features, collapse = " + ")))
fit_cox <- coxph(f_cox, data = model_data)

# 提取 Cox 公式系数
cox_coef <- coef(fit_cox)
cox_formula <- paste(paste(round(cox_coef, 4), "*", names(cox_coef)), collapse = " + ")
message("PFS Risk Score Formula: ", cox_formula)

# 计算风险评分
model_data$PFS_Risk_Score <- predict(fit_cox, type = "lp")

# 2.2 Response Logistic 模型
f_logit <- as.formula(paste("Response_Binary ~", paste(features, collapse = " + ")))
fit_logit <- glm(f_logit, data = model_data, family = binomial)

# 提取 Logistic 公式
logit_coef <- coef(fit_logit)
logit_formula <- paste(round(logit_coef[1], 4), "+", 
                       paste(round(logit_coef[-1], 4), "*", names(logit_coef[-1]), collapse = " + "))
message("Response Logit Formula: ", logit_formula)

# 计算预测概率
model_data$Response_Prob <- predict(fit_logit, type = "response")

# 保存结果
sink(paste0(dir_model, "/Model_Formulas.txt"))
cat("=== PFS Cox Model ===\n")
print(summary(fit_cox))
cat("\nRisk Score Formula:\n", cox_formula, "\n\n")

cat("=== Response Logistic Model ===\n")
print(summary(fit_logit))
cat("\nLogit Formula:\n", logit_formula, "\n")
sink()

# 3.1 准备矩阵
x <- as.matrix(model_data[, features])
y <- model_data$Response_Binary

set.seed(123)
# 3.2 Lasso 建模
cv_fit <- cv.glmnet(x, y, family = "binomial", alpha = 1, type.measure = "auc", nfolds = 5) # 样本少时nfolds减小
best_lambda <- cv_fit$lambda.min
final_lasso <- glmnet(x, y, family = "binomial", alpha = 1, lambda = best_lambda)

# 预测
model_data$Lasso_Prob <- as.numeric(predict(final_lasso, newx = x, type = "response"))

# 3.3 ROC 曲线
roc_lasso <- roc(model_data$Response_Binary, model_data$Lasso_Prob, quiet = TRUE)
pdf(paste0(dir_model, "/Lasso_ROC.pdf"), width = 5, height = 5)
plot(roc_lasso, print.auc = TRUE, main = "Lasso ROC", col = "#2E5A88", lwd = 3)
dev.off()

# 3.4 校准曲线 (Calibration Plot)
# 手动分箱法 (应对小样本)
model_data$Cal_Group <- cut(model_data$Lasso_Prob, 
                            breaks = quantile(model_data$Lasso_Prob, probs = seq(0, 1, 0.25), na.rm = TRUE), 
                            include.lowest = TRUE)
cal_summary <- model_data %>%
  group_by(Cal_Group) %>%
  summarise(Predicted = mean(Lasso_Prob), Observed = mean(Response_Binary), .groups = 'drop')

p_cal <- ggplot(cal_summary, aes(x = Predicted, y = Observed)) +
  geom_line(color = "#2E5A88", size = 1) +
  geom_point(color = "#D62728", size = 3) +
  geom_abline(slope = 1, intercept = 0, linetype = "dashed", color = "grey50") +
  xlim(0, 1) + ylim(0, 1) +
  labs(title = "Lasso Calibration Plot", x = "Predicted Probability", y = "Observed Fraction") +
  theme_bw()

ggsave(paste0(dir_model, "/Lasso_Calibration.pdf"), p_cal, width = 4, height = 4)

# 3.5 决策曲线 (DCA)
# 使用 dcurves 包
tryCatch({
  dca_df <- model_data
  dca_res <- dcurves::dca(Response_Binary ~ Lasso_Prob, data = dca_df, thresholds = seq(0, 1, by = 0.01))
  
  pdf(paste0(dir_model, "/Lasso_DCA.pdf"), width = 5, height = 4)
  plot(dca_res, smooth = TRUE) 
  dev.off()
}, error = function(e) message("DCA分析报错 (可能是样本量过小或单一类别): ", e$message))

# 4.1 风险评分分组 KM
model_data$Risk_Group <- ifelse(model_data$PFS_Risk_Score > median(model_data$PFS_Risk_Score), "High Risk", "Low Risk")
model_data$Risk_Group <- factor(model_data$Risk_Group, levels = c("Low Risk", "High Risk"))

# 自定义生存报告
report <- get_custom_surv_report(model_data, "pfs.time", "pfs", "Risk_Group", "Low Risk")
writeLines(report, paste0(dir_model, "/Risk_Score_KM_Report.txt"))

# 绘图
fit_risk <- survfit(Surv(pfs.time, pfs) ~ Risk_Group, data = model_data)
p_km_risk <- ggsurvplot(fit_risk, data = model_data, pval = TRUE, risk.table = TRUE, 
                        palette = c("#2E9FDF", "#E7B800"),
                        title = "PFS by Composite Risk Score",
                        xlab = "Time (Months)")

pdf(paste0(dir_model, "/Composite_Risk_Score_KM.pdf"), width = 6, height = 6, onefile = FALSE)
print(p_km_risk)
dev.off()

# 4.2 多因素 Cox 森林图 (校正临床特征)
# 准备协变量
covariates <- c("Sex", "Age", "ECOG", "Liver.M", "PD.L1") # 示例协变量，根据实际数据调整
valid_covars <- intersect(covariates, names(pdata))

# 合并协变量到 model_data (如果尚未合并)
forest_data <- model_data %>%
  left_join(dplyr::select(pdata, PID, all_of(valid_covars)), by = "PID")

# 构建公式: Surv ~ Risk_Score + Covariates
f_multi <- as.formula(paste("Surv(pfs.time, pfs) ~ PFS_Risk_Score +", paste(valid_covars, collapse = " + ")))

tryCatch({
  fit_multi <- coxph(f_multi, data = forest_data)
  
  # 绘制森林图
  p_forest <- ggforest(fit_multi, data = forest_data, 
                       main = "Multivariate Cox: Risk Score & Clinical Features",
                       fontsize = 0.8, noDigits = 2)
  
  pdf(paste0(dir_model, "/Multivariate_Cox_Forest.pdf"), width = 8, height = 6)
  print(p_forest)
  dev.off()
}, error = function(e) message("森林图绘制失败 (可能是变量缺失或模型不收敛): ", e$message))


