Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

S/acceptance probability/acceptance ratio #390

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions blackjax/adaptation/step_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def init(inital_step_size: float) -> DualAveragingAdaptationState:
return DualAveragingAdaptationState(*da_init(inital_step_size))

def update(
da_state: DualAveragingAdaptationState, acceptance_probability: float
da_state: DualAveragingAdaptationState, acceptance_rate: float
) -> DualAveragingAdaptationState:
"""Update the state of the Dual Averaging adaptive algorithm.

Expand All @@ -134,7 +134,7 @@ def update(
-------
The updated state of the dual averaging algorithm.
"""
gradient = target - acceptance_probability
gradient = target - acceptance_rate
return DualAveragingAdaptationState(*da_update(da_state, gradient))

def final(da_state: DualAveragingAdaptationState) -> float:
Expand Down Expand Up @@ -252,7 +252,7 @@ def update(rss_state: ReasonableStepSizeState) -> Tuple:
kernel = kernel_generator(step_size)
_, info = kernel(rng_key, reference_state)

new_direction = jnp.where(target_accept < info.acceptance_probability, 1, -1)
new_direction = jnp.where(target_accept < info.acceptance_rate, 1, -1)
return ReasonableStepSizeState(rng_key, new_direction, direction, step_size)

rss_state = ReasonableStepSizeState(rng_key, 0, 0, initial_step_size)
Expand Down
4 changes: 2 additions & 2 deletions blackjax/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@ def one_step(carry, xs):
adaptation_state,
adaptation_stage,
new_state.position,
info.acceptance_probability,
info.acceptance_rate,
)

return (
Expand Down Expand Up @@ -1310,7 +1310,7 @@ def one_step(carry, rng_key):
**extra_parameters,
)
new_adaptation_state = update(
adaptation_state, new_state.position, info.acceptance_probability
adaptation_state, new_state.position, info.acceptance_rate
)
return (
(new_state, new_adaptation_state),
Expand Down
2 changes: 1 addition & 1 deletion blackjax/mcmc/ghmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def potential_fn(x):
jax.tree_map(lambda m, s: m * s, proposal.momentum, momentum_inverse_scale),
proposal.potential_energy,
proposal.potential_energy_grad,
info.acceptance_probability,
info.acceptance_rate,
)

return state, info
Expand Down
4 changes: 2 additions & 2 deletions blackjax/mcmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class HMCInfo(NamedTuple):

momentum:
The momentum that was sampled and used to integrate the trajectory.
acceptance_probability
acceptance_rate
The acceptance probability of the transition, linked to the energy
difference between the original and the proposed states.
is_accepted
Expand All @@ -55,7 +55,7 @@ class HMCInfo(NamedTuple):
"""

momentum: PyTree
acceptance_probability: float
acceptance_rate: float
is_accepted: bool
is_divergent: bool
energy: float
Expand Down
6 changes: 3 additions & 3 deletions blackjax/mcmc/mala.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ class MALAInfo(NamedTuple):
This additional information can be used for debugging or computing
diagnostics.

acceptance_probability
The acceptance probability of the transition.
acceptance_rate
The acceptance rate of the transition.
is_accepted
Whether the proposed position was accepted or the original position
was returned.

"""

acceptance_probability: float
acceptance_rate: float
is_accepted: bool


Expand Down
4 changes: 2 additions & 2 deletions blackjax/mcmc/marginal_latent_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class MarginalInfo(NamedTuple):
This additional information can be used for debugging or computing
diagnostics.

acceptance_probability
acceptance_rate
The acceptance probability of the transition, linked to the energy
difference between the original and the proposed states.
is_accepted
Expand All @@ -51,7 +51,7 @@ class MarginalInfo(NamedTuple):
The state proposed by the proposal.
"""

acceptance_probability: float
acceptance_rate: float
is_accepted: bool
proposal: MarginalState

Expand Down
54 changes: 36 additions & 18 deletions blackjax/mcmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class NUTSInfo(NamedTuple):
num_integration_steps
Number of integration steps that were taken. This is also the number of
states in the full trajectory.
acceptance_probability
acceptance_rate
average acceptance probabilty across entire trajectory
"""

Expand All @@ -56,7 +56,7 @@ class NUTSInfo(NamedTuple):
trajectory_rightmost_state: integrators.IntegratorState
num_trajectory_expansions: int
num_integration_steps: int
acceptance_probability: float
acceptance_rate: float


def kernel(
Expand All @@ -67,15 +67,18 @@ def kernel(
"""Build an iterative NUTS kernel.

This algorithm is an iteration on the original NUTS algorithm [Hoffman2014]_
with two major differences: - We do not use slice samplig but multinomial
sampling for the proposal [Betancourt2017]_; - The trajectory expansion is
not recursive but iterative [Phan2019]_, [Lao2020]_.
with two major differences:

- We do not use slice samplig but multinomial sampling for the proposal
[Betancourt2017]_;
- The trajectory expansion is not recursive but iterative [Phan2019]_,
[Lao2020]_.

The implementation can seem unusual for those familiar with similar
algorithms. Indeed, we do not conceptualize the trajectory construction as
building a tree. We feel that the tree lingo, inherited from the recursive
version, is unnecessarily complicated and hides the more general concepts
on which the NUTS algorithm is built.
upon which the NUTS algorithm is built.

NUTS, in essence, consists in sampling a trajectory by iteratively choosing
a direction at random and integrating in this direction a number of times
Expand All @@ -85,17 +88,30 @@ def kernel(

Parameters
----------
logprob_fb
Log probability function we wish to sample from.
parameters
A NamedTuple that contains the parameters of the kernel to be built.
integrator
The simplectic integrator used to build trajectories.
divergence_threshold
The absolute difference in energy above which we consider
a transition "divergent".
max_num_doublings
The maximum number of times we expand the trajectory by
doubling the number of steps if the trajectory does not
turn onto itself.

References
----------
.. [Hoffman2014] Hoffman, Matthew D., and Andrew Gelman. "The No-U-Turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo." J. Mach. Learn. Res. 15.1 (2014): 1593-1623.
.. [Betancourt2017] Betancourt, Michael. "A conceptual introduction to Hamiltonian Monte Carlo." arXiv preprint arXiv:1701.02434 (2017).
.. [Phan2019] Phan, Du, Neeraj Pradhan, and Martin Jankowiak. "Composable effects for flexible and accelerated probabilistic programming in NumPyro." arXiv preprint arXiv:1912.11554 (2019).
.. [Lao2020] Lao, Junpeng, et al. "tfp. mcmc: Modern markov chain monte carlo tools built for modern hardware." arXiv preprint arXiv:2002.01184 (2020).
.. [Hoffman2014]: Hoffman, Matthew D., and Andrew Gelman.
"The No-U-Turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo."
J. Mach. Learn. Res. 15.1 (2014): 1593-1623.
.. [Betancourt2017]: Betancourt, Michael.
"A conceptual introduction to Hamiltonian Monte Carlo."
arXiv preprint arXiv:1701.02434 (2017).
.. [Phan2019]: Phan Du, Neeraj Pradhan, and Martin Jankowiak.
"Composable effects for flexible and accelerated probabilistic programming in NumPyro."
arXiv preprint arXiv:1912.11554 (2019).
.. [Lao2020]: Lao, Junpeng, et al.
"tfp. mcmc: Modern markov chain monte carlo tools built for modern hardware."
arXiv preprint arXiv:2002.01184 (2020).

"""

Expand Down Expand Up @@ -161,7 +177,8 @@ def iterative_nuts_proposal(
kinetic_energy
Function that computes the kinetic energy.
uturn_check_fn:
Function that determines whether the trajectory is turning on itself (metric-dependant).
Function that determines whether the trajectory is turning on itself
(metric-dependant).
step_size
Size of the integration step.
max_num_expansions
Expand All @@ -171,7 +188,8 @@ def iterative_nuts_proposal(

Returns
-------
A kernel that generates a new chain state and information about the transition.
A kernel that generates a new chain state and information about the
transition.

"""
(
Expand Down Expand Up @@ -223,7 +241,7 @@ def propose(rng_key, initial_state: integrators.IntegratorState, step_size):
num_doublings, sampled_proposal, new_trajectory, _ = expansion_state
# Compute average acceptance probabilty across entire trajectory,
# even over subtrees that may have been rejected
acceptance_probability = (
acceptance_rate = (
jnp.exp(sampled_proposal.sum_log_p_accept) / new_trajectory.num_states
)

Expand All @@ -236,7 +254,7 @@ def propose(rng_key, initial_state: integrators.IntegratorState, step_size):
new_trajectory.rightmost_state,
num_doublings,
new_trajectory.num_states,
acceptance_probability,
acceptance_rate,
)

return sampled_proposal.state, info
Expand Down
80 changes: 44 additions & 36 deletions blackjax/mcmc/proposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
class Proposal(NamedTuple):
"""Proposal for the next chain step.
state:
The trajectory state corresponding to this proposal.
The trajectory state that corresponds to this proposal.
energy:
The potential energy corresponding to the state.
The total energy that corresponds to this proposal.
weight:
Weight of the proposal. It is equal to the logarithm of the sum of the canonical
densities of each state :math:`e^{-H(z)}` along the trajectory.
Expand All @@ -39,7 +39,7 @@ def update(initial_energy: float, state: IntegratorState) -> Tuple[Proposal, boo
The trajectory state records information about the position in the state
space and corresponding potential energy. A proposal also carries a
weight that is equal to the difference between the current energy and
the previous one. It thus carries information about the previous state
the previous one. It thus carries information about the previous states
as well as the current state.
Parameters
Expand All @@ -58,6 +58,7 @@ def update(initial_energy: float, state: IntegratorState) -> Tuple[Proposal, boo

# The weight of the new proposal is equal to H0 - H(z_new)
weight = delta_energy

# Acceptance statistic min(e^{H0 - H(z_new)}, 1)
sum_log_p_accept = jnp.minimum(delta_energy, 0.0)

Expand All @@ -80,11 +81,12 @@ def update(initial_energy: float, state: IntegratorState) -> Tuple[Proposal, boo


def static_binomial_sampling(rng_key, proposal, new_proposal):
"""Accept or reject a proposal based on its weight.
"""Accept or reject a proposal.
In the static setting, the `log_weight` of the proposal will be equal to the
difference of energy between the beginning and the end of the trajectory (truncated at 0.). It
is implemented this way to keep a consistent API with progressive sampling.
In the static setting, the probability with which the new proposal is
accepted is a function of the difference in energy between the previous and
the current states. If the current energy is lower than the previous one
then the new proposal is accepted with probability 1.
"""
p_accept = jnp.clip(jnp.exp(new_proposal.weight), a_max=1)
Expand All @@ -98,35 +100,6 @@ def static_binomial_sampling(rng_key, proposal, new_proposal):
)


# --------------------------------------------------------------------
# NON-REVERSIVLE SLICE SAMPLING
# --------------------------------------------------------------------


def nonreversible_slice_sampling(slice, proposal, new_proposal):
"""Slice sampling for non-reversible Metropolis-Hasting update.
Performs a non-reversible update of a uniform [0, 1] value
for Metropolis-Hastings accept/reject decisions [1]_, in addition
to the accept/reject step of a current state and new proposal.
References
----------
.. [1]: Neal, R. M. (2020). Non-reversibly updating a uniform
[0, 1] value for Metropolis accept/reject decisions.
arXiv preprint arXiv:2001.11950.
"""

delta_energy = new_proposal.weight
do_accept = jnp.log(jnp.abs(slice)) <= delta_energy
return jax.lax.cond(
do_accept,
lambda _: (new_proposal, do_accept, slice * jnp.exp(-delta_energy)),
lambda _: (proposal, do_accept, slice),
operand=None,
)


# --------------------------------------------------------------------
# PROGRESSIVE SAMPLING
#
Expand Down Expand Up @@ -170,6 +143,12 @@ def progressive_biased_sampling(rng_key, proposal, new_proposal):
Unlike uniform sampling, biased sampling favors new proposals. It thus
biases the transition away from the trajectory's initial state.
References
----------
.. [1]: Betancourt, Michael.
"A conceptual introduction to Hamiltonian Monte Carlo."
arXiv preprint arXiv:1701.02434 (2017).
"""
p_accept = jnp.clip(jnp.exp(new_proposal.weight - proposal.weight), a_max=1)
do_accept = jax.random.bernoulli(rng_key, p_accept)
Expand All @@ -194,3 +173,32 @@ def progressive_biased_sampling(rng_key, proposal, new_proposal):
),
operand=None,
)


# --------------------------------------------------------------------
# NON-REVERSIVLE SLICE SAMPLING
# --------------------------------------------------------------------


def nonreversible_slice_sampling(slice, proposal, new_proposal):
"""Slice sampling for non-reversible Metropolis-Hasting update.
Performs a non-reversible update of a uniform [0, 1] value
for Metropolis-Hastings accept/reject decisions [1]_, in addition
to the accept/reject step of a current state and new proposal.
References
----------
.. [1]: Neal, R. M. (2020).
"Non-reversibly updating a uniform [0, 1] value for Metropolis accept/reject decisions."
arXiv preprint arXiv:2001.11950.
"""

delta_energy = new_proposal.weight
do_accept = jnp.log(jnp.abs(slice)) <= delta_energy
return jax.lax.cond(
do_accept,
lambda _: (new_proposal, do_accept, slice * jnp.exp(-delta_energy)),
lambda _: (proposal, do_accept, slice),
operand=None,
)
10 changes: 5 additions & 5 deletions blackjax/mcmc/rmh.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class RMHInfo(NamedTuple):
This additional information can be used for debugging or computing
diagnostics.

acceptance_probability
acceptance_rate
The acceptance probability of the transition, linked to the energy
difference between the original and the proposed states.
is_accepted
Expand All @@ -39,7 +39,7 @@ class RMHInfo(NamedTuple):
The state proposed by the proposal.
"""

acceptance_probability: float
acceptance_rate: float
is_accepted: bool
proposal: RMHState

Expand Down Expand Up @@ -125,12 +125,12 @@ def rmh(

if proposal_logprob_fn is None:

def acceptance_probability(state: RMHState, proposal: RMHState):
def acceptance_rate(state: RMHState, proposal: RMHState):
return proposal.log_probability - state.log_probability

else:

def acceptance_probability(state: RMHState, proposal: RMHState):
def acceptance_rate(state: RMHState, proposal: RMHState):
return (
proposal.log_probability
+ proposal_logprob_fn(proposal.position, state.position) # type: ignore
Expand Down Expand Up @@ -163,7 +163,7 @@ def kernel(rng_key: PRNGKey, state: RMHState) -> Tuple[RMHState, RMHInfo]:
new_log_probability = logprob_fn(new_position)
new_state = RMHState(new_position, new_log_probability)

delta = acceptance_probability(state, new_state)
delta = acceptance_rate(state, new_state)
delta = jnp.where(jnp.isnan(delta), -jnp.inf, delta)
p_accept = jnp.clip(jnp.exp(delta), a_max=1.0)

Expand Down
Loading