library(tidyverse)


if (!file.exists(here::here("figure1.preprocessed.rd"))) {

  message("Preprocessing data...")
  library(skater)

  read_akt2 <- function(fp, varnames = c("prarr", "paara", "prrra", "praaa", "paarr", "prraa")) {
    read_akt(fp) %>%
      mutate(file = basename(fp)) %>%
      mutate(params = gsub("halfSiblings1_g3-b4-i1\\.|\\.akt", "", file)) %>%
      separate(params, into = varnames, sep = "_", remove = FALSE) %>%
      mutate_at(.vars = varnames,  .funs = list(~ as.numeric(gsub("[[:alpha:]]", "", .))))
  }

  tmp_fam <- read_fam("data/input/simulation-everyone.fam")
  tmp_ped <- fam2ped(tmp_fam)

  truth_degrees <-
    ped2kinpair(tmp_ped$ped[[1]]) %>%
    filter(id1 == "halfSiblings1_g3-b4-i1" | id2 == "halfSiblings1_g3-b4-i1") %>%
    mutate(degree = kin2degree(k, max_degree = 3)) %>%
    mutate(degree = ifelse(is.na(degree), "Unrelated", degree)) %>%
    rename(truth_degree = degree, truth_k = k) %>%
    arrange_ids(., id1, id2)


  ## read in akt results
  fps <- list.files("data/akt/", full.names = TRUE)

  res <- map_df(fps, read_akt2)

  ## filter akt results to just pairs with individual sim-ed
  ## join to truth degrees
  res <-
    res %>%
    filter(id1 == "halfSiblings1_g3-b4-i1" | id2 == "halfSiblings1_g3-b4-i1") %>%
    mutate(degree = kin2degree(k, max_degree = 3)) %>%
    mutate(degree = ifelse(is.na(degree), "Unrelated", degree)) %>%
    arrange_ids(id1, id2) %>%
    left_join(truth_degrees, by = c("id1","id2"))

  ## group_by first so we can use the group_keys to get and store param names
  res_grouped <-
    res %>%
    group_by(params)

  ## map the confusion_matrix function over list of akt results
  confmatres <-
    res_grouped %>%
    group_split() %>%
    set_names(unlist(group_keys(res_grouped))) %>%
    map(., ~confusion_matrix(prediction = .x$degree, target = .x$truth_degree))

  warnings()

  ## define a helper to map over conusion matrix list and grab the accuracy
  get_acc <- function(x) {
    x %>%
      pluck("Accuracy")
  }

  ## map over confusion matrix list
  ## use names of list to pull out error params used
  ## misc clean up for params to help with plotting later
  accuracy_dat <-
    map_df(confmatres, get_acc) %>%
    mutate(params = names(confmatres)) %>%
    separate(params,
             into = c("prarr", "paara", "prrra", "praaa", "paarr", "prraa"),
             sep = "_",
             remove = FALSE) %>%
    mutate_at(.vars = c("prarr", "paara", "prrra", "praaa", "paarr", "prraa"),
              .funs = list(~ as.numeric(gsub("[[:alpha:]]", "", .)))) %>%
    mutate(total_error = prarr + paara + prrra + praaa + paarr + prraa) %>%
    mutate(params = gsub("_", ";", params)) %>%
    mutate(params = gsub("p", "", params)) %>%
    mutate(params = factor(params)) %>%
    mutate(params = fct_reorder(params, total_error))

  accuracy_dat_zeros <- accuracy_dat %>%
    select(Accuracy, prarr:prraa) %>%
    filter(prarr==0 & paara==0 & prrra==0 & praaa==0 & paarr==0 & prraa==0) %>%
    gather(errormode, error_rate, -Accuracy)

  save(accuracy_dat, accuracy_dat_zeros, file=here::here("figure1.preprocessed.rd"))
} else {
  message("Loading preprocessed data...")
  load(here::here("figure1.preprocessed.rd"))
}


accuracy_dat %>%
  select(Accuracy, prarr:prraa) %>%
  gather(errormode, error_rate, -Accuracy) %>%
  filter(error_rate>0) %>%
  bind_rows(accuracy_dat_zeros) %>%
  mutate(`Error mode`=chartr("ar", "AR", errormode)) %>%
  mutate(`Error mode`=`Error mode` %>% factor() %>% fct_reorder(-Accuracy, .fun=mean)) %>%
  ## code to extract incoming genoytpe
  # mutate(ingen = str_extract(`Error mode`, "^.{3}")) %>%
  # mutate(ingen = gsub("p","",ingen)) %>%
  ggplot(aes(error_rate, Accuracy)) +
  geom_point(aes(col=`Error mode`, pch = `Error mode`)) +
  geom_hline(yintercept=unique(accuracy_dat$`Accuracy Guessing`), lty=1, lwd=1.25) +
  geom_line(aes(col=`Error mode`)) +
  scale_x_continuous(breaks = seq(0,0.2, by=0.01)) +
  # facet_wrap(~ingen, ncol = 3)
  labs(x="Error rate", y="Classification accuracy",
       title=NULL,
       subtitle=NULL) +
  theme_bw()

ggsave("figure1.png", width=10, height=6, dpi = 350)



library(geomtextpath)
accuracy_dat %>%
  select(Accuracy, prarr:prraa) %>%
  gather(errormode, error_rate, -Accuracy) %>%
  filter(error_rate>0) %>%
  bind_rows(accuracy_dat_zeros) %>%
  mutate(`Error mode`=chartr("ar", "AR", errormode)) %>%
  mutate(`Error mode`=`Error mode` %>% factor() %>% fct_reorder(-Accuracy, .fun=mean)) %>%
  ggplot(aes(error_rate, Accuracy)) +
  geom_hline(yintercept=unique(accuracy_dat$`Accuracy Guessing`), lty=1, lwd=1.25) +
  geom_textline(aes(label=`Error mode`, group=`Error mode`, col=`Error mode`),
                straight=TRUE, hjust=.15, vjust=.4, fontface=2) +
  scale_x_continuous(breaks = seq(0,0.2, by=0.01)) +
  labs(x="Error rate", y="Classification accuracy",
       title=NULL,
       subtitle=NULL) +
  theme_bw() +
  theme(legend.position="none")

ggsave("figure1.textline.png", width=10, height=6, dpi = 350)
