diff --git a/blackjax/adaptation/step_size.py b/blackjax/adaptation/step_size.py index 00892b922..9812712eb 100644 --- a/blackjax/adaptation/step_size.py +++ b/blackjax/adaptation/step_size.py @@ -119,7 +119,7 @@ def init(inital_step_size: float) -> DualAveragingAdaptationState: return DualAveragingAdaptationState(*da_init(inital_step_size)) def update( - da_state: DualAveragingAdaptationState, acceptance_probability: float + da_state: DualAveragingAdaptationState, acceptance_rate: float ) -> DualAveragingAdaptationState: """Update the state of the Dual Averaging adaptive algorithm. @@ -134,7 +134,7 @@ def update( ------- The updated state of the dual averaging algorithm. """ - gradient = target - acceptance_probability + gradient = target - acceptance_rate return DualAveragingAdaptationState(*da_update(da_state, gradient)) def final(da_state: DualAveragingAdaptationState) -> float: @@ -252,7 +252,7 @@ def update(rss_state: ReasonableStepSizeState) -> Tuple: kernel = kernel_generator(step_size) _, info = kernel(rng_key, reference_state) - new_direction = jnp.where(target_accept < info.acceptance_probability, 1, -1) + new_direction = jnp.where(target_accept < info.acceptance_rate, 1, -1) return ReasonableStepSizeState(rng_key, new_direction, direction, step_size) rss_state = ReasonableStepSizeState(rng_key, 0, 0, initial_step_size) diff --git a/blackjax/kernels.py b/blackjax/kernels.py index 2b3bd9348..3366de944 100644 --- a/blackjax/kernels.py +++ b/blackjax/kernels.py @@ -719,7 +719,7 @@ def one_step(carry, xs): adaptation_state, adaptation_stage, new_state.position, - info.acceptance_probability, + info.acceptance_rate, ) return ( @@ -1310,7 +1310,7 @@ def one_step(carry, rng_key): **extra_parameters, ) new_adaptation_state = update( - adaptation_state, new_state.position, info.acceptance_probability + adaptation_state, new_state.position, info.acceptance_rate ) return ( (new_state, new_adaptation_state), diff --git a/blackjax/mcmc/ghmc.py b/blackjax/mcmc/ghmc.py index b40a121f0..ab46cc0e7 100644 --- a/blackjax/mcmc/ghmc.py +++ b/blackjax/mcmc/ghmc.py @@ -152,7 +152,7 @@ def potential_fn(x): jax.tree_map(lambda m, s: m * s, proposal.momentum, momentum_inverse_scale), proposal.potential_energy, proposal.potential_energy_grad, - info.acceptance_probability, + info.acceptance_rate, ) return state, info diff --git a/blackjax/mcmc/hmc.py b/blackjax/mcmc/hmc.py index 513110c10..4eca66059 100644 --- a/blackjax/mcmc/hmc.py +++ b/blackjax/mcmc/hmc.py @@ -34,7 +34,7 @@ class HMCInfo(NamedTuple): momentum: The momentum that was sampled and used to integrate the trajectory. - acceptance_probability + acceptance_rate The acceptance probability of the transition, linked to the energy difference between the original and the proposed states. is_accepted @@ -55,7 +55,7 @@ class HMCInfo(NamedTuple): """ momentum: PyTree - acceptance_probability: float + acceptance_rate: float is_accepted: bool is_divergent: bool energy: float diff --git a/blackjax/mcmc/mala.py b/blackjax/mcmc/mala.py index 8623e9314..d8ff0bb3f 100644 --- a/blackjax/mcmc/mala.py +++ b/blackjax/mcmc/mala.py @@ -32,15 +32,15 @@ class MALAInfo(NamedTuple): This additional information can be used for debugging or computing diagnostics. - acceptance_probability - The acceptance probability of the transition. + acceptance_rate + The acceptance rate of the transition. is_accepted Whether the proposed position was accepted or the original position was returned. """ - acceptance_probability: float + acceptance_rate: float is_accepted: bool diff --git a/blackjax/mcmc/marginal_latent_gaussian.py b/blackjax/mcmc/marginal_latent_gaussian.py index e8bbd1a5a..c3a6ed770 100644 --- a/blackjax/mcmc/marginal_latent_gaussian.py +++ b/blackjax/mcmc/marginal_latent_gaussian.py @@ -41,7 +41,7 @@ class MarginalInfo(NamedTuple): This additional information can be used for debugging or computing diagnostics. - acceptance_probability + acceptance_rate The acceptance probability of the transition, linked to the energy difference between the original and the proposed states. is_accepted @@ -51,7 +51,7 @@ class MarginalInfo(NamedTuple): The state proposed by the proposal. """ - acceptance_probability: float + acceptance_rate: float is_accepted: bool proposal: MarginalState diff --git a/blackjax/mcmc/nuts.py b/blackjax/mcmc/nuts.py index 6786286d2..70c726da8 100644 --- a/blackjax/mcmc/nuts.py +++ b/blackjax/mcmc/nuts.py @@ -44,7 +44,7 @@ class NUTSInfo(NamedTuple): num_integration_steps Number of integration steps that were taken. This is also the number of states in the full trajectory. - acceptance_probability + acceptance_rate average acceptance probabilty across entire trajectory """ @@ -56,7 +56,7 @@ class NUTSInfo(NamedTuple): trajectory_rightmost_state: integrators.IntegratorState num_trajectory_expansions: int num_integration_steps: int - acceptance_probability: float + acceptance_rate: float def kernel( @@ -67,15 +67,18 @@ def kernel( """Build an iterative NUTS kernel. This algorithm is an iteration on the original NUTS algorithm [Hoffman2014]_ - with two major differences: - We do not use slice samplig but multinomial - sampling for the proposal [Betancourt2017]_; - The trajectory expansion is - not recursive but iterative [Phan2019]_, [Lao2020]_. + with two major differences: + + - We do not use slice samplig but multinomial sampling for the proposal + [Betancourt2017]_; + - The trajectory expansion is not recursive but iterative [Phan2019]_, + [Lao2020]_. The implementation can seem unusual for those familiar with similar algorithms. Indeed, we do not conceptualize the trajectory construction as building a tree. We feel that the tree lingo, inherited from the recursive version, is unnecessarily complicated and hides the more general concepts - on which the NUTS algorithm is built. + upon which the NUTS algorithm is built. NUTS, in essence, consists in sampling a trajectory by iteratively choosing a direction at random and integrating in this direction a number of times @@ -85,17 +88,30 @@ def kernel( Parameters ---------- - logprob_fb - Log probability function we wish to sample from. - parameters - A NamedTuple that contains the parameters of the kernel to be built. + integrator + The simplectic integrator used to build trajectories. + divergence_threshold + The absolute difference in energy above which we consider + a transition "divergent". + max_num_doublings + The maximum number of times we expand the trajectory by + doubling the number of steps if the trajectory does not + turn onto itself. References ---------- - .. [Hoffman2014] Hoffman, Matthew D., and Andrew Gelman. "The No-U-Turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo." J. Mach. Learn. Res. 15.1 (2014): 1593-1623. - .. [Betancourt2017] Betancourt, Michael. "A conceptual introduction to Hamiltonian Monte Carlo." arXiv preprint arXiv:1701.02434 (2017). - .. [Phan2019] Phan, Du, Neeraj Pradhan, and Martin Jankowiak. "Composable effects for flexible and accelerated probabilistic programming in NumPyro." arXiv preprint arXiv:1912.11554 (2019). - .. [Lao2020] Lao, Junpeng, et al. "tfp. mcmc: Modern markov chain monte carlo tools built for modern hardware." arXiv preprint arXiv:2002.01184 (2020). + .. [Hoffman2014]: Hoffman, Matthew D., and Andrew Gelman. + "The No-U-Turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo." + J. Mach. Learn. Res. 15.1 (2014): 1593-1623. + .. [Betancourt2017]: Betancourt, Michael. + "A conceptual introduction to Hamiltonian Monte Carlo." + arXiv preprint arXiv:1701.02434 (2017). + .. [Phan2019]: Phan Du, Neeraj Pradhan, and Martin Jankowiak. + "Composable effects for flexible and accelerated probabilistic programming in NumPyro." + arXiv preprint arXiv:1912.11554 (2019). + .. [Lao2020]: Lao, Junpeng, et al. + "tfp. mcmc: Modern markov chain monte carlo tools built for modern hardware." + arXiv preprint arXiv:2002.01184 (2020). """ @@ -161,7 +177,8 @@ def iterative_nuts_proposal( kinetic_energy Function that computes the kinetic energy. uturn_check_fn: - Function that determines whether the trajectory is turning on itself (metric-dependant). + Function that determines whether the trajectory is turning on itself + (metric-dependant). step_size Size of the integration step. max_num_expansions @@ -171,7 +188,8 @@ def iterative_nuts_proposal( Returns ------- - A kernel that generates a new chain state and information about the transition. + A kernel that generates a new chain state and information about the + transition. """ ( @@ -223,7 +241,7 @@ def propose(rng_key, initial_state: integrators.IntegratorState, step_size): num_doublings, sampled_proposal, new_trajectory, _ = expansion_state # Compute average acceptance probabilty across entire trajectory, # even over subtrees that may have been rejected - acceptance_probability = ( + acceptance_rate = ( jnp.exp(sampled_proposal.sum_log_p_accept) / new_trajectory.num_states ) @@ -236,7 +254,7 @@ def propose(rng_key, initial_state: integrators.IntegratorState, step_size): new_trajectory.rightmost_state, num_doublings, new_trajectory.num_states, - acceptance_probability, + acceptance_rate, ) return sampled_proposal.state, info diff --git a/blackjax/mcmc/proposal.py b/blackjax/mcmc/proposal.py index 73c0a6cd0..4adda4bcc 100644 --- a/blackjax/mcmc/proposal.py +++ b/blackjax/mcmc/proposal.py @@ -10,9 +10,9 @@ class Proposal(NamedTuple): """Proposal for the next chain step. state: - The trajectory state corresponding to this proposal. + The trajectory state that corresponds to this proposal. energy: - The potential energy corresponding to the state. + The total energy that corresponds to this proposal. weight: Weight of the proposal. It is equal to the logarithm of the sum of the canonical densities of each state :math:`e^{-H(z)}` along the trajectory. @@ -39,7 +39,7 @@ def update(initial_energy: float, state: IntegratorState) -> Tuple[Proposal, boo The trajectory state records information about the position in the state space and corresponding potential energy. A proposal also carries a weight that is equal to the difference between the current energy and - the previous one. It thus carries information about the previous state + the previous one. It thus carries information about the previous states as well as the current state. Parameters @@ -58,6 +58,7 @@ def update(initial_energy: float, state: IntegratorState) -> Tuple[Proposal, boo # The weight of the new proposal is equal to H0 - H(z_new) weight = delta_energy + # Acceptance statistic min(e^{H0 - H(z_new)}, 1) sum_log_p_accept = jnp.minimum(delta_energy, 0.0) @@ -80,11 +81,12 @@ def update(initial_energy: float, state: IntegratorState) -> Tuple[Proposal, boo def static_binomial_sampling(rng_key, proposal, new_proposal): - """Accept or reject a proposal based on its weight. + """Accept or reject a proposal. - In the static setting, the `log_weight` of the proposal will be equal to the - difference of energy between the beginning and the end of the trajectory (truncated at 0.). It - is implemented this way to keep a consistent API with progressive sampling. + In the static setting, the probability with which the new proposal is + accepted is a function of the difference in energy between the previous and + the current states. If the current energy is lower than the previous one + then the new proposal is accepted with probability 1. """ p_accept = jnp.clip(jnp.exp(new_proposal.weight), a_max=1) @@ -98,35 +100,6 @@ def static_binomial_sampling(rng_key, proposal, new_proposal): ) -# -------------------------------------------------------------------- -# NON-REVERSIVLE SLICE SAMPLING -# -------------------------------------------------------------------- - - -def nonreversible_slice_sampling(slice, proposal, new_proposal): - """Slice sampling for non-reversible Metropolis-Hasting update. - - Performs a non-reversible update of a uniform [0, 1] value - for Metropolis-Hastings accept/reject decisions [1]_, in addition - to the accept/reject step of a current state and new proposal. - - References - ---------- - .. [1]: Neal, R. M. (2020). Non-reversibly updating a uniform - [0, 1] value for Metropolis accept/reject decisions. - arXiv preprint arXiv:2001.11950. - """ - - delta_energy = new_proposal.weight - do_accept = jnp.log(jnp.abs(slice)) <= delta_energy - return jax.lax.cond( - do_accept, - lambda _: (new_proposal, do_accept, slice * jnp.exp(-delta_energy)), - lambda _: (proposal, do_accept, slice), - operand=None, - ) - - # -------------------------------------------------------------------- # PROGRESSIVE SAMPLING # @@ -170,6 +143,12 @@ def progressive_biased_sampling(rng_key, proposal, new_proposal): Unlike uniform sampling, biased sampling favors new proposals. It thus biases the transition away from the trajectory's initial state. + References + ---------- + .. [1]: Betancourt, Michael. + "A conceptual introduction to Hamiltonian Monte Carlo." + arXiv preprint arXiv:1701.02434 (2017). + """ p_accept = jnp.clip(jnp.exp(new_proposal.weight - proposal.weight), a_max=1) do_accept = jax.random.bernoulli(rng_key, p_accept) @@ -194,3 +173,32 @@ def progressive_biased_sampling(rng_key, proposal, new_proposal): ), operand=None, ) + + +# -------------------------------------------------------------------- +# NON-REVERSIVLE SLICE SAMPLING +# -------------------------------------------------------------------- + + +def nonreversible_slice_sampling(slice, proposal, new_proposal): + """Slice sampling for non-reversible Metropolis-Hasting update. + + Performs a non-reversible update of a uniform [0, 1] value + for Metropolis-Hastings accept/reject decisions [1]_, in addition + to the accept/reject step of a current state and new proposal. + + References + ---------- + .. [1]: Neal, R. M. (2020). + "Non-reversibly updating a uniform [0, 1] value for Metropolis accept/reject decisions." + arXiv preprint arXiv:2001.11950. + """ + + delta_energy = new_proposal.weight + do_accept = jnp.log(jnp.abs(slice)) <= delta_energy + return jax.lax.cond( + do_accept, + lambda _: (new_proposal, do_accept, slice * jnp.exp(-delta_energy)), + lambda _: (proposal, do_accept, slice), + operand=None, + ) diff --git a/blackjax/mcmc/rmh.py b/blackjax/mcmc/rmh.py index 0fe583305..521942686 100644 --- a/blackjax/mcmc/rmh.py +++ b/blackjax/mcmc/rmh.py @@ -29,7 +29,7 @@ class RMHInfo(NamedTuple): This additional information can be used for debugging or computing diagnostics. - acceptance_probability + acceptance_rate The acceptance probability of the transition, linked to the energy difference between the original and the proposed states. is_accepted @@ -39,7 +39,7 @@ class RMHInfo(NamedTuple): The state proposed by the proposal. """ - acceptance_probability: float + acceptance_rate: float is_accepted: bool proposal: RMHState @@ -125,12 +125,12 @@ def rmh( if proposal_logprob_fn is None: - def acceptance_probability(state: RMHState, proposal: RMHState): + def acceptance_rate(state: RMHState, proposal: RMHState): return proposal.log_probability - state.log_probability else: - def acceptance_probability(state: RMHState, proposal: RMHState): + def acceptance_rate(state: RMHState, proposal: RMHState): return ( proposal.log_probability + proposal_logprob_fn(proposal.position, state.position) # type: ignore @@ -163,7 +163,7 @@ def kernel(rng_key: PRNGKey, state: RMHState) -> Tuple[RMHState, RMHInfo]: new_log_probability = logprob_fn(new_position) new_state = RMHState(new_position, new_log_probability) - delta = acceptance_probability(state, new_state) + delta = acceptance_rate(state, new_state) delta = jnp.where(jnp.isnan(delta), -jnp.inf, delta) p_accept = jnp.clip(jnp.exp(delta), a_max=1.0) diff --git a/blackjax/mcmc/trajectory.py b/blackjax/mcmc/trajectory.py index c390f529c..0ee85f07c 100644 --- a/blackjax/mcmc/trajectory.py +++ b/blackjax/mcmc/trajectory.py @@ -1,6 +1,7 @@ """Procedures to build trajectories for algorithms in the HMC family. To propose a new state, algorithms in the HMC family generally proceed by [1]_: + 1. Sampling a trajectory starting from the initial point; 2. Sampling a new state from this sampled trajectory. @@ -24,7 +25,9 @@ References ---------- -.. [1]: Betancourt, Michael. "A conceptual introduction to Hamiltonian Monte Carlo." arXiv preprint arXiv:1701.02434 (2017). +.. [1]: Betancourt, Michael. + "A conceptual introduction to Hamiltonian Monte Carlo." + arXiv preprint arXiv:1701.02434 (2017). """ from typing import Callable, NamedTuple, Tuple @@ -151,7 +154,8 @@ def dynamic_progressive_integration( is_criterion_met Determines whether the termination criterion has been met. divergence_threshold - Value of the difference of energy between two consecutive states above which we say a transition is divergent. + Value of the difference of energy between two consecutive states above + which we say a transition is divergent. """ _, generate_proposal = proposal_generator(kinetic_energy, divergence_threshold) @@ -166,8 +170,9 @@ def integrate( step_size, initial_energy, ): - """Integrate the trajectory starting from `initial_state` and update - the proposal sequentially until the termination criterion is met. + """Integrate the trajectory starting from `initial_state` and update the + proposal sequentially (hence progressive) until the termination + criterion is met (hence dynamic). Parameters ---------- @@ -178,7 +183,8 @@ def integrate( direction int in {-1, 1} The direction in which to expand the trajectory. termination_state - The state that keeps track of the information needed for the termination criterion. + The state that keeps track of the information needed for the + termination criterion. max_num_steps The maximum number of integration steps. The expansion will stop when this number is reached if the termination criterion has not @@ -186,7 +192,8 @@ def integrate( step_size The step size of the symplectic integrator. initial_energy - Initial energy H0 of the HMC step (not to confused with the initial energy of the subtree) + Initial energy H0 of the HMC step (not to confused with the initial + energy of the subtree) """ @@ -512,9 +519,9 @@ def do_keep_expanding(loop_state) -> bool: def expand_once(loop_state): """Expand the current trajectory. - At each step we draw a direction at random, build a subtrajectory starting - from the leftmost or rightmost point of the current trajectory that is - twice as long as the current trajectory. + At each step we draw a direction at random, build a subtrajectory + starting from the leftmost or rightmost point of the current + trajectory that is twice as long as the current trajectory. Once that is done, possibly update the current proposal with that of the subtrajectory. @@ -554,9 +561,10 @@ def expand_once(loop_state): # Update the proposal # - # We do not accept proposals that come from diverging or turning subtrajectories. - # However the definition of the acceptance probability is such that the - # acceptance probability needs to be computed across the entire trajectory. + # We do not accept proposals that come from diverging or turning + # subtrajectories. However the definition of the acceptance + # probability is such that the acceptance probability needs to be + # computed across the entire trajectory. def update_sum_log_p_accept(inputs): _, proposal, new_proposal = inputs return Proposal( @@ -577,8 +585,8 @@ def update_sum_log_p_accept(inputs): # Is the full trajectory making a U-Turn? # - # We first merge the subtrajectory that was just generated with the trajectory - # and check the U-Turn criterior on the whole trajectory. + # We first merge the subtrajectory that was just generated with the + # trajectory and check the U-Turn criterior on the whole trajectory. left_trajectory, right_trajectory = reorder_trajectories( direction, trajectory, new_trajectory ) diff --git a/examples/howto_use_numpyro.md b/examples/howto_use_numpyro.md index ac91dedb9..30b0b1108 100644 --- a/examples/howto_use_numpyro.md +++ b/examples/howto_use_numpyro.md @@ -110,7 +110,7 @@ def inference_loop(rng_key, kernel, initial_state, num_samples): _, (states, infos) = jax.lax.scan(one_step, initial_state, keys) return states, ( - infos.acceptance_probability, + infos.acceptance_rate, infos.is_divergent, infos.num_integration_steps, ) diff --git a/examples/howto_use_tfp.md b/examples/howto_use_tfp.md index bacf33cd1..0e27c8bdc 100644 --- a/examples/howto_use_tfp.md +++ b/examples/howto_use_tfp.md @@ -148,7 +148,7 @@ Extra information about the inference is contained in the `infos` namedtuple. Le ```{code-cell} ipython3 :tags: [hide-cell] -acceptance_rate = np.mean(infos.acceptance_probability) +acceptance_rate = np.mean(infos.acceptance_rate) print(f"Average acceptance rate: {acceptance_rate:.2f}") ```