Stan  2.10.0
probability, sampling & optimization
Public Member Functions | Protected Attributes | List of all members
stan::variational::advi< Model, Q, BaseRNG > Class Template Reference

Automatic Differentiation Variational Inference. More...

#include <advi.hpp>

Public Member Functions

 advi (Model &m, Eigen::VectorXd &cont_params, BaseRNG &rng, int n_monte_carlo_grad, int n_monte_carlo_elbo, int eval_elbo, int n_posterior_samples)
 Constructor. More...
 
double calc_ELBO (const Q &variational, interface_callbacks::writer::base_writer &message_writer) const
 Calculates the Evidence Lower BOund (ELBO) by sampling from the variational distribution and then evaluating the log joint, adjusted by the entropy term of the variational distribution. More...
 
void calc_ELBO_grad (const Q &variational, Q &elbo_grad, interface_callbacks::writer::base_writer &message_writer) const
 Calculates the "black box" gradient of the ELBO. More...
 
double adapt_eta (Q &variational, int adapt_iterations, interface_callbacks::writer::base_writer &message_writer) const
 Heuristic grid search to adapt eta to the scale of the problem. More...
 
void stochastic_gradient_ascent (Q &variational, double eta, double tol_rel_obj, int max_iterations, interface_callbacks::writer::base_writer &message_writer, interface_callbacks::writer::base_writer &diagnostic_writer) const
 Runs stochastic gradient ascent with an adaptive stepsize sequence. More...
 
int run (double eta, bool adapt_engaged, int adapt_iterations, double tol_rel_obj, int max_iterations, interface_callbacks::writer::base_writer &message_writer, interface_callbacks::writer::base_writer &parameter_writer, interface_callbacks::writer::base_writer &diagnostic_writer) const
 Runs ADVI and writes to output. More...
 
double circ_buff_median (const boost::circular_buffer< double > &cb) const
 Compute the median of a circular buffer. More...
 
double rel_difference (double prev, double curr) const
 Compute the relative difference between two double values. More...
 

Protected Attributes

Model & model_
 
Eigen::VectorXd & cont_params_
 
BaseRNG & rng_
 
int n_monte_carlo_grad_
 
int n_monte_carlo_elbo_
 
int eval_elbo_
 
int n_posterior_samples_
 

Detailed Description

template<class Model, class Q, class BaseRNG>
class stan::variational::advi< Model, Q, BaseRNG >

Automatic Differentiation Variational Inference.

Implements "black box" variational inference using stochastic gradient ascent to maximize the Evidence Lower Bound for a given model and variational family.

Template Parameters
Modelclass of model
Qclass of variational distribution
BaseRNGclass of random number generator

Definition at line 40 of file advi.hpp.

Constructor & Destructor Documentation

template<class Model , class Q , class BaseRNG >
stan::variational::advi< Model, Q, BaseRNG >::advi ( Model &  m,
Eigen::VectorXd &  cont_params,
BaseRNG &  rng,
int  n_monte_carlo_grad,
int  n_monte_carlo_elbo,
int  eval_elbo,
int  n_posterior_samples 
)
inline

Constructor.

Parameters
mstan model
cont_paramsinitialization of continuous parameters
rngrandom number generator
n_monte_carlo_gradnumber of samples for gradient computation
n_monte_carlo_elbonumber of samples for ELBO computation
eval_elboevaluate ELBO at every "eval_elbo" iters
n_posterior_samplesnumber of samples to draw from posterior
Exceptions
std::runtime_errorif n_monte_carlo_grad is not positive
std::runtime_errorif n_monte_carlo_elbo is not positive
std::runtime_errorif eval_elbo is not positive
std::runtime_errorif n_posterior_samples is not positive

Definition at line 57 of file advi.hpp.

Member Function Documentation

template<class Model , class Q , class BaseRNG >
double stan::variational::advi< Model, Q, BaseRNG >::adapt_eta ( Q &  variational,
int  adapt_iterations,
interface_callbacks::writer::base_writer message_writer 
) const
inline

Heuristic grid search to adapt eta to the scale of the problem.

Parameters
[in]variationalinitial variational distribution.
adapt_iterationsnumber of iterations to spend doing stochastic gradient ascent at each proposed eta value.
message_writerwriter for messages
Returns
adapted (tuned) value of eta via heuristic grid search
Exceptions
std::domain_errorIf either (a) the initial ELBO cannot be computed at the initial variational distribution, (b) all step-size proposals in eta_sequence fail.

Definition at line 182 of file advi.hpp.

template<class Model , class Q , class BaseRNG >
double stan::variational::advi< Model, Q, BaseRNG >::calc_ELBO ( const Q &  variational,
interface_callbacks::writer::base_writer message_writer 
) const
inline

Calculates the Evidence Lower BOund (ELBO) by sampling from the variational distribution and then evaluating the log joint, adjusted by the entropy term of the variational distribution.

Parameters
[in]variationalvariational approximation at which to evaluate the ELBO.
message_writerwriter for messages
Returns
the evidence lower bound.
Exceptions
std::domain_errorIf, after n_monte_carlo_elbo_ number of draws from the variational distribution all give non-finite log joint evaluations. This means that the model is severly ill conditioned or that the variational distribution has somehow collapsed.

Definition at line 100 of file advi.hpp.

template<class Model , class Q , class BaseRNG >
void stan::variational::advi< Model, Q, BaseRNG >::calc_ELBO_grad ( const Q &  variational,
Q &  elbo_grad,
interface_callbacks::writer::base_writer message_writer 
) const
inline

Calculates the "black box" gradient of the ELBO.

Parameters
[in]variationalvariational approximation at which to evaluate the ELBO.
[out]elbo_gradgradient of ELBO with respect to variational approximation.
message_writerwriter for messages

Definition at line 147 of file advi.hpp.

template<class Model , class Q , class BaseRNG >
double stan::variational::advi< Model, Q, BaseRNG >::circ_buff_median ( const boost::circular_buffer< double > &  cb) const
inline

Compute the median of a circular buffer.

Parameters
cbcircular buffer with some number of values in it.
Returns
median of values in circular buffer.

Definition at line 552 of file advi.hpp.

template<class Model , class Q , class BaseRNG >
double stan::variational::advi< Model, Q, BaseRNG >::rel_difference ( double  prev,
double  curr 
) const
inline

Compute the relative difference between two double values.

Parameters
prevprevious value
currcurrent value
Returns
absolutely value of relative difference

Definition at line 572 of file advi.hpp.

template<class Model , class Q , class BaseRNG >
int stan::variational::advi< Model, Q, BaseRNG >::run ( double  eta,
bool  adapt_engaged,
int  adapt_iterations,
double  tol_rel_obj,
int  max_iterations,
interface_callbacks::writer::base_writer message_writer,
interface_callbacks::writer::base_writer parameter_writer,
interface_callbacks::writer::base_writer diagnostic_writer 
) const
inline

Runs ADVI and writes to output.

Parameters
etaeta parameter of stepsize sequence
adapt_engagedboolean flag for eta adaptation
adapt_iterationsnumber of iterations for eta adaptation
tol_rel_objrelative tolerance parameter for convergence
max_iterationsmax number of iterations to run algorithm
message_writerwriter for messages
parameter_writerwriter for parameters (typically to file)
diagnostic_writerwriter for diagnostic information

Definition at line 487 of file advi.hpp.

template<class Model , class Q , class BaseRNG >
void stan::variational::advi< Model, Q, BaseRNG >::stochastic_gradient_ascent ( Q &  variational,
double  eta,
double  tol_rel_obj,
int  max_iterations,
interface_callbacks::writer::base_writer message_writer,
interface_callbacks::writer::base_writer diagnostic_writer 
) const
inline

Runs stochastic gradient ascent with an adaptive stepsize sequence.

Parameters
[in,out]variationalinitia variational distribution
etastepsize scaling parameter
tol_rel_objrelative tolerance parameter for convergence
max_iterationsmax number of iterations to run algorithm
message_writerwriter for mesasges
diagnostic_writerwriter for diagnostic information
Exceptions
std::domain_errorIf the ELBO or its gradient is ever non-finite, at any iteration

Definition at line 328 of file advi.hpp.

Member Data Documentation

template<class Model , class Q , class BaseRNG >
Eigen::VectorXd& stan::variational::advi< Model, Q, BaseRNG >::cont_params_
protected

Definition at line 578 of file advi.hpp.

template<class Model , class Q , class BaseRNG >
int stan::variational::advi< Model, Q, BaseRNG >::eval_elbo_
protected

Definition at line 582 of file advi.hpp.

template<class Model , class Q , class BaseRNG >
Model& stan::variational::advi< Model, Q, BaseRNG >::model_
protected

Definition at line 577 of file advi.hpp.

template<class Model , class Q , class BaseRNG >
int stan::variational::advi< Model, Q, BaseRNG >::n_monte_carlo_elbo_
protected

Definition at line 581 of file advi.hpp.

template<class Model , class Q , class BaseRNG >
int stan::variational::advi< Model, Q, BaseRNG >::n_monte_carlo_grad_
protected

Definition at line 580 of file advi.hpp.

template<class Model , class Q , class BaseRNG >
int stan::variational::advi< Model, Q, BaseRNG >::n_posterior_samples_
protected

Definition at line 583 of file advi.hpp.

template<class Model , class Q , class BaseRNG >
BaseRNG& stan::variational::advi< Model, Q, BaseRNG >::rng_
protected

Definition at line 579 of file advi.hpp.


The documentation for this class was generated from the following file:

     [ Stan Home Page ] © 2011–2016, Stan Development Team.