diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index 7b0e26325..d32d2da25 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -149,6 +149,8 @@ def _binom_inv_cond_fn(val): return cond_exclude_large_mu & (geom_acc <= n) log1_p = jnp.log1p(-p) + # Make sure p=0 is never taken into account as a fix for possible zeros in p. + log1_p = jnp.where(log1_p == 0, -jnp.finfo(log1_p.dtype).tiny, log1_p) ret = lax.while_loop(_binom_inv_cond_fn, _binom_inv_body_fn, (-1, key, 0.0)) return ret[0]