diff --git a/blackjax/mcmc/nuts.py b/blackjax/mcmc/nuts.py index 9de9ed080..70c726da8 100644 --- a/blackjax/mcmc/nuts.py +++ b/blackjax/mcmc/nuts.py @@ -67,15 +67,18 @@ def kernel( """Build an iterative NUTS kernel. This algorithm is an iteration on the original NUTS algorithm [Hoffman2014]_ - with two major differences: - We do not use slice samplig but multinomial - sampling for the proposal [Betancourt2017]_; - The trajectory expansion is - not recursive but iterative [Phan2019]_, [Lao2020]_. + with two major differences: + + - We do not use slice samplig but multinomial sampling for the proposal + [Betancourt2017]_; + - The trajectory expansion is not recursive but iterative [Phan2019]_, + [Lao2020]_. The implementation can seem unusual for those familiar with similar algorithms. Indeed, we do not conceptualize the trajectory construction as building a tree. We feel that the tree lingo, inherited from the recursive version, is unnecessarily complicated and hides the more general concepts - on which the NUTS algorithm is built. + upon which the NUTS algorithm is built. NUTS, in essence, consists in sampling a trajectory by iteratively choosing a direction at random and integrating in this direction a number of times @@ -85,17 +88,30 @@ def kernel( Parameters ---------- - logprob_fb - Log probability function we wish to sample from. - parameters - A NamedTuple that contains the parameters of the kernel to be built. + integrator + The simplectic integrator used to build trajectories. + divergence_threshold + The absolute difference in energy above which we consider + a transition "divergent". + max_num_doublings + The maximum number of times we expand the trajectory by + doubling the number of steps if the trajectory does not + turn onto itself. References ---------- - .. [Hoffman2014] Hoffman, Matthew D., and Andrew Gelman. "The No-U-Turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo." J. Mach. Learn. Res. 15.1 (2014): 1593-1623. - .. [Betancourt2017] Betancourt, Michael. "A conceptual introduction to Hamiltonian Monte Carlo." arXiv preprint arXiv:1701.02434 (2017). - .. [Phan2019] Phan, Du, Neeraj Pradhan, and Martin Jankowiak. "Composable effects for flexible and accelerated probabilistic programming in NumPyro." arXiv preprint arXiv:1912.11554 (2019). - .. [Lao2020] Lao, Junpeng, et al. "tfp. mcmc: Modern markov chain monte carlo tools built for modern hardware." arXiv preprint arXiv:2002.01184 (2020). + .. [Hoffman2014]: Hoffman, Matthew D., and Andrew Gelman. + "The No-U-Turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo." + J. Mach. Learn. Res. 15.1 (2014): 1593-1623. + .. [Betancourt2017]: Betancourt, Michael. + "A conceptual introduction to Hamiltonian Monte Carlo." + arXiv preprint arXiv:1701.02434 (2017). + .. [Phan2019]: Phan Du, Neeraj Pradhan, and Martin Jankowiak. + "Composable effects for flexible and accelerated probabilistic programming in NumPyro." + arXiv preprint arXiv:1912.11554 (2019). + .. [Lao2020]: Lao, Junpeng, et al. + "tfp. mcmc: Modern markov chain monte carlo tools built for modern hardware." + arXiv preprint arXiv:2002.01184 (2020). """ @@ -161,7 +177,8 @@ def iterative_nuts_proposal( kinetic_energy Function that computes the kinetic energy. uturn_check_fn: - Function that determines whether the trajectory is turning on itself (metric-dependant). + Function that determines whether the trajectory is turning on itself + (metric-dependant). step_size Size of the integration step. max_num_expansions @@ -171,7 +188,8 @@ def iterative_nuts_proposal( Returns ------- - A kernel that generates a new chain state and information about the transition. + A kernel that generates a new chain state and information about the + transition. """ ( diff --git a/blackjax/mcmc/proposal.py b/blackjax/mcmc/proposal.py index 73c0a6cd0..4adda4bcc 100644 --- a/blackjax/mcmc/proposal.py +++ b/blackjax/mcmc/proposal.py @@ -10,9 +10,9 @@ class Proposal(NamedTuple): """Proposal for the next chain step. state: - The trajectory state corresponding to this proposal. + The trajectory state that corresponds to this proposal. energy: - The potential energy corresponding to the state. + The total energy that corresponds to this proposal. weight: Weight of the proposal. It is equal to the logarithm of the sum of the canonical densities of each state :math:`e^{-H(z)}` along the trajectory. @@ -39,7 +39,7 @@ def update(initial_energy: float, state: IntegratorState) -> Tuple[Proposal, boo The trajectory state records information about the position in the state space and corresponding potential energy. A proposal also carries a weight that is equal to the difference between the current energy and - the previous one. It thus carries information about the previous state + the previous one. It thus carries information about the previous states as well as the current state. Parameters @@ -58,6 +58,7 @@ def update(initial_energy: float, state: IntegratorState) -> Tuple[Proposal, boo # The weight of the new proposal is equal to H0 - H(z_new) weight = delta_energy + # Acceptance statistic min(e^{H0 - H(z_new)}, 1) sum_log_p_accept = jnp.minimum(delta_energy, 0.0) @@ -80,11 +81,12 @@ def update(initial_energy: float, state: IntegratorState) -> Tuple[Proposal, boo def static_binomial_sampling(rng_key, proposal, new_proposal): - """Accept or reject a proposal based on its weight. + """Accept or reject a proposal. - In the static setting, the `log_weight` of the proposal will be equal to the - difference of energy between the beginning and the end of the trajectory (truncated at 0.). It - is implemented this way to keep a consistent API with progressive sampling. + In the static setting, the probability with which the new proposal is + accepted is a function of the difference in energy between the previous and + the current states. If the current energy is lower than the previous one + then the new proposal is accepted with probability 1. """ p_accept = jnp.clip(jnp.exp(new_proposal.weight), a_max=1) @@ -98,35 +100,6 @@ def static_binomial_sampling(rng_key, proposal, new_proposal): ) -# -------------------------------------------------------------------- -# NON-REVERSIVLE SLICE SAMPLING -# -------------------------------------------------------------------- - - -def nonreversible_slice_sampling(slice, proposal, new_proposal): - """Slice sampling for non-reversible Metropolis-Hasting update. - - Performs a non-reversible update of a uniform [0, 1] value - for Metropolis-Hastings accept/reject decisions [1]_, in addition - to the accept/reject step of a current state and new proposal. - - References - ---------- - .. [1]: Neal, R. M. (2020). Non-reversibly updating a uniform - [0, 1] value for Metropolis accept/reject decisions. - arXiv preprint arXiv:2001.11950. - """ - - delta_energy = new_proposal.weight - do_accept = jnp.log(jnp.abs(slice)) <= delta_energy - return jax.lax.cond( - do_accept, - lambda _: (new_proposal, do_accept, slice * jnp.exp(-delta_energy)), - lambda _: (proposal, do_accept, slice), - operand=None, - ) - - # -------------------------------------------------------------------- # PROGRESSIVE SAMPLING # @@ -170,6 +143,12 @@ def progressive_biased_sampling(rng_key, proposal, new_proposal): Unlike uniform sampling, biased sampling favors new proposals. It thus biases the transition away from the trajectory's initial state. + References + ---------- + .. [1]: Betancourt, Michael. + "A conceptual introduction to Hamiltonian Monte Carlo." + arXiv preprint arXiv:1701.02434 (2017). + """ p_accept = jnp.clip(jnp.exp(new_proposal.weight - proposal.weight), a_max=1) do_accept = jax.random.bernoulli(rng_key, p_accept) @@ -194,3 +173,32 @@ def progressive_biased_sampling(rng_key, proposal, new_proposal): ), operand=None, ) + + +# -------------------------------------------------------------------- +# NON-REVERSIVLE SLICE SAMPLING +# -------------------------------------------------------------------- + + +def nonreversible_slice_sampling(slice, proposal, new_proposal): + """Slice sampling for non-reversible Metropolis-Hasting update. + + Performs a non-reversible update of a uniform [0, 1] value + for Metropolis-Hastings accept/reject decisions [1]_, in addition + to the accept/reject step of a current state and new proposal. + + References + ---------- + .. [1]: Neal, R. M. (2020). + "Non-reversibly updating a uniform [0, 1] value for Metropolis accept/reject decisions." + arXiv preprint arXiv:2001.11950. + """ + + delta_energy = new_proposal.weight + do_accept = jnp.log(jnp.abs(slice)) <= delta_energy + return jax.lax.cond( + do_accept, + lambda _: (new_proposal, do_accept, slice * jnp.exp(-delta_energy)), + lambda _: (proposal, do_accept, slice), + operand=None, + ) diff --git a/blackjax/mcmc/trajectory.py b/blackjax/mcmc/trajectory.py index c390f529c..0ee85f07c 100644 --- a/blackjax/mcmc/trajectory.py +++ b/blackjax/mcmc/trajectory.py @@ -1,6 +1,7 @@ """Procedures to build trajectories for algorithms in the HMC family. To propose a new state, algorithms in the HMC family generally proceed by [1]_: + 1. Sampling a trajectory starting from the initial point; 2. Sampling a new state from this sampled trajectory. @@ -24,7 +25,9 @@ References ---------- -.. [1]: Betancourt, Michael. "A conceptual introduction to Hamiltonian Monte Carlo." arXiv preprint arXiv:1701.02434 (2017). +.. [1]: Betancourt, Michael. + "A conceptual introduction to Hamiltonian Monte Carlo." + arXiv preprint arXiv:1701.02434 (2017). """ from typing import Callable, NamedTuple, Tuple @@ -151,7 +154,8 @@ def dynamic_progressive_integration( is_criterion_met Determines whether the termination criterion has been met. divergence_threshold - Value of the difference of energy between two consecutive states above which we say a transition is divergent. + Value of the difference of energy between two consecutive states above + which we say a transition is divergent. """ _, generate_proposal = proposal_generator(kinetic_energy, divergence_threshold) @@ -166,8 +170,9 @@ def integrate( step_size, initial_energy, ): - """Integrate the trajectory starting from `initial_state` and update - the proposal sequentially until the termination criterion is met. + """Integrate the trajectory starting from `initial_state` and update the + proposal sequentially (hence progressive) until the termination + criterion is met (hence dynamic). Parameters ---------- @@ -178,7 +183,8 @@ def integrate( direction int in {-1, 1} The direction in which to expand the trajectory. termination_state - The state that keeps track of the information needed for the termination criterion. + The state that keeps track of the information needed for the + termination criterion. max_num_steps The maximum number of integration steps. The expansion will stop when this number is reached if the termination criterion has not @@ -186,7 +192,8 @@ def integrate( step_size The step size of the symplectic integrator. initial_energy - Initial energy H0 of the HMC step (not to confused with the initial energy of the subtree) + Initial energy H0 of the HMC step (not to confused with the initial + energy of the subtree) """ @@ -512,9 +519,9 @@ def do_keep_expanding(loop_state) -> bool: def expand_once(loop_state): """Expand the current trajectory. - At each step we draw a direction at random, build a subtrajectory starting - from the leftmost or rightmost point of the current trajectory that is - twice as long as the current trajectory. + At each step we draw a direction at random, build a subtrajectory + starting from the leftmost or rightmost point of the current + trajectory that is twice as long as the current trajectory. Once that is done, possibly update the current proposal with that of the subtrajectory. @@ -554,9 +561,10 @@ def expand_once(loop_state): # Update the proposal # - # We do not accept proposals that come from diverging or turning subtrajectories. - # However the definition of the acceptance probability is such that the - # acceptance probability needs to be computed across the entire trajectory. + # We do not accept proposals that come from diverging or turning + # subtrajectories. However the definition of the acceptance + # probability is such that the acceptance probability needs to be + # computed across the entire trajectory. def update_sum_log_p_accept(inputs): _, proposal, new_proposal = inputs return Proposal( @@ -577,8 +585,8 @@ def update_sum_log_p_accept(inputs): # Is the full trajectory making a U-Turn? # - # We first merge the subtrajectory that was just generated with the trajectory - # and check the U-Turn criterior on the whole trajectory. + # We first merge the subtrajectory that was just generated with the + # trajectory and check the U-Turn criterior on the whole trajectory. left_trajectory, right_trajectory = reorder_trajectories( direction, trajectory, new_trajectory )