 
## clearing the environment
rm(list = ls())  
gc()    

library(rstan)
library(loo)
library(tidyverse)
library(bayesplot)
####################################################################################

## model specific details that needs to be change for every run
modelName <- "LinearDLD_memTreg_EM"

## Setting all the directories for opeartions
projectDir <- getwd()
scriptDir <- file.path(projectDir, "scripts")
modelDir <- file.path(projectDir, "stan_models")
dataDir <- file.path(projectDir, "datafiles")
toolsDir <- file.path(scriptDir, "tools")
outputDir <- file.path(projectDir, "output_fit")
saveDir <- file.path(projectDir, 'save_csv', 'memory_Treg2')
LooDir <- file.path('loo_fit') 

# loadiong the scr# loadiong the script that contains functions for plotting stan parameters
source(file.path(toolsDir, "stanTools.R"))                # save results in new folder

# compiling multiple stan objects together that ran on different nodes
stanfit1 <- read_stan_csv(file.path(saveDir, paste0(modelName, "_a1", ".csv")))
stanfit2 <- read_stan_csv(file.path(saveDir, paste0(modelName, "_a2",".csv")))
stanfit3 <- read_stan_csv(file.path(saveDir, paste0(modelName, "_a3",".csv")))
stanfit4 <- read_stan_csv(file.path(saveDir, paste0(modelName, "_a4",".csv")))
stanfit5 <- read_stan_csv(file.path(saveDir, paste0(modelName, "_a5",".csv")))

fit <- sflist2stanfit(list(stanfit5, stanfit4, stanfit3, stanfit2, stanfit1))

# finding the parameters used in the model 
# using the last parameter("sigma4") in the array to get the total number of parameters set in the model
num_pars <- which(fit@model_pars %in% "global_params") -1      # the variable "sigma4" will change depdending on the data used
parametersToPlot <- fit@model_pars[1:num_pars]

# number of post-burn-in samples that are used for plotting 
nPost <- nrow(fit)

ymta <- as.data.frame(fit, pars = parametersToPlot) 

quantile(ymta$psi, probs = c(0.025, 0.5, 0.975)) * 1000
#write.csv(ymta, file = paste0("save_csv/parsMat_", modelName, ".csv"))

################################################################################################
################################################################################################
## loading required datasets for plotting
## importing data to be fitted 
Nfd_file <- file.path("data", "Treg_memory_Nfd.csv")
Treg_memory_Nfd <- read.csv(Nfd_file) %>%
  arrange(age.at.S1K) %>%
  mutate(ageBMT_bin = ifelse(age.at.BMT <= 56, 'agebin1',
                             ifelse(age.at.BMT <= 70, 'agebin2',
                                    ifelse(age.at.BMT <= 84, 'agebin3', 'agebin4'))),
         days_postBMT= age.at.S1K - age.at.BMT) %>%
  select(-Popln) %>%
  filter(mouse.ID != 314807)

ki_file <- file.path("data", "Treg_memory_ki.csv")
Treg_memory_ki <- read.csv(ki_file) %>%
  arrange(age.at.S1K) %>%
  mutate(ageBMT_bin = ifelse(age.at.BMT <= 56, 'agebin1',
                             ifelse(age.at.BMT <= 70, 'agebin2',
                                    ifelse(age.at.BMT <= 84, 'agebin3', 'agebin4'))),
         days_postBMT= age.at.S1K - age.at.BMT) %>%
  select(-Popln) %>% rename(Donor = ki_donor, Host = ki_host) %>%
  gather(c(Donor, Host), key = "subcomp", value = "ki_prop") %>%
  filter(mouse.ID != 314806)


# ################################################################################################
# calculating PSIS-L00-CV for the fit
naive_counts_loglik <- extract_log_lik(fit, parameter_name = "log_lik_counts", merge_chains = TRUE)
naive_fd_loglik <- extract_log_lik(fit, parameter_name = "log_lik_Nfd", merge_chains = TRUE)
ki_donor_loglik <- extract_log_lik(fit, parameter_name = "log_lik_ki_donor", merge_chains = TRUE)
ki_host_loglik <- extract_log_lik(fit, parameter_name = "log_lik_ki_host", merge_chains = TRUE)

#combined_loglik <- extract_log_lik(fit, parameter_name = "log_lik", merge_chains = TRUE)
log_lik_comb <- cbind(naive_counts_loglik, naive_fd_loglik,
                      ki_donor_loglik, ki_host_loglik)

# optional but recommended
ll_array <- extract_log_lik(fit,parameter_name = "log_lik_counts", merge_chains = FALSE)
r_eff <- relative_eff(exp(ll_array))

# loo-ic values
loo_loglik <- loo(log_lik_comb, save_psis = FALSE, cores = 8)

# saving loo object as rds file for model comparison
loofilename <- paste0("loosave_", modelName, ".rds")
write_rds(loo_loglik, file  = file.path(LooDir, loofilename))


ploocv <- data.frame("Model" = modelName,
                     "LooIC" = loo_loglik$estimates[3],
                     "SE" = loo_loglik$estimates[6],
                     "PLoo" = loo_loglik$estimates[2])
ploocv

write.table(ploocv, file = file.path(outputDir, "stat_table_mem.csv"),
            sep = ",", append = T, quote = FALSE,
            col.names = F, row.names = FALSE)

### posterior distributions of parameters
ptable <- monitor(as.array(fit, pars = parametersToPlot), warmup = 0, print = FALSE)
out_table <- ptable[1:num_pars, c(1, 3, 4, 8)]
out_table
write.csv(out_table, file = file.path(outputDir, paste0('params_', modelName, ".csv")))


################################################################################################

## open graphics device 
## saving  plots for quality control 
pdf(file = file.path(outputDir, paste(modelName,"Plots%03d.pdf", sep = "")),
    width = 12, height = 5, onefile = F)

pairs(fit, pars = parametersToPlot)

options(bayesplot.base_size = 15,
        bayesplot.base_family = "sans")
bayesplot::color_scheme_set(scheme = "viridis")

rhats <- rhat(fit, pars = parametersToPlot)
mcmc_rhat(rhats) + yaxis_text() + myTheme

ratios1 <- neff_ratio(fit, pars = parametersToPlot)
mcmc_neff(ratios1) + yaxis_text() + myTheme

posterior <- as.array(fit)
mcmc_acf(posterior, pars = parametersToPlot) + myTheme

mcmcHistory(fit, pars = parametersToPlot, nParPerPage = 4, myTheme = myTheme)

mcmc_dens_overlay(posterior, parametersToPlot)
mcmc_dens(posterior, parametersToPlot) + myTheme


dev.off()
