Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix binominal distribution #1860

Merged
merged 5 commits into from
Sep 10, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions numpyro/distributions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ def _binom_inv_cond_fn(val):
def _binomial_dispatch(key, p, n):
def dispatch(key, p, n):
is_le_mid = p <= 0.5
#Make sure p=0 is never taken into account as a fix for possible zeros in p.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we just simply clip p to tiny? jnp.clip(p, minval=jnp.finfo(p.dtype).tiny)

Copy link
Contributor Author

@InfinityMod InfinityMod Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works for me: jnp.clip(p, jnp.finfo(jnp.float32).tiny)
There's a change with a depreciation of the argument a_min, which changes to min for jnp.clip.
So, to support future and early versions, I set the min argument as a positional argument.
However, this can also be an error if the argument order changes.

jnp.finfo(p.dtype) doesn't seem to work, with float32 it's working.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this might affect samples with small p in x64. How about updating _binomial_inversion to

geom = jnp.floor(jnp.log1p(-u) / log1_p) + 1
...
log1_p = jnp.log1p(-p)
log1_p = jnp.where(log1_p == 0, -jnp.finfo(log1_p.dtype).tiny, log1_p)

The issue seems to come from log1_p=0 and jnp.log1p(-u) < 0, which leads to a negative geom

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also fix the lint issue? I think you need a space # comment for the comment.

p = jnp.sum(jnp.stack((p, jnp.ones(p.shape)*0.001)), where = (p == 0))
pq = jnp.where(is_le_mid, p, 1 - p)
mu = n * pq
k = lax.cond(
Expand Down