Skip to content

Commit

Permalink
Clean misc docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Oct 27, 2022
1 parent 60881e2 commit df3081f
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 64 deletions.
46 changes: 32 additions & 14 deletions blackjax/mcmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,18 @@ def kernel(
"""Build an iterative NUTS kernel.
This algorithm is an iteration on the original NUTS algorithm [Hoffman2014]_
with two major differences: - We do not use slice samplig but multinomial
sampling for the proposal [Betancourt2017]_; - The trajectory expansion is
not recursive but iterative [Phan2019]_, [Lao2020]_.
with two major differences:
- We do not use slice samplig but multinomial sampling for the proposal
[Betancourt2017]_;
- The trajectory expansion is not recursive but iterative [Phan2019]_,
[Lao2020]_.
The implementation can seem unusual for those familiar with similar
algorithms. Indeed, we do not conceptualize the trajectory construction as
building a tree. We feel that the tree lingo, inherited from the recursive
version, is unnecessarily complicated and hides the more general concepts
on which the NUTS algorithm is built.
upon which the NUTS algorithm is built.
NUTS, in essence, consists in sampling a trajectory by iteratively choosing
a direction at random and integrating in this direction a number of times
Expand All @@ -85,17 +88,30 @@ def kernel(
Parameters
----------
logprob_fb
Log probability function we wish to sample from.
parameters
A NamedTuple that contains the parameters of the kernel to be built.
integrator
The simplectic integrator used to build trajectories.
divergence_threshold
The absolute difference in energy above which we consider
a transition "divergent".
max_num_doublings
The maximum number of times we expand the trajectory by
doubling the number of steps if the trajectory does not
turn onto itself.
References
----------
.. [Hoffman2014] Hoffman, Matthew D., and Andrew Gelman. "The No-U-Turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo." J. Mach. Learn. Res. 15.1 (2014): 1593-1623.
.. [Betancourt2017] Betancourt, Michael. "A conceptual introduction to Hamiltonian Monte Carlo." arXiv preprint arXiv:1701.02434 (2017).
.. [Phan2019] Phan, Du, Neeraj Pradhan, and Martin Jankowiak. "Composable effects for flexible and accelerated probabilistic programming in NumPyro." arXiv preprint arXiv:1912.11554 (2019).
.. [Lao2020] Lao, Junpeng, et al. "tfp. mcmc: Modern markov chain monte carlo tools built for modern hardware." arXiv preprint arXiv:2002.01184 (2020).
.. [Hoffman2014]: Hoffman, Matthew D., and Andrew Gelman.
"The No-U-Turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo."
J. Mach. Learn. Res. 15.1 (2014): 1593-1623.
.. [Betancourt2017]: Betancourt, Michael.
"A conceptual introduction to Hamiltonian Monte Carlo."
arXiv preprint arXiv:1701.02434 (2017).
.. [Phan2019]: Phan Du, Neeraj Pradhan, and Martin Jankowiak.
"Composable effects for flexible and accelerated probabilistic programming in NumPyro."
arXiv preprint arXiv:1912.11554 (2019).
.. [Lao2020]: Lao, Junpeng, et al.
"tfp. mcmc: Modern markov chain monte carlo tools built for modern hardware."
arXiv preprint arXiv:2002.01184 (2020).
"""

Expand Down Expand Up @@ -161,7 +177,8 @@ def iterative_nuts_proposal(
kinetic_energy
Function that computes the kinetic energy.
uturn_check_fn:
Function that determines whether the trajectory is turning on itself (metric-dependant).
Function that determines whether the trajectory is turning on itself
(metric-dependant).
step_size
Size of the integration step.
max_num_expansions
Expand All @@ -171,7 +188,8 @@ def iterative_nuts_proposal(
Returns
-------
A kernel that generates a new chain state and information about the transition.
A kernel that generates a new chain state and information about the
transition.
"""
(
Expand Down
80 changes: 44 additions & 36 deletions blackjax/mcmc/proposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
class Proposal(NamedTuple):
"""Proposal for the next chain step.
state:
The trajectory state corresponding to this proposal.
The trajectory state that corresponds to this proposal.
energy:
The potential energy corresponding to the state.
The total energy that corresponds to this proposal.
weight:
Weight of the proposal. It is equal to the logarithm of the sum of the canonical
densities of each state :math:`e^{-H(z)}` along the trajectory.
Expand All @@ -39,7 +39,7 @@ def update(initial_energy: float, state: IntegratorState) -> Tuple[Proposal, boo
The trajectory state records information about the position in the state
space and corresponding potential energy. A proposal also carries a
weight that is equal to the difference between the current energy and
the previous one. It thus carries information about the previous state
the previous one. It thus carries information about the previous states
as well as the current state.
Parameters
Expand All @@ -58,6 +58,7 @@ def update(initial_energy: float, state: IntegratorState) -> Tuple[Proposal, boo

# The weight of the new proposal is equal to H0 - H(z_new)
weight = delta_energy

# Acceptance statistic min(e^{H0 - H(z_new)}, 1)
sum_log_p_accept = jnp.minimum(delta_energy, 0.0)

Expand All @@ -80,11 +81,12 @@ def update(initial_energy: float, state: IntegratorState) -> Tuple[Proposal, boo


def static_binomial_sampling(rng_key, proposal, new_proposal):
"""Accept or reject a proposal based on its weight.
"""Accept or reject a proposal.
In the static setting, the `log_weight` of the proposal will be equal to the
difference of energy between the beginning and the end of the trajectory (truncated at 0.). It
is implemented this way to keep a consistent API with progressive sampling.
In the static setting, the probability with which the new proposal is
accepted is a function of the difference in energy between the previous and
the current states. If the current energy is lower than the previous one
then the new proposal is accepted with probability 1.
"""
p_accept = jnp.clip(jnp.exp(new_proposal.weight), a_max=1)
Expand All @@ -98,35 +100,6 @@ def static_binomial_sampling(rng_key, proposal, new_proposal):
)


# --------------------------------------------------------------------
# NON-REVERSIVLE SLICE SAMPLING
# --------------------------------------------------------------------


def nonreversible_slice_sampling(slice, proposal, new_proposal):
"""Slice sampling for non-reversible Metropolis-Hasting update.
Performs a non-reversible update of a uniform [0, 1] value
for Metropolis-Hastings accept/reject decisions [1]_, in addition
to the accept/reject step of a current state and new proposal.
References
----------
.. [1]: Neal, R. M. (2020). Non-reversibly updating a uniform
[0, 1] value for Metropolis accept/reject decisions.
arXiv preprint arXiv:2001.11950.
"""

delta_energy = new_proposal.weight
do_accept = jnp.log(jnp.abs(slice)) <= delta_energy
return jax.lax.cond(
do_accept,
lambda _: (new_proposal, do_accept, slice * jnp.exp(-delta_energy)),
lambda _: (proposal, do_accept, slice),
operand=None,
)


# --------------------------------------------------------------------
# PROGRESSIVE SAMPLING
#
Expand Down Expand Up @@ -170,6 +143,12 @@ def progressive_biased_sampling(rng_key, proposal, new_proposal):
Unlike uniform sampling, biased sampling favors new proposals. It thus
biases the transition away from the trajectory's initial state.
References
----------
.. [1]: Betancourt, Michael.
"A conceptual introduction to Hamiltonian Monte Carlo."
arXiv preprint arXiv:1701.02434 (2017).
"""
p_accept = jnp.clip(jnp.exp(new_proposal.weight - proposal.weight), a_max=1)
do_accept = jax.random.bernoulli(rng_key, p_accept)
Expand All @@ -194,3 +173,32 @@ def progressive_biased_sampling(rng_key, proposal, new_proposal):
),
operand=None,
)


# --------------------------------------------------------------------
# NON-REVERSIVLE SLICE SAMPLING
# --------------------------------------------------------------------


def nonreversible_slice_sampling(slice, proposal, new_proposal):
"""Slice sampling for non-reversible Metropolis-Hasting update.
Performs a non-reversible update of a uniform [0, 1] value
for Metropolis-Hastings accept/reject decisions [1]_, in addition
to the accept/reject step of a current state and new proposal.
References
----------
.. [1]: Neal, R. M. (2020).
"Non-reversibly updating a uniform [0, 1] value for Metropolis accept/reject decisions."
arXiv preprint arXiv:2001.11950.
"""

delta_energy = new_proposal.weight
do_accept = jnp.log(jnp.abs(slice)) <= delta_energy
return jax.lax.cond(
do_accept,
lambda _: (new_proposal, do_accept, slice * jnp.exp(-delta_energy)),
lambda _: (proposal, do_accept, slice),
operand=None,
)
36 changes: 22 additions & 14 deletions blackjax/mcmc/trajectory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Procedures to build trajectories for algorithms in the HMC family.
To propose a new state, algorithms in the HMC family generally proceed by [1]_:
1. Sampling a trajectory starting from the initial point;
2. Sampling a new state from this sampled trajectory.
Expand All @@ -24,7 +25,9 @@
References
----------
.. [1]: Betancourt, Michael. "A conceptual introduction to Hamiltonian Monte Carlo." arXiv preprint arXiv:1701.02434 (2017).
.. [1]: Betancourt, Michael.
"A conceptual introduction to Hamiltonian Monte Carlo."
arXiv preprint arXiv:1701.02434 (2017).
"""
from typing import Callable, NamedTuple, Tuple
Expand Down Expand Up @@ -151,7 +154,8 @@ def dynamic_progressive_integration(
is_criterion_met
Determines whether the termination criterion has been met.
divergence_threshold
Value of the difference of energy between two consecutive states above which we say a transition is divergent.
Value of the difference of energy between two consecutive states above
which we say a transition is divergent.
"""
_, generate_proposal = proposal_generator(kinetic_energy, divergence_threshold)
Expand All @@ -166,8 +170,9 @@ def integrate(
step_size,
initial_energy,
):
"""Integrate the trajectory starting from `initial_state` and update
the proposal sequentially until the termination criterion is met.
"""Integrate the trajectory starting from `initial_state` and update the
proposal sequentially (hence progressive) until the termination
criterion is met (hence dynamic).
Parameters
----------
Expand All @@ -178,15 +183,17 @@ def integrate(
direction int in {-1, 1}
The direction in which to expand the trajectory.
termination_state
The state that keeps track of the information needed for the termination criterion.
The state that keeps track of the information needed for the
termination criterion.
max_num_steps
The maximum number of integration steps. The expansion will stop
when this number is reached if the termination criterion has not
been met.
step_size
The step size of the symplectic integrator.
initial_energy
Initial energy H0 of the HMC step (not to confused with the initial energy of the subtree)
Initial energy H0 of the HMC step (not to confused with the initial
energy of the subtree)
"""

Expand Down Expand Up @@ -512,9 +519,9 @@ def do_keep_expanding(loop_state) -> bool:
def expand_once(loop_state):
"""Expand the current trajectory.
At each step we draw a direction at random, build a subtrajectory starting
from the leftmost or rightmost point of the current trajectory that is
twice as long as the current trajectory.
At each step we draw a direction at random, build a subtrajectory
starting from the leftmost or rightmost point of the current
trajectory that is twice as long as the current trajectory.
Once that is done, possibly update the current proposal with that of
the subtrajectory.
Expand Down Expand Up @@ -554,9 +561,10 @@ def expand_once(loop_state):

# Update the proposal
#
# We do not accept proposals that come from diverging or turning subtrajectories.
# However the definition of the acceptance probability is such that the
# acceptance probability needs to be computed across the entire trajectory.
# We do not accept proposals that come from diverging or turning
# subtrajectories. However the definition of the acceptance
# probability is such that the acceptance probability needs to be
# computed across the entire trajectory.
def update_sum_log_p_accept(inputs):
_, proposal, new_proposal = inputs
return Proposal(
Expand All @@ -577,8 +585,8 @@ def update_sum_log_p_accept(inputs):

# Is the full trajectory making a U-Turn?
#
# We first merge the subtrajectory that was just generated with the trajectory
# and check the U-Turn criterior on the whole trajectory.
# We first merge the subtrajectory that was just generated with the
# trajectory and check the U-Turn criterior on the whole trajectory.
left_trajectory, right_trajectory = reorder_trajectories(
direction, trajectory, new_trajectory
)
Expand Down

0 comments on commit df3081f

Please sign in to comment.