Stan  2.10.0
probability, sampling & optimization
base_nuts_classic.hpp
Go to the documentation of this file.
1 #ifndef STAN_MCMC_HMC_NUTS_BASE_NUTS_HPP
2 #define STAN_MCMC_HMC_NUTS_BASE_NUTS_HPP
3 
5 #include <boost/math/special_functions/fpclassify.hpp>
8 #include <algorithm>
9 #include <cmath>
10 #include <limits>
11 #include <string>
12 #include <vector>
13 
14 namespace stan {
15  namespace mcmc {
16 
17  struct nuts_util {
18  // Constants through each recursion
19  double log_u;
20  double H0;
21  int sign;
22 
23  // Aggregators through each recursion
24  int n_tree;
25  double sum_prob;
26  bool criterion;
27  };
28 
29  // The No-U-Turn Sampler (NUTS) with the
30  // original slice sampler implementation
31  template <class Model, template<class, class> class Hamiltonian,
32  template<class> class Integrator, class BaseRNG>
34  public base_hmc<Model, Hamiltonian, Integrator, BaseRNG> {
35  public:
36  base_nuts_classic(const Model& model, BaseRNG& rng):
37  base_hmc<Model, Hamiltonian, Integrator, BaseRNG>(model, rng),
38  depth_(0), max_depth_(5), max_delta_(1000),
39  n_leapfrog_(0), divergent_(0), energy_(0) {
40  }
41 
43 
44  void set_max_depth(int d) {
45  if (d > 0)
46  max_depth_ = d;
47  }
48 
49  void set_max_delta(double d) {
50  max_delta_ = d;
51  }
52 
53  int get_max_depth() { return this->max_depth_; }
54  double get_max_delta() { return this->max_delta_; }
55 
56  sample
57  transition(sample& init_sample,
60  // Initialize the algorithm
61  this->sample_stepsize();
62 
63  nuts_util util;
64 
65  this->seed(init_sample.cont_params());
66 
67  this->hamiltonian_.sample_p(this->z_, this->rand_int_);
68  this->hamiltonian_.init(this->z_, info_writer, error_writer);
69 
70  ps_point z_plus(this->z_);
71  ps_point z_minus(z_plus);
72 
73  ps_point z_sample(z_plus);
74  ps_point z_propose(z_plus);
75 
76  int n_cont = init_sample.cont_params().size();
77 
78  Eigen::VectorXd rho_init = this->z_.p;
79  Eigen::VectorXd rho_plus(n_cont); rho_plus.setZero();
80  Eigen::VectorXd rho_minus(n_cont); rho_minus.setZero();
81 
82  util.H0 = this->hamiltonian_.H(this->z_);
83 
84  // Sample the slice variable
85  util.log_u = std::log(this->rand_uniform_());
86 
87  // Build a balanced binary tree until the NUTS criterion fails
88  util.criterion = true;
89  int n_valid = 0;
90 
91  this->depth_ = 0;
92  this->divergent_ = 0;
93 
94  util.n_tree = 0;
95  util.sum_prob = 0;
96 
97  while (util.criterion && (this->depth_ <= this->max_depth_)) {
98  // Randomly sample a direction in time
99  ps_point* z = 0;
100  Eigen::VectorXd* rho = 0;
101 
102  if (this->rand_uniform_() > 0.5) {
103  z = &z_plus;
104  rho = &rho_plus;
105  util.sign = 1;
106  } else {
107  z = &z_minus;
108  rho = &rho_minus;
109  util.sign = -1;
110  }
111 
112  // And build a new subtree in that direction
113  this->z_.ps_point::operator=(*z);
114 
115  int n_valid_subtree = build_tree(depth_, *rho, 0, z_propose, util,
116  info_writer, error_writer);
117  ++(this->depth_);
118 
119  *z = this->z_;
120 
121  // Metropolis-Hastings sample the fresh subtree
122  if (!util.criterion)
123  break;
124 
125  double subtree_prob = 0;
126 
127  if (n_valid) {
128  subtree_prob = static_cast<double>(n_valid_subtree) /
129  static_cast<double>(n_valid);
130  } else {
131  subtree_prob = n_valid_subtree ? 1 : 0;
132  }
133 
134  if (this->rand_uniform_() < subtree_prob)
135  z_sample = z_propose;
136 
137  n_valid += n_valid_subtree;
138 
139  // Check validity of completed tree
140  this->z_.ps_point::operator=(z_plus);
141  Eigen::VectorXd delta_rho = rho_minus + rho_init + rho_plus;
142 
143  util.criterion = compute_criterion(z_minus, this->z_, delta_rho);
144  }
145 
146  this->n_leapfrog_ = util.n_tree;
147 
148  double accept_prob = util.sum_prob / static_cast<double>(util.n_tree);
149 
150  this->z_.ps_point::operator=(z_sample);
151  this->energy_ = this->hamiltonian_.H(this->z_);
152  return sample(this->z_.q, - this->z_.V, accept_prob);
153  }
154 
155  void get_sampler_param_names(std::vector<std::string>& names) {
156  names.push_back("stepsize__");
157  names.push_back("treedepth__");
158  names.push_back("n_leapfrog__");
159  names.push_back("divergent__");
160  names.push_back("energy__");
161  }
162 
163  void get_sampler_params(std::vector<double>& values) {
164  values.push_back(this->epsilon_);
165  values.push_back(this->depth_);
166  values.push_back(this->n_leapfrog_);
167  values.push_back(this->divergent_);
168  values.push_back(this->energy_);
169  }
170 
171  virtual bool compute_criterion(ps_point& start,
172  typename Hamiltonian<Model, BaseRNG>
173  ::PointType& finish,
174  Eigen::VectorXd& rho) = 0;
175 
176  // Returns number of valid points in the completed subtree
177  int build_tree(int depth, Eigen::VectorXd& rho,
178  ps_point* z_init_parent, ps_point& z_propose,
179  nuts_util& util,
182  // Base case
183  if (depth == 0) {
184  this->integrator_.evolve(this->z_, this->hamiltonian_,
185  util.sign * this->epsilon_,
186  info_writer, error_writer);
187  rho += this->z_.p;
188 
189  if (z_init_parent) *z_init_parent = this->z_;
190  z_propose = this->z_;
191 
192  double h = this->hamiltonian_.H(this->z_);
193  if (boost::math::isnan(h))
194  h = std::numeric_limits<double>::infinity();
195 
196  util.criterion = util.log_u + (h - util.H0) < this->max_delta_;
197  if (!util.criterion) ++(this->divergent_);
198 
199  util.sum_prob += std::min(1.0, std::exp(util.H0 - h));
200  util.n_tree += 1;
201 
202  return (util.log_u + (h - util.H0) < 0);
203 
204  } else {
205  // General recursion
206  Eigen::VectorXd left_subtree_rho(rho.size());
207  left_subtree_rho.setZero();
208  ps_point z_init(this->z_);
209 
210  int n1 = build_tree(depth - 1, left_subtree_rho, &z_init,
211  z_propose, util,
212  info_writer, error_writer);
213 
214  if (z_init_parent) *z_init_parent = z_init;
215 
216  if (!util.criterion) return 0;
217 
218  Eigen::VectorXd right_subtree_rho(rho.size());
219  right_subtree_rho.setZero();
220  ps_point z_propose_right(z_init);
221 
222  int n2 = build_tree(depth - 1, right_subtree_rho, 0,
223  z_propose_right, util,
224  info_writer, error_writer);
225 
226  double accept_prob = static_cast<double>(n2) /
227  static_cast<double>(n1 + n2);
228 
229  if ( util.criterion && (this->rand_uniform_() < accept_prob) )
230  z_propose = z_propose_right;
231 
232  Eigen::VectorXd& subtree_rho = left_subtree_rho;
233  subtree_rho += right_subtree_rho;
234 
235  rho += subtree_rho;
236 
237  util.criterion &= compute_criterion(z_init, this->z_, subtree_rho);
238 
239  return n1 + n2;
240  }
241  }
242 
243  int depth_;
245  double max_delta_;
246 
249  double energy_;
250  };
251 
252  } // mcmc
253 } // stan
254 #endif
Hamiltonian< Model, BaseRNG >::PointType z_
Definition: base_hmc.hpp:164
void sample(stan::mcmc::base_mcmc *sampler, int num_warmup, int num_samples, int num_thin, int refresh, bool save, stan::services::sample::mcmc_writer< Model, SampleRecorder, DiagnosticRecorder, MessageRecorder > &mcmc_writer, stan::mcmc::sample &init_s, Model &model, RNG &base_rng, const std::string &prefix, const std::string &suffix, std::ostream &o, StartTransitionCallback &callback, interface_callbacks::writer::base_writer &info_writer, interface_callbacks::writer::base_writer &error_writer)
Definition: sample.hpp:17
Probability, optimization and sampling library.
virtual bool compute_criterion(ps_point &start, typename Hamiltonian< Model, BaseRNG >::PointType &finish, Eigen::VectorXd &rho)=0
Point in a generic phase space.
Definition: ps_point.hpp:17
int build_tree(int depth, Eigen::VectorXd &rho, ps_point *z_init_parent, ps_point &z_propose, nuts_util &util, interface_callbacks::writer::base_writer &info_writer, interface_callbacks::writer::base_writer &error_writer)
double cont_params(int k) const
Definition: sample.hpp:25
base_writer is an abstract base class defining the interface for Stan writer callbacks.
Definition: base_writer.hpp:20
base_nuts_classic(const Model &model, BaseRNG &rng)
void seed(const Eigen::VectorXd &q)
Definition: base_hmc.hpp:53
boost::uniform_01< BaseRNG & > rand_uniform_
Definition: base_hmc.hpp:171
Hamiltonian< Model, BaseRNG >::PointType & z()
Definition: base_hmc.hpp:130
sample transition(sample &init_sample, interface_callbacks::writer::base_writer &info_writer, interface_callbacks::writer::base_writer &error_writer)
void get_sampler_param_names(std::vector< std::string > &names)
Hamiltonian< Model, BaseRNG > hamiltonian_
Definition: base_hmc.hpp:166
Integrator< Hamiltonian< Model, BaseRNG > > integrator_
Definition: base_hmc.hpp:165
void get_sampler_params(std::vector< double > &values)

     [ Stan Home Page ] © 2011–2016, Stan Development Team.