# hhSAR functions

# custom functions to help with analysis

days_with_protection <- function(sample,vaccine){
    x <- difftime(sample, vaccine, units = 'days')
    x <- parse_number(as.character(x))
    x[is.na(x)] <- 0
    x <- pmax(x,0)
    x
}

recode_vax <- function(x){
    case_when(x == "ChAdOx1"  ~ 1L,
              x == "BNT162b2" ~ 2L,
              x == "None"     ~ 3L,
              TRUE            ~ NA_integer_)
}

decode_vax <- function(x){
    case_when(x == 2L ~ "BNT162b2",
              x == 1L ~ "ChAdOx1",
              x == 3L ~ "None",
              TRUE    ~ NA_character_) %>%
        factor(., levels = c("None", "ChAdOx1", "BNT162b2"))
}

recode_doses <- function(x){
    case_when(x == "One"  ~ 1L,
              x == "Two"  ~ 2L,
              x == "Zero" ~ 3L,
              TRUE        ~ NA_integer_)
}

decode_doses <- function(x){
    case_when(x == 1L ~ "One",
              x == 2L ~ "Two",
              x == 3L ~ "Zero",
              TRUE    ~ NA_character_) 
}

mkdir <- function(dir){
    if (!dir.exists(dir)){
        dir.create(dir)
    }    
}

random_string <- function(n = 8){
    paste0(sample(c(LETTERS,letters,0:9), size = n, replace = T), collapse = "")
}

clean_data <- function(dat_raw, 
                       HOCONUMBER_to_exclude = character(0),
                       NickID_to_exclude     = character(0)){
    
    # remove carriage return character
    dat <- dat_raw %>% 
        distinct %>% 
        mutate(FLCORC = sub(pattern     = "\r",
                            replacement = "",
                            x           = FLCORC, 
                            fixed       = T))
    
    # remove any withdrawn
    dat %<>% filter(!grepl(pattern = 'withdrawn', 
                           x = str_to_lower(LOSTorWITHDRAWN)))
    
    # remove any not tested
    dat %<>% filter(FLCOR8 %in% c("#det", "#nvd"))
    
    # remove missing FLCORC values
    dat %<>% filter(!is.na(FLCORC))
    
    # remove two doses of nothing
    # lapsed as withdrawn flag now available
    dat %<>% filter(!(NickID %in% c("CAC1599/MS2",
                                    "VEC57523/KR2")))
    
    dat %<>% filter(!grepl(pattern = 'moderna', 
                           x = str_to_lower(ReCodedVaccineType)))
    
    # exclude any with missing vaccine dates
    
    dat %<>% filter(!(is.na(VaccDate1) & DosesCovidVacc != "Not vaccinated"),
                   !(is.na(VaccDate2) & DosesCovidVacc == "Two doses"))
    
    dat %<>% filter(!(HOCONUMBER %in% HOCONUMBER_to_exclude))
    dat %<>% filter(!(NickID     %in% NickID_to_exclude))
    
    
    dat %<>% mutate_at(.vars = vars(SWAB2CTe, FLCOEC),
                       .funs = ~ifelse(. == "#na" | . == "#nd", NA, .))
    
    dat %<>% mutate_at(.vars = vars(FLCO8C, FLCOEC),
                       .funs = fix_Ct)
    
    
    # fix dates where vaccine2 date has wrong year
    # first, convert from dttm to date
    
    dat %<>% mutate_at(.vars = vars(SAMPLE_DT, VaccDate1, VaccDate2, RECEPT_DT),
                       .funs = as.Date) %>%
        # then fix the 2012 typo
        mutate(SAMPLE_DT = as.character(SAMPLE_DT),
               SAMPLE_DT = sub(pattern     = "2012", 
                               replacement = "2021",
                               x           = SAMPLE_DT),
               SAMPLE_DT = as.Date(SAMPLE_DT),
               CONTACTSymptomStartDate = as.character(CONTACTSymptomStartDate),
               CONTACTSymptomStartDate = sub(pattern = "^3021",
                                             replacement = "2021",
                                             x = CONTACTSymptomStartDate),
               CONTACTSymptomStartDate = as.Date(CONTACTSymptomStartDate)) %>%
        # then fix dates where the second vaccine was given before the first
        mutate(VaccDate2 = VaccDate2 + years(VaccDate2 < VaccDate1)) 
    
    # then if the sample date is missing, assume it was taken a day before the
    # sample was returned
    # this may be a bit dicey, as we should have days 1, 3, 7. How do we better check when it was taken?
    dat %<>% 
        mutate(SAMPLE_DT = if_else(is.na(SAMPLE_DT), # if SAMPLE_DT is missing
                                   RECEPT_DT - 1,    # assume 1 day delay
                                   SAMPLE_DT))       # or leave as is
    
    # fix incorrect SAMPLE dates
    dat %<>% 
        mutate(SAMPLE_DT = case_when(
            NickID == 'CAC0196/UP1' & SAMPLE_DT == '2021-04-09' ~ as.Date('2021-05-09'),
            NickID == 'CAC0258/AR1' & SAMPLE_DT == '2021-02-07' ~ as.Date('2021-05-05'),
            NickID == 'CAC0293/AW4' & SAMPLE_DT == '2021-03-12' ~ as.Date('2021-03-18'),
            NickID == 'VEC1978/TG1' & SAMPLE_DT == '2021-01-15' ~ as.Date('2021-04-15'),
            NickID == 'VEC2418/BM1' & SAMPLE_DT == '2021-01-30' ~ as.Date('2021-04-30'),
            NickID == 'VEC0398/GR1' & SAMPLE_DT == '2021-02-18' ~ as.Date('2021-03-18'),
            NickID == 'VEC1867/SB1' & SAMPLE_DT == '2021-03-18' ~ as.Date('2021-04-15'),
            NickID == 'VEC1867/SB1' & is.na(SAMPLE_DT)          ~ as.Date('2021-04-09'),
            NickID == 'VEC0227/YK1' & SAMPLE_DT == '2021-03-19' ~ as.Date('2021-02-19'),
            NickID == 'VEC0227/HK2' & SAMPLE_DT == '2021-03-19' ~ as.Date('2021-02-19'), 
            NickID == 'VEC0938/GR1' & SAMPLE_DT == '2021-02-18' ~ as.Date('2021-03-18'), 
            NickID == 'VEC2034/CC1' & SAMPLE_DT == '2021-04-09' ~ as.Date('2021-04-14'), 
            NickID == 'VEC2430/AM1' & RECEPT_DT == '2021-04-23' ~ as.Date('2021-04-23'), 
            NickID == 'VEC2430/AM1' & RECEPT_DT == '2021-05-04' ~ as.Date('2021-04-30'), 
            TRUE                                                ~ SAMPLE_DT
        ))
    
    # fix some bad vaccination dates
    dat %<>% mutate(VaccDate1 = case_when(
        NickID == 'VEC58146/KY1' ~ as.Date("2021-04-04"),
        TRUE ~ VaccDate1
    ))
    
    
    # fix a typo and recode factors
    dat %<>% mutate(ReCodedVaccineType =
                        sub(pattern     = "AstraZenica", 
                            replacement = "AstraZeneca",
                            x           = ReCodedVaccineType),
                    ReCodedVaccineType = factor(ReCodedVaccineType,
                                                levels = c("None", "Pfizer", "AstraZeneca")),
                    ReCodedVaccineType = forcats::fct_explicit_na(ReCodedVaccineType, "None"),
                    
                    ReCodedVaccineType = fct_recode(ReCodedVaccineType,
                                                    "BNT162b2" = "Pfizer",
                                                    "ChAdOx1"  = "AstraZeneca"),
                    ReCodedVaccineType = fct_relevel(ReCodedVaccineType,
                                                     "ChAdOx1",
                                                     "BNT162b2",
                                                     "None"),
                    DosesCovidVacc_num = 
                        case_when(DosesCovidVacc == "Not vaccinated" ~ 0L,
                                  DosesCovidVacc == "One dose"       ~ 1L,
                                  DosesCovidVacc == "Two doses"      ~ 2L),
                    DosesCovidVacc = factor(DosesCovidVacc_num, 
                                            labels = c("Zero", "One", "Two")))
    
    # make sure we only keep households with an index case
    dat_hh_remaining_indexes <- 
        ungroup(dat) %>% 
        filter(STATUS == "CASE") %>%
        distinct(HOCONUMBER)
    
    dat %<>% inner_join(dat_hh_remaining_indexes, by = "HOCONUMBER")
    
    # check that we're tossing any singleton households
    
    dat_hh_size <- ungroup(dat) %>% 
        distinct(HOCONUMBER, NickID) %>%
        count(HOCONUMBER) %>%
        filter(n <= 1) %>% 
        select(HOCONUMBER)
    
    dat %<>% anti_join(dat_hh_size, by = "HOCONUMBER")
    
    dat %<>% mutate(StatedHouseholdSize = 
                        case_when(HOCONUMBER == "CAC0112" ~ "2",
                                  TRUE ~ StatedHouseholdSize))
    
    dat
}

calculate_serial_intervals <- function(dat){
    
    
    # need to build a data frame that contains info on both case and contact vaccination status
    # need to ensure we grab the days with protection, derivable from vaccination date
    
    contact_cases <- dat %>% split(.$STATUS) %>%
        map(~distinct(.x, NickID, HOCONUMBER)) %>%
        {map2(.x = ., .y = names(.),
              .f = ~rename_at(.x, 
                              .vars = vars(NickID),
                              .funs = function(x)paste(x, .y, sep = "_")))} %>%
        {left_join(.$CONTACT, .$CASE, by = "HOCONUMBER")}
    
    Symptoms <- contact_cases %>% #select(contains("NickID")) %>%
        left_join(distinct(select(dat, NickID_CASE    = NickID, CASESymptomStartDate)),
                  by = "NickID_CASE") %>%
        left_join(distinct(select(dat, NickID_CONTACT = NickID, CONTACTSymptomStartDate)),
                  by = "NickID_CONTACT") 
    
    # which households have negative serial intervals?
    dat_SI <-  Symptoms %>%
        mutate(SerialInterval = difftime(CONTACTSymptomStartDate, 
                                         CASESymptomStartDate, units = 'days')) %>%
        mutate(Asymptomatic = is.na(CASESymptomStartDate))
    
    # a/symptomatic contacts
    dat %>% 
        filter(STATUS == "CONTACT") %>%
        mutate(Covid = 0L + grepl(pattern = "SARS CoV-2 detected in this sample", x = FLCORC)) %>%
        group_by(NickID) %>%
        summarise(Covid = max(Covid)) %>%
        inner_join(dat_SI,  by = c("NickID" = "NickID_CONTACT")) %>%
        mutate(Asymptomatic = is.na(CONTACTSymptomStartDate))
    
}

fit_sn_to_serial_interval <- function(dat_SI, min_SI = -Inf){
    SI_x <- dat_SI %>% 
        filter(Covid == 1, SerialInterval >= min_SI) %>% 
        na.omit %>%
        pull(SerialInterval) %>%
        as.numeric
    
    dat_SI_parms <- fitdistr(x       = SI_x, 
                             densfun = sn::dsn, 
                             method  = "SANN",
                             start   = list(xi    = -mean(SI_x), 
                                            omega = sd(SI_x),
                                            alpha = 0))
    
    dat_SI_parms
}

get_sn_from_parms <- function(parms, dat_SI, n = 101, r = NULL){
    if (is.null(r)){
        r <- as.numeric(range(dat_SI$SerialInterval, na.rm=T))    
    }
    
    x <- seq(r[1], r[2], length.out = n)
    
    data.frame(x = x,
               d = sn::dsn(x = x, dp = parms))
}


get_sn_quantiles_from_parms <- function(parms, p = c(0.025, 0.5, 0.975), digits = Inf){
    sn::qsn(p  = p, 
            dp = parms) %>%
        {setNames(round(., digits = digits), scales::percent(p, accuracy = 0.1))}
}

get_dsn <- function(dat_SI, min_SI = -Inf){
    fit_sn_to_serial_interval(dat_SI, min_SI = min_SI)
}

get_sn <- function(dsn, dat_SI, n = 101, r = NULL){
    get_sn_from_parms(parms = dsn$estimate, 
                      r = r,
                      dat_SI = dat_SI, n = n)
}

drop_long_SIs <- function(dat, dat_SI, min_SI = -Inf){
    rows_to_drop_for_symptoms <- dat_SI %>%
        filter(SerialInterval < min_SI) %>%
        distinct(HOCONUMBER)
    
    anti_join(dat, rows_to_drop_for_symptoms, by = "HOCONUMBER")
    
}

get_vaccine_dates <- function(dat){
    distinct(dat,
             STATUS, HOCONUMBER, NickID, 
             vax  = ReCodedVaccineType, 
             dose = DosesCovidVacc,
             VaccDate1,
             VaccDate2,
             RoundedAge) %>%
        split(.$STATUS) %>%
        map(~select(.x, -STATUS))
}

get_swab_dates <- function(dat){
    
    Swabs_start <- dplyr::select(dat, NickID,
                                 SAMPLE_DT, RECEPT_DT,
                                 matches("SWAB[1-3]CTOrf"), FLCO8C, FLCORC) %>%
        arrange(NickID, RECEPT_DT, SAMPLE_DT) %>%
        group_by(NickID) %>%
        mutate(id = 1:n()) %>%
        tidyr::gather(Swab, value, matches("SWAB[1-3]CTOrf")) %>%
        mutate(Swab = parse_number(Swab))
    
    Swabs <- Swabs_start %>%
        ungroup %>%
        filter(id == Swab) %>%
        select(-id, -value) %>%
        mutate(Covid = 0L + !grepl(pattern = "NOT detected", x = FLCORC)) %>%
        select(-FLCO8C, -FLCORC)
    
    Swabs
    
}

bind_vaccines_and_Swabs <- function(vaccines, Swabs){
    map(.x = vaccines, 
        ~left_join(.x, Swabs, by = "NickID")) %>%
        map(~mutate(.x, 
                    dwp1 = days_with_protection(SAMPLE_DT, VaccDate1),
                    dwp2 = days_with_protection(SAMPLE_DT, VaccDate2))) %>%
        map(~select(.x,
                    HOCONUMBER, NickID,
                    vax, dose, 
                    SAMPLE_DT, Swab,
                    Covid,
                    dwp1, dwp2, age = RoundedAge)) %>%
        {map2(.x = ., 
              .y = names(.), 
              .f = ~rename_at(.x,
                              .vars = vars(-HOCONUMBER),
                              .funs = function(x){paste(x,.y,sep = "_")}))} %>%
        {left_join(.[[2]], .[[1]], by = "HOCONUMBER")} %>% 
        group_by_at(.vars = vars(-contains("dwp"),
                                 -contains("SAMPLE_DT"),
                                 -contains("Swab"),
                                 -contains("Covid"))) %>%
        dplyr::summarise(dwp1_CONTACT = min(dwp1_CONTACT),
                         dwp2_CONTACT = min(dwp2_CONTACT),
                         dwp1_CASE    = min(dwp1_CASE),
                         dwp2_CASE    = min(dwp2_CASE),
                         Covid        = max(Covid_CONTACT),
                         .groups = 'drop') %>%
        distinct
}

recode_vars <- function(x){
    mutate(x, vax_con   = recode_vax(vax_CONTACT),
           vax_case  = recode_vax(vax_CASE),
           dose_case = recode_doses(dose_CASE),
           dose_con  = recode_doses(dose_CONTACT))
}

make_pred <- function(x){
    pred_all <- x %>% 
        ungroup %>% 
        select(contains('vax')) %>% 
        distinct %>%
        crossing(dwp1_CONTACT = c(0, 21),
                 dwp2_CONTACT = c(0,  7),
                 dwp1_CASE    = c(0, 21),
                 dwp2_CASE    = c(0,  7),
                 age_CASE     = c(25, 55),
                 age_CONTACT  = c(15, 25)) %>%
        crossing(data.frame(Variant = c("Alpha", "Delta"),
                            p_delta = c(0,1))) %>%
        filter(!(vax_CONTACT == "None" & 
                     (dwp1_CONTACT != 0 | dwp2_CONTACT != 0 ) )) %>%
        filter(!(vax_CASE    == "None" & 
                     (dwp1_CASE    != 0 | dwp2_CASE    != 0 ) )) %>%
        filter(!(vax_CONTACT != "None" & 
                     (dwp1_CONTACT == 0 & dwp2_CONTACT == 0 ) )) %>%
        filter(!(vax_CASE    != "None" & 
                     (dwp1_CASE    == 0 & dwp2_CASE    == 0 ) )) %>%
        filter(dwp2_CONTACT <= dwp1_CONTACT & dwp2_CASE <= dwp1_CASE) %>%
        filter(!(age_CONTACT < 18 & vax_CONTACT != "None"))
    
    
    pred_all %<>% mutate(idx = 1:n())
    
    pred_temp <- filter(pred_all, vax_con == 3 | vax_case == 3)
    
    pred_temp_V <- filter(pred_temp, !(vax_con == 3 & vax_case == 3))
    pred_temp_U <- filter(pred_temp,  (vax_con == 3 & vax_case == 3))
    
    # first let's look at the vaccinated CASES vs their unvaccinated counterparts
    # we don't care what status the contact has, just compare to the case
    vaxed_cases_with_counterfactuals <-
        pred_temp_V %>%
        filter(vax_case != 3) %>%
        left_join(select(pred_temp_U, 
                         -vax_case, -vax_CASE,
                         -vax_CONTACT, -vax_con,
                         -dwp1_CASE, -dwp2_CASE,
                         -dwp1_CONTACT, -dwp2_CONTACT) %>%
                      rename(idx_cf = idx))
    
    vaxed_cons_with_counterfactuals <-
        pred_temp_V %>%
        filter(vax_con != 3) %>%
        left_join(select(pred_temp_U, 
                         -vax_case, -vax_CASE,
                         -vax_CONTACT, -vax_con,
                         -dwp1_CASE, -dwp2_CASE,
                         -dwp1_CONTACT, -dwp2_CONTACT) %>%
                      rename(idx_cf = idx))
    cfs <- bind_rows(vaxed_cons_with_counterfactuals,
                     vaxed_cases_with_counterfactuals) %>%
        select(idx, idx_cf) %>%
        as.list
    
    pred_list <- pred_all %>% #select(-idx) %>%
        {append(x = as.list(.), values = list(n = nrow(.)))} %>%
        {setNames(object = ., nm = paste(names(.), "pred", sep = "_"))} %>%
        append(., values = cfs)
    
    pred_list$n_idx <- length(pred_list$idx)
    
    pred_list
    
}

make_jags_data_list <- function(vas_dat, vas_pred){
    vas_dat %>%
        {append(x = as.list(.), values = list(n = nrow(.)))} %>% 
        append(., values = vas_pred)
}

fit_model <- function(vas_list,
                      file = 'hhSARjagsModel.R', 
                      #out  = random_string(),
                      n.burn = 1e4,
                      n.post = 1e5,
                      n.pred = 1e5){
    
    vas_mod <- jags.model(file = file, data = vas_list)
    
    vas_vars <- c("beta.0",
                  "delta",
                  "epsilon_contact", "epsilon_case",
                  paste(rep(c("Beta", "Gamma"), each = 2),
                        rep(c(1,      2),       times = 2),
                        sep = '.'))
    
    vas_brn <- jags.samples(vas_mod, variable.names = vas_vars, n.iter = n.burn)
    vas_pst <- coda.samples(vas_mod, variable.names = vas_vars, n.iter = n.post)
    
    to_return <- list(vas_pst = vas_pst)
    
    if (n.pred > 0L){
        vas_prd <- coda.samples(vas_mod, variable.names = "p_pred", n.iter = n.pred)
        vas_VEs <- coda.samples(vas_mod, variable.names = "VE",     n.iter = n.pred)
        
        to_return <- append(to_return, list(vas_prd = vas_prd,
                                            vas_VEs = vas_VEs))
    } 
    to_return
    
}

calculate_ORs <- function(x, conf_level = c(0.5, 0.95)){
    # browser()
    mmcc:::tidy.mcmc.list(x$vas_pst, conf_level = conf_level) %>%
        ungroup %>%
        filter(grepl(pattern = "\\.[1-2]\\[[1-2],[1-2]\\]", x = parameter)) %>%
        mutate(prec = 1/sd^2) %>%
        mutate_at(.vars = vars(-parameter, -sd, -prec), .funs = exp) %>%
        separate(parameter, into = c("Effect", "Dose", "Vaccine", "p_delta_pred",
                                     "dummy"), 
                 sep = "(\\.|\\[|,|\\])") %>%
        select(-dummy) %>%
        mutate(Effect = case_when(
            grepl(pattern = "^Beta",  x = Effect) ~ "Infection protection",
            grepl(pattern = "^Gamma", x = Effect) ~ "Transmission reduction")) %>%
        mutate(Vaccine = decode_vax(Vaccine),
               Variant = case_when(p_delta_pred == 1 ~ "Alpha",
                                   p_delta_pred == 2 ~ "Delta",
                                   TRUE ~ "Unknown variant")
        ) %>%
        arrange(Effect, Vaccine, Dose, Variant)
}


calculate_VEs <- function(x, y, conf_level = c(0.5, 0.95)){
    # browser()
    variables <- as.data.frame(y[grep(x = names(y),
                                      value = T, pattern = "_pred")])[y$idx,]
    
    mmcc:::tidy.mcmc.list(x$vas_VEs, conf_level = conf_level) %>%
        ungroup %>%
        bind_cols(variables) %>%
        rename_all(.funs = ~sub(x = ., pattern = "_pred", replacement = "")) %>%
        mutate(doses_CONTACT = 0L + (dwp1_CONTACT > 0) + (dwp2_CONTACT > 0) ) %>%
        mutate(doses_CASE = 0L + (dwp1_CASE > 0) + (dwp2_CASE > 0) ) %>%
        select(-vax_con, -vax_case, -contains("dwp"), -idx, -p_delta)
}

pretty_ors <- function(x){
    mutate(x, CI = sprintf("%0.2f (%0.2f, %0.2f)", median, `2.5%`, `97.5%`)) %>%
        select(one_of(c("Analysis", "Variant", "Vaccine", "Dose", "Effect", "CI"))) %>%
        spread(Effect, CI) %>%
        rename(Doses = Dose) %>%
        arrange(Analysis, Variant, Vaccine, Doses) %>%
        mutate_at(.vars = vars(one_of(c("Analysis", "Variant", "Vaccine", "Doses"))),
                  .funs = unfill_for_table) 
    
}

pretty_VEs <- function(x, unfill = FALSE){
    VEs <- mutate_at(x,
                     .vars = vars(median, `2.5%`, `97.5%`),
                     .funs = ~percent(x = ., accuracy =  1)) %>%
        mutate(CI = sprintf("%s (%s, %s)",
                            median,
                            `2.5%`,
                            `97.5%`)) %>%
        mutate(Vaccine = ifelse(vax_CASE == "None",
                                paste(doses_CONTACT, as.character(vax_CONTACT)),
                                paste(doses_CASE,    as.character(vax_CASE))),
               Effect = ifelse(vax_CASE == "None",
                               "Infection protection",
                               "Transmission reduction")) %>%
        mutate(Case = factor(age_CASE, levels = c(25, 55),
                             labels = c("Adult <50",
                                        "Adult 50+")),
               Contact = factor(age_CONTACT, levels = c(15, 25),
                                labels = c("Child <18",
                                           "Adult 18+"))) %>%
        select(Analysis, Variant, Vaccine, Effect,
               Case, Contact, CI) %>%
        pivot_wider(names_from = 'Effect', values_from = 'CI') %>%
        arrange(Analysis, Variant, Vaccine, Case, Contact)
    
    if (unfill){
        VEs %<>%
            mutate_at(.vars = vars(Case, Contact),
                      .funs = as.character) %>%
            mutate_at(.vars = vars(`Infection protection`,
                                   `Transmission reduction`),
                      .funs = ~ifelse(is.na(.), "-", .)) %>%
            mutate_at(.vars = vars(Analysis, Variant, Vaccine),
                      .funs = unfill_for_table) %>%
            mutate_all(.funs = ~ifelse(is.na(.), "", .))
    }
    
    VEs
}


plot_pred <- function(x, white = 0.25){
    # x %<>% mutate(fill = cut(median_numeric, breaks = seq(0, 1, by = 0.1))) %>%
    #     mutate(fill = droplevels(fill))
    
    x %<>% mutate_at(.vars = vars(case_status, contact_status),
                     .funs = ~factor(., levels = c(0,1,2),
                                     labels = c(" ", "1", "2")))
    
    ggplot(data= x ,
           aes(x     = case_status,
               y     = contact_status,
               color = median_numeric >= white)) +
        geom_tile(aes(fill = median_numeric),
                  color = 'black',
                  size = 0.5) +
        geom_text(aes(label = median),
                  size = 3) +
        geom_text(aes(label = `2.5%`),
                  size = 2, nudge_y = -0.25) +
        geom_text(aes(label = `97.5%`), 
                  size = 2, nudge_y = 0.25) +
        facet_nested(vax_CONTACT ~ Variant +  vax_CASE,
                     switch = 'y',
                     space  = 'free', 
                     scales = 'free',
                     nest_line = TRUE#,
                     # labeller = labeller(
                     #     vax_CASE    = function(x){sprintf("Case:\n%s ",  x)},
                     #     vax_CONTACT = function(x){sprintf("Contact:\n%s", x)})
        ) +
        xlab("Case vaccine status") +
        ylab("Contact vaccine status") +
        # scale_fill_brewer(palette = 'Reds',
        #                   name = 'Median predicted\nSecondary Attack Rate',
        #                   drop = FALSE) +
        scale_fill_viridis_c(name   = 'Median predicted\nSecondary Attack Rate', 
                             option = "A",
                             direction = -1,
                             
                             limits = c(0,1)) +
        #theme_minimal() +
        theme(panel.grid = element_blank(),
              axis.title = element_blank(),
              strip.placement = "outside") +
        scale_color_manual(values = c(`TRUE`  = "white",
                                      `FALSE` = "black")) + 
        guides(color= 'none') +
        # ggtitle("Predicted household Secondary Attack Rate",
        #         "Model accounts for vaccination status of case and contact") +
        scale_y_discrete(position = 'left') +
        scale_x_discrete(position = 'top')
}

unfill_vec <- function(x, replace_value = NA) {
    same <- x == dplyr::lag(x)
    ifelse(!is.na(same) & same, replace_value, x)
}

unfill_for_table <- function(x){
    tidyr::replace_na(unfill_vec(as.character(x)),"")
}

vaccine_table <- function(x, format = 'latex'){
    # x a list
    full_tab <- map(.x = x, ~ungroup(.x) %>%
                        mutate_at(.vars = vars(dose_CONTACT, dose_CASE),
                                  .funs = function(x){fct_recode(x, 
                                                                 " " = "Zero", 
                                                                 "1" = "One",
                                                                 "2" = "Two")}) %>%
                        mutate_at(.vars = vars(vax_CONTACT, vax_CASE),
                                  .funs = function(x){fct_recode(x,
                                                                 "Unvaccinated" = "None")}) %>%
                        # mutate_at(.vars = vars(vax_CONTACT, vax_CASE),
                        #           .funs = function(x){fct_recode(x,
                        #                                          "BNT162b2" = "AstraZeneca",
                        #                                          "ChAdOx1"  = "Pfizer")}) %>%
                        transmute(CONTACT = paste(dose_CONTACT, vax_CONTACT),
                                  CASE    = paste(dose_CASE,    vax_CASE)) %>%
                        mutate_all(.funs = trimws) %>%
                        group_by(CASE, CONTACT) %>%
                        count %>%
                        ungroup %>%
                        rename(Case = CASE) %>%
                        spread(CONTACT, n, 0)) %>%
        bind_rows(., .id = "Min. SI") 
    
    n_vax <- length(unique(full_tab$Case))
    n_SIs <- length(unique(full_tab$`Min. SI`))
    
    full_tab %<>%
        mutate(`Min. SI` = unfill_vec(`Min. SI`)) %>%
        replace_na(replace = list(`Min. SI` = "")) 
    
    to_return <- full_tab
    
    if (format == 'latex'){
        
        to_return <- full_tab %>%
            knitr::kable(., booktabs = TRUE, 
                         linesep = "",
                         format = format) %>%
            kableExtra::row_spec(kable_input = ., 
                                 row = seq(n_vax,nrow(full_tab) - 1,by = n_vax),
                                 #row         = nrow(.)-1,     
                                 hline_after = TRUE) 
    }
    
    if (format == 'html'){
        to_return <- full_tab %>%
            knitr::kable(., booktabs = TRUE, 
                         linesep = "",
                         format = format) %>%
            kableExtra::kable_classic(full_width = F) 
    }
    
    to_return
    
}


amend_vaccination_status <- function(x, k1 = 21, k2 = 7){
    
    # x a data frame
    # k1, k2 the time from vaccination to immunisation
    
    x %>%
        mutate(
            DosesCovidVacc = case_when(
                SAMPLE_DT >= VaccDate2 + k2     ~ "Two",
                SAMPLE_DT >= VaccDate1 + k1     ~ "One",
                SAMPLE_DT <  VaccDate1 + k1     ~ "Zero",
                is.na(VaccDate1)                ~ "Zero",
                TRUE                            ~ NA_character_),
            ReCodedVaccineType = as.character(ReCodedVaccineType),
            ReCodedVaccineType = ifelse(DosesCovidVacc == "Zero", "None", ReCodedVaccineType),
            ReCodedVaccineType = factor(ReCodedVaccineType, levels = c("None",
                                                                       "ChAdOx1",
                                                                       "BNT162b2"))
        ) %>%
        mutate(DosesCovidVacc = factor(DosesCovidVacc, levels = c("Zero", "One", "Two"),
                                       ordered = T))
}

reallocate_status <- function(dat){
    
    # read in Jada's work
    hoco  <- read_csv("phylo/hoco_meta.csv", col_types = cols(ECOM = col_character()))
    
    phylo <- read_csv("phylo/analyse_tree_summarized.csv")
    
    updated_case <- phylo %>% 
        filter(relationship == "contact infected case") %>%
        select(HOCONUMBER) %>%
        inner_join(phylo) %>%
        select(HOCONUMBER, contains("ERR")) %>%
        rename(CASE = ERR_source,
               CONTACT = ERR_Recipient) %>%
        gather(STATUS_updated, ERR, -HOCONUMBER)
    # now we need to anti_join and then do an inner_join on the hoco data less the STATUS
    
    untouched <-  anti_join(hoco, select(updated_case, HOCONUMBER))
    
    touched   <- inner_join(hoco, select(updated_case, HOCONUMBER))
    
    updated_statuses <- distinct(touched, HOCONUMBER, STATUS,NickID, ERR) %>%
        arrange(HOCONUMBER) %>%
        left_join(updated_case) %>%
        filter(STATUS != STATUS_updated) %>%
        distinct(NickID, STATUS_updated)
    
    left_join(dat, updated_statuses) %>%
        select(HOCONUMBER, NickID, starts_with("STATUS"), 
               starts_with("CASE"),
               starts_with("CONTACT"),
               everything()) %>%
        arrange(HOCONUMBER, desc(STATUS_updated), NickID) %>%
        mutate_at(.vars = vars(contains("SymptomStartDate")), ~ifelse(is.na(.), "", as.character(.))) %>%
        unite("SymptomDate", CASESymptomStartDate, CONTACTSymptomStartDate, sep="") %>%
        # # make a new column, SymptomDate that contains the date of onset
        # mutate(SymptomDate = 
        #            case_when(STATUS_updated == "CONTACT" ~ CASESymptomStartDate,
        #                      STATUS_updated == "CASE"    ~ CONTACTSymptomStartDate,
        #                      STATUS         == "CONTACT" ~ CONTACTSymptomStartDate,
        #                      STATUS         == "CASE"    ~ CASESymptomStartDate)) %>%
        mutate(STATUS = 
                   ifelse(is.na(STATUS_updated),
                          STATUS,
                          STATUS_updated)) %>%
        mutate(CASESymptomStartDate =
                   ifelse(STATUS == "CASE",
                          SymptomDate,
                          NA),
               CONTACTSymptomStartDate = 
                   ifelse(STATUS == "CONTACT",
                          SymptomDate,
                          NA)) %>%
        select(HOCONUMBER, NickID, 
               starts_with("STATUS"), 
               starts_with("CASE"),
               starts_with("CONTACT"),
               SymptomDate,
               everything()) %>%
        select(-SymptomDate, -STATUS_updated) %>%
        mutate_at(.vars = vars(contains("SymptomStartDate")), ~as.Date(ifelse(. == "", NA, .)))
    
    
}


calculate_variant_prob <- function(dat, Swabs, model){
    
    meta <- read_csv("phylo/hoco_meta.csv") %>%
        mutate(Sequenced = !is.na(ERR)) %>%
        select(NickID, Sequenced)
    
    dat %>%
        mutate(FLCVSR_ = sub(pattern = "202012", replacement = "20DEC", x = FLCVSR)) %>%
        inner_join(Swabs) %>%
        group_by(NickID) %>%
        filter(STATUS == "CASE") %>%
        mutate(Variant = case_when(grepl("20DEC", FLCVSR_) ~ "Alpha",
                                   grepl("21APR", FLCVSR_) ~ "Delta",
                                   TRUE ~ NA_character_)) %>%
        select(FLCVSR, FLCVSQ, Variant, NickID, Covid, SAMPLE_DT, HOCONUMBER) %>% 
        left_join(meta) %>%
        group_by(NickID, HOCONUMBER) %>%
        fill(Variant, .direction = 'downup') %>%
        replace_na(list(Sequenced = FALSE)) %>%
        group_by(NickID, HOCONUMBER) %>%
        summarise(SAMPLE_DT = min(SAMPLE_DT),
                  Variant   = unique(Variant),
                  Covid     = max(Covid),
                  Sequenced = max(Sequenced, na.rm=T)) %>%
        ungroup %>%
        mutate(Variant = case_when(is.na(Variant) & Covid & Sequenced  ~ "Unknown",
                                   is.na(Variant) & Covid & !Sequenced ~ "Not sequenced",
                                   is.na(Variant) & !Covid             ~ "Uninfected",
                                   TRUE ~ Variant)) %>%
        mutate(time = as.numeric(SAMPLE_DT - as.Date('2020-01-01'))) %>%
        mutate(pred = predict(newdata = ., 
                              object = model,
                              type = 'response')) %>%
        mutate(pred = case_when(Variant == "Alpha" ~ 0,
                                Variant == "Delta" ~ 1,
                                TRUE               ~ pred)) %>%
        select(HOCONUMBER, Variant, p_delta = pred) 
}


fix_Ct <- function(x){
    dplyr::case_when(x == "#nd" ~ 40,
                     x == ">40" ~ 40,
                     x == ">45" ~ 40,
                     TRUE       ~ readr::parse_number(x))
}
