## clearing the environment
# rm(list = ls())
# gc()

library(rstan)
library(tidyverse)
library(cowplot)
#########################################################
#### plotting style
my_theme <- theme(
  axis.text = element_text(size = 11),
  strip.text = element_text(size = 11),
  axis.title = element_text(size = 11),
  plot.title = element_text(size = 11, hjust = 0.5),
  legend.background = element_blank(), legend.key = element_blank(),
  legend.text = element_text(size = 11),
  legend.title = element_text(size = 11),
  strip.background = element_blank(),
)

# setting ggplot theme for rest fo the plots
theme_set(theme_bw())


####### plotting
fancy_scientific <- function(l) {
  # turn in to character string in scientific notation
  l <- format(l, scientific = TRUE)
  # quote the part before the exponent to keep all the digits
  l <- gsub("^(.*)e", "'\\1'e", l)
  # remove + after exponent, if exists. E.g.: (e^+2 -> e^2)
  l <- gsub("e\\+", "e", l)
  # turn the 'e' into plotmath format
  l <- gsub("e", "%*%10^", l)
  # convert 1x10^ or 1.000x10^ -> 10^
  l <- gsub("\\'1[\\.0]*\\'\\%\\*\\%", "", l)
  # return this as an expression
  parse(text = l)
}

log10minorbreaks <- as.numeric(1:10 %o% 10^(-8:8))

## Setting all the directories for opeartions
project_dir <- getwd()
data_dir <- file.path(project_dir, "datafiles")
save_dir <- file.path(project_dir, "save_csv", "memory_Treg")

## modelname
modelname <- "LinearDLD_memTreg_naiTreg"

# compiling multiple stan objects together that ran on different nodes
stanfit1 <- read_stan_csv(file.path(save_dir, paste0(modelname, "_a1", ".csv")))
stanfit2 <- read_stan_csv(file.path(save_dir, paste0(modelname, "_a2", ".csv")))
stanfit3 <- read_stan_csv(file.path(save_dir, paste0(modelname, "_a3", ".csv")))
stanfit4 <- read_stan_csv(file.path(save_dir, paste0(modelname, "_a4", ".csv")))
stanfit5 <- read_stan_csv(file.path(save_dir, paste0(modelname, "_a5", ".csv")))

## merged stanfit object
fit <- sflist2stanfit(list(stanfit5, stanfit4, stanfit3, stanfit2, stanfit1))


# finding the parameters used in the model
# using the last parameter("sigma4") in the array
# to get the total number of parameters set in the model
num_pars_m <- which(fit@model_pars %in% "sigma_counts") - 1
params_m <- fit@model_pars[1:num_pars_m]


## running plotting script for model fits
source("scripts/stan_extract_forplotting_mem.R")

### Extract parameter samples as a dataframe
par_dist <- as.data.frame(fit, pars = c(params_m))

quant1.1 <- function(y1_0, y2_0, psi, lambda_D, rho_D, mu){
  (exp(y1_0) + exp(y2_0)) -
   ( psi * 10^6.2/(lambda_D + rho_D + mu - 0.027) )
}

q11_df <- c()
for (i in 1:nrow(par_dist)){
  q11_df[i] <- quant1.1(par_dist$y1_0[i], par_dist$y2_0[i], par_dist$psi[i], par_dist$lambda_D[i], par_dist$rho_D[i], par_dist$mu[i])
}


quantile(q11_df, probs = c(0.025, 0.5, 0.975))
mean(q11_df)


quant1.2 <- function(lambda_I, lambda_D){
  lambda_D - lambda_I 
}

q12_df <- data.frame("id" = seq(1:nrow(par_dist)))
for (i in 1:nrow(par_dist)){
  q12_df[i, 2] <- quant1.2(par_dist$lambda_I[i], par_dist$lambda_D[i])
}

ggplot(q12_df) +
  geom_line(aes(x=id, y=V2)) +
  ylim(-1, 1)

quant3.2 <- function(mu, lambda_I){
  nu = 0.027
  return(nu - mu - lambda_I )
}

q32_df <- data.frame("id" = seq(1:nrow(par_dist)))
for (i in 1:nrow(par_dist)){
  q32_df[i, 2] <- quant1.2(par_dist$mu[i], par_dist$lambda_I[i])
}

ggplot(q32_df) +
  geom_line(aes(x=id, y=V2)) +
  ylim(-1, 1)

ggplot(q12_df) +
  geom_boxplot(aes(y=V2)) +
  ylim(-1, 1)

quantile(exp(par_dist$y1_0) + exp(par_dist$y2_0), probs = c(0.025, 0.5, 0.975))
mean(exp(par_dist$y1_0) + exp(par_dist$y2_0))

quantile(exp(par_dist$y3_0) + exp(par_dist$y4_0), probs = c(0.025, 0.5, 0.975))
mean(exp(par_dist$y3_0) + exp(par_dist$y4_0))


## calculate residence and interdivision times
par_est <- data.frame()
for (i in seq_len(nrow(par_dist))) {
  ## 1/rate of division --> interdivision time
  par_est[i, "div_times_fast"] <- 1 / par_dist$rho_D[i]
  par_est[i, "div_times_slow"] <- 1 / par_dist$rho_I[i]
  # nolint
  ## 1/rate of loss --> residence time
  par_est[i, "resid_times_fast"] <- 1 / (par_dist$delta_D[i] + par_dist$mu[i]) # nolint: line_length_linter.
  par_est[i, "resid_times_slow"] <- 1 / (par_dist$delta_I[i] + par_dist$rho_I[i])
}

quantile((par_dist$delta_D + par_dist$mu - par_dist$rho_D), probs = c(0.025, 0.5, 0.975))
mean((par_dist$delta_D + par_dist$mu - par_dist$rho_D))

quantile((par_dist$delta_I - par_dist$rho_I), probs = c(0.025, 0.5, 0.975))
mean((par_dist$delta_I - par_dist$rho_I))

quantile(1/(par_dist$delta_D + par_dist$mu ), probs = c(0.025, 0.5, 0.975))
mean(1/(par_dist$delta_D + par_dist$mu ))

quantile(1/par_dist$rho_D, probs = c(0.025, 0.5, 0.975))
mean(1/par_dist$rho_D)

quantile(1/(par_dist$delta_I), probs = c(0.025, 0.5, 0.975))
mean(1/(par_dist$delta_I))

quantile(1/par_dist$rho_I, probs = c(0.025, 0.5, 0.975))
mean(1/(par_dist$rho_I))


## summarize across par dist
parplot <- par_est %>%
  gather(key = "parname", value = par_value) %>%
  group_by(parname) %>%
  summarize(
    lb = quantile(par_value, probs = 0.05),
    estimate = mean(par_value),
    ub = quantile(par_value, probs = 0.95)
  ) %>%
  mutate(
    subpop = ifelse(grepl("slow", parname), "Slow", "Fast"),
    param = ifelse(grepl("div", parname), "Inter-division \ntime",
      "Residence \ntime"
    )
  )


## function that describes changes in total pool size of source population
sp_numbers <- function(ts_seq) {
  t0 <- 40.0
  basl <- 6.2
  nu <- 0.027
  # peripheral naive Tregs numbers
  return(10^basl * exp(-nu * (ts_seq - t0)))
}

## Prcursor size changing with time
sp_vec <- sapply(ts_pred1, sp_numbers)

## function to evaluate daily influx from source into target population
theta_spline <- function(ts_seq, psi) {
  # psi is the rate of daily influx
  value <- (psi * sp_numbers(ts_seq))
  return(value)
}

## target population size -- is same across age BMT groups
target_df <- as.data.frame(fit, pars = "counts_mean_pred1")

## fraction of fast cells in the total mem Treg
## fast fraction is mapped on the earliest age at BMT and thus would be same across age BMT groups
fastfrac_df <- as.data.frame(fit, pars = "fast_fraction_pred1")


### influx into fast as a fraction of total (fast + slow)
influx_fast <- data.frame()
for (j in 1:ncol(target_df)) {
  for (i in 1:nrow(target_df)) {
    theta_j <- theta_spline(ts_pred1[j], par_dist$psi[i])

    influx_fast[i, j] <- theta_j / target_df[i, j]
  }
}

### influx into slow as a fraction of slow
influx_slow <- data.frame()
## mu x fast_counts/slow_counts
## mu x fast_fraction x N/(1 - fast_fraction) x N
for (j in 1:ncol(fastfrac_df)) {
  for (i in 1:nrow(fastfrac_df)) {
    influx_slow[i, j] = par_dist$mu[i] * fastfrac_df[i, j]/(1-fastfrac_df[i, j])
  }
}

### influx into slow as a fraction of slow
slow_frac <- data.frame()
## (1 - fast_fraction)
for (j in 1:ncol(fastfrac_df)) {
  for (i in 1:nrow(fastfrac_df)) {
    slow_frac[i, j] = (1-fastfrac_df[i, j])
  }
}



### influx into slow as a fraction of slow
fast_count <- data.frame()
## (1 - fast_fraction)
for (j in 1:ncol(fastfrac_df)) {
  for (i in 1:nrow(fastfrac_df)) {
    fast_count[i, j] = (fastfrac_df[i, j] * target_df[i, j])
  }
}
### influx into slow as a fraction of slow
slow_count <- data.frame()
## (1 - fast_fraction)
for (j in 1:ncol(fastfrac_df)) {
  for (i in 1:nrow(fastfrac_df)) {
    slow_count[i, j] = (1-fastfrac_df[i, j]) * target_df[i, j]
  }
}


## posterior predictive distributions
plot_function(fitobject = fit) ### this is from source("scripts/stan_extract_forplotting_mem.R")

## legend Key
legn_labels <- c("6-8", "8-10", "10-12", "12-25")

## counts
p1 <- ggplot() +
  geom_ribbon(
    data = Counts_pred,
    aes(x = timeseries, ymin = lb, ymax = ub, fill = ageBMT_bin), alpha = 0.2
  ) +
  geom_line(
    data = Counts_pred,
    aes(x = timeseries, y = median, color = ageBMT_bin)
  ) +
  geom_errorbar(
    data = Counts_sigma_obs,
    aes(x = timeseries, ymin = lb, ymax = ub, col = ageBMT_bin),
    alpha = 0.5, width = 0.01
  ) +
  geom_point(
    data = Treg_memory_Nfd,
    aes(x = age.at.S1K, y = total_counts, color = ageBMT_bin), size = 2
  ) +
  labs(title = paste("Total counts"), y = NULL, x = "Host age (days)") +
  scale_color_discrete(
    name = "Host age at \n BMT (Wks)",
    labels = legn_labels
  ) +
  scale_x_continuous(
    limits = c(60, 460),
    trans = "log10", breaks = c(75, 150, 300)
  ) +
  scale_y_continuous(
    limits = c(1e5, 1e7), trans = "log10", breaks = c(1e4, 1e5, 1e6, 1e7, 1e8),
    minor_breaks = log10minorbreaks, labels = fancy_scientific
  ) +
  guides(fill = "none", col = "none") +
  my_theme

## nfd
p2 <- ggplot() +
  geom_ribbon(
    data = Nfd_pred,
    aes(x = timeseries, ymin = lb, ymax = ub, fill = ageBMT_bin), alpha = 0.15
  ) +
  geom_line(
    data = Nfd_pred,
    aes(x = timeseries, y = median, color = ageBMT_bin)
  ) +
  geom_point(
    data = Treg_memory_Nfd,
    aes(x = age.at.S1K, y = Nfd, color = ageBMT_bin), size = 2
  ) +
  geom_errorbar(
    data = Nfd_sigma_obs,
    aes(x = timeseries, ymin = lb, ymax = ub, col = ageBMT_bin),
    alpha = 0.5, width = 5
  ) +
  labs(x = "Host age (days)", y = NULL, title = "Normalised Chimerism") +
  scale_color_discrete(name = "Host age at BMT (Wks)", labels = legn_labels) +
  scale_x_continuous(limits = c(1, 450), breaks = c(0, 200, 400)) +
  scale_y_continuous(
    limits = c(0, 1.0),
    breaks = c(0, 0.25, 0.5, 0.75, 1.0)
  ) +
  guides(fill = "none", col="none") +
  my_theme +
  theme(legend.position = "top", legend.direction = "horizontal")


fac_labels <- c(
  `agebin1` = "6-8 weeks",
  `agebin2` = "8-10 weeks",
  `agebin3` = "10-12 weeks",
  `agebin4` = "12-25 weeks"
)


## ki prop
p3 <- ggplot() +
  geom_ribbon(
    data = ki_pred,
    aes(x = timeseries, ymin = lb * 100, ymax = ub * 100, fill = subcomp),
    alpha = 0.25
  ) +
  geom_line(
    data = ki_pred,
    aes(x = timeseries, y = median * 100, color = subcomp)
  ) +
  geom_point(
    data = Treg_memory_ki,
    aes(x = age.at.S1K, y = ki_prop * 100, color = subcomp), size = 2
  ) +
  geom_errorbar(
    data = ki_donor_sigma_obs,
    aes(x = timeseries, ymin = lb * 100, ymax = ub * 100, col = subcomp),
    alpha = 0.5, width = 5
  ) +
  geom_errorbar(
    data = ki_host_sigma_obs,
    aes(x = timeseries, ymin = lb * 100, ymax = ub * 100, col = subcomp),
    alpha = 0.5, width = 5
  ) +
  scale_color_manual(values = c(2, 4)) +
  labs(x = "Host age (days)", y = NULL,
       title = expression("Proportion of Ki67"^"hi" ~ " cells")) +
  scale_x_continuous(limits = c(60, 450), breaks = c(0, 200, 400)) +
  scale_y_continuous(limits = c(0, 60), breaks = c(0, 20, 40, 60)) +
  facet_wrap(~ageBMT_bin, scales = "free", labeller = as_labeller(fac_labels)) +
  guides(fill = "none", col = "none") +
  my_theme


### parame plots --
## interdivision and residence times
p4 <- ggplot(parplot) +
  geom_errorbar(
    aes(
      y = estimate, ymin = lb, ymax = ub, x = subpop,
      color = subpop
    ),
    width = 0.2, linetype = 1, position = position_dodge(0.4)
  ) +
  geom_point(aes(y = estimate, x = subpop, col = subpop),
             position = position_dodge(width = 0.4), stat = "identity", size = 4
  ) +
  scale_color_manual(values = c("#004594", "#BE749E"), name = NULL) +
  facet_wrap(~param) +
  labs(y = "Days", x = NULL, title = NULL) +
  my_theme +
  theme(
    axis.text.x = element_blank(),
    legend.position = c(0.78, 0.86)
  )



## first row
pww <- cowplot::plot_grid(p2, p4,
                          ncol = 2, nrow = 1
)
print(pww)

p5 <- influx_fast %>%
  gather(factor_key = TRUE) %>%
  group_by(key) %>%
  summarize(
    lb = quantile(value, probs = 0.05),
    estimate = quantile(value, probs = 0.5),
    ub = quantile(value, probs = 0.95)
  ) %>%
  bind_cols("timeseries" = c(ts_pred1)) %>%
  ggplot() +
  geom_line(aes(x = timeseries, y = estimate), size = 0.75, col = "#004594") +
  # geom_line(aes(x=timeseries, y=sp_vec), col=2)+
  geom_ribbon(aes(x = timeseries, ymin = lb, ymax = ub),
    fill = "#004594", alpha = 0.2
  ) +
  labs(
    x = "Host age (days)", y = NULL,
    title = "Daily influx into the fast subset \nas a fraction of total"
  ) +
  scale_y_continuous(limits = c(1e-8, 2e-2),
                     trans = "log10",
                     breaks = c(1e-8, 1e-6, 1e-4, 1e-2),
                     minor_breaks = log10minorbreaks,
                     labels = fancy_scientific) +
  scale_x_log10()+
  my_theme 


p6 <- influx_slow %>%
  gather(factor_key = TRUE) %>%
  group_by(key) %>%
  summarize(
    lb = quantile(value, probs = 0.05),
    estimate = quantile(value, probs = 0.5),
    ub = quantile(value, probs = 0.95)
  ) %>%
  bind_cols("timeseries" = ts_pred1) %>%
  ggplot() +
  geom_line(aes(x = timeseries, y = estimate), size = 0.75, col = "#004594") +
  # geom_line(aes(x=timeseries, y=sp_vec), col=2)+
  geom_ribbon(aes(x = timeseries, ymin = lb, ymax = ub),
              alpha = 0.2, fill = "#004594"
  ) +
  labs(
    x = "Host age (days)", y = NULL,
    title = "Daily influx into the slow subset \nas a fraction of slow"
  ) +
  scale_y_continuous(limits = c(1e-3, 0.04),
                     breaks = c(0, 0.02, 0.04),
                     minor_breaks = log10minorbreaks) +
  scale_x_log10()+
  my_theme


p7 <- slow_frac %>%
  gather(factor_key = TRUE) %>%
  group_by(key) %>%
  summarize(
    lb = quantile(value, probs = 0.05),
    estimate = quantile(value, probs = 0.5),
    ub = quantile(value, probs = 0.95)
  ) %>%
  bind_cols("timeseries" = ts_pred1) %>%
  ggplot() +
  geom_line(aes(x = timeseries, y = estimate), linewidth = 0.75, col = "#004594") +
  # geom_line(aes(x=timeseries, y=sp_vec), col=2)+
  geom_ribbon(aes(x = timeseries, ymin = lb, ymax = ub),
              alpha = 0.2, fill = "#004594"
  ) +
  labs(
    x = "Host age (days)", y = NULL,
    title = "Fraction of slow in Treg memory"
  ) +
  scale_y_continuous(limits = c(0, 1),
                     breaks = c(0, 0.2, 0.4, 0.6, 0.8, 1),
                     minor_breaks = log10minorbreaks) +
  scale_x_log10()+
  my_theme

fast_count_plot <- fast_count %>%
  gather(factor_key = TRUE) %>%
  group_by(key) %>%
  summarize(
    lb = quantile(value, probs = 0.05),
    estimate = quantile(value, probs = 0.5),
    ub = quantile(value, probs = 0.95)
  ) %>%
  bind_cols("timeseries" = ts_pred1,
            "subpop" = rep("Fast", length(ts_pred1)))

slow_count_plot <- slow_count %>%
  gather(factor_key = TRUE) %>%
  group_by(key) %>%
  summarize(
    lb = quantile(value, probs = 0.05),
    estimate = quantile(value, probs = 0.5),
    ub = quantile(value, probs = 0.95)
  ) %>%
  bind_cols("timeseries" = ts_pred1,
            "subpop" = rep("Slow", length(ts_pred1)))

fast_slow_count_plot <- rbind(fast_count_plot,
                              slow_count_plot)

p8 <-  ggplot(fast_slow_count_plot) +
  geom_line(aes(x = timeseries, y = estimate, col=subpop), size = 0.75) +
  geom_ribbon(data = fast_slow_count_plot,
              aes(x = timeseries, ymin = lb, ymax = ub, 
                                         fill =subpop),
              alpha = 0.2) +
  labs(
    x = "Host age (days)", title = NULL,
    y = "Pool size"
  ) +
  scale_color_manual(values = c("#004594", "#BE749E"), name = NULL) +
  scale_fill_manual(values = c("#004594", "#BE749E"), name = NULL) +
  facet_wrap(~subpop) +
  scale_y_continuous(limits = c(2e3, 1e6),
                     trans = "log10",
                     breaks = c(1e4, 1e5, 1e6),
                     minor_breaks = log10minorbreaks, 
                     labels = fancy_scientific) +
  scale_x_log10()+
  my_theme +
  guides(col="none", fill="none") +
  theme(
    legend.position = c(0.18, 0.86)
  )




###

# Extract the legend to label across plot row in the plot_grid
p_legend <- ggplot(
  Treg_memory_Nfd,
  aes(x = age.at.S1K, y = Nfd, color = ageBMT_bin)
) +
  geom_point() +
  scale_color_discrete(
    name = "Host age at BMT (Wks)",
    labels = legn_labels
  ) +
  theme(legend.position = "bottom") +
  my_theme

legend_across <- cowplot::get_legend(p_legend)


# Create a title --- for a column in the plot_grid function
title1 <- cowplot::ggdraw() +
  cowplot::draw_label("Linear model fits",
    fontface = "bold", size = 12, x = 0.45, hjust = 0.0
  )

title2 <- cowplot::ggdraw() +
  cowplot::draw_label("Parameter estimates",
    fontface = "bold", size = 11, x = 0.45, hjust = 0.0
  )

## first row
ptop <- cowplot::plot_grid(p1, p2,
  labels = c("A", "B"), label_size = 11,
  ncol = 2, nrow = 1
)

## with the title
pfirst <- cowplot::plot_grid(title1, legend_across, ptop,
  nrow = 3, rel_heights = c(0.1, 0.05, 1),
  align = "vh"#, axis = "bt"
)

## middle row
pmiddle <- cowplot::plot_grid(p3,
  labels = c("C"), label_size = 11,
  ncol = 1, nrow = 1,
  align = "vh"#, axis = "bt"
)

## bottom row
pbottom <- cowplot::plot_grid(p4, p5, p6,
  labels = c("D", "E"), label_size = 11,
  rel_widths = c(1.0, 1.0, 1.0),
  ncol = 3, nrow = 1, align = "vh", axis = "bt"
)

### bottom row with the title
plastrow <- cowplot::plot_grid(title2, pbottom,
  nrow = 2, rel_heights = c(0.05, 1),
  align = "vh", axis = "bt"
)

p_all <- cowplot::plot_grid(pfirst, pmiddle, pbottom,
  nrow = 3,
  align = "vh", axis = "btlr",
  rel_heights = c(1, 1.6, 1.0)
)

print(p_all)

#
# paligned <- cowplot::align_plots(ptop, pmiddle, p7,
#                                  axis = "btlr", align = "v")
# plast <- cowplot::plot_grid(paligned[[3]], p8, p9, ncol = 3, align = "h", axis = "bt")
# print(plast)
# final_plot <- cowplot::plot_grid(paligned[[1]], paligned[[2]], plast,
#                                  nrow = 3,
#                                  align = "v", axis = "rl")
# print(final_plot)


ggsave(
  filename = "Figure_4.pdf", device = "pdf",
  height = 10, width = 7.5, units = "in"
)

