1 #ifndef STAN_VARIATIONAL_ADVI_HPP
2 #define STAN_VARIATIONAL_ADVI_HPP
4 #include <stan/math.hpp>
14 #include <boost/circular_buffer.hpp>
15 #include <boost/lexical_cast.hpp>
26 namespace variational {
39 template <
class Model,
class Q,
class BaseRNG>
58 Eigen::VectorXd& cont_params,
60 int n_monte_carlo_grad,
61 int n_monte_carlo_elbo,
63 int n_posterior_samples)
71 static const char*
function =
"stan::variational::advi";
72 math::check_positive(
function,
73 "Number of Monte Carlo samples for gradients",
75 math::check_positive(
function,
76 "Number of Monte Carlo samples for ELBO",
78 math::check_positive(
function,
79 "Evaluate ELBO at every eval_elbo iteration",
81 math::check_positive(
function,
82 "Number of posterior samples for output",
103 static const char*
function =
104 "stan::variational::advi::calc_ELBO";
107 int dim = variational.dimension();
108 Eigen::VectorXd zeta(dim);
110 int n_dropped_evaluations = 0;
112 variational.sample(
rng_, zeta);
114 std::stringstream ss;
115 double log_prob =
model_.template log_prob<false, true>(zeta, &ss);
116 if (ss.str().length() > 0)
117 message_writer(ss.str());
118 stan::math::check_finite(
function,
"log_prob", log_prob);
121 }
catch (
const std::domain_error& e) {
122 ++n_dropped_evaluations;
123 if (n_dropped_evaluations >= n_monte_carlo_elbo_) {
124 const char* name =
"The number of dropped evaluations";
125 const char* msg1 =
"has reached its maximum amount (";
126 const char* msg2 =
"). Your model may be either severely "
127 "ill-conditioned or misspecified.";
128 stan::math::domain_error(
function, name, n_monte_carlo_elbo_,
134 elbo += variational.entropy();
151 static const char*
function =
152 "stan::variational::advi::calc_ELBO_grad";
154 stan::math::check_size_match(
function,
155 "Dimension of elbo_grad",
156 elbo_grad.dimension(),
157 "Dimension of variational q",
158 variational.dimension());
159 stan::math::check_size_match(
function,
160 "Dimension of variational q",
161 variational.dimension(),
162 "Dimension of variables in model",
165 variational.calc_grad(elbo_grad,
183 int adapt_iterations,
186 static const char*
function =
"stan::variational::advi::adapt_eta";
188 stan::math::check_positive(
function,
189 "Number of adaptation iterations",
192 message_writer(
"Begin eta adaptation.");
195 const int eta_sequence_size = 5;
196 double eta_sequence[eta_sequence_size] = {100, 10, 1, 0.1, 0.01};
199 double elbo = -std::numeric_limits<double>::max();
200 double elbo_best = -std::numeric_limits<double>::max();
203 elbo_init =
calc_ELBO(variational, message_writer);
204 }
catch (
const std::domain_error& e) {
205 const char* name =
"Cannot compute ELBO using the initial "
206 "variational distribution.";
207 const char* msg1 =
"Your model may be either "
208 "severely ill-conditioned or misspecified.";
209 stan::math::domain_error(
function, name,
"", msg1);
213 Q elbo_grad = Q(
model_.num_params_r());
216 Q history_grad_squared = Q(
model_.num_params_r());
218 double pre_factor = 0.9;
219 double post_factor = 0.1;
221 double eta_best = 0.0;
225 bool do_more_tuning =
true;
226 int eta_sequence_index = 0;
227 while (do_more_tuning) {
229 eta = eta_sequence[eta_sequence_index];
231 int print_progress_m;
232 for (
int iter_tune = 1; iter_tune <= adapt_iterations; ++iter_tune) {
233 print_progress_m = eta_sequence_index
234 * adapt_iterations + iter_tune;
237 adapt_iterations * eta_sequence_size,
238 adapt_iterations,
true,
"",
"", message_writer);
244 }
catch (
const std::domain_error& e) {
245 elbo_grad.set_to_zero();
249 if (iter_tune == 1) {
250 history_grad_squared += elbo_grad.square();
252 history_grad_squared = pre_factor * history_grad_squared
253 + post_factor * elbo_grad.square();
255 eta_scaled = eta / sqrt(static_cast<double>(iter_tune));
257 variational += eta_scaled * elbo_grad
258 / (tau + history_grad_squared.sqrt());
263 elbo =
calc_ELBO(variational, message_writer);
264 }
catch (
const std::domain_error& e) {
265 elbo = -std::numeric_limits<double>::max();
271 if (elbo < elbo_best && elbo_best > elbo_init) {
272 std::stringstream ss;
274 <<
" Found best value [eta = " << eta_best
276 if (eta_sequence_index < eta_sequence_size - 1)
277 ss << (
" earlier than expected.");
280 message_writer(ss.str());
282 do_more_tuning =
false;
284 if (eta_sequence_index < eta_sequence_size - 1) {
291 if (elbo > elbo_init) {
292 std::stringstream ss;
294 <<
" Found best value [eta = " << eta_best
296 message_writer(ss.str());
299 do_more_tuning =
false;
301 const char* name =
"All proposed step-sizes";
302 const char* msg1 =
"failed. Your model may be either "
303 "severely ill-conditioned or misspecified.";
304 stan::math::domain_error(
function, name,
"", msg1);
308 history_grad_squared.set_to_zero();
310 ++eta_sequence_index;
335 static const char*
function =
336 "stan::variational::advi::stochastic_gradient_ascent";
338 stan::math::check_positive(
function,
"Eta stepsize", eta);
339 stan::math::check_positive(
function,
340 "Relative objective function tolerance",
342 stan::math::check_positive(
function,
343 "Maximum iterations",
347 Q elbo_grad = Q(
model_.num_params_r());
350 Q history_grad_squared = Q(
model_.num_params_r());
352 double pre_factor = 0.9;
353 double post_factor = 0.1;
358 double elbo_best = -std::numeric_limits<double>::max();
359 double elbo_prev = -std::numeric_limits<double>::max();
360 double delta_elbo = std::numeric_limits<double>::max();
361 double delta_elbo_ave = std::numeric_limits<double>::max();
362 double delta_elbo_med = std::numeric_limits<double>::max();
366 =
static_cast<int>(std::max(0.1 * max_iterations /
eval_elbo_,
368 boost::circular_buffer<double> elbo_diff(cb_size);
370 message_writer(
"Begin stochastic gradient ascent.");
371 message_writer(
" iter"
378 clock_t start = clock();
383 bool do_more_iterations =
true;
384 for (
int iter_counter = 1; do_more_iterations; ++iter_counter) {
389 if (iter_counter == 1) {
390 history_grad_squared += elbo_grad.square();
392 history_grad_squared = pre_factor * history_grad_squared
393 + post_factor * elbo_grad.square();
395 eta_scaled = eta / sqrt(static_cast<double>(iter_counter));
398 variational += eta_scaled * elbo_grad
399 / (tau + history_grad_squared.sqrt());
404 elbo =
calc_ELBO(variational, message_writer);
405 if (elbo > elbo_best)
408 elbo_diff.push_back(delta_elbo);
409 delta_elbo_ave = std::accumulate(elbo_diff.begin(),
410 elbo_diff.end(), 0.0)
411 /
static_cast<double>(elbo_diff.size());
413 std::stringstream ss;
415 << std::setw(4) << iter_counter
417 << std::right << std::setw(9) << std::setprecision(1)
420 << std::setw(16) << std::fixed << std::setprecision(3)
423 << std::setw(15) << std::fixed << std::setprecision(3)
427 delta_t =
static_cast<double>(end - start) / CLOCKS_PER_SEC;
429 std::vector<double> print_vector;
430 print_vector.clear();
431 print_vector.push_back(iter_counter);
432 print_vector.push_back(delta_t);
433 print_vector.push_back(elbo);
434 diagnostic_writer(print_vector);
436 if (delta_elbo_ave < tol_rel_obj) {
437 ss <<
" MEAN ELBO CONVERGED";
438 do_more_iterations =
false;
441 if (delta_elbo_med < tol_rel_obj) {
442 ss <<
" MEDIAN ELBO CONVERGED";
443 do_more_iterations =
false;
447 if (delta_elbo_med > 0.5 || delta_elbo_ave > 0.5) {
448 ss <<
" MAY BE DIVERGING... INSPECT ELBO";
452 message_writer(ss.str());
454 if (do_more_iterations ==
false &&
456 message_writer(
"Informational Message: The ELBO at a previous "
457 "iteration is larger than the ELBO upon "
459 message_writer(
"This variational approximation may not "
460 "have converged to a good optimum.");
464 if (iter_counter == max_iterations) {
465 message_writer(
"Informational Message: The maximum number of "
466 "iterations is reached! The algorithm may not have "
468 message_writer(
"This variational approximation is not "
469 "guaranteed to be meaningful.");
470 do_more_iterations =
false;
487 int run(
double eta,
bool adapt_engaged,
int adapt_iterations,
488 double tol_rel_obj,
int max_iterations,
493 diagnostic_writer(
"iter,time_in_seconds,ELBO");
499 eta =
adapt_eta(variational, adapt_iterations, message_writer);
500 parameter_writer(
"Stepsize adaptation complete.");
501 std::stringstream ss;
502 ss <<
"eta = " << eta;
503 parameter_writer(ss.str());
507 tol_rel_obj, max_iterations,
508 message_writer, diagnostic_writer);
515 std::vector<int> disc_vector;
518 0, cont_vector, disc_vector,
523 std::stringstream ss;
524 ss <<
"Drawing a sample of size "
526 <<
" from the approximate posterior... ";
527 message_writer(ss.str());
535 0, cont_vector, disc_vector,
539 message_writer(
"COMPLETED.");
554 std::vector<double> v;
555 for (boost::circular_buffer<double>::const_iterator i = cb.begin();
556 i != cb.end(); ++i) {
560 size_t n = v.size() / 2;
561 std::nth_element(v.begin(), v.begin()+n, v.end());
573 return std::fabs((curr - prev) / prev);
Automatic Differentiation Variational Inference.
Probability, optimization and sampling library.
double adapt_eta(Q &variational, int adapt_iterations, interface_callbacks::writer::base_writer &message_writer) const
Heuristic grid search to adapt eta to the scale of the problem.
void write_iteration(Model &model, RNG &base_rng, double lp, std::vector< double > &cont_vector, std::vector< int > &disc_vector, interface_callbacks::writer::base_writer &message_writer, interface_callbacks::writer::base_writer ¶meter_writer)
void calc_ELBO_grad(const Q &variational, Q &elbo_grad, interface_callbacks::writer::base_writer &message_writer) const
Calculates the "black box" gradient of the ELBO.
base_writer is an abstract base class defining the interface for Stan writer callbacks.
int run(double eta, bool adapt_engaged, int adapt_iterations, double tol_rel_obj, int max_iterations, interface_callbacks::writer::base_writer &message_writer, interface_callbacks::writer::base_writer ¶meter_writer, interface_callbacks::writer::base_writer &diagnostic_writer) const
Runs ADVI and writes to output.
advi(Model &m, Eigen::VectorXd &cont_params, BaseRNG &rng, int n_monte_carlo_grad, int n_monte_carlo_elbo, int eval_elbo, int n_posterior_samples)
Constructor.
double rel_difference(double prev, double curr) const
Compute the relative difference between two double values.
void stochastic_gradient_ascent(Q &variational, double eta, double tol_rel_obj, int max_iterations, interface_callbacks::writer::base_writer &message_writer, interface_callbacks::writer::base_writer &diagnostic_writer) const
Runs stochastic gradient ascent with an adaptive stepsize sequence.
double circ_buff_median(const boost::circular_buffer< double > &cb) const
Compute the median of a circular buffer.
Eigen::VectorXd & cont_params_
void print_progress(int m, int start, int finish, int refresh, bool tune, const std::string &prefix, const std::string &suffix, interface_callbacks::writer::base_writer &writer)
Helper function for printing progress for variational inference.
double calc_ELBO(const Q &variational, interface_callbacks::writer::base_writer &message_writer) const
Calculates the Evidence Lower BOund (ELBO) by sampling from the variational distribution and then eva...