Preliminary work

Efficient linear combinations for the ORSF algorithm.

Byron C. Jaeger (Wake Forest School of Medicine)
2021-11-05

Introduction

This document is written as a supplement to my application for the Wake Forest Clinical and Translational Science Institute (CTSI) pilot award. My intent for this article is not to add any content to my application. This article simply aims to show that the preliminary results in my application are reproducible.

Datasets

We’ll use two datasets in this article.

Primary Biliary Cholangitis (PBC)

This data is from the Mayo Clinic trial in PBC conducted between 1974 and 1984. A total of 424 PBC patients, referred to Mayo Clinic during that ten-year interval, met eligibility criteria for the randomized placebo controlled trial of the drug D-penicillamine. After processing this data to remove missing values, we have a total of 276 observations.

Show code
paged_table(as_tibble(cbind(data_pbc$y, data_pbc$x)))

Assay of serum free light chain (FL chain)

This is a stratified random sample containing 1/2 of the subjects from a study of the relationship between serum free light chain (FL Chain) and mortality. The original sample contains samples on approximately 2/3 of the residents of Olmsted County aged 50 or greater. After processing this data to remove missing values, we have a total of 6524 observations.

Show code
paged_table(as_tibble(cbind(data_flchain$y, data_flchain$x)))

Testing

This section verifies that, for both datasets, the code I have written gives the exact same answer as the survival package. (Having faster code wouldn’t matter if the faster code gave the wrong answer).

Show code
bind_rows(pbc = test_pbc, 
          flchain = test_flchain, 
          .id = 'data') |> 
  drop_na() |> 
  mutate(
    data = recode(
      data, 
      pbc = "Primary Biliary Cholangitis (PBC) data",
      flchain = "Assay of serum free light chain (FL chain) data"
    ),
    across(ends_with('pvalue'), table_pvalue),
    across(where(is.numeric), table_value)
  ) |> 
  group_by(data) |> 
  gt(rowname_col = 'variable') |> 
  cols_align(align = 'center') |> 
  cols_align(align = 'right', columns = 'surv_pvalue') |> 
  cols_label(
    orsf_beta = 'AORSF', 
    surv_beta = "survival",
    orsf_stderr = 'AORSF', 
    surv_stderr = "survival",
    orsf_pvalue = 'AORSF', 
    surv_pvalue = "survival"
  ) |>
  tab_stubhead(label = 'Variable') |> 
  tab_spanner(
    label = 'Regression coefficients', 
    columns = c("orsf_beta", "surv_beta")
  ) |> 
  tab_spanner(
    label = 'Standard error', 
    columns = c("orsf_stderr", "surv_stderr")
  ) |> 
  tab_spanner(
    label = 'P-value', 
    columns = c("orsf_pvalue", "surv_pvalue")
  ) |> 
  tab_source_note(
    source_note = "table values were created using R/cox_test.R"
  )
Variable Regression coefficients Standard error P-value
AORSF survival AORSF survival AORSF survival
Primary Biliary Cholangitis (PBC) data
trt -0.0988 -0.0988 0.1198 0.1198 .41 .41
age 0.0422 0.0422 0.0062 0.0062 <.001 <.001
ascites 0.2902 0.2902 0.2075 0.2075 .16 .16
hepato 0.0451 0.0451 0.1349 0.1349 .74 .74
spiders 0.0465 0.0465 0.1340 0.1340 .73 .73
edema 1.074 1.074 0.2130 0.2130 <.001 <.001
bili 0.0679 0.0679 0.0145 0.0145 <.001 <.001
chol 0.0006 0.0006 0.0002 0.0002 .02 .02
albumin -0.5630 -0.5630 0.1628 0.1628 <.001 <.001
copper 0.0039 0.0039 0.0005 0.0005 <.001 <.001
alk.phos 0.0000 0.0000 0.0000 0.0000 .34 .34
ast 0.0033 0.0033 0.0011 0.0011 .002 .002
trig -0.0010 -0.0010 0.0008 0.0008 .18 .18
platelet -0.0002 -0.0002 0.0007 0.0007 .75 .75
protime 0.2177 0.2177 0.0630 0.0630 <.001 <.001
stage 0.4433 0.4433 0.0938 0.0938 <.001 <.001
Assay of serum free light chain (FL chain) data
age 0.9171 0.9171 0.0132 0.0132 <.001 <.001
sexF -0.1352 -0.1352 0.0138 0.0138 <.001 <.001
sample.yr 0.0690 0.0690 0.0119 0.0119 <.001 <.001
kappa 0.0271 0.0271 0.0111 0.0111 .02 .02
lambda 0.0907 0.0907 0.0098 0.0098 <.001 <.001
flc.grp 0.1118 0.1118 0.0150 0.0150 <.001 <.001
creatinine 0.0105 0.0105 0.0052 0.0052 .04 .04
mgus 0.0057 0.0057 0.0046 0.0046 .21 .21
table values were created using R/cox_test.R

Benchmark

This section shows the mean and median computation time taken by four separate approaches to find a linear combination of predictor variables:

  1. glmnet with 10-fold cross-validation (CV), one of the current options to find linear combinations of predictors in obliqueRSF.

  2. glmnet without CV, the default approach to find linear combinations of predictors in obliqueRSF.

  3. The coxph.fit function from the survival package, i.e., the code that was adapted to create the proposed routine to find linear combinations in AORSF.

  4. The proposed routine to find linear combinations in AORSF.

Show code
bench_combined |> 
  mutate(
    across(where(is.numeric), table_value),
    across(ends_with("rel"), ~ recode(.x, "1.000" = '1 (ref)')),
    data = recode(
      data, 
      pbc = "Primary Biliary Cholangitis (PBC) data",
      flchain = "Assay of serum free light chain (FL chain) data"
    ),
    expr = recode(
      expr,
      glmnet.cv = "glmnet, 10-fold CV",
      glmnet = "glmnet, no CV",
      surv = "survival",
      orsf = "AORSF"
    )
  ) |> 
  select(data, expr, mean_abs, median_abs, mean_rel, median_rel) |> 
  group_by(data) |> 
  gt(rowname_col = 'expr') |> 
  cols_align(align = 'center') |> 
  cols_align(align = 'right', columns = 'median_rel') |> 
  cols_label(
    mean_abs = "Mean",
    median_abs = "Median",
    mean_rel = "Mean",
    median_rel = "Median"
  ) |> 
  tab_spanner(label = 'Milliseconds', columns = c('mean_abs', 'median_abs')) |> 
  tab_spanner(label = 'Ratio', columns = c('mean_rel', 'median_rel')) |> 
  tab_source_note("Results are averaged over 500 runs of the computations.")
Milliseconds Ratio
Mean Median Mean Median
Primary Biliary Cholangitis (PBC) data
glmnet, 10-fold CV 8.999 8.828 240.0 236.8
glmnet, no CV 0.4713 0.4638 12.57 12.45
survival 0.0922 0.0884 2.460 2.371
AORSF 0.0375 0.0372 1 (ref) 1 (ref)
Assay of serum free light chain (FL chain) data
glmnet, 10-fold CV 90.97 90.22 252.3 252.0
glmnet, no CV 5.381 5.348 14.93 14.93
survival 0.8787 0.7923 2.439 2.214
AORSF 0.3615 0.3588 1 (ref) 1 (ref)
Results are averaged over 500 runs of the computations.

Reproducibility

If you are interested in reproducing this code, you can clone or fork the GitHub repository and it should work fine! You may find minor discrepancies in timings as these can vary from system to system. However, if you find any major discrepancy (e.g., the coefficients in the testing section no longer match), please click the reproducibility receipt below to see the exact specifications of my computing environment.

Reproducibility receipt
Show code
## datetime
Sys.time()
[1] "2021-11-05 09:17:54 EDT"
Show code
## repository
if(requireNamespace('git2r', quietly = TRUE)) {
  git2r::repository()
} else {
  c(
    system2("git", args = c("log", "--name-status", "-1"), stdout = TRUE),
    system2("git", args = c("remote", "-v"), stdout = TRUE)
  )
}
Local:    main D:/coxph_bench
Remote:   main @ origin (https://github.com/bcjaeger/coxph_bench.git)
Head:     [1d7f1de] 2021-11-05: rename for website
Show code
## session info
sessionInfo()
R version 4.1.1 (2021-08-10)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 19043)

Matrix products: default

locale:
[1] LC_COLLATE=English_U.S. Virgin Islands.1252 
[2] LC_CTYPE=English_U.S. Virgin Islands.1252   
[3] LC_MONETARY=English_U.S. Virgin Islands.1252
[4] LC_NUMERIC=C                                
[5] LC_TIME=English_United States.1252          

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods  
[7] base     

other attached packages:
 [1] gt_0.3.1             Rcpp_1.0.7           glmnet_4.1-2        
 [4] Matrix_1.3-4         survival_3.2-11      forcats_0.5.1       
 [7] stringr_1.4.0        dplyr_1.0.7          purrr_0.3.4         
[10] readr_2.0.2          tidyr_1.1.4          tibble_3.1.5        
[13] ggplot2_3.3.5        tidyverse_1.3.1      rmarkdown_2.11      
[16] table.glue_0.0.2     microbenchmark_1.4.8 tarchetypes_0.3.2   
[19] targets_0.8.1        dotenv_1.0.3         conflicted_1.0.4    

loaded via a namespace (and not attached):
 [1] fs_1.5.0                 lubridate_1.8.0         
 [3] httr_1.4.2               rprojroot_2.0.2         
 [5] tools_4.1.1              backports_1.3.0         
 [7] bslib_0.3.1              utf8_1.2.2              
 [9] R6_2.5.1                 DBI_1.1.1               
[11] colorspace_2.0-2         withr_2.4.2             
[13] tidyselect_1.1.1         processx_3.5.2          
[15] downlit_0.4.0            git2r_0.28.0            
[17] compiler_4.1.1           cli_3.0.1               
[19] rvest_1.0.2              xml2_1.3.2              
[21] bookdown_0.24            sass_0.4.0              
[23] checkmate_2.0.0          scales_1.1.1            
[25] callr_3.7.0              digest_0.6.28           
[27] pkgconfig_2.0.3          htmltools_0.5.2         
[29] dbplyr_2.1.1             fastmap_1.1.0           
[31] rlang_0.4.12             readxl_1.3.1            
[33] rstudioapi_0.13          shape_1.4.6             
[35] jquerylib_0.1.4          generics_0.1.1          
[37] jsonlite_1.7.2           distill_1.3             
[39] magrittr_2.0.1           munsell_0.5.0           
[41] fansi_0.5.0              lifecycle_1.0.1         
[43] stringi_1.7.5            yaml_2.2.1              
[45] grid_4.1.1               crayon_1.4.1            
[47] lattice_0.20-44          haven_2.4.3             
[49] splines_4.1.1            hms_1.1.1               
[51] knitr_1.36               ps_1.6.0                
[53] pillar_1.6.4             igraph_1.2.7            
[55] codetools_0.2-18         reprex_2.0.1            
[57] glue_1.4.2               evaluate_0.14           
[59] RcppArmadillo_0.10.7.0.0 data.table_1.14.2       
[61] modelr_0.1.8             vctrs_0.3.8             
[63] tzdb_0.2.0               foreach_1.5.1           
[65] cellranger_1.1.0         gtable_0.3.0            
[67] assertthat_0.2.1         cachem_1.0.6            
[69] xfun_0.27                broom_0.7.9             
[71] iterators_1.0.13         memoise_2.0.0           
[73] ellipsis_0.3.2           here_1.0.1              

Source code

The R code used to produce this article can be found in the R directory. The C++ code I have written for this benchmark is below.

Show code

#include <RcppArmadillo.h>
#include <RcppArmadilloExtensions/sample.h>
#include <Rcpp.h>

// [[Rcpp::depends(RcppArmadillo)]]

using namespace Rcpp;

// ----------------------------------------------------------------------------
// ---------------------------- global parameters -----------------------------
// ----------------------------------------------------------------------------

// special note: dont change these doubles to uword,
//               even though some of them could be uwords;
//               operations involving uwords and doubles are not
//               straightforward and may break the routine.
// also: double + uword is slower than double + double.

double
  weight_avg,
  weight_events,
  denom_events,
  denom,
  n_events,
  n_risk,
  temp1,
  temp2,
  w_node_person,
  x_beta,
  risk,
  loglik;



// armadillo unsigned integers
arma::uword
  i,
  j,
  k,
  iter,
  person,
  n_vars;

// a delayed break statement
bool break_loop;

// armadillo vectors (doubles)
arma::vec
  beta_current,
  beta_new,
  w_node,
  u,
  a,
  a2,
  XB,
  Risk;

// armadillo matrices (doubles)
arma::mat
  x_node,
  y_node,
  imat,
  cmat,
  cmat2;

// ----------------------------------------------------------------------------
// ---------------------------- scaling input data ----------------------------
// ----------------------------------------------------------------------------

// [[Rcpp::export]]
arma::mat x_scale_wtd(){

  // set aside memory for outputs
  // first column holds the mean values
  // second column holds the scale values

  arma::mat out(n_vars, 2);
  arma::vec means = out.unsafe_col(0);   // Reference to column 1
  arma::vec scales = out.unsafe_col(1);  // Reference to column 2

  double w_node_sum = arma::sum(w_node);

  for(i = 0; i < n_vars; i++) {

    arma::vec x_i = x_node.unsafe_col(i);

    means.at(i) = arma::sum( w_node % x_i ) / w_node_sum;

    x_i -= means.at(i);

    scales.at(i) = arma::sum(w_node % arma::abs(x_i));

    if(scales(i) > 0)
      scales.at(i) = w_node_sum / scales.at(i);
    else
      scales.at(i) = 1.0; // rare case of constant covariate;

    x_i *= scales.at(i);

  }


  return(out);

}

// ----------------------------------------------------------------------------
// ---------------------------- cholesky functions ----------------------------
// ----------------------------------------------------------------------------

// [[Rcpp::export]]
void cholesky(){

  double eps_chol = 0;
  double toler = 1e-8;
  double pivot;

  for(i = 0; i < n_vars; i++){

    if(imat.at(i,i) > eps_chol) eps_chol = imat.at(i,i);

    // copy upper right values to bottom left
    for(j = (i+1); j<n_vars; j++){
      imat.at(j,i) = imat.at(i,j);
    }
  }

  if (eps_chol == 0)
    eps_chol = toler; // no positive diagonals!
  else
    eps_chol = eps_chol * toler;

  for (i = 0; i < n_vars; i++) {

    pivot = imat.at(i, i);

    if (pivot < R_PosInf && pivot > eps_chol) {

      for(j = (i+1); j < n_vars; j++){

        temp1 = imat.at(j,i) / pivot;
        imat.at(j,i) = temp1;
        imat.at(j,j) -= temp1*temp1*pivot;

        for(k = (j+1); k < n_vars; k++){

          imat.at(k, j) -= temp1 * imat.at(k, i);

        }

      }

    } else {

      imat.at(i, i) = 0;

    }

  }

}

// [[Rcpp::export]]
void cholesky_solve(){

  for (i = 0; i < n_vars; i++) {

    temp1 = u[i];

    for (j = 0; j < i; j++){

      temp1 -= u[j] * imat.at(i, j);
      u[i] = temp1;

    }

  }


  for (i = n_vars; i >= 1; i--){

    if (imat.at(i-1, i-1) == 0){

      u[i-1] = 0;

    } else {

      temp1 = u[i-1] / imat.at(i-1, i-1);

      for (j = i; j < n_vars; j++){
        temp1 -= u[j] * imat.at(j, i-1);
      }

      u[i-1] = temp1;

    }

  }

}

// [[Rcpp::export]]
void cholesky_invert(){

  /*
   ** invert the cholesky in the lower triangle
   **   take full advantage of the cholesky's diagonal of 1's
   */
  for (i=0; i<n_vars; i++){

    if (imat.at(i,i) >0) {

      imat.at(i,i) = 1.0 / imat.at(i,i);

      for (j=(i+1); j<n_vars; j++) {

        imat.at(j, i) = -imat.at(j, i);

        for (k=0; k<i; k++){
          imat.at(j, k) += imat.at(j, i) * imat.at(i, k);
        }

      }

    }

  }

  /*
   ** lower triangle now contains inverse of cholesky
   ** calculate F'DF (inverse of cholesky decomp process) to get inverse
   **   of original imat
   */
  for (i=0; i<n_vars; i++) {

    if (imat.at(i, i) == 0) {

      for (j=0; j<i; j++) imat.at(i, j) = 0;
      for (j=i; j<n_vars; j++) imat.at(j, i) = 0;

    } else {

      for (j=(i+1); j<n_vars; j++) {

        temp1 = imat.at(j, i) * imat.at(j, j);

        if (j!=i) imat.at(i, j) = temp1;

        for (k=i; k<j; k++){
          imat.at(i, k) += temp1*imat.at(j, k);
        }

      }

    }

  }

}

// ----------------------------------------------------------------------------
// ------------------- Newton Raphson algo for Cox PH model -------------------
// ----------------------------------------------------------------------------

// [[Rcpp::export]]
double newtraph_cph_iter(const arma::uword& method,
                         const arma::vec& beta){

  denom = 0;

  loglik = 0;

  n_risk = 0;

  person = x_node.n_rows - 1;

  u.fill(0);
  a.fill(0);
  a2.fill(0);
  imat.fill(0);
  cmat.fill(0);
  cmat2.fill(0);

  // this loop has a strange break condition to accomodate
  // the restriction that a uvec (uword) cannot be < 0

  break_loop = false;

  XB = x_node * beta;
  Risk = arma::exp(XB) % w_node;

  for( ; ; ){

    temp2 = y_node.at(person, 0); // time of event for current person
    n_events  = 0 ; // number of deaths at this time point
    weight_events = 0 ; // sum of w_node for the deaths
    denom_events = 0 ; // sum of weighted risks for the deaths

    // walk through this set of tied times
    while(y_node.at(person, 0) == temp2){

      n_risk++;

      x_beta = XB.at(person);
      risk = Risk.at(person);

      // x_beta = 0;
      //
      // for(i = 0; i < n_vars; i++){
      //   x_beta += beta_current.at(i) * x.at(person, i);
      // }

      w_node_person = w_node.at(person);

      //risk = exp(x_beta) * w_node_person;

      if (y_node.at(person, 1) == 0) {

        denom += risk;

        /* a contains weighted sums of x, cmat sums of squares */

        for (i=0; i<n_vars; i++) {

          temp1 = risk * x_node.at(person, i);

          a[i] += temp1;

          for (j = 0; j <= i; j++){
            cmat.at(j, i) += temp1 * x_node.at(person, j);
          }

        }

      } else {

        n_events++;

        weight_events += w_node_person;
        denom_events += risk;
        loglik += w_node_person * x_beta;

        for (i=0; i<n_vars; i++) {

          u[i]  += w_node_person * x_node.at(person, i);
          a2[i] += risk * x_node.at(person, i);

          for (j=0; j<=i; j++){
            cmat2.at(j, i) += risk * x_node.at(person, i) * x_node.at(person, j);
          }

        }

      }

      if(person == 0){
        break_loop = true;
        break;
      }

      person--;

    }

    // we need to add to the main terms
    if (n_events > 0) {

      if (method == 0 || n_events == 1) { // Breslow

        denom  += denom_events;
        loglik -= weight_events * log(denom);

        for (i=0; i<n_vars; i++) {

          a[i]  += a2[i];
          temp1  = a[i] / denom;  // mean
          u[i]  -=  weight_events * temp1;

          for (j=0; j<=i; j++) {
            cmat.at(j, i) += cmat2.at(j, i);
            imat.at(j, i) += weight_events * (cmat.at(j, i) - temp1 * a[j]) / denom;
          }

        }

      } else {
        /* Efron
         **  If there are 3 deaths we have 3 terms: in the first the
         **  three deaths are all in, in the second they are 2/3
         **  in the sums, and in the last 1/3 in the sum.  Let k go
         **  1 to n_events: we sequentially add a2/n_events and cmat2/n_events
         **  and efron_wt/n_events to the totals.
         */
        weight_avg = weight_events/n_events;

        for (k=0; k<n_events; k++) {

          denom  += denom_events / n_events;
          loglik -= weight_avg * log(denom);

          for (i=0; i<n_vars; i++) {

            a[i] += a2[i] / n_events;
            temp1 = a[i]  / denom;
            u[i] -= weight_avg * temp1;

            for (j=0; j<=i; j++) {
              cmat.at(j, i) += cmat2.at(j, i) / n_events;
              imat.at(j, i) += weight_avg * (cmat.at(j, i) - temp1 * a[j]) / denom;
            }

          }

        }

      }

      a2.fill(0);
      cmat2.fill(0);

    }

    if(break_loop) break;

  }

  return(loglik);

}


// [[Rcpp::export]]
double newtraph_cph_init(const arma::uword& method){

  denom = 0;
  loglik = 0;
  n_risk = 0;

  person = x_node.n_rows - 1;

  u.fill(0);
  a.fill(0);
  a2.fill(0);
  imat.fill(0);
  cmat.fill(0);
  cmat2.fill(0);

  // this loop has a strange break condition to accomodate
  // the restriction that a uvec (uword) cannot be < 0

  break_loop = false;

  x_beta = 0.0;

  for( ; ; ){

    temp2 = y_node.at(person, 0); // time of event for current person
    n_events  = 0 ; // number of deaths at this time point
    weight_events = 0 ; // sum of w_node for the deaths
    denom_events = 0 ; // sum of weighted risks for the deaths

    // walk through this set of tied times
    while(y_node.at(person, 0) == temp2){

      n_risk++;

      risk = w_node.at(person);

      if (y_node.at(person, 1) == 0) {

        denom += risk;

        /* a contains weighted sums of x, cmat sums of squares */

        for (i=0; i<n_vars; i++) {

          temp1 = risk * x_node.at(person, i);

          a[i] += temp1;

          for (j = 0; j <= i; j++){
            cmat.at(j, i) += temp1 * x_node.at(person, j);
          }

        }

      } else {

        n_events++;

        denom_events += risk;

        for (i=0; i<n_vars; i++) {

          temp1 = risk * x_node.at(person, i);

          u[i]  += temp1;
          a2[i] += temp1;

          for (j=0; j<=i; j++){
            cmat2.at(j, i) += temp1 * x_node.at(person, j);
          }

        }

      }

      if(person == 0){
        break_loop = true;
        break;
      }

      person--;

    }

    // we need to add to the main terms
    if (n_events > 0) {

      if (method == 0 || n_events == 1) { // Breslow

        denom  += denom_events;
        loglik -= denom_events * log(denom);

        for (i=0; i<n_vars; i++) {

          a[i]  += a2[i];
          temp1  = a[i] / denom;  // mean
          u[i]  -=  denom_events * temp1;

          for (j=0; j<=i; j++) {
            cmat.at(j, i) += cmat2.at(j, i);
            imat.at(j, i) += denom_events * (cmat.at(j, i) - temp1 * a[j]) / denom;
          }

        }

      } else {
        /* Efron
         **  If there are 3 deaths we have 3 terms: in the first the
         **  three deaths are all in, in the second they are 2/3
         **  in the sums, and in the last 1/3 in the sum.  Let k go
         **  1 to n_events: we sequentially add a2/n_events and cmat2/n_events
         **  and efron_wt/n_events to the totals.
         */
        weight_avg = denom_events/n_events;

        for (k = 0; k < n_events; k++) {

          denom  += denom_events / n_events;
          loglik -= weight_avg * log(denom);

          for (i = 0; i < n_vars; i++) {

            a[i] += a2[i] / n_events;
            temp1 = a[i]  / denom;
            u[i] -= weight_avg * temp1;

            for (j=0; j<=i; j++) {
              cmat.at(j, i) += cmat2.at(j, i) / n_events;
              imat.at(j, i) += weight_avg * (cmat.at(j, i) - temp1 * a[j]) / denom;
            }

          }

        }

      }

      a2.fill(0);
      cmat2.fill(0);

    }

    if(break_loop) break;

  }

  return(loglik);

}


// [[Rcpp::export]]
arma::mat c_cox_fit(NumericMatrix& x,
                    NumericMatrix& y,
                    NumericVector& weights,
                    const arma::uword& method=0,
                    const double& eps = 1/10,
                    const arma::uword& iter_max = 20){

  x_node = arma::mat(x.begin(), x.nrow(), x.ncol(), false);
  y_node = arma::mat(y.begin(), y.nrow(), y.ncol(), false);
  w_node = arma::vec(weights.begin(), weights.size(), false);

  n_vars = x_node.n_cols;

  arma::mat x_transforms = x_scale_wtd();

  double stat_current, stat_best, halving = 0;

  // set_size is fast but unsafe, initial values are random
  beta_current.set_size(n_vars);
  beta_new.set_size(n_vars);

  // fill with 0's, o.w. different betas every time you run this
  beta_current.fill(0);
  beta_new.fill(0);


  // these are filled with initial values later
  XB.set_size(x_node.n_rows);
  Risk.set_size(x_node.n_rows);
  u.set_size(n_vars);
  a.set_size(n_vars);
  a2.set_size(n_vars);
  imat.set_size(n_vars, n_vars);
  cmat.set_size(n_vars, n_vars);
  cmat2.set_size(n_vars, n_vars);

  // do the initial iteration
  stat_best = newtraph_cph_init(method);

  //Rcpp::Rcout << "stat_best: " << stat_best << std::endl;

  // update beta_current
  cholesky();
  cholesky_solve();
  beta_new = beta_current + u;

  if(iter_max > 1 && stat_best < R_PosInf){

    for(iter = 1; iter < iter_max; iter++){

      // do the next iteration
      stat_current = newtraph_cph_iter(method, beta_new);

      //Rcpp::Rcout << "stat_current: " << stat_current << std::endl;

      cholesky();

      // check for convergence
      // break the loop if the new ll is ~ same as old best ll
      if(abs(1 - stat_best / stat_current) < eps){
        break;
      }

      if(stat_current < stat_best){ // it's not converging!

        halving++; // get more aggressive when it doesn't work

        // reduce the magnitude by which beta_new modifies beta_current
        for (i = 0; i < n_vars; i++){
          beta_new[i] = (beta_new[i]+halving*beta_current[i]) / (halving+1.0);
        }

      } else { // it's converging!

        halving = 0;
        stat_best = stat_current;

        cholesky_solve();

        for (i = 0; i < n_vars; i++) {

          beta_current[i] = beta_new[i];
          beta_new[i] = beta_new[i] +  u[i];

        }

      }

    }

  }

  // invert imat and return to original scale
  cholesky_invert();

  for (i = 0; i < n_vars; i++) {
    beta_current[i] = beta_new[i];
  }


  for (i=0; i < n_vars; i++) {

    beta_current[i] *= x_transforms.at(i, 1);
    imat.at(i, i) *= x_transforms.at(i, 1) * x_transforms.at(i, 1);

  }

  for(i = 0; i < n_vars; i++){

    if(std::isinf(beta_current[i])) beta_current[i] = 0;

    if(std::isinf(imat.at(i, i))) imat.at(i, i) = 1.0;

  }

  arma::vec se = arma::sqrt(imat.diag());

  arma::vec pv(n_vars);

  for(i = 0; i < n_vars; i++){
    pv[i] = R::pchisq(pow(beta_current[i] / se[i], 2), 1, false, false);
  }

  arma::mat out(n_vars, 3);

  out.col(0) = beta_current;
  out.col(1) = se;
  out.col(2) = pv;

  return(out);

}