Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid infinite loop in vmapped Binomial with p=0 #1462

Merged
merged 3 commits into from
Aug 7, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions numpyro/distributions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ def stirling_approx_tail(k):
)


_binomial_mu_thresh = 10


def _binomial_btrs(key, p, n):
"""
Based on the transformed rejection sampling algorithm (BTRS) from the
Expand Down Expand Up @@ -103,13 +106,19 @@ def accept_fn(k, u, v):
k, key, u, v = val
early_accept = (jnp.abs(u) <= tr_params.u_r) & (v <= tr_params.v_r)
early_reject = (k < 0) | (k > n)
return lax.cond(
# when vmapped _binomial_dispatch will convert the cond condition into
# a HLO select that will execute both branches. This is a workaround
# that avoids the resulting infinite loop when p=0. This should also
# improve performance in less catastrophic cases.
cond_exclude_small_mu = p * n > _binomial_mu_thresh
tbenthompson marked this conversation as resolved.
Show resolved Hide resolved
cond_main = lax.cond(
early_accept | early_reject,
(),
lambda _: ~early_accept,
(k, u, v),
lambda x: ~accept_fn(*x),
)
return cond_exclude_small_mu & cond_main

tr_params = _get_tr_params(n, p)
ret = lax.while_loop(
Expand All @@ -129,7 +138,11 @@ def _binom_inv_body_fn(val):

def _binom_inv_cond_fn(val):
i, _, geom_acc = val
return geom_acc <= n
# see the note on cond_exclude_small_mu in _binomial_btrs
# this cond_exclude_large_mu is unnecessary for correctness but will
# still improve performance.
cond_exclude_large_mu = p * n < _binomial_mu_thresh
return cond_exclude_large_mu & (geom_acc <= n)

log1_p = jnp.log1p(-p)
ret = lax.while_loop(_binom_inv_cond_fn, _binom_inv_body_fn, (-1, key, 0.0))
Expand All @@ -142,7 +155,7 @@ def dispatch(key, p, n):
pq = jnp.where(is_le_mid, p, 1 - p)
mu = n * pq
k = lax.cond(
mu < 10,
mu < _binomial_mu_thresh,
(key, pq, n),
lambda x: _binomial_inversion(*x),
(key, pq, n),
Expand Down
10 changes: 10 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2498,3 +2498,13 @@ def test_kl_dirichlet_dirichlet(shape):
x = p.sample(random.PRNGKey(0), (10_000,)).copy()
expected = jnp.mean((p.log_prob(x) - q.log_prob(x)), 0)
assert_allclose(actual, expected, rtol=0.05)


def test_vmapped_binomial_p0():
# test that vmapped binomial with p = 0 does not have an infinite loop
def sample_binomial_withp0(key):
n = 2 * (random.uniform(key) > 0.5)
_, key = random.split(key)
return dist.Binomial(total_count=n, probs=0).sample(key)

jax.vmap(sample_binomial_withp0)(random.split(random.PRNGKey(0), 1))