-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Poisson may enter infinite loop under vmap #582
Comments
lumip
added a commit
to lumip/numpyro
that referenced
this issue
Apr 30, 2020
fehiepsi
pushed a commit
that referenced
this issue
May 5, 2020
* Tighter conditions for loops in _poisson sampling fixes issue #582 * fixing linting issues
I guess this can be closed now, unless you want to wait for results on the related jax issue? |
Thanks, @lumip ! |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
We ran into a situation yesterday where we essentially run something like
which did not terminate for a long while (we gave up after 10 minutes or so). The same happens for
numpyro.infer.Predictive(model, num_samples=100, parallel=True)(<rngkey>)
.The issue appears when using non-constant rate parameters for the Poisson distribution and vmap around the (seeded)
model function. Using a constant rate, a straightforward call to model or even a vmap call that maps over just
one PRNGKey are all fine. This is for any (valid) value of
k
and happens on CPU and GPU (most of our tests where onCPU, though, so not sure if GPU has the same behavior in all cases, but we would believe so).
After a long while, we were able to track down the bug and it seems to be as follows:
numpy.distributions.Poisson
essentially just refers tonumpy.distributions.util._poisson
, which in turns callsnumpy.distributions._poisson_one
(vmapped over its dimensions).The last one of these is defined as
i.e., performs conditional branching depending on the value of the Poisson rate parameter.
Now, it seems like
lax.cond
actually executes both of_poisson_large
and_poisson_small
(and not just for jit compilation)and that the while loop in
_poisson_large
does not terminate for (some) small rate values - at least not within an acceptabletime frame. We are not sure why it would be like this, but it is the best explanation for our observations.
A fix for this behavior follows as a pull request (essentially, we just needed to include the branching condition
val[1] >= 10
in the condition of the while loop in_poisson_large
).Escalating this issue to the folks at jax might also be a smart choice and I will do so soon.edit: This is reported as jax issue 2947 to see if they have some better explanation and potentially nice solution for this.
The text was updated successfully, but these errors were encountered: