## clearing the environment
#rm(list = ls())
#gc()

library(rstan)
library(tidyverse)
library(kableExtra)
#########################################################
#### plotting style
my_theme <- theme(
  axis.text = element_text(size = 14),
  strip.text = element_text(size = 14),
  axis.title = element_text(size = 14),
  plot.title = element_text(size = 14, hjust = 0.5),
  legend.background = element_blank(), legend.key = element_blank(),
  legend.text = element_text(size = 13),
  legend.title = element_text(size = 13),
)

# setting ggplot theme for rest fo the plots
theme_set(theme_bw())


####### plotting
fancy_scientific <- function(l) {
  # turn in to character string in scientific notation
  l <- format(l, scientific = TRUE)
  # quote the part before the exponent to keep all the digits
  l <- gsub("^(.*)e", "'\\1'e", l)
  # remove + after exponent, if exists. E.g.: (e^+2 -> e^2)
  l <- gsub("e\\+", "e", l)
  # turn the 'e' into plotmath format
  l <- gsub("e", "%*%10^", l)
  # convert 1x10^ or 1.000x10^ -> 10^
  l <- gsub("\\'1[\\.0]*\\'\\%\\*\\%", "", l)
  # return this as an expression
  parse(text = l)
}

log10minorbreaks <- as.numeric(1:10 %o% 10^(-3:8))

## Setting all the directories for opeartions
project_dir <- getwd()
data_dir <- file.path(project_dir, "datafiles")
output_dir <- file.path(project_dir, "output_fit")
save_dir <- file.path(project_dir, "save_csv", "naive_Treg")

## time sequence for analysis
ts_pred <- 10^seq(log10(66), log10(450), length.out = 300)

## modelname
modelname <- "Incumbent_DeNovo"

# compiling multiple stan objects together that ran on different nodes
stanfit1 <- read_stan_csv(file.path(save_dir, paste0(modelname, "_a1", ".csv")))
stanfit2 <- read_stan_csv(file.path(save_dir, paste0(modelname, "_a2", ".csv")))
stanfit3 <- read_stan_csv(file.path(save_dir, paste0(modelname, "_a3", ".csv")))
stanfit4 <- read_stan_csv(file.path(save_dir, paste0(modelname, "_a4", ".csv")))
stanfit5 <- read_stan_csv(file.path(save_dir, paste0(modelname, "_a5", ".csv")))

## merged stanfit object
fit <- sflist2stanfit(list(stanfit5, stanfit4, stanfit3, stanfit2, stanfit1))


# finding the parameters used in the model
# using the last parameter("sigma4") in the array
# to get the total number of parameters set in the model
num_pars_m <- which(fit@model_pars %in% "sigma_counts_naive") - 1
params_m <- fit@model_pars[1:num_pars_m]


### Extract samples as dataframe
par_dist <- as.data.frame(fit, pars = c(params_m))


## incumbent fraction
par_est <- data.frame()
for (i in 1:nrow(par_dist)){
    ## 1/rate of division --> interdivision time
    par_est[i, "div_times_disp"]  <-  1/par_dist$rho_D[i]
    par_est[i, "div_times_inc"]  <- 1/par_dist$rho_I[i]
    
    ## 1/rate of loss --> residence time
    par_est[i, "resid_times_disp"]  <- 1/(par_dist$lambda_D[i] + par_dist$rho_D[i])
    par_est[i, "resid_times_inc"]  <- 1/(par_dist$rho_I[i])
}



parplot <- par_est %>% 
  gather(key="parname", value=par_value) %>%
  group_by(parname) %>%
  summarize(lb = quantile(par_value, probs = 0.05),
            estimate = quantile(par_value, probs = 0.5),
            ub = quantile(par_value, probs = 0.95)) %>%
  mutate(subpop = ifelse(grepl("inc", parname), "Incumbent", "Displaceable"),
         param = ifelse(grepl("div", parname), "Inter-division time", "Residence time")) 



## function that describes changes in total pool size of source population
sp_numbers <- function(ts_seq) {
  t0 <- 40.0
  b0 <- 4.6
  nu <- 160
  # de novo naive Tregs numbers
  value <- 10^b0 * exp(-(ts_seq - t0) / nu)
  return(value)
}

## Prcursor size changing with time
sp_vec <- sapply(ts_pred, sp_numbers)

## function to evaluate daily influx from source into target population
theta_spline <- function(ts_seq, psi) {
  # psi is the rate of deaily influx
  value <- (psi * sp_numbers(ts_seq))
  return(value)
}

## target population size 
target_df <- as.data.frame(fit, pars = "counts_naive_mean_pred1")


## incumbent fraction
inc_frac <- data.frame()
for (j in 1:ncol(target_df)){
  for(i in 1:nrow(target_df)){
    ## number incumbent stays same throughout the time course
    inc_j <- exp(par_dist$y3_0[i]) + exp(par_dist$y4_0[i])
    
    ## total population size
    total_j <- target_df[i, j]
    
    inc_frac[i, j] = inc_j/total_j
  }
}


## contributions of division and source influx to target pop maintenance
contribution_to_disp <- data.frame()
## (rho * total_N)/(theta + rho * total_N) 
## i.e. (1/(1+ (theta/rho * total_N))
for (j in 1:ncol(target_df)){
  for(i in 1:nrow(target_df)){
    theta_j <- theta_spline(ts_pred[j], par_dist$psi[i])
    #print(paste0("theta_", i, "_", j))
    #print(theta_j)
    
    rhoNdis_j <- par_dist$rho_D[i] * (target_df[i, j] - (exp(par_dist$y3_0[i]) + exp(par_dist$y4_0[i])))
    #print(paste0("rhoN_", i, "_", j))
    #print(rhoN_j)
    
    contribution_to_disp[i, j] = 1/
      (1+ (theta_j/rhoNdis_j))
  }
}


## posterior predictive distributions

source('scripts/stan_extract_forplotting.R')
plot_function(fitobject = fit)

legn_labels <- c('6-8', '8-10', '10-12', '12-25')

p2 <- ggplot() +
  geom_ribbon(data = Counts_pred, aes(x = timeseries, ymin = lb, ymax = ub, fill = ageBMT_bin), alpha = 0.2)+
  #geom_ribbon(data = Counts_withsigma, aes(x = timeseries, ymin = lb, ymax = ub, fill = ageBMT_bin), alpha = 0.2)+
  geom_line(data = Counts_pred, aes(x = timeseries, y = median, color = ageBMT_bin), linewidth=1) +
  geom_errorbar(data = Counts_naive_sigma_obs, aes(x = timeseries, ymin = lb, ymax = ub, col=ageBMT_bin),
                alpha = 0.5, width=0.01)+
  geom_point(data = counts_data, aes(x = age.at.S1K, y = naive, color = ageBMT_bin), size=2) +
  labs(title=paste('Total counts'),  y=NULL, x= "Host age (days)") + 
  scale_color_discrete(name="Host age at \n BMT (Wks)", labels=legn_labels)+
  scale_x_continuous(limits = c(60, 460) , trans="log10", breaks=c(75, 150, 300)) + #scale_y_log10() +
  scale_y_continuous(limits = c(1e5, 1e7), trans="log10", breaks=c(1e4, 1e5, 1e6, 1e7, 1e8), minor_breaks = log10minorbreaks, labels =fancy_scientific) +
  guides(fill = 'none', col="none") + my_theme 

p4 <- ggplot() +
  geom_ribbon(data = Nfd_pred, aes(x = timeseries, ymin = lb, ymax = ub, fill = ageBMT_bin), alpha = 0.15)+
  geom_line(data = Nfd_pred, aes(x = timeseries, y = median, color = ageBMT_bin), linewidth=1) +
  geom_point(data = Nfd_data, aes(x = age.at.S1K, y = naive, color = ageBMT_bin), size=2) +
  geom_errorbar(data = Nfd_naive_sigma_obs, aes(x = timeseries, ymin = lb, ymax = ub, col=ageBMT_bin),
                alpha = 0.5, width=5)+
  labs(x = "Host age (days)", y = NULL, title = "Normalised Chimerism") +
  scale_color_discrete(name="Host age at BMT (Wks)", labels=legn_labels)+
  scale_x_continuous(limits = c(1, 450), breaks = c(0,200, 400))+
  scale_y_continuous(limits =c(0, 1.1), breaks = c(0, 0.2, 0.4, 0.6, 0.8, 1.0)) + 
  guides(fill='none') + my_theme +
  theme(legend.position = "top", legend.direction = "horizontal")


fac_labels <- c(`agebin1`= '6-8 weeks', `agebin2`= '8-10 weeks', `agebin3`= '10-12 weeks', `agebin4`= '12-25 weeks')


p6 <- ggplot() +
  geom_ribbon(data = ki_naive_pred, aes(x = timeseries, ymin = lb*100, ymax = ub*100, fill = subcomp), alpha = 0.25)+
  geom_line(data = ki_naive_pred, aes(x = timeseries, y = median*100, color = subcomp), linewidth=1) +
  geom_point(data = ki_data, aes(x = age.at.S1K, y = naive*100, color = subcomp), size=2) +
  #geom_errorbar(data = ki_donor_naive_sigma_obs, aes(x = timeseries, ymin = lb*100, ymax = ub*100, col=subcomp),
                #alpha = 0.5, width=5)+
  #geom_errorbar(data = ki_host_naive_sigma_obs, aes(x = timeseries, ymin = lb*100, ymax = ub*100, col=subcomp),
                #alpha = 0.5, width=5)+
  scale_color_manual(values = c(2, 4))+
  labs(x = "Host age (days)", y = NULL, title = "% Ki67hi") +
  scale_x_continuous(limits = c(60, 450), breaks = c(0,200,400))+
  scale_y_continuous(limits =c(0, 25), breaks = c(0, 10, 20, 30, 40, 50)) + 
  #facet_wrap(~ ageBMT_bin, scales = 'free', labeller = as_labeller(fac_labels))+
  guides(fill='none', col="none") + my_theme 

p7 <- ggplot(parplot) +
  geom_errorbar(aes(y=estimate, ymin=lb, ymax=ub, x=subpop, col=subpop),
                width=0.2, linetype=1,  position=position_dodge(0.4)) +
  geom_point(aes(y=estimate, x=subpop, col=subpop),
             position=position_dodge(width=0.4), stat = "identity", size=4) + 
  scale_color_manual(values=c(4,2), name=NULL)+
  facet_wrap( ~ param) +
  labs(y="Days", x=NULL, title = NULL) +
  my_theme + theme(axis.text.x = element_blank(),
                   legend.position = c(0.825, 0.9))

p8 <- contribution_to_disp %>% 
  gather(factor_key = TRUE) %>%
  group_by(key) %>%
  summarize(lb = quantile(value, probs = 0.05),
            estimate = quantile(value, probs = 0.5),
            ub = quantile(value, probs = 0.95)) %>%
  bind_cols("timeseries" = c(ts_pred1)) %>%
  ggplot() +
  geom_line(aes(x=timeseries, y=estimate), size=0.75, col=4)+
  #geom_line(aes(x=timeseries, y=sp_vec), col=2)+
  geom_ribbon(aes(x=timeseries, ymin=lb, ymax=ub), fill=4, alpha=0.2)+
  labs(x="Host age (days)", y=NULL, title = "Fraction of total production \nof displaceables that \nderives from division") +
  ylim(0.3, 0.9) +
  #scale_y_log10(limits= c(1e-3, 0/1)) +
  my_theme


p9 <- inc_frac %>% 
  gather(factor_key = TRUE) %>%
  group_by(key) %>%
  summarize(lb = quantile(value, probs = 0.05),
            estimate = quantile(value, probs = 0.5),
            ub = quantile(value, probs = 0.95)) %>%
  bind_cols("timeseries" = c(ts_pred1)) %>%
  ggplot() +
  geom_line(aes(x=timeseries, y=estimate), size=0.75, col=4)+
  #geom_line(aes(x=timeseries, y=sp_vec), col=2)+
  geom_ribbon(aes(x=timeseries, ymin=lb, ymax=ub), fill=4, alpha=0.2)+
  labs(x="Host age (days)", y=NULL, title = "Incumbent fraction") +
  ylim(0, 0.5) +
  my_theme

### 

## modelname
modelname <- "Neutral_DeNovo"

# compiling multiple stan objects together that ran on different nodes
stanfit1 <- read_stan_csv(file.path(save_dir, paste0(modelname, "_a1", ".csv")))
stanfit2 <- read_stan_csv(file.path(save_dir, paste0(modelname, "_a2", ".csv")))
stanfit3 <- read_stan_csv(file.path(save_dir, paste0(modelname, "_a3", ".csv")))
stanfit4 <- read_stan_csv(file.path(save_dir, paste0(modelname, "_a4", ".csv")))
stanfit5 <- read_stan_csv(file.path(save_dir, paste0(modelname, "_a5", ".csv")))

## merged stanfit object
fit_2 <- sflist2stanfit(list(stanfit1, stanfit2, stanfit3, stanfit4, stanfit5))

# finding the parameters used in the model
# using the last parameter("sigma4") in the array
# to get the total number of parameters set in the model
num_pars_m <- which(fit_2@model_pars %in% "sigma_counts_naive") - 1
params_m <- fit_2@model_pars[1:num_pars_m]

source('scripts/stan_extract_forplotting.R')
plot_function(fitobject = fit_2)

legn_labels <- c('6-8', '8-10', '10-12', '12-25')

p1 <- ggplot() +
  geom_ribbon(data = Counts_pred, aes(x = timeseries, ymin = lb, ymax = ub, fill = ageBMT_bin), alpha = 0.2)+
  #geom_ribbon(data = Counts_withsigma, aes(x = timeseries, ymin = lb, ymax = ub, fill = ageBMT_bin), alpha = 0.2)+
  geom_line(data = Counts_pred, aes(x = timeseries, y = median, color = ageBMT_bin), linewidth=1) +
  geom_errorbar(data = Counts_naive_sigma_obs, aes(x = timeseries, ymin = lb, ymax = ub, col=ageBMT_bin),
                alpha = 0.5, width=0.01)+
  geom_point(data = counts_data, aes(x = age.at.S1K, y = naive, color = ageBMT_bin), size=2) +
  labs(title=paste('Total counts'),  y=NULL, x= "Host age (days)") + 
  scale_color_discrete(name="Host age at \n BMT (Wks)", labels=legn_labels)+
  scale_x_continuous(limits = c(60, 460) , trans="log10", breaks=c(75, 150, 300)) + #scale_y_log10() +
  scale_y_continuous(limits = c(1e5, 1e7), trans="log10", breaks=c(1e4, 1e5, 1e6, 1e7, 1e8), minor_breaks = log10minorbreaks, labels =fancy_scientific) +
  guides(fill = 'none', col="none") + my_theme 

p3 <- ggplot() +
  geom_ribbon(data = Nfd_pred, aes(x = timeseries, ymin = lb, ymax = ub, fill = ageBMT_bin), alpha = 0.15)+
  geom_line(data = Nfd_pred, aes(x = timeseries, y = median, color = ageBMT_bin), linewidth=1) +
  geom_point(data = Nfd_data, aes(x = age.at.S1K, y = naive, color = ageBMT_bin), size=2) +
  geom_errorbar(data = Nfd_naive_sigma_obs, aes(x = timeseries, ymin = lb, ymax = ub, col=ageBMT_bin),
                alpha = 0.5, width=5)+
  labs(x = "Host age (days)", y = NULL, title = "Normalised Chimerism") +
  scale_color_discrete(name="Host age at \n BMT (Wks)", labels=legn_labels)+
  scale_x_continuous(limits = c(1, 450), breaks = c(0,200, 400))+
  scale_y_continuous(limits =c(0, 1.1), breaks = c(0, 0.2, 0.4, 0.6, 0.8, 1.0)) + 
  guides(fill='none', col='none') + my_theme 



fac_labels <- c(`agebin1`= '6-8 weeks', `agebin2`= '8-10 weeks', `agebin3`= '10-12 weeks', `agebin4`= '12-25 weeks')


p5 <- ggplot() +
  geom_ribbon(data = ki_naive_pred, aes(x = timeseries, ymin = lb*100, ymax = ub*100, fill = subcomp), alpha = 0.25)+
  geom_line(data = ki_naive_pred, aes(x = timeseries, y = median*100, color = subcomp), linewidth=1) +
  geom_point(data = ki_data, aes(x = age.at.S1K, y = naive*100, color = subcomp), size=2) +
  #geom_errorbar(data = ki_donor_naive_sigma_obs, aes(x = timeseries, ymin = lb*100, ymax = ub*100, col=subcomp),
                #alpha = 0.5, width=5)+
  #geom_errorbar(data = ki_host_naive_sigma_obs, aes(x = timeseries, ymin = lb*100, ymax = ub*100, col=subcomp),
                #alpha = 0.5, width=5)+
  scale_color_manual(values = c(2, 4))+
  labs(x = "Host age (days)", y = NULL, title = "% Ki67hi") +
  scale_x_continuous(limits = c(60, 450), breaks = c(0,100,200,300, 400, 500))+
  scale_y_continuous(limits =c(0, 25), breaks = c(0, 10, 20, 30, 40, 50)) + 
  #facet_wrap(~ ageBMT_bin, scales = 'free', labeller = as_labeller(fac_labels))+
  guides(fill='none') + my_theme + theme(legend.title = element_blank(), legend.position = c(0.19, 0.92))



ptop <- cowplot::plot_grid(p1, p3, p2, p4, ncol = 4, nrow = 1)
pmiddle <- cowplot::plot_grid(p5, p6, ncol = 2, nrow = 1)
pbottom <- cowplot::plot_grid(p7, p8, p9, ncol = 3, nrow = 1)
cowplot::plot_grid(ptop, pmiddle, pbottom, nrow = 3)

ggsave(filename = "Figure_3.pdf", device = "pdf", height = 11, width = 8.5, units = "in")


