## clearing the environment
rm(list = ls())  
gc()    

library(rstan)
library(tidyverse)
library(kableExtra)
####################################################################################

#### plotting style
myTheme <- theme(axis.text = element_text(size = 20),
                 strip.text = element_text(size = 19),
                 axis.title =  element_text(size = 20),
                 plot.title = element_text(size=20,  hjust = 0.5),
                 legend.background = element_blank(), legend.key = element_blank(),
                 legend.text = element_text(size = 18))

# setting ggplot theme for rest fo the plots
theme_set(theme_bw())


####### plotting
fancy_scientific <- function(l) {
  # turn in to character string in scientific notation
  l <- format(l, scientific = TRUE)
  # quote the part before the exponent to keep all the digits
  l <- gsub("^(.*)e", "'\\1'e", l)
  # remove + after exponent, if exists. E.g.: (e^+2 -> e^2)
  l <- gsub("e\\+","e",l)  
  # turn the 'e' into plotmath format
  l <- gsub("e", "%*%10^", l)
  # convert 1x10^ or 1.000x10^ -> 10^
  l <- gsub("\\'1[\\.0]*\\'\\%\\*\\%", "", l)
  # return this as an expression
  parse(text=l)
}

log10minorbreaks=as.numeric(1:10 %o% 10^(-3:8))

## Setting all the directories for opeartions
projectDir <- getwd()
dataDir <- file.path(projectDir, "datafiles")
outputDir <- file.path(projectDir, "output_fit")
saveDir <- file.path(projectDir, 'save_csv', "naive_Treg")


## model specific details that needs to be change for every run
M1 <- "LinearDLD_memTreg_naiTreg"

# best fit params
params_mat <- read.csv(file = paste0("save_csv/parsMat_", M1, ".csv")) %>% select(-X) %>%
  mutate(Lifespan_fast= 1/(lambda_D + rho_D),
         Lifespan_slow= 1/(lambda_I + rho_I),
         clonal_halftime_fast = log(2)/lambda_D,
         clonal_halftime_slow = log(2)/lambda_I,
         Maturation_time = 1/mu,
         Interdivision_time_fast= 1/(rho_D),
         Interdivision_time_slow= 1/(rho_I)) %>%  
  select(contains("fast"), contains("slow"), contains("Maturation")) %>%
  gather(factor_key = TRUE) %>%
  group_by(key) %>%
  summarize(estimate = quantile(value, probs = 0.5),
            lb = quantile(value, probs = 0.025),
            ub = quantile(value, probs = 0.975)) %>%  
  mutate(parname = ifelse(grepl("Life", key), "Residence time", 
                          ifelse(grepl("Inter", key), "Interdivision time",
                                 ifelse(grepl("Maturation", key), "Maturation time", "Clonal halflife"))),
         compname = ifelse(grepl("slow", key), "Slow", "Fast")) 

modelname <- "linear model with naive Treg subset as the precursor population"
#modelname <- "Incumbent model with thymic SP CD4 subset as the precursor population"

partab_export <- params_mat %>%
  arrange(desc(parname))  %>%mutate_if(is.numeric, ~round(., 0)) %>%
  select(parname, compname, estimate, lb, ub) %>%
  kable(align=c("l", "l", "c", "c", "c"),
        caption = paste0("Estimates from the ",  modelname),
        col.names = c("Parameter",
                      "Compartment",
                      "Mean",
                      "2.5%",
                      "97.5%"))  %>%
  #add_header_above(c(" " = 3, "Credible interval" = 2 )) %>%
  add_header_above(c(" " = 2, "Estimates in days" = 3 )) %>%
  collapse_rows(columns = 1, valign = "top") %>% 
  kable_classic_2(full_width = F) %>% 
  kable_styling(font_size = 16, html_font = "helvetica") #%>%
  #footnote(alphabet = "Error bars depict 1 S.D. around the mean estimate")


partab_export

kableExtra::save_kable(partab_export, file = paste0(M1, "_partab.pdf"))

param_mem <- params_mat %>%
  filter(!grepl("Clonal", parname))


ggplot(param_mem, aes(y=estimate, x=factor(key), col=compname))+
  labs(y="Days") + scale_color_manual(values = c(4, 2), name=NULL)+
  geom_errorbar(aes(y=estimate, ymin=lb, ymax=ub, x=key),
                width=0.15, linetype=1,  position=position_dodge(0.4)) +
  geom_point(position=position_dodge(width=0.4), stat = "identity", size=4) + 
  facet_wrap(~ factor(parname), scales = "free") + 
  expand_limits(y = c(1, 290))  +
  #scale_y_log10(limits=c(1, 300), breaks=c(1, 3, 10, 30, 100), minor_breaks=log10minorbreaks)+
  myTheme + theme(axis.text.x=element_blank(),
                  axis.title.x=element_blank(),
                  legend.position = c(0.917, 0.9))


ggsave(filename = file.path(outputDir, paste0(M1, "_Pars.pdf")), last_plot(),
       device = "pdf", height = 4.5, width = 8.5)





# compiling multiple stan objects together that ran on different nodes
stanfit1 <- read_stan_csv(file.path(saveDir, paste0(M2, "_a1", ".csv")))
stanfit2 <- read_stan_csv(file.path(saveDir, paste0(M2, "_a2",".csv")))
stanfit3 <- read_stan_csv(file.path(saveDir, paste0(M2, "_a3",".csv")))
stanfit4 <- read_stan_csv(file.path(saveDir, paste0(M2, "_a4", ".csv")))
stanfit5 <- read_stan_csv(file.path(saveDir, paste0(M2, "_a5",".csv")))

fit_M2 <- sflist2stanfit(list(stanfit5, stanfit4, stanfit3, stanfit2, stanfit1))


# finding the parameters used in the model 
# using the last parameter("sigma4") in the array to get the total number of parameters set in the model
num_pars_M2 <- which(fit_M2@model_pars %in% "sigma_counts_naive") -1      
params_M2 <- fit_M2@model_pars[1:num_pars_M2]



ptable1 <- monitor(as.array(fit1, pars = params_M1), warmup = 0, print = FALSE)
out_table1 <- data.frame(ptable1[1:num_pars_M1, c(1, 4, 8)])
names(out_table1) <- c('Estimates', 'par_lb', 'par_ub')

ptable2 <- monitor(as.array(fit2, pars = params_M2), warmup = 0, print = FALSE)
out_table2 <- data.frame(ptable2[1:num_pars_M2, c(1, 4, 8)])
names(out_table2) <- c('Estimates', 'par_lb', 'par_ub')


ptable3 <- monitor(as.array(fit3, pars = params_M3), warmup = 0, print = FALSE)
out_table3 <- data.frame(ptable3[1:num_pars_M3, c(1, 4, 8)])
names(out_table3) <- c('Estimates', 'par_lb', 'par_ub')


df_pars <- rbind(out_table1, out_table2, out_table3) %>%
  mutate(parname = c(row.names(out_table1), row.names(out_table2), row.names(out_table3)),
         Source = c(rep('naive Treg', num_pars_M1), rep('naive conv', num_pars_M2), rep('EM conv', num_pars_M3))) %>%
  filter(!grepl("_0", parname))

blank_data <- data.frame(parname = rep(c("psi", "mu", "rho_D", "lambda_D", "lambda_I", "rho_I"), 3),
                        # Param = rep(c("Rate of influx", "Rate of division displaceable", "Rate of loss displaceable", "Rate of division Incumbent"), 3),
                         Source = c(rep('naive Treg', 6), rep('naive conv', 6), rep('EM conv', 6)),
                         Estimates = c(3e-5, 0.003, 0.3, 1e-5, 1e-5, 3e-3, 
                                       0.03, 0.3, 0.03, 0.1, 0.1, 0.1,
                                       0.03, 0.3, 0.003, 0.1, 0.1, 0.3))

fac_labels <- c(`psi`="Rate of influx",
                `mu`="Fast to slow", 
                `rho_D`="Division rate of fast", 
                `lambda_D`="Net loss rate of fast",
                `lambda_I`="Net loss rate of slow",
                `rho_I`="Division rate of slow")

ggplot(df_pars, aes(y=Estimates, x=factor(Source), col=Source))+
  labs(y=NULL) +
  geom_errorbar(aes(y=Estimates, ymin=par_lb, ymax=par_ub, x=Source),
                width=0.2, linetype=1,  position=position_dodge(0.4)) +
  geom_blank(data = blank_data)+
  geom_point(position=position_dodge(width=0.4), stat = "identity", size=4) + 
  facet_wrap(~ factor(parname), scales = "free", labeller = as_labeller(fac_labels)) + 
  #expand_limits(y = c(0.001, 1))  +
  scale_y_log10()+
  myTheme + theme(axis.text.x=element_blank(),
                  axis.title.x=element_blank())


# time sequence for predictions specific to age bins within the data
ts_pred1 <- 10^seq(log10(66), log10(450), length.out = 300)
ts_pred2 <- 10^seq(log10(91), log10(450), length.out = 300)
ts_pred3 <- 10^seq(log10(90), log10(450), length.out = 300)
ts_pred4 <- 10^seq(log10(174), log10(450), length.out = 300)
tb_pred1 <- rep(45, 300)
tb_pred2 <- rep(66, 300)
tb_pred3 <- rep(76, 300)
tb_pred4 <- rep(118, 300)


# naive Treg counts in the thymus with 90% envelopes
Counts_pred1 <- as.data.frame(fit1, pars = c("counts_naive_mean_pred1", "counts_naive_mean_pred2",
                                                 "counts_naive_mean_pred3", "counts_naive_mean_pred4")) %>%
  gather(factor_key = TRUE) %>%
  group_by(key) %>%
  summarize(lb = quantile(value, probs = 0.05),
            median = quantile(value, probs = 0.5),
            ub = quantile(value, probs = 0.95)) %>%
  bind_cols("timeseries" = c(ts_pred1, ts_pred2, ts_pred3, ts_pred4))%>%
  mutate(ageBMT_bin = ifelse(grepl("pred1", key),"agebin1",
                             ifelse(grepl("pred2", key), "agebin2",
                                    ifelse(grepl("pred3", key), "agebin3", "agebin4")))) 

# naive Treg counts in the thymus with 90% envelopes
Counts_pred2 <- as.data.frame(fit2, pars = c("counts_mean_pred1", "counts_mean_pred2",
                                           "counts_mean_pred3", "counts_mean_pred4")) %>%
  gather(factor_key = TRUE) %>%
  group_by(key) %>%
  summarize(lb = quantile(value, probs = 0.05),
            median = quantile(value, probs = 0.5),
            ub = quantile(value, probs = 0.95)) %>%
  bind_cols("timeseries" = c(ts_pred1, ts_pred2, ts_pred3, ts_pred4))%>%
  mutate(ageBMT_bin = ifelse(grepl("pred1", key),"agebin1",
                             ifelse(grepl("pred2", key), "agebin2",
                                    ifelse(grepl("pred3", key), "agebin3", "agebin4")))) 


# naive Treg counts in the periphery  with 90% envelopes
Counts_per_pred1 <- as.data.frame(fit1, pars = c("counts_per_mean_pred1", "counts_per_mean_pred2",
                                               "counts_per_mean_pred3", "counts_per_mean_pred4")) %>%
  gather(factor_key = TRUE) %>%
  group_by(key) %>%
  summarize(lb = quantile(value, probs = 0.05),
            median = quantile(value, probs = 0.5),
            ub = quantile(value, probs = 0.95))  %>%
  bind_cols("timeseries" = c(ts_pred1, ts_pred2, ts_pred3, ts_pred4))%>%
  mutate(ageBMT_bin = ifelse(grepl("pred1", key),"agebin1",
                             ifelse(grepl("pred2", key), "agebin2",
                                    ifelse(grepl("pred3", key), "agebin3", "agebin4"))),
         location = "Periphery",
         Model = "M1")  


# naive Treg counts in the thymus with 90% envelopes
Counts_thy_pred2 <- as.data.frame(fit2, pars = c("counts_thy_mean_pred1", "counts_thy_mean_pred2",
                                                 "counts_thy_mean_pred3", "counts_thy_mean_pred4")) %>%
  gather(factor_key = TRUE) %>%
  group_by(key) %>%
  summarize(lb = quantile(value, probs = 0.05),
            median = quantile(value, probs = 0.5),
            ub = quantile(value, probs = 0.95)) %>%
  bind_cols("timeseries" = c(ts_pred1, ts_pred2, ts_pred3, ts_pred4))%>%
  mutate(ageBMT_bin = ifelse(grepl("pred1", key),"agebin1",
                             ifelse(grepl("pred2", key), "agebin2",
                                    ifelse(grepl("pred3", key), "agebin3", "agebin4"))),
         location = "Thymus",
         Model = "M2") 


# naive Treg counts in the periphery  with 90% envelopes
Counts_per_pred2 <- as.data.frame(fit2, pars = c("counts_per_mean_pred1", "counts_per_mean_pred2",
                                                 "counts_per_mean_pred3", "counts_per_mean_pred4")) %>%
  gather(factor_key = TRUE) %>%
  group_by(key) %>%
  summarize(lb = quantile(value, probs = 0.05),
            median = quantile(value, probs = 0.5),
            ub = quantile(value, probs = 0.95))  %>%
  bind_cols("timeseries" = c(ts_pred1, ts_pred2, ts_pred3, ts_pred4))%>%
  mutate(ageBMT_bin = ifelse(grepl("pred1", key),"agebin1",
                             ifelse(grepl("pred2", key), "agebin2",
                                    ifelse(grepl("pred3", key), "agebin3", "agebin4"))),
         location = "Periphery",
         Model = "M2")  



Counts_pred <- rbind(Counts_thy_pred1, Counts_thy_pred2, Counts_per_pred1, Counts_per_pred2)

legn_labels <- c('6-8', '8-10', '10-12', '12-25')

ggplot() +
  geom_ribbon(data = Counts_pred, aes(x = timeseries, ymin = lb, ymax = ub, fill = ageBMT_bin), alpha = 0.15)+
  geom_line(data = Counts_pred, aes(x = timeseries, y = median, color = ageBMT_bin)) +
  geom_point(data = counts_data, aes(x = age.at.S1K, y = total_counts, color = ageBMT_bin), size=2) +
  labs(title=paste('Total counts of naive Tregs'),  y=NULL, x= "Host age (days)") + 
  scale_color_discrete(name="Host age at \n BMT (Wks)", labels=legn_labels)+
  scale_x_continuous(limits = c(60, 450) , trans="log10", breaks=c(10, 30, 100, 300))+
  scale_y_continuous(limits = c(5e3, 5e6), trans="log10", breaks=c(1e4, 1e5, 1e6, 1e7, 1e8), minor_breaks = log10minorbreaks, labels =fancy_scientific) +
  facet_grid(factor(Model)~factor(location, levels =c('Thymus', "Periphery")))+
  guides(fill = 'none') + myTheme 










