1 #ifndef STAN_MCMC_HMC_NUTS_BASE_NUTS_HPP
2 #define STAN_MCMC_HMC_NUTS_BASE_NUTS_HPP
5 #include <boost/math/special_functions/fpclassify.hpp>
6 #include <stan/math/prim/scal/fun/log_sum_exp.hpp>
20 template <
class Model,
template<
class,
class>
class Hamiltonian,
21 template<
class>
class Integrator,
class BaseRNG>
25 :
base_hmc<Model, Hamiltonian, Integrator, BaseRNG>(model, rng),
62 Eigen::VectorXd p_sharp_plus = this->
hamiltonian_.dtau_dp(this->
z_);
63 Eigen::VectorXd p_sharp_minus = this->
hamiltonian_.dtau_dp(this->
z_);
64 Eigen::VectorXd rho = this->
z_.p;
65 double log_sum_weight = 0;
69 double sum_metro_prob = 1;
77 Eigen::VectorXd rho_subtree(rho.size());
78 rho_subtree.setZero();
80 bool valid_subtree =
false;
81 double log_sum_weight_subtree
82 = -std::numeric_limits<double>::infinity();
85 this->
z_.ps_point::operator=(z_plus);
89 log_sum_weight_subtree, sum_metro_prob,
90 info_writer, error_writer);
91 z_plus.ps_point::operator=(this->
z_);
94 this->
z_.ps_point::operator=(z_minus);
98 log_sum_weight_subtree, sum_metro_prob,
99 info_writer, error_writer);
100 z_minus.ps_point::operator=(this->
z_);
104 if (!valid_subtree)
break;
110 = std::exp(log_sum_weight_subtree - log_sum_weight);
112 z_sample = z_propose;
114 if (log_sum_weight_subtree > log_sum_weight) {
115 z_sample = z_propose;
118 = std::exp(log_sum_weight_subtree - log_sum_weight);
120 z_sample = z_propose;
124 = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree);
137 = sum_metro_prob /
static_cast<double>(n_leapfrog + 1);
139 this->
z_.ps_point::operator=(z_sample);
141 return sample(this->
z_.q, -this->z_.V, accept_prob);
145 names.push_back(
"stepsize__");
146 names.push_back(
"treedepth__");
147 names.push_back(
"n_leapfrog__");
148 names.push_back(
"divergent__");
149 names.push_back(
"energy__");
154 values.push_back(this->
depth_);
157 values.push_back(this->
energy_);
161 Eigen::VectorXd& p_sharp_plus,
162 Eigen::VectorXd& rho) {
163 return p_sharp_plus.dot(rho) > 0
164 && p_sharp_minus.dot(rho) > 0;
184 double H0,
double sign,
int& n_leapfrog,
185 double& log_sum_weight,
double& sum_metro_prob,
192 info_writer, error_writer);
196 if (boost::math::isnan(h))
197 h = std::numeric_limits<double>::infinity();
201 log_sum_weight = math::log_sum_exp(log_sum_weight, H0 - h);
206 sum_metro_prob += std::exp(H0 - h);
208 z_propose = this->
z_;
214 Eigen::VectorXd p_sharp_left = this->
hamiltonian_.dtau_dp(this->
z_);
216 Eigen::VectorXd rho_subtree(rho.size());
217 rho_subtree.setZero();
220 double log_sum_weight_left = -std::numeric_limits<double>::infinity();
223 =
build_tree(depth - 1, rho_subtree, z_propose,
224 H0, sign, n_leapfrog,
225 log_sum_weight_left, sum_metro_prob,
226 info_writer, error_writer);
228 if (!valid_left)
return false;
232 double log_sum_weight_right = -std::numeric_limits<double>::infinity();
235 =
build_tree(depth - 1, rho_subtree, z_propose_right,
236 H0, sign, n_leapfrog,
237 log_sum_weight_right, sum_metro_prob,
238 info_writer, error_writer);
240 if (!valid_right)
return false;
243 double log_sum_weight_subtree
244 = math::log_sum_exp(log_sum_weight_left, log_sum_weight_right);
246 = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree);
248 if (log_sum_weight_right > log_sum_weight_subtree) {
249 z_propose = z_propose_right;
252 = std::exp(log_sum_weight_right - log_sum_weight_subtree);
254 z_propose = z_propose_right;
258 Eigen::VectorXd p_sharp_right = this->
hamiltonian_.dtau_dp(this->
z_);
base_nuts(const Model &model, BaseRNG &rng)
Hamiltonian< Model, BaseRNG >::PointType z_
void sample(stan::mcmc::base_mcmc *sampler, int num_warmup, int num_samples, int num_thin, int refresh, bool save, stan::services::sample::mcmc_writer< Model, SampleRecorder, DiagnosticRecorder, MessageRecorder > &mcmc_writer, stan::mcmc::sample &init_s, Model &model, RNG &base_rng, const std::string &prefix, const std::string &suffix, std::ostream &o, StartTransitionCallback &callback, interface_callbacks::writer::base_writer &info_writer, interface_callbacks::writer::base_writer &error_writer)
void get_sampler_params(std::vector< double > &values)
Probability, optimization and sampling library.
Point in a generic phase space.
sample transition(sample &init_sample, interface_callbacks::writer::base_writer &info_writer, interface_callbacks::writer::base_writer &error_writer)
void set_max_delta(double d)
The No-U-Turn sampler (NUTS) with multinomial sampling.
double cont_params(int k) const
base_writer is an abstract base class defining the interface for Stan writer callbacks.
bool compute_criterion(Eigen::VectorXd &p_sharp_minus, Eigen::VectorXd &p_sharp_plus, Eigen::VectorXd &rho)
void seed(const Eigen::VectorXd &q)
boost::uniform_01< BaseRNG & > rand_uniform_
void set_max_depth(int d)
int build_tree(int depth, Eigen::VectorXd &rho, ps_point &z_propose, double H0, double sign, int &n_leapfrog, double &log_sum_weight, double &sum_metro_prob, interface_callbacks::writer::base_writer &info_writer, interface_callbacks::writer::base_writer &error_writer)
Recursively build a new subtree to completion or until the subtree becomes invalid.
void get_sampler_param_names(std::vector< std::string > &names)
Hamiltonian< Model, BaseRNG > hamiltonian_
Integrator< Hamiltonian< Model, BaseRNG > > integrator_