Stan  2.13.1
probability, sampling & optimization
base_nuts.hpp
Go to the documentation of this file.
1 #ifndef STAN_MCMC_HMC_NUTS_BASE_NUTS_HPP
2 #define STAN_MCMC_HMC_NUTS_BASE_NUTS_HPP
3 
5 #include <boost/math/special_functions/fpclassify.hpp>
6 #include <stan/math/prim/scal.hpp>
9 #include <algorithm>
10 #include <cmath>
11 #include <limits>
12 #include <string>
13 #include <vector>
14 
15 namespace stan {
16  namespace mcmc {
20  template <class Model, template<class, class> class Hamiltonian,
21  template<class> class Integrator, class BaseRNG>
22  class base_nuts : public base_hmc<Model, Hamiltonian, Integrator, BaseRNG> {
23  public:
24  base_nuts(const Model& model, BaseRNG& rng)
25  : base_hmc<Model, Hamiltonian, Integrator, BaseRNG>(model, rng),
26  depth_(0), max_depth_(5), max_deltaH_(1000),
27  n_leapfrog_(0), divergent_(0), energy_(0) {
28  }
29 
31 
32  void set_max_depth(int d) {
33  if (d > 0)
34  max_depth_ = d;
35  }
36 
37  void set_max_delta(double d) {
38  max_deltaH_ = d;
39  }
40 
41  int get_max_depth() { return this->max_depth_; }
42  double get_max_delta() { return this->max_deltaH_; }
43 
44  sample
45  transition(sample& init_sample,
48  // Initialize the algorithm
49  this->sample_stepsize();
50 
51  this->seed(init_sample.cont_params());
52 
53  this->hamiltonian_.sample_p(this->z_, this->rand_int_);
54  this->hamiltonian_.init(this->z_, info_writer, error_writer);
55 
56  ps_point z_plus(this->z_);
57  ps_point z_minus(z_plus);
58 
59  ps_point z_sample(z_plus);
60  ps_point z_propose(z_plus);
61 
62  Eigen::VectorXd p_sharp_plus = this->hamiltonian_.dtau_dp(this->z_);
63  Eigen::VectorXd p_sharp_minus = this->hamiltonian_.dtau_dp(this->z_);
64  Eigen::VectorXd rho = this->z_.p;
65  double log_sum_weight = 0; // log(exp(H0 - H0))
66 
67  double H0 = this->hamiltonian_.H(this->z_);
68  int n_leapfrog = 0;
69  double sum_metro_prob = 1; // exp(H0 - H0)
70 
71  // Build a trajectory until the NUTS criterion is no longer satisfied
72  this->depth_ = 0;
73  this->divergent_ = 0;
74 
75  while (this->depth_ < this->max_depth_) {
76  // Build a new subtree in a random direction
77  Eigen::VectorXd rho_subtree(rho.size());
78  rho_subtree.setZero();
79 
80  bool valid_subtree = false;
81  double log_sum_weight_subtree
82  = -std::numeric_limits<double>::infinity();
83 
84  if (this->rand_uniform_() > 0.5) {
85  this->z_.ps_point::operator=(z_plus);
86  valid_subtree
87  = build_tree(this->depth_, rho_subtree, z_propose,
88  H0, 1, n_leapfrog,
89  log_sum_weight_subtree, sum_metro_prob,
90  info_writer, error_writer);
91  z_plus.ps_point::operator=(this->z_);
92  p_sharp_plus = this->hamiltonian_.dtau_dp(this->z_);
93  } else {
94  this->z_.ps_point::operator=(z_minus);
95  valid_subtree
96  = build_tree(this->depth_, rho_subtree, z_propose,
97  H0, -1, n_leapfrog,
98  log_sum_weight_subtree, sum_metro_prob,
99  info_writer, error_writer);
100  z_minus.ps_point::operator=(this->z_);
101  p_sharp_minus = this->hamiltonian_.dtau_dp(this->z_);
102  }
103 
104  if (!valid_subtree) break;
105 
106  // Sample from an accepted subtree
107  ++(this->depth_);
108 
109  if (log_sum_weight_subtree > log_sum_weight) {
110  z_sample = z_propose;
111  } else {
112  double accept_prob
113  = std::exp(log_sum_weight_subtree - log_sum_weight);
114  if (this->rand_uniform_() < accept_prob)
115  z_sample = z_propose;
116  }
117 
118  log_sum_weight
119  = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree);
120 
121  // Break when NUTS criterion is not longer satisfied
122  rho += rho_subtree;
123  if (!compute_criterion(p_sharp_minus, p_sharp_plus, rho))
124  break;
125  }
126 
127  this->n_leapfrog_ = n_leapfrog;
128 
129  // Compute average acceptance probabilty across entire trajectory,
130  // even over subtrees that may have been rejected
131  double accept_prob
132  = sum_metro_prob / static_cast<double>(n_leapfrog + 1);
133 
134  this->z_.ps_point::operator=(z_sample);
135  this->energy_ = this->hamiltonian_.H(this->z_);
136  return sample(this->z_.q, -this->z_.V, accept_prob);
137  }
138 
139  void get_sampler_param_names(std::vector<std::string>& names) {
140  names.push_back("stepsize__");
141  names.push_back("treedepth__");
142  names.push_back("n_leapfrog__");
143  names.push_back("divergent__");
144  names.push_back("energy__");
145  }
146 
147  void get_sampler_params(std::vector<double>& values) {
148  values.push_back(this->epsilon_);
149  values.push_back(this->depth_);
150  values.push_back(this->n_leapfrog_);
151  values.push_back(this->divergent_);
152  values.push_back(this->energy_);
153  }
154 
155  bool compute_criterion(Eigen::VectorXd& p_sharp_minus,
156  Eigen::VectorXd& p_sharp_plus,
157  Eigen::VectorXd& rho) {
158  return p_sharp_plus.dot(rho) > 0
159  && p_sharp_minus.dot(rho) > 0;
160  }
161 
178  int build_tree(int depth, Eigen::VectorXd& rho, ps_point& z_propose,
179  double H0, double sign, int& n_leapfrog,
180  double& log_sum_weight, double& sum_metro_prob,
183  // Base case
184  if (depth == 0) {
185  this->integrator_.evolve(this->z_, this->hamiltonian_,
186  sign * this->epsilon_,
187  info_writer, error_writer);
188  ++n_leapfrog;
189 
190  double h = this->hamiltonian_.H(this->z_);
191  if (boost::math::isnan(h))
192  h = std::numeric_limits<double>::infinity();
193 
194  if ((h - H0) > this->max_deltaH_) this->divergent_ = true;
195 
196  log_sum_weight = math::log_sum_exp(log_sum_weight, H0 - h);
197 
198  if (H0 - h > 0)
199  sum_metro_prob += 1;
200  else
201  sum_metro_prob += std::exp(H0 - h);
202 
203  z_propose = this->z_;
204  rho += this->z_.p;
205 
206  return !this->divergent_;
207  }
208  // General recursion
209  Eigen::VectorXd p_sharp_left = this->hamiltonian_.dtau_dp(this->z_);
210 
211  Eigen::VectorXd rho_subtree(rho.size());
212  rho_subtree.setZero();
213 
214  // Build the left subtree
215  double log_sum_weight_left = -std::numeric_limits<double>::infinity();
216 
217  bool valid_left
218  = build_tree(depth - 1, rho_subtree, z_propose,
219  H0, sign, n_leapfrog,
220  log_sum_weight_left, sum_metro_prob,
221  info_writer, error_writer);
222 
223  if (!valid_left) return false;
224 
225  // Build the right subtree
226  ps_point z_propose_right(this->z_);
227  double log_sum_weight_right = -std::numeric_limits<double>::infinity();
228 
229  bool valid_right
230  = build_tree(depth - 1, rho_subtree, z_propose_right,
231  H0, sign, n_leapfrog,
232  log_sum_weight_right, sum_metro_prob,
233  info_writer, error_writer);
234 
235  if (!valid_right) return false;
236 
237  // Multinomial sample from right subtree
238  double log_sum_weight_subtree
239  = math::log_sum_exp(log_sum_weight_left, log_sum_weight_right);
240  log_sum_weight
241  = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree);
242 
243  if (log_sum_weight_right > log_sum_weight_subtree) {
244  z_propose = z_propose_right;
245  } else {
246  double accept_prob
247  = std::exp(log_sum_weight_right - log_sum_weight_subtree);
248  if (this->rand_uniform_() < accept_prob)
249  z_propose = z_propose_right;
250  }
251 
252  rho += rho_subtree;
253  Eigen::VectorXd p_sharp_right = this->hamiltonian_.dtau_dp(this->z_);
254  return compute_criterion(p_sharp_left, p_sharp_right, rho_subtree);
255  }
256 
257  int depth_;
259  double max_deltaH_;
260 
263  double energy_;
264  };
265 
266  } // mcmc
267 } // stan
268 #endif
base_nuts(const Model &model, BaseRNG &rng)
Definition: base_nuts.hpp:24
Hamiltonian< Model, BaseRNG >::PointType z_
Definition: base_hmc.hpp:164
void sample(stan::mcmc::base_mcmc *sampler, int num_warmup, int num_samples, int num_thin, int refresh, bool save, stan::services::sample::mcmc_writer< Model, SampleRecorder, DiagnosticRecorder, MessageRecorder > &mcmc_writer, stan::mcmc::sample &init_s, Model &model, RNG &base_rng, const std::string &prefix, const std::string &suffix, std::ostream &o, StartTransitionCallback &callback, interface_callbacks::writer::base_writer &info_writer, interface_callbacks::writer::base_writer &error_writer)
Definition: sample.hpp:17
void get_sampler_params(std::vector< double > &values)
Definition: base_nuts.hpp:147
Probability, optimization and sampling library.
Point in a generic phase space.
Definition: ps_point.hpp:17
sample transition(sample &init_sample, interface_callbacks::writer::base_writer &info_writer, interface_callbacks::writer::base_writer &error_writer)
Definition: base_nuts.hpp:45
void set_max_delta(double d)
Definition: base_nuts.hpp:37
The No-U-Turn sampler (NUTS) with multinomial sampling.
Definition: base_nuts.hpp:22
base_writer is an abstract base class defining the interface for Stan writer callbacks.
Definition: base_writer.hpp:20
bool compute_criterion(Eigen::VectorXd &p_sharp_minus, Eigen::VectorXd &p_sharp_plus, Eigen::VectorXd &rho)
Definition: base_nuts.hpp:155
void seed(const Eigen::VectorXd &q)
Definition: base_hmc.hpp:53
boost::uniform_01< BaseRNG & > rand_uniform_
Definition: base_hmc.hpp:171
void set_max_depth(int d)
Definition: base_nuts.hpp:32
int build_tree(int depth, Eigen::VectorXd &rho, ps_point &z_propose, double H0, double sign, int &n_leapfrog, double &log_sum_weight, double &sum_metro_prob, interface_callbacks::writer::base_writer &info_writer, interface_callbacks::writer::base_writer &error_writer)
Recursively build a new subtree to completion or until the subtree becomes invalid.
Definition: base_nuts.hpp:178
void get_sampler_param_names(std::vector< std::string > &names)
Definition: base_nuts.hpp:139
double cont_params(int k) const
Definition: sample.hpp:24
Hamiltonian< Model, BaseRNG > hamiltonian_
Definition: base_hmc.hpp:166
Integrator< Hamiltonian< Model, BaseRNG > > integrator_
Definition: base_hmc.hpp:165

     [ Stan Home Page ] © 2011–2016, Stan Development Team.