library(tidyr)
library(ggplot2)
library(dplyr)
library(mgcv)
library(scam)
library(caret)
library(readr)

logit <- function(p){
  return(log(p/(1-p)))
}

# import data for calibration fitting
labeled_sample <- read_csv("labeled_sample-adj_proba.csv") %>%
  select(-X1)

redrawn_samples <- read_csv("redrawn_labeled_samples_nodupes.csv") %>%
  select(-c(X1, cal_proba)) %>%
  mutate(cargo = as.factor(cargo)) %>%
  mutate(puff = ifelse(score == 1, "puff", "nonpuff")) %>%
  drop_na() # some scores are NA?

scored_dat <- read_csv("binned_labeled_samples_20191121.csv") %>%
  select(-c(X1, adj_proba)) %>%
  mutate(cargo = as.factor(cargo)) %>%
  mutate(puff = ifelse(score == 1, "puff", "nonpuff")) %>%
  rbind(redrawn_samples) %>%
  arrange(cargo) %>%
  distinct() %>%
  left_join(labeled_sample %>% select(cell, particle, adj_proba),
            by=c("cell", "particle")) %>%
  drop_na() %>%
  mutate(adj_proba = (adj_proba + 0.000001)/(1 + 0.000001),
         l_prob = logit(adj_proba))


# fit a shape-constrained additive model (scam) to predict actual probability
# that an event is a puff, given cargo and random forest output probability
# (scam works like gam, but with an additional monotonicity constraint)
# cal_gam <- scored_dat %>%
#   scam(score ~ cargo + s(adj_proba, by=cargo, bs="mpi"),
#        family=binomial(), data=.)
cal_gam <- scored_dat %>%
  scam(score ~ cargo + s(l_prob, by=cargo, bs="mpi"),
       family=binomial(), data=.)

# here's what the calibrating functions look like
scored_dat %>%
  mutate(cal_prob = cal_gam$fitted.values) %>%
  ggplot(aes(x = adj_proba, y = cal_prob, color=cargo)) +
  geom_line()

# specify data to apply model to, and calculate calibrated probabilities
eval_data <- data.frame(cargo = scored_dat$cargo,
                        l_prob = scored_dat$l_prob)

# these are the adjusted probabilities for each event
cal_prob <- predict(cal_gam, 
                    newdata = eval_data,
                    type="response")


# here's some calibration curves, with some pointwise confidence bands,
# of the adjusted probabilities for each cargo
# (not necessary for the sampling)
cal_dat_b2 <- (scored_dat %>%
                 mutate(cal_prob = cal_prob) %>%
                 filter(cargo == "B2") %>%
                 calibration(puff ~ cal_prob, class="puff", cuts=10,
                             data=.))$data

cal_dat_mor <- (scored_dat %>%
                  mutate(cal_prob = cal_prob) %>%
                  filter(cargo == "MOR") %>%
                  calibration(puff ~ cal_prob, class="puff", cuts=10,
                              data=.))$data

cal_dat_tfr <- (scored_dat %>%
                  mutate(cal_prob = cal_prob) %>%
                  filter(cargo == "TfR") %>%
                  calibration(puff ~ cal_prob, class="puff", cuts=10,
                              data=.))$data

cal_dat_mor %>%
  mutate(cargo = "MOR") %>%
  rbind(cal_dat_b2 %>% mutate(cargo = "B2")) %>%
  rbind(cal_dat_tfr %>% mutate(cargo = "TfR")) %>%
  mutate(Upper = ifelse(is.na(Upper), 100, Upper),
         Lower = ifelse(is.na(Lower), 0, Lower)) %>%
  ggplot(aes(x = midpoint)) +
  geom_point(aes(y = Percent)) +
  geom_line(aes(y = Percent)) +
  geom_line(aes(y = Lower), color="orange", alpha=0.8) +
  geom_line(aes(y = Upper), color="orange", alpha=0.8) +
  geom_abline(intercept=0, slope=1, color="blue",
              alpha=0.4) +
  geom_rug(data = scored_dat, aes(x = 100*cal_prob), sides = "b", alpha=0.2) +
  facet_wrap(~cargo, nrow=2)

# Sampling plan, for the unlabeled data:
# 1) Make eval_data from full test data, rather than scored_dat
# 2) Use cal_gam (trained on scored_dat) to predict on eval_data
# 3) cal_prob is then the calibration-adjusted probabilities for every event
#    in the training data
# 4) divide cal_prob values into bins [0, 0.1], (0.1, 0.2], etc. within each cargo
# 5) sample without replacement in each bin, for each cargo. Aiming for the 
#    same numbers as before in each bin. 

# The above, but on the unlabeled data:
unlabeled_sample <- read_csv("unlabeled_samples-adj_proba.csv") %>%
  select(-X1)

redrawn_samples <- read_csv("redrawn_unlabeled_samples_nodupes.csv") %>%
  select(-c(X1, cal_proba)) %>%
  mutate(cargo = as.factor(cargo)) %>%
  mutate(puff = ifelse(score == 1, "puff", "nonpuff")) %>%
  drop_na() # some scores are NA?

scored_dat <- read_csv("scored_binned_samples-20191121.csv") %>%
  select(-c(X1,adj_proba)) %>%
  mutate(cargo = as.factor(cargo)) %>%
  mutate(puff = ifelse(score == 1, "puff", "nonpuff")) %>%
  rbind(redrawn_samples) %>%
  arrange(cargo) %>%
  distinct() %>%
  left_join(unlabeled_sample %>% select(cell, particle, adj_proba),
            by=c("cell", "particle")) %>%
  drop_na() %>%
  mutate(adj_proba = (adj_proba + 0.000001)/(1 + 0.000001),
         l_prob = logit(adj_proba))

# specify datrea to apply model to, and calculate calibrated probabilities
eval_data <- data.frame(cargo = scored_dat$cargo,
                        l_prob = scored_dat$l_prob)

# these are the adjusted probabilities for each event
cal_prob <- predict(cal_gam, 
                    newdata = eval_data,
                    type="response")

cal_dat_b2 <- (scored_dat %>%
                 mutate(cal_prob = cal_prob) %>%
                 filter(cargo == "B2") %>%
                 calibration(puff ~ cal_prob, class="puff", cuts=10,
                             data=.))$data

cal_dat_mor <- (scored_dat %>%
                  mutate(cal_prob = cal_prob) %>%
                  filter(cargo == "MOR") %>%
                  calibration(puff ~ cal_prob, class="puff", cuts=10,
                              data=.))$data

cal_dat_tfr <- (scored_dat %>%
                  mutate(cal_prob = cal_prob) %>%
                  filter(cargo == "TfR") %>%
                  calibration(puff ~ cal_prob, class="puff", cuts=10,
                              data=.))$data

cal_dat_mor %>%
  mutate(cargo = "MOR") %>%
  rbind(cal_dat_b2 %>% mutate(cargo = "B2")) %>%
  rbind(cal_dat_tfr %>% mutate(cargo = "TfR")) %>%
  mutate(Upper = ifelse(is.na(Upper), 100, Upper),
         Lower = ifelse(is.na(Lower), 0, Lower)) %>%
  ggplot(aes(x = midpoint)) +
  geom_point(aes(y = Percent)) +
  geom_line(aes(y = Percent)) +
  geom_line(aes(y = Lower), color="orange", alpha=0.8) +
  geom_line(aes(y = Upper), color="orange", alpha=0.8) +
  geom_abline(intercept=0, slope=1, color="blue",
              alpha=0.4) +
  geom_rug(data = scored_dat, aes(x = 100*cal_prob), sides = "b", alpha=0.2) +
  facet_wrap(~cargo, nrow=2)