-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix binominal distribution #1860
Fix binominal distribution #1860
Conversation
numpyro/distributions/util.py
Outdated
@@ -156,6 +156,8 @@ def _binom_inv_cond_fn(val): | |||
def _binomial_dispatch(key, p, n): | |||
def dispatch(key, p, n): | |||
is_le_mid = p <= 0.5 | |||
#Make sure p=0 is never taken into account as a fix for possible zeros in p. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we just simply clip p to tiny? jnp.clip(p, minval=jnp.finfo(p.dtype).tiny)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This works for me: jnp.clip(p, jnp.finfo(jnp.float32).tiny)
There's a change with a depreciation of the argument a_min, which changes to min for jnp.clip.
So, to support future and early versions, I set the min argument as a positional argument.
However, this can also be an error if the argument order changes.
jnp.finfo(p.dtype) doesn't seem to work, with float32 it's working.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this might affect samples with small p
in x64. How about updating _binomial_inversion
to
geom = jnp.floor(jnp.log1p(-u) / log1_p) + 1
...
log1_p = jnp.log1p(-p)
log1_p = jnp.where(log1_p == 0, -jnp.finfo(log1_p.dtype).tiny, log1_p)
The issue seems to come from log1_p=0
and jnp.log1p(-u) < 0
, which leads to a negative geom
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also fix the lint issue? I think you need a space # comment
for the comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @InfinityMod! This issue is subtle.
Hi, Pull Request #1807 ends in an infinite loop while running the test setup for the distributions, causing the tests to end until the time limit is reached.
The misbehavior can be tested just by running the test_distributions setup or only especially the
test_vmapped_binominal_p0 function solely with the following code included directly after the imports in test_distributions.py:
While I can't 100% say why this endless loop only occurs when increasing the numeric accuracy to 64 bits, I'm sure to have a suitable solution for the behavior (which also seems faulty at 32 bits).
The problem stems from the _binomial_dispatch function in util.py (see the description commit). In short, zero values for the probability p are also passed to the dispatch function, even if filtered out afterward via lax.cond. Therefore, the underlying while_loop in the _binomial_inversion function runs infinitely.
This merge request solves this issue by ensuring that no p values equal to zero are passed to the underlying functions. As lax.cond is still filtering out the results of the zero-corrected values, there's no change for instances using the mentioned methods.