utils::globalVariables(c(":=", "in_bin", "bin_id"))

#' Conditional density estimation with HAL in a single cross-validation fold
#'
#' @details Estimates the conditional density of A|W for a subset of the full
#'  set of observations based on the inputted structure of the cross-validation
#'  folds. This is a helper function intended to be used to select the optimal
#'  value of the penalization parameter for the highly adaptive lasso estimates
#'  of the conditional hazard (via \code{\link[origami]{cross_validate}}). The
#'
#' @param fold Object specifying cross-validation folds as generated by a call
#'  to \code{\link[origami]{make_folds}}.
#' @param long_data A \code{data.table} or \code{data.frame} object containing
#'  the data in long format, as given in \insertRef{diaz2011super}{haldensify},
#'  as produced by \code{\link{format_long_hazards}}.
#' @param wts A \code{numeric} vector of observation-level weights, matching in
#'  its length the number of records present in the long format data. Default
#'  is to weight all observations equally.
#' @param lambda_seq A \code{numeric} sequence of values of the tuning
#'  parameter of the Lasso L1 regression passed to
#'  \code{\link[hal9001]{fit_hal}}.
#'
#' @importFrom stats aggregate plogis
#' @importFrom origami training validation fold_index
#' @importFrom assertthat assert_that
#' @importFrom hal9001 fit_hal
#' @importFrom Rdpack reprompt
#'
#' @return A \code{list}, containing density predictions, observations IDs,
#'  observation-level weights, and cross-validation indices for conditional
#'  density estimation on a single fold of the overall data.
cv_haldensify <- function(fold, long_data, wts = rep(1, nrow(long_data)),
                          lambda_seq = exp(seq(-1, -13, length = 100))) {
  # make training and validation folds
  train_set <- origami::training(long_data)
  valid_set <- origami::validation(long_data)

  # subset observation-level weights to the correct size
  wts_train <- wts[fold$training_set]
  wts_valid <- wts[fold$validation_set]

  # fit a HAL regression on the training set
  # NOTE: not selecting lambda by CV so no need to pass IDs for fold splitting
  hal_fit_train <- hal9001::fit_hal(
    X = as.matrix(train_set[, -c(1, 2)]),
    Y = as.numeric(train_set$in_bin),
    max_degree = NULL,
    fit_type = "glmnet",
    family = "binomial",
    lambda = lambda_seq,
    cv_select = FALSE,
    standardize = FALSE, # pass to glmnet
    weights = wts_train, # pass to glmnet
    yolo = FALSE
  )

  # get intercept and coefficient fits for this value of lambda from glmnet
  alpha_hat <- hal_fit_train$glmnet_lasso$a0
  betas_hat <- hal_fit_train$glmnet_lasso$beta
  coefs_hat <- rbind(alpha_hat, betas_hat)

  # make design matrix for validation set manually
  pred_x_basis <- hal9001::make_design_matrix(
    as.matrix(valid_set[, -c(1, 2)]),
    hal_fit_train$basis_list
  )
  pred_x_basis <- hal9001::apply_copy_map(
    pred_x_basis,
    hal_fit_train$copy_map
  )
  pred_x_basis <- cbind(rep(1, nrow(valid_set)), pred_x_basis)

  # manually predict along sequence of lambdas
  preds_logit <- pred_x_basis %*% coefs_hat
  preds <- stats::plogis(as.matrix(preds_logit))

  # compute hazard for a given observation by looping over individuals
  density_pred_each_obs <- lapply(unique(valid_set$obs_id), function(id) {
    # get predictions for the current observation only
    hazard_pred_this_obs <- matrix(preds[valid_set$obs_id == id, ],
      ncol = length(lambda_seq)
    )

    # map hazard to density for a single observation and return
    density_pred_this_obs <-
      map_hazard_to_density(hazard_pred_single_obs = hazard_pred_this_obs)

    return(density_pred_this_obs)
  })

  # aggregate predictions across observations
  density_pred <- do.call(rbind, as.list(density_pred_each_obs))

  # collapse weights to the observation level
  wts_valid_reduced <- stats::aggregate(
    wts_valid, list(valid_set$obs_id),
    unique
  )
  colnames(wts_valid_reduced) <- c("id", "weight")

  # construct output
  out <- list(
    preds = density_pred,
    ids = wts_valid_reduced$id,
    wts = wts_valid_reduced$weight,
    fold = origami::fold_index()
  )
  return(out)
}

################################################################################

#' Cross-validated conditional density estimation with HAL
#'
#' @details Estimation of the conditional density A|W through using the highly
#'  adaptive lasso to estimate the conditional hazard of failure in a given
#'  bin over the support of A. Cross-validation is used to select the optimal
#'  value of the penalization parameters, based on minimization of the weighted
#'  log-likelihood loss for a density.
#'
#' @param A The \code{numeric} vector or similar of the observed values of an
#'  intervention for a group of observational units of interest.
#' @param W A \code{data.frame}, \code{matrix}, or similar giving the values of
#'  baseline covariates (potential confounders) for the observed units whose
#'  observed intervention values are provided in the previous argument.
#' @param wts A \code{numeric} vector of observation-level weights. The default
#'  is to weight all observations equally.
#' @param grid_type A \code{character} indicating the strategy (or strategies)
#'  to be used in creating bins along the observed support of the intervention
#'  \code{A}. For bins of equal range, use "equal_range"; consult documentation
#'  of \code{\link[ggplot2]{cut_interval}} for more information. To ensure each
#'  bin has the same number of points, use "equal_mass"; consult documentation
#'  of \code{\link[ggplot2]{cut_number}} for details.
#' @param n_bins Only used if \code{type} is set to \code{"equal_range"} or
#'  \code{"equal_mass"}. This \code{numeric} value indicates the number(s) of
#'  bins into which the support of the intervention \code{A} is to be divided.
#' @param lambda_seq A \code{numeric} sequence of values of the tuning
#'  parameter of the Lasso L1 regression passed to
#'  \code{\link[hal9001]{fit_hal}}.
#' @param use_future A \code{logical} indicating whether to attempt to use
#'  parallelization based on the \pkg{future} and \pkg{future.apply} packages.
#'  If set to \code{TRUE}, \code{\link[future.apply]{future_mapply}} will be
#'  used in place of \code{mapply}. When set to \code{TRUE}, a parallelization
#'  scheme must be set externally by using \code{\link[future]{plan}}.
#' @param seed_int An integer used to set the seed in the cross-validation
#'  procedure used to select binning values. This is passed to the argument
#'  \code{future.seed} of \code{\link[future.apply]{future_mapply}}.
#'
#' @importFrom origami make_folds cross_validate
#' @importFrom future.apply future_mapply
#' @importFrom hal9001 fit_hal
#'
#' @return Object of class \code{haldensify}, containing a fitted
#'  \code{hal9001} object, a vector of break points used in binning \code{A}
#'  over its support \code{W}, sizes of the bins used in each fit, the tuning
#'  parameters selected by cross-validation, and the range of the \code{A}.
#'
#' @examples
#' # simulate data: W ~ U[-4, 4] and A|W ~ N(mu = W, sd = 0.5)
#' set.seed(76924)
#' n_train <- 100
#' w <- runif(n_train, -4, 4)
#' a <- rnorm(n_train, w, 0.5)
#' # learn relationship A|W using HAL-based density estimation procedure
#' mod_haldensify <- haldensify(
#'   A = a, W = w, n_bins = 5,
#'   lambda_seq = exp(seq(-1, -13, length = 300))
#' )
#' @export
haldensify <- function(A,
                       W,
                       wts = rep(1, length(A)),
                       grid_type = c(
                         "equal_range", "equal_mass"
                       ),
                       n_bins = c(5, 10),
                       lambda_seq = exp(seq(-1, -13, length = 1000)),
                       use_future = FALSE,
                       seed_int = 791L) {
  # catch input
  call <- match.call(expand.dots = TRUE)

  # run CV-HAL for all combinations of n_bins and grid_type combos
  tune_grid <- expand.grid(
    grid_type = grid_type, n_bins = n_bins,
    stringsAsFactors = FALSE
  )

  # apply grid of binning strategies and bin number over estimation routine to
  # select Lasso tuning parameter via cross-validated loss minimization
  args <-
    list(
      FUN = function(n_bins, grid_type) {
        # re-format input data into long hazards structure
        reformatted_output <- format_long_hazards(
          A = A, W = W, wts = wts,
          grid_type = grid_type, n_bins = n_bins
        )
        long_data <- reformatted_output$data
        bin_sizes <- reformatted_output$bin_length

        # extract weights from long format data structure
        wts_long <- long_data$wts
        long_data[, wts := NULL]

        # make folds with origami
        folds <- origami::make_folds(long_data, cluster_ids = long_data$obs_id)

        # call cross_validate on cv_density function...
        haldensity <- origami::cross_validate(
          cv_fun = cv_haldensify,
          folds = folds,
          long_data = long_data,
          wts = wts_long,
          lambda_seq = lambda_seq,
          use_future = FALSE,
          .combine = FALSE
        )

        # re-organize output cross-validation procedure
        density_pred_unscaled <- do.call(rbind, as.list(haldensity$preds))

        # re-scale predictions by multiplying by bin width for each failure bin
        density_pred_scaled <- apply(density_pred_unscaled, 2, function(x) {
          pred <- x / bin_sizes[long_data[in_bin == 1, bin_id]]
          return(pred)
        })
        obs_wts <- do.call(c, as.list(haldensity$wts))

        # compute loss for the given individual
        density_loss <- apply(density_pred_scaled, 2, function(x) {
          pred_weighted <- x * obs_wts
          loss_weighted <- -log(pred_weighted)
          return(loss_weighted)
        })

        # take column means to have average loss across sequence of lambdas
        loss_mean <- colMeans(density_loss)
        lambda_loss_min_idx <- which.min(loss_mean)
        lambda_loss_min <- lambda_seq[lambda_loss_min_idx]

        # format output
        out <- list(
          lambda_loss_min_idx = lambda_loss_min_idx,
          lambda_loss_min = lambda_loss_min,
          loss_mean = loss_mean
        )
        return(out)
      },
      n_bins = tune_grid$n_bins,
      grid_type = tune_grid$grid_type,
      SIMPLIFY = FALSE
    )

  # tweak arguments to flexibly use future parallelization if so desired
  if (use_future) {
    mapply_fun <- future.apply::future_mapply
    args$future.seed <- seed_int
  } else {
    mapply_fun <- mapply
  }

  # run procedure to select tuning parameters via cross-validation
  select_out <- do.call(what = mapply_fun, args = args)

  # extract n_bins idx with min loss
  all_loss <- lapply(select_out, "[[", "loss_mean")
  min_loss_idx <- lapply(all_loss, which.min)
  min_loss <- lapply(all_loss, min)
  tune_select <- tune_grid[which.min(min_loss), , drop = FALSE]

  # re-format input data into long hazards structure
  reformatted_output <- format_long_hazards(
    A = A, W = W, wts = wts,
    grid_type = tune_select$grid_type,
    n_bins = tune_select$n_bins
  )
  long_data <- reformatted_output$data
  breakpoints <- reformatted_output$breaks
  bin_sizes <- reformatted_output$bin_length

  # extract weights from long format data structure
  wts_long <- long_data$wts
  long_data[, wts := NULL]

  # fit a HAL regression on the full data set with the CV-selected lambda
  hal_fit <- hal9001::fit_hal(
    X = as.matrix(long_data[, -c(1, 2)]),
    Y = as.numeric(long_data$in_bin),
    max_degree = NULL,
    fit_type = "glmnet",
    family = "binomial",
    lambda = lambda_seq,
    cv_select = FALSE,
    standardize = FALSE, # passed to glmnet
    weights = wts_long, # passed to glmnet
    yolo = FALSE
  )

  # replace coefficients
  hal_fit$coefs <-
    hal_fit$coefs[, select_out[[which.min(min_loss)]]$lambda_loss_min_idx]

  # construct output
  out <- list(
    hal_fit = hal_fit,
    breaks = breakpoints,
    bin_sizes = bin_sizes,
    call = call,
    tune_select = tune_select,
    select_out = select_out,
    range_a = range(A)
  )
  class(out) <- "haldensify"
  return(out)
}
