Stan  2.10.0
probability, sampling & optimization
base_nuts.hpp
Go to the documentation of this file.
1 #ifndef STAN_MCMC_HMC_NUTS_BASE_NUTS_HPP
2 #define STAN_MCMC_HMC_NUTS_BASE_NUTS_HPP
3 
5 #include <boost/math/special_functions/fpclassify.hpp>
6 #include <stan/math/prim/scal/fun/log_sum_exp.hpp>
9 #include <algorithm>
10 #include <cmath>
11 #include <limits>
12 #include <string>
13 #include <vector>
14 
15 namespace stan {
16  namespace mcmc {
20  template <class Model, template<class, class> class Hamiltonian,
21  template<class> class Integrator, class BaseRNG>
22  class base_nuts : public base_hmc<Model, Hamiltonian, Integrator, BaseRNG> {
23  public:
24  base_nuts(const Model& model, BaseRNG& rng)
25  : base_hmc<Model, Hamiltonian, Integrator, BaseRNG>(model, rng),
26  depth_(0), max_depth_(5), max_deltaH_(1000),
27  n_leapfrog_(0), divergent_(0), energy_(0) {
28  }
29 
31 
32  void set_max_depth(int d) {
33  if (d > 0)
34  max_depth_ = d;
35  }
36 
37  void set_max_delta(double d) {
38  max_deltaH_ = d;
39  }
40 
41  int get_max_depth() { return this->max_depth_; }
42  double get_max_delta() { return this->max_deltaH_; }
43 
44  sample
45  transition(sample& init_sample,
48  // Initialize the algorithm
49  this->sample_stepsize();
50 
51  this->seed(init_sample.cont_params());
52 
53  this->hamiltonian_.sample_p(this->z_, this->rand_int_);
54  this->hamiltonian_.init(this->z_, info_writer, error_writer);
55 
56  ps_point z_plus(this->z_);
57  ps_point z_minus(z_plus);
58 
59  ps_point z_sample(z_plus);
60  ps_point z_propose(z_plus);
61 
62  Eigen::VectorXd p_sharp_plus = this->hamiltonian_.dtau_dp(this->z_);
63  Eigen::VectorXd p_sharp_minus = this->hamiltonian_.dtau_dp(this->z_);
64  Eigen::VectorXd rho = this->z_.p;
65  double log_sum_weight = 0; // log(exp(H0 - H0))
66 
67  double H0 = this->hamiltonian_.H(this->z_);
68  int n_leapfrog = 0;
69  double sum_metro_prob = 1; // exp(H0 - H0)
70 
71  // Build a trajectory until the NUTS criterion is no longer satisfied
72  this->depth_ = 0;
73  this->divergent_ = 0;
74 
75  while (this->depth_ < this->max_depth_) {
76  // Build a new subtree in a random direction
77  Eigen::VectorXd rho_subtree(rho.size());
78  rho_subtree.setZero();
79 
80  bool valid_subtree = false;
81  double log_sum_weight_subtree
82  = -std::numeric_limits<double>::infinity();
83 
84  if (this->rand_uniform_() > 0.5) {
85  this->z_.ps_point::operator=(z_plus);
86  valid_subtree
87  = build_tree(this->depth_, rho_subtree, z_propose,
88  H0, 1, n_leapfrog,
89  log_sum_weight_subtree, sum_metro_prob,
90  info_writer, error_writer);
91  z_plus.ps_point::operator=(this->z_);
92  p_sharp_plus = this->hamiltonian_.dtau_dp(this->z_);
93  } else {
94  this->z_.ps_point::operator=(z_minus);
95  valid_subtree
96  = build_tree(this->depth_, rho_subtree, z_propose,
97  H0, -1, n_leapfrog,
98  log_sum_weight_subtree, sum_metro_prob,
99  info_writer, error_writer);
100  z_minus.ps_point::operator=(this->z_);
101  p_sharp_minus = this->hamiltonian_.dtau_dp(this->z_);
102  }
103 
104  if (!valid_subtree) break;
105 
106  // Sample from an accepted subtree
107  ++(this->depth_);
108 
109  double accept_prob
110  = std::exp(log_sum_weight_subtree - log_sum_weight);
111  if (this->rand_uniform_() < accept_prob)
112  z_sample = z_propose;
113 
114  if (log_sum_weight_subtree > log_sum_weight) {
115  z_sample = z_propose;
116  } else {
117  double accept_prob
118  = std::exp(log_sum_weight_subtree - log_sum_weight);
119  if (this->rand_uniform_() < accept_prob)
120  z_sample = z_propose;
121  }
122 
123  log_sum_weight
124  = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree);
125 
126  // Break when NUTS criterion is not longer satisfied
127  rho += rho_subtree;
128  if (!compute_criterion(p_sharp_minus, p_sharp_plus, rho))
129  break;
130  }
131 
132  this->n_leapfrog_ = n_leapfrog;
133 
134  // Compute average acceptance probabilty across entire trajectory,
135  // even over subtrees that may have been rejected
136  double accept_prob
137  = sum_metro_prob / static_cast<double>(n_leapfrog + 1);
138 
139  this->z_.ps_point::operator=(z_sample);
140  this->energy_ = this->hamiltonian_.H(this->z_);
141  return sample(this->z_.q, -this->z_.V, accept_prob);
142  }
143 
144  void get_sampler_param_names(std::vector<std::string>& names) {
145  names.push_back("stepsize__");
146  names.push_back("treedepth__");
147  names.push_back("n_leapfrog__");
148  names.push_back("divergent__");
149  names.push_back("energy__");
150  }
151 
152  void get_sampler_params(std::vector<double>& values) {
153  values.push_back(this->epsilon_);
154  values.push_back(this->depth_);
155  values.push_back(this->n_leapfrog_);
156  values.push_back(this->divergent_);
157  values.push_back(this->energy_);
158  }
159 
160  bool compute_criterion(Eigen::VectorXd& p_sharp_minus,
161  Eigen::VectorXd& p_sharp_plus,
162  Eigen::VectorXd& rho) {
163  return p_sharp_plus.dot(rho) > 0
164  && p_sharp_minus.dot(rho) > 0;
165  }
166 
183  int build_tree(int depth, Eigen::VectorXd& rho, ps_point& z_propose,
184  double H0, double sign, int& n_leapfrog,
185  double& log_sum_weight, double& sum_metro_prob,
188  // Base case
189  if (depth == 0) {
190  this->integrator_.evolve(this->z_, this->hamiltonian_,
191  sign * this->epsilon_,
192  info_writer, error_writer);
193  ++n_leapfrog;
194 
195  double h = this->hamiltonian_.H(this->z_);
196  if (boost::math::isnan(h))
197  h = std::numeric_limits<double>::infinity();
198 
199  if ((h - H0) > this->max_deltaH_) this->divergent_ = true;
200 
201  log_sum_weight = math::log_sum_exp(log_sum_weight, H0 - h);
202 
203  if (H0 - h > 0)
204  sum_metro_prob += 1;
205  else
206  sum_metro_prob += std::exp(H0 - h);
207 
208  z_propose = this->z_;
209  rho += this->z_.p;
210 
211  return !this->divergent_;
212  }
213  // General recursion
214  Eigen::VectorXd p_sharp_left = this->hamiltonian_.dtau_dp(this->z_);
215 
216  Eigen::VectorXd rho_subtree(rho.size());
217  rho_subtree.setZero();
218 
219  // Build the left subtree
220  double log_sum_weight_left = -std::numeric_limits<double>::infinity();
221 
222  bool valid_left
223  = build_tree(depth - 1, rho_subtree, z_propose,
224  H0, sign, n_leapfrog,
225  log_sum_weight_left, sum_metro_prob,
226  info_writer, error_writer);
227 
228  if (!valid_left) return false;
229 
230  // Build the right subtree
231  ps_point z_propose_right(this->z_);
232  double log_sum_weight_right = -std::numeric_limits<double>::infinity();
233 
234  bool valid_right
235  = build_tree(depth - 1, rho_subtree, z_propose_right,
236  H0, sign, n_leapfrog,
237  log_sum_weight_right, sum_metro_prob,
238  info_writer, error_writer);
239 
240  if (!valid_right) return false;
241 
242  // Multinomial sample from right subtree
243  double log_sum_weight_subtree
244  = math::log_sum_exp(log_sum_weight_left, log_sum_weight_right);
245  log_sum_weight
246  = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree);
247 
248  if (log_sum_weight_right > log_sum_weight_subtree) {
249  z_propose = z_propose_right;
250  } else {
251  double accept_prob
252  = std::exp(log_sum_weight_right - log_sum_weight_subtree);
253  if (this->rand_uniform_() < accept_prob)
254  z_propose = z_propose_right;
255  }
256 
257  rho += rho_subtree;
258  Eigen::VectorXd p_sharp_right = this->hamiltonian_.dtau_dp(this->z_);
259  return compute_criterion(p_sharp_left, p_sharp_right, rho_subtree);
260  }
261 
262  int depth_;
264  double max_deltaH_;
265 
268  double energy_;
269  };
270 
271  } // mcmc
272 } // stan
273 #endif
base_nuts(const Model &model, BaseRNG &rng)
Definition: base_nuts.hpp:24
Hamiltonian< Model, BaseRNG >::PointType z_
Definition: base_hmc.hpp:164
void sample(stan::mcmc::base_mcmc *sampler, int num_warmup, int num_samples, int num_thin, int refresh, bool save, stan::services::sample::mcmc_writer< Model, SampleRecorder, DiagnosticRecorder, MessageRecorder > &mcmc_writer, stan::mcmc::sample &init_s, Model &model, RNG &base_rng, const std::string &prefix, const std::string &suffix, std::ostream &o, StartTransitionCallback &callback, interface_callbacks::writer::base_writer &info_writer, interface_callbacks::writer::base_writer &error_writer)
Definition: sample.hpp:17
void get_sampler_params(std::vector< double > &values)
Definition: base_nuts.hpp:152
Probability, optimization and sampling library.
Point in a generic phase space.
Definition: ps_point.hpp:17
sample transition(sample &init_sample, interface_callbacks::writer::base_writer &info_writer, interface_callbacks::writer::base_writer &error_writer)
Definition: base_nuts.hpp:45
void set_max_delta(double d)
Definition: base_nuts.hpp:37
The No-U-Turn sampler (NUTS) with multinomial sampling.
Definition: base_nuts.hpp:22
double cont_params(int k) const
Definition: sample.hpp:25
base_writer is an abstract base class defining the interface for Stan writer callbacks.
Definition: base_writer.hpp:20
bool compute_criterion(Eigen::VectorXd &p_sharp_minus, Eigen::VectorXd &p_sharp_plus, Eigen::VectorXd &rho)
Definition: base_nuts.hpp:160
void seed(const Eigen::VectorXd &q)
Definition: base_hmc.hpp:53
boost::uniform_01< BaseRNG & > rand_uniform_
Definition: base_hmc.hpp:171
void set_max_depth(int d)
Definition: base_nuts.hpp:32
int build_tree(int depth, Eigen::VectorXd &rho, ps_point &z_propose, double H0, double sign, int &n_leapfrog, double &log_sum_weight, double &sum_metro_prob, interface_callbacks::writer::base_writer &info_writer, interface_callbacks::writer::base_writer &error_writer)
Recursively build a new subtree to completion or until the subtree becomes invalid.
Definition: base_nuts.hpp:183
void get_sampler_param_names(std::vector< std::string > &names)
Definition: base_nuts.hpp:144
Hamiltonian< Model, BaseRNG > hamiltonian_
Definition: base_hmc.hpp:166
Integrator< Hamiltonian< Model, BaseRNG > > integrator_
Definition: base_hmc.hpp:165

     [ Stan Home Page ] © 2011–2016, Stan Development Team.