Stan  2.10.0
probability, sampling & optimization
impl_leapfrog.hpp
Go to the documentation of this file.
1 #ifndef STAN_MCMC_HMC_INTEGRATORS_IMPL_LEAPFROG_HPP
2 #define STAN_MCMC_HMC_INTEGRATORS_IMPL_LEAPFROG_HPP
3 
4 #include <stan/math/prim/mat/fun/Eigen.hpp>
6 
7 namespace stan {
8  namespace mcmc {
9 
10  template <typename Hamiltonian>
11  class impl_leapfrog: public base_leapfrog<Hamiltonian> {
12  public:
13  impl_leapfrog(): base_leapfrog<Hamiltonian>(),
14  max_num_fixed_point_(10),
15  fixed_point_threshold_(1e-8) {}
16 
18  typename Hamiltonian::PointType& z,
19  Hamiltonian& hamiltonian,
20  double epsilon,
23  hat_phi(z, hamiltonian, epsilon, info_writer, error_writer);
24  hat_tau(z, hamiltonian, epsilon, this->max_num_fixed_point_,
25  info_writer, error_writer);
26  }
27 
28  void update_q(typename Hamiltonian::PointType& z,
29  Hamiltonian& hamiltonian,
30  double epsilon,
33  // hat{T} = dT/dp * d/dq
34  Eigen::VectorXd q_init = z.q + 0.5 * epsilon * hamiltonian.dtau_dp(z);
35  Eigen::VectorXd delta_q(z.q.size());
36 
37  for (int n = 0; n < this->max_num_fixed_point_; ++n) {
38  delta_q = z.q;
39  z.q.noalias() = q_init + 0.5 * epsilon * hamiltonian.dtau_dp(z);
40  hamiltonian.update_metric(z, info_writer, error_writer);
41 
42  delta_q -= z.q;
43  if (delta_q.cwiseAbs().maxCoeff() < this->fixed_point_threshold_)
44  break;
45  }
46  hamiltonian.update_gradients(z, info_writer, error_writer);
47  }
48 
50  typename Hamiltonian::PointType& z,
51  Hamiltonian& hamiltonian,
52  double epsilon,
55  hat_tau(z, hamiltonian, epsilon, 1, info_writer, error_writer);
56  hat_phi(z, hamiltonian, epsilon, info_writer, error_writer);
57  }
58 
59  // hat{phi} = dphi/dq * d/dp
60  void hat_phi(typename Hamiltonian::PointType& z,
61  Hamiltonian& hamiltonian,
62  double epsilon,
65  z.p -= epsilon * hamiltonian.dphi_dq(z, info_writer, error_writer);
66  }
67 
68  // hat{tau} = dtau/dq * d/dp
69  void hat_tau(typename Hamiltonian::PointType& z,
70  Hamiltonian& hamiltonian,
71  double epsilon,
72  int num_fixed_point,
75  Eigen::VectorXd p_init = z.p;
76  Eigen::VectorXd delta_p(z.p.size());
77 
78  for (int n = 0; n < num_fixed_point; ++n) {
79  delta_p = z.p;
80  z.p.noalias() = p_init
81  - epsilon
82  * hamiltonian.dtau_dq(z, info_writer, error_writer);
83  delta_p -= z.p;
84  if (delta_p.cwiseAbs().maxCoeff() < this->fixed_point_threshold_)
85  break;
86  }
87  }
88 
90  return this->max_num_fixed_point_;
91  }
92 
94  if (n > 0) this->max_num_fixed_point_ = n;
95  }
96 
98  return this->fixed_point_threshold_;
99  }
100 
101  void set_fixed_point_threshold(double t) {
102  if (t > 0) this->fixed_point_threshold_ = t;
103  }
104 
105  private:
106  int max_num_fixed_point_;
107  double fixed_point_threshold_;
108  };
109 
110  } // mcmc
111 } // stan
112 
113 #endif
void set_fixed_point_threshold(double t)
Probability, optimization and sampling library.
void end_update_p(typename Hamiltonian::PointType &z, Hamiltonian &hamiltonian, double epsilon, interface_callbacks::writer::base_writer &info_writer, interface_callbacks::writer::base_writer &error_writer)
void update_q(typename Hamiltonian::PointType &z, Hamiltonian &hamiltonian, double epsilon, interface_callbacks::writer::base_writer &info_writer, interface_callbacks::writer::base_writer &error_writer)
void hat_tau(typename Hamiltonian::PointType &z, Hamiltonian &hamiltonian, double epsilon, int num_fixed_point, interface_callbacks::writer::base_writer &info_writer, interface_callbacks::writer::base_writer &error_writer)
base_writer is an abstract base class defining the interface for Stan writer callbacks.
Definition: base_writer.hpp:20
void begin_update_p(typename Hamiltonian::PointType &z, Hamiltonian &hamiltonian, double epsilon, interface_callbacks::writer::base_writer &info_writer, interface_callbacks::writer::base_writer &error_writer)
void hat_phi(typename Hamiltonian::PointType &z, Hamiltonian &hamiltonian, double epsilon, interface_callbacks::writer::base_writer &info_writer, interface_callbacks::writer::base_writer &error_writer)
void set_max_num_fixed_point(int n)

     [ Stan Home Page ] © 2011–2016, Stan Development Team.