library(dplyr)
library(ggplot2)
library(pracma)
library(readxl)
source("prevalence_functions.R")

# --- Get Prevalence Data --- #
Ns = 10000  #Number of bootstrap iterations

adjustspec = TRUE
smallthresh = FALSE
agesexn = TRUE

posc0 = correctedPrevalence(smallthresh, adjustspec, agesexn)

pop_MN = 1.793E6
pop_SP = 12.18E6

posc0 = posc0 %>% dplyr::filter(donmo >= 3 & donmo <= 10) %>% arrange(donmo) #Ignore cases before march

outputs = c("alphamin_MN", "alphamin_SP", "U_MN_MN", "U_SP_MN",  "pmin_SP", "pmin_MN")
for(x in outputs)
  assign(paste0(x,"_s"), vector("list", Ns+1))

start_time = Sys.time()
for(is in 1:(Ns+1)) {
  print(is)
  if(is == Ns+1)
    posc = posc0
  else
    posc = prev.bootstrap(posc0)
  
  posc$prev[posc$prev < 0] = 0
  pos_SP = (posc %>% dplyr::filter(location == "Sao Paulo"))$prev
  N_SP = (posc %>% dplyr::filter(location == "Sao Paulo"))$n
  names(pos_SP) = (posc %>% dplyr::filter(location == "Sao Paulo"))$donmo
  
  pos_MN = (posc %>% dplyr::filter(location == "Manaus"))$prev
  N_MN = (posc %>% dplyr::filter(location == "Manaus"))$n
  names(pos_MN) = (posc %>% dplyr::filter(location == "Manaus"))$donmo
  
  # --- Prevalence Correction --- #
  
  par = calculate.alphamin.p(pos_MN)
  alphamin_MN = par$alpha
  pmin_MN = par$p
  par = calculate.alphamin.p.SP(pos_SP, FALSE, alphamin_MN, pmin_MN)
  alphamin_SP = par$alpha
  pmin_SP = par$p

  u_MN_MN = calculate.u(pos_MN, alphamin_MN, FALSE, pmin_MN)  #cases per day
  u_SP_MN = calculate.u(pos_SP,alphamin_SP, FALSE, pmin_SP)

  U_MN_MN = cumsum(u_MN_MN)  #Cumulative cases
  U_SP_MN = cumsum(u_SP_MN)

  for(x in outputs)
    eval(parse(text=sprintf('%s_s[[%s]] = %s', x, is, x)))
}

end_time = Sys.time()
print(paste("Elapsed:", end_time - start_time))

for(x in outputs)
{
  eval(parse(text = sprintf("%s.est = %s_s[[Ns+1]]", x, x)))
  X = get(paste0(x,"_s"))[1:Ns]
  cil = rep(NA, length(X[[1]]))
  ciu = rep(NA, length(X[[1]]))
  for(i in 1:length(ciu)) {
    Xi = sapply(1:Ns, function(k) X[[k]][i]) 
    ciu[i] = quantile(Xi, 0.975)
    cil[i] = quantile(Xi, 0.025)
  }
  assign(paste0(x, ".est"), get(paste0(x, "_s"))[[Ns+1]]) 
  assign(paste0(x, ".cil"), cil)
  assign(paste0(x, ".ciu"), ciu)
}


# --- Generate plots --- #

posc_SP = posc %>% dplyr::filter(location == "Sao Paulo")
posc_MN = posc %>% dplyr::filter(location == "Manaus")


Nm = length(posc_MN$donmo)
dfplot_MN = data.frame(Month = rep(posc_MN$donmo, 2), Prevalence =  c(posc_MN$prev, U_MN_MN.est), 
                       Method = c(rep("Measured Prelavence", Nm), rep("Corrected Prevalence", Nm)),  
                       lb = c(posc_MN$ci_l, U_MN_MN.cil), 
                       ub = c(posc_MN$ci_u, U_MN_MN.ciu) )

Nm = length(posc_SP$donmo)
dfplot_SP = data.frame(Month = rep(posc_SP$donmo, 2), Prevalence =  c(posc_SP$prev, U_SP_MN.est), 
                       Method = c(rep("Measured Prelavence", Nm),
                                  rep("Corrected Prevalence", Nm)), 
                       lb = c(posc_SP$ci_l, U_SP_MN.cil), ub = c(posc_SP$ci_u, U_SP_MN.ciu) )

g1_MN = ggplot(dfplot_MN, aes(x = Month, y = Prevalence, fill = Method)) + geom_crossbar(aes(ymin = lb, ymax = ub), position = "dodge")+ 
  ggtitle("Corrected Prevalence for Manaus") + theme_bw() + theme(legend.position = "none")

g1_SP = ggplot(dfplot_SP, aes(x = Month, y = Prevalence, fill = Method)) + geom_crossbar(aes(ymin = lb, ymax = ub), position = "dodge")+ 
  ggtitle("Corrected Prevalence for São Paulo") + theme_bw() + theme(legend.position = "none")


write.csv(dfplot_MN, "data/correction Manaus.csv")
write.csv(dfplot_SP, "data/correction Sao Paulo.csv")
save.image(file='data/sim10k.RData')