Efficient linear combinations for the ORSF algorithm.
This document is written as a supplement to my application for the Wake Forest Clinical and Translational Science Institute (CTSI) pilot award. My intent for this article is not to add any content to my application. This article simply aims to show that the preliminary results in my application are reproducible.
We’ll use two datasets in this article.
This data is from the Mayo Clinic trial in PBC conducted between 1974 and 1984. A total of 424 PBC patients, referred to Mayo Clinic during that ten-year interval, met eligibility criteria for the randomized placebo controlled trial of the drug D-penicillamine. After processing this data to remove missing values, we have a total of 276 observations.
paged_table(as_tibble(cbind(data_pbc$y, data_pbc$x)))
This is a stratified random sample containing 1/2 of the subjects from a study of the relationship between serum free light chain (FL Chain) and mortality. The original sample contains samples on approximately 2/3 of the residents of Olmsted County aged 50 or greater. After processing this data to remove missing values, we have a total of 6524 observations.
paged_table(as_tibble(cbind(data_flchain$y, data_flchain$x)))
This section verifies that, for both datasets, the code I have written gives the exact same answer as the survival package. (Having faster code wouldn’t matter if the faster code gave the wrong answer).
bind_rows(pbc = test_pbc,
flchain = test_flchain,
.id = 'data') |>
drop_na() |>
mutate(
data = recode(
data,
pbc = "Primary Biliary Cholangitis (PBC) data",
flchain = "Assay of serum free light chain (FL chain) data"
),
across(ends_with('pvalue'), table_pvalue),
across(where(is.numeric), table_value)
) |>
group_by(data) |>
gt(rowname_col = 'variable') |>
cols_align(align = 'center') |>
cols_align(align = 'right', columns = 'surv_pvalue') |>
cols_label(
orsf_beta = 'AORSF',
surv_beta = "survival",
orsf_stderr = 'AORSF',
surv_stderr = "survival",
orsf_pvalue = 'AORSF',
surv_pvalue = "survival"
) |>
tab_stubhead(label = 'Variable') |>
tab_spanner(
label = 'Regression coefficients',
columns = c("orsf_beta", "surv_beta")
) |>
tab_spanner(
label = 'Standard error',
columns = c("orsf_stderr", "surv_stderr")
) |>
tab_spanner(
label = 'P-value',
columns = c("orsf_pvalue", "surv_pvalue")
) |>
tab_source_note(
source_note = "table values were created using R/cox_test.R"
)
| Variable | Regression coefficients | Standard error | P-value | |||
|---|---|---|---|---|---|---|
| AORSF | survival | AORSF | survival | AORSF | survival | |
| Primary Biliary Cholangitis (PBC) data | ||||||
| trt | -0.0988 | -0.0988 | 0.1198 | 0.1198 | .41 | .41 |
| age | 0.0422 | 0.0422 | 0.0062 | 0.0062 | <.001 | <.001 |
| ascites | 0.2902 | 0.2902 | 0.2075 | 0.2075 | .16 | .16 |
| hepato | 0.0451 | 0.0451 | 0.1349 | 0.1349 | .74 | .74 |
| spiders | 0.0465 | 0.0465 | 0.1340 | 0.1340 | .73 | .73 |
| edema | 1.074 | 1.074 | 0.2130 | 0.2130 | <.001 | <.001 |
| bili | 0.0679 | 0.0679 | 0.0145 | 0.0145 | <.001 | <.001 |
| chol | 0.0006 | 0.0006 | 0.0002 | 0.0002 | .02 | .02 |
| albumin | -0.5630 | -0.5630 | 0.1628 | 0.1628 | <.001 | <.001 |
| copper | 0.0039 | 0.0039 | 0.0005 | 0.0005 | <.001 | <.001 |
| alk.phos | 0.0000 | 0.0000 | 0.0000 | 0.0000 | .34 | .34 |
| ast | 0.0033 | 0.0033 | 0.0011 | 0.0011 | .002 | .002 |
| trig | -0.0010 | -0.0010 | 0.0008 | 0.0008 | .18 | .18 |
| platelet | -0.0002 | -0.0002 | 0.0007 | 0.0007 | .75 | .75 |
| protime | 0.2177 | 0.2177 | 0.0630 | 0.0630 | <.001 | <.001 |
| stage | 0.4433 | 0.4433 | 0.0938 | 0.0938 | <.001 | <.001 |
| Assay of serum free light chain (FL chain) data | ||||||
| age | 0.9171 | 0.9171 | 0.0132 | 0.0132 | <.001 | <.001 |
| sexF | -0.1352 | -0.1352 | 0.0138 | 0.0138 | <.001 | <.001 |
| sample.yr | 0.0690 | 0.0690 | 0.0119 | 0.0119 | <.001 | <.001 |
| kappa | 0.0271 | 0.0271 | 0.0111 | 0.0111 | .02 | .02 |
| lambda | 0.0907 | 0.0907 | 0.0098 | 0.0098 | <.001 | <.001 |
| flc.grp | 0.1118 | 0.1118 | 0.0150 | 0.0150 | <.001 | <.001 |
| creatinine | 0.0105 | 0.0105 | 0.0052 | 0.0052 | .04 | .04 |
| mgus | 0.0057 | 0.0057 | 0.0046 | 0.0046 | .21 | .21 |
| table values were created using R/cox_test.R | ||||||
This section shows the mean and median computation time taken by four separate approaches to find a linear combination of predictor variables:
glmnet with 10-fold cross-validation (CV), one of the current options to find linear combinations of predictors in obliqueRSF.
glmnet without CV, the default approach to find linear combinations of predictors in obliqueRSF.
The coxph.fit function from the survival package, i.e., the code that was adapted to create the proposed routine to find linear combinations in AORSF.
The proposed routine to find linear combinations in AORSF.
bench_combined |>
mutate(
across(where(is.numeric), table_value),
across(ends_with("rel"), ~ recode(.x, "1.000" = '1 (ref)')),
data = recode(
data,
pbc = "Primary Biliary Cholangitis (PBC) data",
flchain = "Assay of serum free light chain (FL chain) data"
),
expr = recode(
expr,
glmnet.cv = "glmnet, 10-fold CV",
glmnet = "glmnet, no CV",
surv = "survival",
orsf = "AORSF"
)
) |>
select(data, expr, mean_abs, median_abs, mean_rel, median_rel) |>
group_by(data) |>
gt(rowname_col = 'expr') |>
cols_align(align = 'center') |>
cols_align(align = 'right', columns = 'median_rel') |>
cols_label(
mean_abs = "Mean",
median_abs = "Median",
mean_rel = "Mean",
median_rel = "Median"
) |>
tab_spanner(label = 'Milliseconds', columns = c('mean_abs', 'median_abs')) |>
tab_spanner(label = 'Ratio', columns = c('mean_rel', 'median_rel')) |>
tab_source_note("Results are averaged over 500 runs of the computations.")
| Milliseconds | Ratio | |||
|---|---|---|---|---|
| Mean | Median | Mean | Median | |
| Primary Biliary Cholangitis (PBC) data | ||||
| glmnet, 10-fold CV | 8.999 | 8.828 | 240.0 | 236.8 |
| glmnet, no CV | 0.4713 | 0.4638 | 12.57 | 12.45 |
| survival | 0.0922 | 0.0884 | 2.460 | 2.371 |
| AORSF | 0.0375 | 0.0372 | 1 (ref) | 1 (ref) |
| Assay of serum free light chain (FL chain) data | ||||
| glmnet, 10-fold CV | 90.97 | 90.22 | 252.3 | 252.0 |
| glmnet, no CV | 5.381 | 5.348 | 14.93 | 14.93 |
| survival | 0.8787 | 0.7923 | 2.439 | 2.214 |
| AORSF | 0.3615 | 0.3588 | 1 (ref) | 1 (ref) |
| Results are averaged over 500 runs of the computations. | ||||
If you are interested in reproducing this code, you can clone or fork the GitHub repository and it should work fine! You may find minor discrepancies in timings as these can vary from system to system. However, if you find any major discrepancy (e.g., the coefficients in the testing section no longer match), please click the reproducibility receipt below to see the exact specifications of my computing environment.
## datetime
Sys.time()
[1] "2021-11-05 09:17:54 EDT"
## repository
if(requireNamespace('git2r', quietly = TRUE)) {
git2r::repository()
} else {
c(
system2("git", args = c("log", "--name-status", "-1"), stdout = TRUE),
system2("git", args = c("remote", "-v"), stdout = TRUE)
)
}
Local: main D:/coxph_bench
Remote: main @ origin (https://github.com/bcjaeger/coxph_bench.git)
Head: [1d7f1de] 2021-11-05: rename for website
## session info
sessionInfo()
R version 4.1.1 (2021-08-10)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 19043)
Matrix products: default
locale:
[1] LC_COLLATE=English_U.S. Virgin Islands.1252
[2] LC_CTYPE=English_U.S. Virgin Islands.1252
[3] LC_MONETARY=English_U.S. Virgin Islands.1252
[4] LC_NUMERIC=C
[5] LC_TIME=English_United States.1252
attached base packages:
[1] stats graphics grDevices utils datasets methods
[7] base
other attached packages:
[1] gt_0.3.1 Rcpp_1.0.7 glmnet_4.1-2
[4] Matrix_1.3-4 survival_3.2-11 forcats_0.5.1
[7] stringr_1.4.0 dplyr_1.0.7 purrr_0.3.4
[10] readr_2.0.2 tidyr_1.1.4 tibble_3.1.5
[13] ggplot2_3.3.5 tidyverse_1.3.1 rmarkdown_2.11
[16] table.glue_0.0.2 microbenchmark_1.4.8 tarchetypes_0.3.2
[19] targets_0.8.1 dotenv_1.0.3 conflicted_1.0.4
loaded via a namespace (and not attached):
[1] fs_1.5.0 lubridate_1.8.0
[3] httr_1.4.2 rprojroot_2.0.2
[5] tools_4.1.1 backports_1.3.0
[7] bslib_0.3.1 utf8_1.2.2
[9] R6_2.5.1 DBI_1.1.1
[11] colorspace_2.0-2 withr_2.4.2
[13] tidyselect_1.1.1 processx_3.5.2
[15] downlit_0.4.0 git2r_0.28.0
[17] compiler_4.1.1 cli_3.0.1
[19] rvest_1.0.2 xml2_1.3.2
[21] bookdown_0.24 sass_0.4.0
[23] checkmate_2.0.0 scales_1.1.1
[25] callr_3.7.0 digest_0.6.28
[27] pkgconfig_2.0.3 htmltools_0.5.2
[29] dbplyr_2.1.1 fastmap_1.1.0
[31] rlang_0.4.12 readxl_1.3.1
[33] rstudioapi_0.13 shape_1.4.6
[35] jquerylib_0.1.4 generics_0.1.1
[37] jsonlite_1.7.2 distill_1.3
[39] magrittr_2.0.1 munsell_0.5.0
[41] fansi_0.5.0 lifecycle_1.0.1
[43] stringi_1.7.5 yaml_2.2.1
[45] grid_4.1.1 crayon_1.4.1
[47] lattice_0.20-44 haven_2.4.3
[49] splines_4.1.1 hms_1.1.1
[51] knitr_1.36 ps_1.6.0
[53] pillar_1.6.4 igraph_1.2.7
[55] codetools_0.2-18 reprex_2.0.1
[57] glue_1.4.2 evaluate_0.14
[59] RcppArmadillo_0.10.7.0.0 data.table_1.14.2
[61] modelr_0.1.8 vctrs_0.3.8
[63] tzdb_0.2.0 foreach_1.5.1
[65] cellranger_1.1.0 gtable_0.3.0
[67] assertthat_0.2.1 cachem_1.0.6
[69] xfun_0.27 broom_0.7.9
[71] iterators_1.0.13 memoise_2.0.0
[73] ellipsis_0.3.2 here_1.0.1
The R code used to produce this article can be found in the R directory. The C++ code I have written for this benchmark is below.
#include <RcppArmadillo.h>
#include <RcppArmadilloExtensions/sample.h>
#include <Rcpp.h>
// [[Rcpp::depends(RcppArmadillo)]]
using namespace Rcpp;
// ----------------------------------------------------------------------------
// ---------------------------- global parameters -----------------------------
// ----------------------------------------------------------------------------
// special note: dont change these doubles to uword,
// even though some of them could be uwords;
// operations involving uwords and doubles are not
// straightforward and may break the routine.
// also: double + uword is slower than double + double.
double
weight_avg,
weight_events,
denom_events,
denom,
n_events,
n_risk,
temp1,
temp2,
w_node_person,
x_beta,
risk,
loglik;
// armadillo unsigned integers
arma::uword
i,
j,
k,
iter,
person,
n_vars;
// a delayed break statement
bool break_loop;
// armadillo vectors (doubles)
arma::vec
beta_current,
beta_new,
w_node,
u,
a,
a2,
XB,
Risk;
// armadillo matrices (doubles)
arma::mat
x_node,
y_node,
imat,
cmat,
cmat2;
// ----------------------------------------------------------------------------
// ---------------------------- scaling input data ----------------------------
// ----------------------------------------------------------------------------
// [[Rcpp::export]]
arma::mat x_scale_wtd(){
// set aside memory for outputs
// first column holds the mean values
// second column holds the scale values
arma::mat out(n_vars, 2);
arma::vec means = out.unsafe_col(0); // Reference to column 1
arma::vec scales = out.unsafe_col(1); // Reference to column 2
double w_node_sum = arma::sum(w_node);
for(i = 0; i < n_vars; i++) {
arma::vec x_i = x_node.unsafe_col(i);
means.at(i) = arma::sum( w_node % x_i ) / w_node_sum;
x_i -= means.at(i);
scales.at(i) = arma::sum(w_node % arma::abs(x_i));
if(scales(i) > 0)
scales.at(i) = w_node_sum / scales.at(i);
else
scales.at(i) = 1.0; // rare case of constant covariate;
x_i *= scales.at(i);
}
return(out);
}
// ----------------------------------------------------------------------------
// ---------------------------- cholesky functions ----------------------------
// ----------------------------------------------------------------------------
// [[Rcpp::export]]
void cholesky(){
double eps_chol = 0;
double toler = 1e-8;
double pivot;
for(i = 0; i < n_vars; i++){
if(imat.at(i,i) > eps_chol) eps_chol = imat.at(i,i);
// copy upper right values to bottom left
for(j = (i+1); j<n_vars; j++){
imat.at(j,i) = imat.at(i,j);
}
}
if (eps_chol == 0)
eps_chol = toler; // no positive diagonals!
else
eps_chol = eps_chol * toler;
for (i = 0; i < n_vars; i++) {
pivot = imat.at(i, i);
if (pivot < R_PosInf && pivot > eps_chol) {
for(j = (i+1); j < n_vars; j++){
temp1 = imat.at(j,i) / pivot;
imat.at(j,i) = temp1;
imat.at(j,j) -= temp1*temp1*pivot;
for(k = (j+1); k < n_vars; k++){
imat.at(k, j) -= temp1 * imat.at(k, i);
}
}
} else {
imat.at(i, i) = 0;
}
}
}
// [[Rcpp::export]]
void cholesky_solve(){
for (i = 0; i < n_vars; i++) {
temp1 = u[i];
for (j = 0; j < i; j++){
temp1 -= u[j] * imat.at(i, j);
u[i] = temp1;
}
}
for (i = n_vars; i >= 1; i--){
if (imat.at(i-1, i-1) == 0){
u[i-1] = 0;
} else {
temp1 = u[i-1] / imat.at(i-1, i-1);
for (j = i; j < n_vars; j++){
temp1 -= u[j] * imat.at(j, i-1);
}
u[i-1] = temp1;
}
}
}
// [[Rcpp::export]]
void cholesky_invert(){
/*
** invert the cholesky in the lower triangle
** take full advantage of the cholesky's diagonal of 1's
*/
for (i=0; i<n_vars; i++){
if (imat.at(i,i) >0) {
imat.at(i,i) = 1.0 / imat.at(i,i);
for (j=(i+1); j<n_vars; j++) {
imat.at(j, i) = -imat.at(j, i);
for (k=0; k<i; k++){
imat.at(j, k) += imat.at(j, i) * imat.at(i, k);
}
}
}
}
/*
** lower triangle now contains inverse of cholesky
** calculate F'DF (inverse of cholesky decomp process) to get inverse
** of original imat
*/
for (i=0; i<n_vars; i++) {
if (imat.at(i, i) == 0) {
for (j=0; j<i; j++) imat.at(i, j) = 0;
for (j=i; j<n_vars; j++) imat.at(j, i) = 0;
} else {
for (j=(i+1); j<n_vars; j++) {
temp1 = imat.at(j, i) * imat.at(j, j);
if (j!=i) imat.at(i, j) = temp1;
for (k=i; k<j; k++){
imat.at(i, k) += temp1*imat.at(j, k);
}
}
}
}
}
// ----------------------------------------------------------------------------
// ------------------- Newton Raphson algo for Cox PH model -------------------
// ----------------------------------------------------------------------------
// [[Rcpp::export]]
double newtraph_cph_iter(const arma::uword& method,
const arma::vec& beta){
denom = 0;
loglik = 0;
n_risk = 0;
person = x_node.n_rows - 1;
u.fill(0);
a.fill(0);
a2.fill(0);
imat.fill(0);
cmat.fill(0);
cmat2.fill(0);
// this loop has a strange break condition to accomodate
// the restriction that a uvec (uword) cannot be < 0
break_loop = false;
XB = x_node * beta;
Risk = arma::exp(XB) % w_node;
for( ; ; ){
temp2 = y_node.at(person, 0); // time of event for current person
n_events = 0 ; // number of deaths at this time point
weight_events = 0 ; // sum of w_node for the deaths
denom_events = 0 ; // sum of weighted risks for the deaths
// walk through this set of tied times
while(y_node.at(person, 0) == temp2){
n_risk++;
x_beta = XB.at(person);
risk = Risk.at(person);
// x_beta = 0;
//
// for(i = 0; i < n_vars; i++){
// x_beta += beta_current.at(i) * x.at(person, i);
// }
w_node_person = w_node.at(person);
//risk = exp(x_beta) * w_node_person;
if (y_node.at(person, 1) == 0) {
denom += risk;
/* a contains weighted sums of x, cmat sums of squares */
for (i=0; i<n_vars; i++) {
temp1 = risk * x_node.at(person, i);
a[i] += temp1;
for (j = 0; j <= i; j++){
cmat.at(j, i) += temp1 * x_node.at(person, j);
}
}
} else {
n_events++;
weight_events += w_node_person;
denom_events += risk;
loglik += w_node_person * x_beta;
for (i=0; i<n_vars; i++) {
u[i] += w_node_person * x_node.at(person, i);
a2[i] += risk * x_node.at(person, i);
for (j=0; j<=i; j++){
cmat2.at(j, i) += risk * x_node.at(person, i) * x_node.at(person, j);
}
}
}
if(person == 0){
break_loop = true;
break;
}
person--;
}
// we need to add to the main terms
if (n_events > 0) {
if (method == 0 || n_events == 1) { // Breslow
denom += denom_events;
loglik -= weight_events * log(denom);
for (i=0; i<n_vars; i++) {
a[i] += a2[i];
temp1 = a[i] / denom; // mean
u[i] -= weight_events * temp1;
for (j=0; j<=i; j++) {
cmat.at(j, i) += cmat2.at(j, i);
imat.at(j, i) += weight_events * (cmat.at(j, i) - temp1 * a[j]) / denom;
}
}
} else {
/* Efron
** If there are 3 deaths we have 3 terms: in the first the
** three deaths are all in, in the second they are 2/3
** in the sums, and in the last 1/3 in the sum. Let k go
** 1 to n_events: we sequentially add a2/n_events and cmat2/n_events
** and efron_wt/n_events to the totals.
*/
weight_avg = weight_events/n_events;
for (k=0; k<n_events; k++) {
denom += denom_events / n_events;
loglik -= weight_avg * log(denom);
for (i=0; i<n_vars; i++) {
a[i] += a2[i] / n_events;
temp1 = a[i] / denom;
u[i] -= weight_avg * temp1;
for (j=0; j<=i; j++) {
cmat.at(j, i) += cmat2.at(j, i) / n_events;
imat.at(j, i) += weight_avg * (cmat.at(j, i) - temp1 * a[j]) / denom;
}
}
}
}
a2.fill(0);
cmat2.fill(0);
}
if(break_loop) break;
}
return(loglik);
}
// [[Rcpp::export]]
double newtraph_cph_init(const arma::uword& method){
denom = 0;
loglik = 0;
n_risk = 0;
person = x_node.n_rows - 1;
u.fill(0);
a.fill(0);
a2.fill(0);
imat.fill(0);
cmat.fill(0);
cmat2.fill(0);
// this loop has a strange break condition to accomodate
// the restriction that a uvec (uword) cannot be < 0
break_loop = false;
x_beta = 0.0;
for( ; ; ){
temp2 = y_node.at(person, 0); // time of event for current person
n_events = 0 ; // number of deaths at this time point
weight_events = 0 ; // sum of w_node for the deaths
denom_events = 0 ; // sum of weighted risks for the deaths
// walk through this set of tied times
while(y_node.at(person, 0) == temp2){
n_risk++;
risk = w_node.at(person);
if (y_node.at(person, 1) == 0) {
denom += risk;
/* a contains weighted sums of x, cmat sums of squares */
for (i=0; i<n_vars; i++) {
temp1 = risk * x_node.at(person, i);
a[i] += temp1;
for (j = 0; j <= i; j++){
cmat.at(j, i) += temp1 * x_node.at(person, j);
}
}
} else {
n_events++;
denom_events += risk;
for (i=0; i<n_vars; i++) {
temp1 = risk * x_node.at(person, i);
u[i] += temp1;
a2[i] += temp1;
for (j=0; j<=i; j++){
cmat2.at(j, i) += temp1 * x_node.at(person, j);
}
}
}
if(person == 0){
break_loop = true;
break;
}
person--;
}
// we need to add to the main terms
if (n_events > 0) {
if (method == 0 || n_events == 1) { // Breslow
denom += denom_events;
loglik -= denom_events * log(denom);
for (i=0; i<n_vars; i++) {
a[i] += a2[i];
temp1 = a[i] / denom; // mean
u[i] -= denom_events * temp1;
for (j=0; j<=i; j++) {
cmat.at(j, i) += cmat2.at(j, i);
imat.at(j, i) += denom_events * (cmat.at(j, i) - temp1 * a[j]) / denom;
}
}
} else {
/* Efron
** If there are 3 deaths we have 3 terms: in the first the
** three deaths are all in, in the second they are 2/3
** in the sums, and in the last 1/3 in the sum. Let k go
** 1 to n_events: we sequentially add a2/n_events and cmat2/n_events
** and efron_wt/n_events to the totals.
*/
weight_avg = denom_events/n_events;
for (k = 0; k < n_events; k++) {
denom += denom_events / n_events;
loglik -= weight_avg * log(denom);
for (i = 0; i < n_vars; i++) {
a[i] += a2[i] / n_events;
temp1 = a[i] / denom;
u[i] -= weight_avg * temp1;
for (j=0; j<=i; j++) {
cmat.at(j, i) += cmat2.at(j, i) / n_events;
imat.at(j, i) += weight_avg * (cmat.at(j, i) - temp1 * a[j]) / denom;
}
}
}
}
a2.fill(0);
cmat2.fill(0);
}
if(break_loop) break;
}
return(loglik);
}
// [[Rcpp::export]]
arma::mat c_cox_fit(NumericMatrix& x,
NumericMatrix& y,
NumericVector& weights,
const arma::uword& method=0,
const double& eps = 1/10,
const arma::uword& iter_max = 20){
x_node = arma::mat(x.begin(), x.nrow(), x.ncol(), false);
y_node = arma::mat(y.begin(), y.nrow(), y.ncol(), false);
w_node = arma::vec(weights.begin(), weights.size(), false);
n_vars = x_node.n_cols;
arma::mat x_transforms = x_scale_wtd();
double stat_current, stat_best, halving = 0;
// set_size is fast but unsafe, initial values are random
beta_current.set_size(n_vars);
beta_new.set_size(n_vars);
// fill with 0's, o.w. different betas every time you run this
beta_current.fill(0);
beta_new.fill(0);
// these are filled with initial values later
XB.set_size(x_node.n_rows);
Risk.set_size(x_node.n_rows);
u.set_size(n_vars);
a.set_size(n_vars);
a2.set_size(n_vars);
imat.set_size(n_vars, n_vars);
cmat.set_size(n_vars, n_vars);
cmat2.set_size(n_vars, n_vars);
// do the initial iteration
stat_best = newtraph_cph_init(method);
//Rcpp::Rcout << "stat_best: " << stat_best << std::endl;
// update beta_current
cholesky();
cholesky_solve();
beta_new = beta_current + u;
if(iter_max > 1 && stat_best < R_PosInf){
for(iter = 1; iter < iter_max; iter++){
// do the next iteration
stat_current = newtraph_cph_iter(method, beta_new);
//Rcpp::Rcout << "stat_current: " << stat_current << std::endl;
cholesky();
// check for convergence
// break the loop if the new ll is ~ same as old best ll
if(abs(1 - stat_best / stat_current) < eps){
break;
}
if(stat_current < stat_best){ // it's not converging!
halving++; // get more aggressive when it doesn't work
// reduce the magnitude by which beta_new modifies beta_current
for (i = 0; i < n_vars; i++){
beta_new[i] = (beta_new[i]+halving*beta_current[i]) / (halving+1.0);
}
} else { // it's converging!
halving = 0;
stat_best = stat_current;
cholesky_solve();
for (i = 0; i < n_vars; i++) {
beta_current[i] = beta_new[i];
beta_new[i] = beta_new[i] + u[i];
}
}
}
}
// invert imat and return to original scale
cholesky_invert();
for (i = 0; i < n_vars; i++) {
beta_current[i] = beta_new[i];
}
for (i=0; i < n_vars; i++) {
beta_current[i] *= x_transforms.at(i, 1);
imat.at(i, i) *= x_transforms.at(i, 1) * x_transforms.at(i, 1);
}
for(i = 0; i < n_vars; i++){
if(std::isinf(beta_current[i])) beta_current[i] = 0;
if(std::isinf(imat.at(i, i))) imat.at(i, i) = 1.0;
}
arma::vec se = arma::sqrt(imat.diag());
arma::vec pv(n_vars);
for(i = 0; i < n_vars; i++){
pv[i] = R::pchisq(pow(beta_current[i] / se[i], 2), 1, false, false);
}
arma::mat out(n_vars, 3);
out.col(0) = beta_current;
out.col(1) = se;
out.col(2) = pv;
return(out);
}