Stan  2.10.0
probability, sampling & optimization
base_xhmc.hpp
Go to the documentation of this file.
1 #ifndef STAN_MCMC_HMC_NUTS_BASE_XHMC_HPP
2 #define STAN_MCMC_HMC_NUTS_BASE_XHMC_HPP
3 
5 #include <boost/math/special_functions/fpclassify.hpp>
6 #include <stan/math/prim/scal/fun/log_sum_exp.hpp>
9 #include <algorithm>
10 #include <cmath>
11 #include <limits>
12 #include <string>
13 #include <vector>
14 
15 namespace stan {
16  namespace mcmc {
40  void stable_sum(double a1, double log_w1, double a2, double log_w2,
41  double& sum_a, double& log_sum_w) {
42  if (log_w2 > log_w1) {
43  double e = std::exp(log_w1 - log_w2);
44  sum_a = (e * a1 + a2) / (1 + e);
45  log_sum_w = log_w2 + std::log(1 + e);
46  } else {
47  double e = std::exp(log_w2 - log_w1);
48  sum_a = (a1 + e * a2) / (1 + e);
49  log_sum_w = log_w1 + std::log(1 + e);
50  }
51  }
52 
57  template <class Model, template<class, class> class Hamiltonian,
58  template<class> class Integrator, class BaseRNG>
59  class base_xhmc : public base_hmc<Model, Hamiltonian, Integrator, BaseRNG> {
60  public:
61  base_xhmc(const Model& model, BaseRNG& rng)
62  : base_hmc<Model, Hamiltonian, Integrator, BaseRNG>(model, rng),
63  depth_(0), max_depth_(5), max_deltaH_(1000), x_delta_(0.1),
64  n_leapfrog_(0), divergent_(0), energy_(0) {
65  }
66 
68 
69  void set_max_depth(int d) {
70  if (d > 0)
71  max_depth_ = d;
72  }
73 
74  void set_max_deltaH(double d) {
75  max_deltaH_ = d;
76  }
77 
78  void set_x_delta(double d) {
79  if (d > 0)
80  x_delta_ = d;
81  }
82 
83  int get_max_depth() { return this->max_depth_; }
84  double get_max_deltaH() { return this->max_deltaH_; }
85  double get_x_delta() { return this->x_delta_; }
86 
87  sample
88  transition(sample& init_sample,
91  // Initialize the algorithm
92  this->sample_stepsize();
93 
94  this->seed(init_sample.cont_params());
95 
96  this->hamiltonian_.sample_p(this->z_, this->rand_int_);
97  this->hamiltonian_.init(this->z_, info_writer, error_writer);
98 
99  ps_point z_plus(this->z_);
100  ps_point z_minus(z_plus);
101 
102  ps_point z_sample(z_plus);
103  ps_point z_propose(z_plus);
104 
105  double ave = this->hamiltonian_.dG_dt(this->z_,
106  info_writer, error_writer);
107  double log_sum_weight = 0; // log(exp(H0 - H0))
108 
109  double H0 = this->hamiltonian_.H(this->z_);
110  int n_leapfrog = 0;
111  double sum_metro_prob = 1; // exp(H0 - H0)
112 
113  // Build a trajectory until the NUTS criterion is no longer satisfied
114  this->depth_ = 0;
115  this->divergent_ = 0;
116 
117  while (this->depth_ < this->max_depth_) {
118  // Build a new subtree in a random direction
119  bool valid_subtree = false;
120  double ave_subtree = 0;
121  double log_sum_weight_subtree
122  = -std::numeric_limits<double>::infinity();
123 
124  if (this->rand_uniform_() > 0.5) {
125  this->z_.ps_point::operator=(z_plus);
126  valid_subtree
127  = build_tree(this->depth_, z_propose,
128  ave_subtree, log_sum_weight_subtree,
129  H0, 1, n_leapfrog, sum_metro_prob,
130  info_writer, error_writer);
131  z_plus.ps_point::operator=(this->z_);
132  } else {
133  this->z_.ps_point::operator=(z_minus);
134  valid_subtree
135  = build_tree(this->depth_, z_propose,
136  ave_subtree, log_sum_weight_subtree,
137  H0, -1, n_leapfrog, sum_metro_prob,
138  info_writer, error_writer);
139  z_minus.ps_point::operator=(this->z_);
140  }
141 
142  if (!valid_subtree) break;
143  stable_sum(ave, log_sum_weight,
144  ave_subtree, log_sum_weight_subtree,
145  ave, log_sum_weight);
146 
147  // Sample from an accepted subtree
148  ++(this->depth_);
149 
150  double accept_prob
151  = std::exp(log_sum_weight_subtree - log_sum_weight);
152  if (this->rand_uniform_() < accept_prob)
153  z_sample = z_propose;
154 
155  // Break if exhaustion criterion is satisfied
156  if (std::fabs(ave) < x_delta_)
157  break;
158  }
159 
160  this->n_leapfrog_ = n_leapfrog;
161 
162  // Compute average acceptance probabilty across entire trajectory,
163  // even over subtrees that may have been rejected
164  double accept_prob
165  = sum_metro_prob / static_cast<double>(n_leapfrog + 1);
166 
167  this->z_.ps_point::operator=(z_sample);
168  this->energy_ = this->hamiltonian_.H(this->z_);
169  return sample(this->z_.q, -this->z_.V, accept_prob);
170  }
171 
172  void get_sampler_param_names(std::vector<std::string>& names) {
173  names.push_back("stepsize__");
174  names.push_back("treedepth__");
175  names.push_back("n_leapfrog__");
176  names.push_back("divergent__");
177  names.push_back("energy__");
178  }
179 
180  void get_sampler_params(std::vector<double>& values) {
181  values.push_back(this->epsilon_);
182  values.push_back(this->depth_);
183  values.push_back(this->n_leapfrog_);
184  values.push_back(this->divergent_);
185  values.push_back(this->energy_);
186  }
187 
204  int build_tree(int depth, ps_point& z_propose,
205  double& ave, double& log_sum_weight,
206  double H0, double sign, int& n_leapfrog,
207  double& sum_metro_prob,
210  // Base case
211  if (depth == 0) {
212  this->integrator_.evolve(this->z_, this->hamiltonian_,
213  sign * this->epsilon_,
214  info_writer, error_writer);
215  ++n_leapfrog;
216 
217  double h = this->hamiltonian_.H(this->z_);
218  if (boost::math::isnan(h))
219  h = std::numeric_limits<double>::infinity();
220 
221  if ((h - H0) > this->max_deltaH_) this->divergent_ = true;
222 
223  double dG_dt = this->hamiltonian_.dG_dt(this->z_,
224  info_writer,
225  error_writer);
226 
227  stable_sum(ave, log_sum_weight,
228  dG_dt, H0 - h,
229  ave, log_sum_weight);
230 
231  if (H0 - h > 0)
232  sum_metro_prob += 1;
233  else
234  sum_metro_prob += std::exp(H0 - h);
235 
236  z_propose = this->z_;
237 
238  return !this->divergent_;
239  }
240  // General recursion
241 
242  // Build the left subtree
243  double ave_left = 0;
244  double log_sum_weight_left = -std::numeric_limits<double>::infinity();
245 
246  bool valid_left
247  = build_tree(depth - 1, z_propose,
248  ave_left, log_sum_weight_left,
249  H0, sign, n_leapfrog, sum_metro_prob,
250  info_writer, error_writer);
251 
252  if (!valid_left) return false;
253  stable_sum(ave, log_sum_weight,
254  ave_left, log_sum_weight_left,
255  ave, log_sum_weight);
256 
257  // Build the right subtree
258  ps_point z_propose_right(this->z_);
259  double ave_right = 0;
260  double log_sum_weight_right = -std::numeric_limits<double>::infinity();
261 
262  bool valid_right
263  = build_tree(depth - 1, z_propose_right,
264  ave_right, log_sum_weight_right,
265  H0, sign, n_leapfrog, sum_metro_prob,
266  info_writer, error_writer);
267 
268  if (!valid_right) return false;
269  stable_sum(ave, log_sum_weight,
270  ave_right, log_sum_weight_right,
271  ave, log_sum_weight);
272 
273  // Multinomial sample from right subtree
274  double ave_subtree;
275  double log_sum_weight_subtree;
276  stable_sum(ave_left, log_sum_weight_left,
277  ave_right, log_sum_weight_right,
278  ave_subtree, log_sum_weight_subtree);
279 
280  double accept_prob
281  = std::exp(log_sum_weight_right - log_sum_weight_subtree);
282  if (this->rand_uniform_() < accept_prob)
283  z_propose = z_propose_right;
284 
285  return std::fabs(ave_subtree) >= x_delta_;
286  }
287 
288  int depth_;
290  double max_deltaH_;
291  double x_delta_;
292 
295  double energy_;
296  };
297 
298  } // mcmc
299 } // stan
300 #endif
Hamiltonian< Model, BaseRNG >::PointType z_
Definition: base_hmc.hpp:164
base_xhmc(const Model &model, BaseRNG &rng)
Definition: base_xhmc.hpp:61
sample transition(sample &init_sample, interface_callbacks::writer::base_writer &info_writer, interface_callbacks::writer::base_writer &error_writer)
Definition: base_xhmc.hpp:88
Exhaustive Hamiltonian Monte Carlo (XHMC) with multinomial sampling.
Definition: base_xhmc.hpp:59
void sample(stan::mcmc::base_mcmc *sampler, int num_warmup, int num_samples, int num_thin, int refresh, bool save, stan::services::sample::mcmc_writer< Model, SampleRecorder, DiagnosticRecorder, MessageRecorder > &mcmc_writer, stan::mcmc::sample &init_s, Model &model, RNG &base_rng, const std::string &prefix, const std::string &suffix, std::ostream &o, StartTransitionCallback &callback, interface_callbacks::writer::base_writer &info_writer, interface_callbacks::writer::base_writer &error_writer)
Definition: sample.hpp:17
Probability, optimization and sampling library.
Point in a generic phase space.
Definition: ps_point.hpp:17
void set_max_depth(int d)
Definition: base_xhmc.hpp:69
int build_tree(int depth, ps_point &z_propose, double &ave, double &log_sum_weight, double H0, double sign, int &n_leapfrog, double &sum_metro_prob, interface_callbacks::writer::base_writer &info_writer, interface_callbacks::writer::base_writer &error_writer)
Recursively build a new subtree to completion or until the subtree becomes invalid.
Definition: base_xhmc.hpp:204
double cont_params(int k) const
Definition: sample.hpp:25
base_writer is an abstract base class defining the interface for Stan writer callbacks.
Definition: base_writer.hpp:20
void stable_sum(double a1, double log_w1, double a2, double log_w2, double &sum_a, double &log_sum_w)
a1 and a2 are running averages of the form and the weights are the respective normalizing constants...
Definition: base_xhmc.hpp:40
void set_x_delta(double d)
Definition: base_xhmc.hpp:78
void get_sampler_params(std::vector< double > &values)
Definition: base_xhmc.hpp:180
void seed(const Eigen::VectorXd &q)
Definition: base_hmc.hpp:53
void get_sampler_param_names(std::vector< std::string > &names)
Definition: base_xhmc.hpp:172
boost::uniform_01< BaseRNG & > rand_uniform_
Definition: base_hmc.hpp:171
void set_max_deltaH(double d)
Definition: base_xhmc.hpp:74
Hamiltonian< Model, BaseRNG > hamiltonian_
Definition: base_hmc.hpp:166
Integrator< Hamiltonian< Model, BaseRNG > > integrator_
Definition: base_hmc.hpp:165

     [ Stan Home Page ] © 2011–2016, Stan Development Team.