## clearing the environment
rm(list = ls())  
gc()    

library(rstan)
library(loo)
library(tidyverse)
library(bayesplot)
####################################################################################

## model specific details that needs to be change for every run
modelName <- "Incumbent_DeNovo"

## Setting all the directories for opeartions
projectDir <- getwd()
scriptDir <- file.path(projectDir, "scripts")
modelDir <- file.path(projectDir, "stan_models")
dataDir <- file.path(projectDir, "datafiles")
toolsDir <- file.path(scriptDir, "tools")
outputDir <- file.path(projectDir, "output_fit")
saveDir <- file.path(projectDir, 'save_csv/StanResultsMar2025')
LooDir <- file.path('loo_fit') 

# loadiong the scr# loadiong the script that contains functions for plotting stan parameters
source(file.path(toolsDir, "stanTools.R"))                # save results in new folder

# compiling multiple stan objects together that ran on different nodes
stanfit1 <- read_stan_csv(file.path(saveDir, paste0(modelName, "_sa1", ".csv")))
stanfit2 <- read_stan_csv(file.path(saveDir, paste0(modelName, "_sa2",".csv")))
stanfit3 <- read_stan_csv(file.path(saveDir, paste0(modelName, "_sa3",".csv")))
stanfit4 <- read_stan_csv(file.path(saveDir, paste0(modelName, "_sa4", ".csv")))
stanfit5 <- read_stan_csv(file.path(saveDir, paste0(modelName, "_sa5",".csv")))

fit <- sflist2stanfit(list(stanfit5, stanfit4, stanfit3, stanfit2, stanfit1))

# finding the parameters used in the model 
# using the last parameter("sigma4") in the array to get the total number of parameters set in the model
num_pars <- which(fit@model_pars %in% "global_params") -1      # the variable "sigma4" will change depdending on the data used
parametersToPlot <- fit@model_pars[1:num_pars]

# number of post-burn-in samples that are used for plotting 
nPost <- nrow(fit)

ymta <- as.data.frame(fit, pars = parametersToPlot) 
write.csv(ymta, file = paste0("save_csv/parsMat_", modelName, ".csv"))

################################################################################################
################################################################################################


################################################################################################
## posterior predictive distributions

source('scripts/stan_extract_forplotting.R')
plot_function(fitobject = fit)

legn_labels <- c('6-8', '8-10', '10-12', '12-25')

ggplot() +
  geom_ribbon(data = Counts_pred, aes(x = timeseries, ymin = lb, ymax = ub, fill = ageBMT_bin), alpha = 0.2)+
  #geom_ribbon(data = Counts_withsigma, aes(x = timeseries, ymin = lb, ymax = ub, fill = ageBMT_bin), alpha = 0.2)+
  geom_line(data = Counts_pred, aes(x = timeseries, y = median, color = ageBMT_bin), linewidth=1) +
  geom_errorbar(data = Counts_naive_sigma_obs, aes(x = timeseries, ymin = lb, ymax = ub, col=ageBMT_bin),
                alpha = 0.5, width=5)+
  geom_point(data = counts_data, aes(x = age.at.S1K, y = naive, color = ageBMT_bin), size=2) +
  labs(title=paste('Total counts of naive Tregs'),  y=NULL, x= "Host age (days)") + 
  scale_color_discrete(name="Host age at \n BMT (Wks)", labels=legn_labels)+
  scale_x_continuous(limits = c(1, 450), breaks = c(0,100,200,300, 400, 500))+
  scale_y_continuous(limits = c(5e4, 2e7), trans="log10", breaks=c(1e4, 1e5, 1e6, 1e7, 1e8), minor_breaks = log10minorbreaks, labels =fancy_scientific) +
  guides(fill = 'none') + myTheme 

ggsave(filename = file.path(outputDir, modelName, "Counts_XonLinear.pdf"), last_plot(),
       device = "pdf", height = 5, width = 8)

ggplot() +
  geom_ribbon(data = Counts_pred, aes(x = timeseries, ymin = lb, ymax = ub, fill = ageBMT_bin), alpha = 0.2)+
  #geom_ribbon(data = Counts_withsigma, aes(x = timeseries, ymin = lb, ymax = ub, fill = ageBMT_bin), alpha = 0.2)+
  geom_line(data = Counts_pred, aes(x = timeseries, y = median, color = ageBMT_bin), linewidth=1) +
  geom_errorbar(data = Counts_naive_sigma_obs, aes(x = timeseries, ymin = lb, ymax = ub, col=ageBMT_bin),
                alpha = 0.5, width=0.01)+
  geom_point(data = counts_data, aes(x = age.at.S1K, y = naive, color = ageBMT_bin), size=2) +
  labs(title=paste('Total counts of naive Tregs'),  y=NULL, x= "Host age (days)") + 
  scale_color_discrete(name="Host age at \n BMT (Wks)", labels=legn_labels)+
  scale_x_continuous(limits = c(60, 460) , trans="log10", breaks=c(75, 150, 300, 450)) + #scale_y_log10() +
  scale_y_continuous(limits = c(5e4, 2e7), trans="log10", breaks=c(1e4, 1e5, 1e6, 1e7, 1e8), minor_breaks = log10minorbreaks, labels =fancy_scientific) +
  guides(fill = 'none') + myTheme 

ggsave(filename = file.path(outputDir, modelName, "Counts_Xonlogscale.pdf"), last_plot(),
       device = "pdf", height = 5, width = 8)


# normalised donor fractions

fac_labels <- c(`agebin1`= '6-8 weeks', `agebin2`= '8-10 weeks', `agebin3`= '10-12 weeks', `agebin4`= '12-25 weeks')
ggplot() +
  geom_ribbon(data = Nfd_pred, aes(x = timeseries, ymin = lb, ymax = ub, fill = ageBMT_bin), alpha = 0.15)+
  geom_line(data = Nfd_pred, aes(x = timeseries, y = median, color = ageBMT_bin), linewidth=1) +
  geom_point(data = Nfd_data, aes(x = age.at.S1K, y = naive, color = ageBMT_bin), size=2) +
  #geom_errorbar(data = Nfd_naive_sigma_obs, aes(x = timeseries, ymin = lb, ymax = ub, col=ageBMT_bin),
   #             alpha = 0.5, width=5)+
  labs(x = "Host age (days)", y = NULL, title = "Normalised Chimerism in naive Tregs") +
  scale_color_discrete(name="Host age at \n BMT (Wks)", labels=legn_labels)+
  scale_x_continuous(limits = c(1, 450), breaks = c(0,100,200,300, 400, 500))+
  scale_y_continuous(limits =c(0, 1.1), breaks = c(0, 0.2, 0.4, 0.6, 0.8, 1.0)) + 
  #facet_wrap(~ ageBMT_bin, scales = 'free', labeller = as_labeller(fac_labels))+
  guides(fill='none')+ myTheme


ggsave(filename = file.path(outputDir, modelName, "Nfd_singlepanel.pdf"), last_plot(),
       device = "pdf", height = 5, width = 8)

ggplot() +
  geom_ribbon(data = Nfd_pred, aes(x = timeseries, ymin = lb, ymax = ub, fill = ageBMT_bin), alpha = 0.15)+
  geom_line(data = Nfd_pred, aes(x = timeseries, y = median, color = ageBMT_bin), linewidth=1) +
  geom_point(data = Nfd_data, aes(x = age.at.S1K, y = naive, color = ageBMT_bin), size=2) +
  geom_errorbar(data = Nfd_naive_sigma_obs, aes(x = timeseries, ymin = lb, ymax = ub, col=ageBMT_bin),
                alpha = 0.5, width=5)+
  labs(x = "Host age (days)", y = NULL, title = "Normalised Chimerism in naive Tregs") +
  scale_color_discrete(name="Host age at \n BMT (Wks)", labels=legn_labels)+
  scale_x_continuous(limits = c(1, 450), breaks = c(0,100,200,300, 400, 500))+
  scale_y_continuous(limits =c(0, 1.1), breaks = c(0, 0.2, 0.4, 0.6, 0.8, 1.0)) + 
  facet_wrap(~ ageBMT_bin, scales = 'free', labeller = as_labeller(fac_labels))+
  guides(fill='none', col="none")+ myTheme

ggsave(filename = file.path(outputDir, modelName, "Nfd_multipanel.pdf"), last_plot(),
       device = "pdf", height = 7.5, width = 8.5)

# thymic Ki67 fractions

ggplot() +
  geom_ribbon(data = ki_naive_pred, aes(x = timeseries, ymin = lb*100, ymax = ub*100, alpha = ageBMT_bin,  fill = subcomp))+
  geom_line(data = ki_naive_pred, aes(x = timeseries , y = median*100, color = subcomp, linetype = ageBMT_bin)) +
  geom_point(data = ki_data, aes(x = age.at.S1K - age.at.BMT, y = naive*100, color = subcomp), size=2) +
  #geom_errorbar(data = ki_donor_naive_sigma_obs, aes(x = timeseries, ymin = lb*100, ymax = ub*100, col=subcomp),
         #       alpha = 0.5, width=5)+
  #geom_errorbar(data = ki_host_naive_sigma_obs, aes(x = timeseries, ymin = lb*100, ymax = ub*100, col=subcomp),
          #      alpha = 0.5, width=5)+
  scale_color_manual(values = c(2, 4), name="")+
  scale_alpha_manual(values = c(0.1, 0.1, 0.1, 0.1), name="")+
  scale_linetype_discrete(name="Host age at \n BMT (Wks)", labels=legn_labels)+
  labs(x = "Host age (days)", y = NULL, title = "% Ki67hi in naive Tregs") +
  scale_x_continuous(limits = c(1, 300), breaks = c(0,100,200,300, 400, 500))+
  scale_y_continuous(limits =c(5, 25),  breaks = c(0, 10, 20, 30, 40, 50)) + 
 # facet_wrap(~ ageBMT_bin, scales = 'free', labeller = as_labeller(fac_labels))+
  guides(fill='none', alpha='none') + myTheme 


ggsave(filename = file.path(outputDir, modelName, "Ki67_DonorHost.pdf"), last_plot(),
       device = "pdf", height = 7.5, width = 8.5)

ki_data %>% group_by(subcomp) %>%
  summarise(mean_ki = mean(naive))

ki_test <- ki_data %>%
  mutate(index_col = rep(seq(1, nrow(ki_data)/2), 2)) %>%
  select(index_col, naive, subcomp) %>%
  spread(key = subcomp, value = naive)

t.test(formula = score ~ time,
       alternative = "greater",
       mu = 0, 
       paired = TRUE,   
       var.equal = TRUE,
       conf.level = 0.95)

boxplot(
  ki_data$naive ~ ki_data$subcomp, 
        col = c("#003C67FF", "#EFC000FF"),
        main = "ICT training score improves knowlege",
        xlab = "Time", ylab = "Score")


#######

num_pars_m <- which(fit@model_pars %in% "sigma_counts_naive") - 1
params_m <- fit@model_pars[1:num_pars_m]
### Extract samples as dataframe
par_dist <- as.data.frame(fit, pars = c(params_m))


## incumbent fraction
par_est <- data.frame()
for (i in 1:nrow(par_dist)){
  ## 1/rate of division --> interdivision time
  par_est[i, "div_times_disp"]  <-  1/par_dist$rho_D[i]
  par_est[i, "div_times_inc"]  <- 1/par_dist$rho_I[i]
  
  ## 1/rate of loss --> residence time
  par_est[i, "resid_times_disp"]  <- 1/(par_dist$lambda_D[i] + par_dist$rho_D[i])
  par_est[i, "resid_times_inc"]  <- 1/(par_dist$rho_I[i])
}



parplot <- par_est %>% 
  gather(key="parname", value=par_value) %>%
  group_by(parname) %>%
  summarize(lb = quantile(par_value, probs = 0.05),
            estimate = quantile(par_value, probs = 0.5),
            ub = quantile(par_value, probs = 0.95)) %>%
  mutate(subpop = ifelse(grepl("inc", parname), "Incumbent", "Displaceable"),
         param = ifelse(grepl("div", parname), "Inter-division time", "Residence time")) 


ggplot(parplot) +
  geom_errorbar(aes(y=estimate, ymin=lb, ymax=ub, x=subpop, col=subpop),
                width=0.2, linetype=1,  position=position_dodge(0.4)) +
  geom_point(aes(y=estimate, x=subpop, col=subpop),
             position=position_dodge(width=0.4), stat = "identity", size=4) + 
  scale_color_manual(values=c(4,2), name=NULL)+
  facet_wrap( ~ param) +
  labs(y="Days", x=NULL, title = NULL) +
  my_theme + theme(axis.text.x = element_blank(),
                   legend.position = c(0.825, 0.9))

## function that describes changes in total pool size of source population
sp_numbers <- function(ts_seq) {
  t0 <- 40.0
  b0 <- 4.6
  nu <- 160
  # de novo naive Tregs numbers
  value <- 10^b0 * exp(-(ts_seq - t0) / nu)
  return(value)
}

## Prcursor size changing with time
sp_vec <- sapply(ts_pred, sp_numbers)

## function to evaluate daily influx from source into target population
theta_spline <- function(ts_seq, psi) {
  # psi is the rate of deaily influx
  value <- (psi * sp_numbers(ts_seq))
  return(value)
}

## target population size 
target_df <- as.data.frame(fit, pars = "counts_naive_mean_pred1")


## incumbent fraction
inc_frac <- data.frame()
for (j in 1:ncol(target_df)){
  for(i in 1:nrow(target_df)){
    ## number incumbent stays same throughout the time course
    inc_j <- exp(par_dist$y3_0[i]) + exp(par_dist$y4_0[i])
    
    ## total population size
    total_j <- target_df[i, j]
    
    inc_frac[i, j] = inc_j/total_j
  }
}

inc_frac %>% 
  gather(factor_key = TRUE) %>%
  group_by(key) %>%
  summarize(lb = quantile(value, probs = 0.05),
            estimate = quantile(value, probs = 0.5),
            ub = quantile(value, probs = 0.95)) %>%
  bind_cols("timeseries" = c(ts_pred)) %>%
  ggplot() +
  geom_line(aes(x=timeseries, y=estimate), size=0.75, col=4)+
  #geom_line(aes(x=timeseries, y=sp_vec), col=2)+
  geom_ribbon(aes(x=timeseries, ymin=lb, ymax=ub), fill=4, alpha=0.2)+
  labs(x="Host age (days)", y=NULL, title = "Incumbent fraction") +
  ylim(0, 0.5) +
  my_theme



## contributions of division and source influx to target pop maintenance
contribution_to_disp <- data.frame()
## (rho * total_N)/(theta + rho * total_N) 
## i.e. (1/(1+ (theta/rho * total_N))
for (j in 1:ncol(target_df)){
  for(i in 1:nrow(target_df)){
    theta_j <- theta_spline(ts_pred[j], par_dist$psi[i])
    #print(paste0("theta_", i, "_", j))
    #print(theta_j)
    
    rhoNdis_j <- par_dist$rho_D[i] * (target_df[i, j] - (exp(par_dist$y3_0[i]) + exp(par_dist$y4_0[i])))
    #print(paste0("rhoN_", i, "_", j))
    #print(rhoN_j)
    
    contribution_to_disp[i, j] = 1/
      (1+ (theta_j/rhoNdis_j))
  }
}

contribution_to_disp %>% 
  gather(factor_key = TRUE) %>%
  group_by(key) %>%
  summarize(lb = quantile(value, probs = 0.05),
            estimate = quantile(value, probs = 0.5),
            ub = quantile(value, probs = 0.95)) %>%
  bind_cols("timeseries" = c(ts_pred)) %>%
  ggplot() +
  geom_line(aes(x=timeseries, y=estimate), size=0.75, col=4)+
  #geom_line(aes(x=timeseries, y=sp_vec), col=2)+
  geom_ribbon(aes(x=timeseries, ymin=lb, ymax=ub), fill=4, alpha=0.2)+
  labs(x="Host age (days)", y=NULL, title = "Fraction of total production of displaceables \nthat derives from division") +
  ylim(0.3, 0.9) +
  #scale_y_log10(limits= c(1e-3, 0/1)) +
  my_theme






# ################################################################################################
# calculating PSIS-L00-CV for the fit
naive_counts_loglik <- extract_log_lik(fit, parameter_name = "log_lik_counts_naive", merge_chains = TRUE)
naive_fd_loglik <- extract_log_lik(fit, parameter_name = "log_lik_Nfd_naive", merge_chains = TRUE)
ki_donor_loglik <- extract_log_lik(fit, parameter_name = "log_lik_ki_donor_naive", merge_chains = TRUE)
ki_host_loglik <- extract_log_lik(fit, parameter_name = "log_lik_ki_host_naive", merge_chains = TRUE)

#combined_loglik <- extract_log_lik(fit, parameter_name = "log_lik", merge_chains = TRUE)
log_lik_comb <- cbind(naive_counts_loglik, naive_fd_loglik,
                      ki_donor_loglik, ki_host_loglik)


# optional but recommended
ll_array <- extract_log_lik(fit,parameter_name = "log_lik_counts_naive", merge_chains = FALSE)
r_eff <- relative_eff(exp(ll_array))

# loo-ic values
loo_loglik <- loo(log_lik_comb, save_psis = FALSE, cores = 8)

## saving loo object as an rds file for model comparison
loofilename <- paste0("loosave_s2", modelName, ".rds")
write_rds(loo_loglik, file  = file.path(LooDir, loofilename))

ploocv <- data.frame("Model" = modelName,
                     "LooIC" = loo_loglik$estimates[3],
                     "SE" = loo_loglik$estimates[6],
                     "PLoo" = loo_loglik$estimates[2])
ploocv
# 
write.table(ploocv, file = file.path(outputDir, "stat_table_naive.csv"),
            sep = ",", append = T, quote = FALSE,
            col.names = F, row.names = FALSE)

### posterior distributions of parameters
ptable <- monitor(as.array(fit, pars = parametersToPlot), warmup = 0, print = FALSE)
out_table <- ptable[1:num_pars, c(1, 3, 4, 8)]
out_table
write.csv(out_table, file = file.path(outputDir, paste0('params_', modelName, ".csv")))


################################################################################################



################################################################################################
## open graphics device 
## saving  plots for quality control 
pdf(file = file.path(outputDir, paste(modelName,"Plots%03d.pdf", sep = "")),
    width = 12, height = 5, onefile = F)

pairs(fit, pars = parametersToPlot)

options(bayesplot.base_size = 15,
        bayesplot.base_family = "sans")
bayesplot::color_scheme_set(scheme = "viridis")

rhats <- rhat(fit, pars = parametersToPlot)
mcmc_rhat(rhats) + yaxis_text() + myTheme

ratios1 <- neff_ratio(fit, pars = parametersToPlot)
mcmc_neff(ratios1) + yaxis_text() + myTheme

posterior <- as.array(fit)
mcmc_acf(posterior, pars = parametersToPlot) + myTheme

mcmcHistory(fit, pars = parametersToPlot, nParPerPage = 4, myTheme = myTheme)

mcmc_dens_overlay(posterior, parametersToPlot)
mcmc_dens(posterior, parametersToPlot) + myTheme


dev.off()
