Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Poisson may enter infinite loop under vmap #582

Closed
lumip opened this issue Apr 30, 2020 · 2 comments
Closed

Poisson may enter infinite loop under vmap #582

lumip opened this issue Apr 30, 2020 · 2 comments

Comments

@lumip
Copy link
Contributor

lumip commented Apr 30, 2020

We ran into a situation yesterday where we essentially run something like

import jax.numpy as np
import jax

import numpyro.distributions as dist
from numpyro.primitives import sample
from numpyro.handlers import seed

def model(k):
    # rate = np.exp(np.ones((k,)))
    rate = np.exp(sample('rate', dist.Normal(np.zeros((k,)), np.ones((k,)))))
    x = sample('x', dist.Poisson(rate))
    return x


seeded_model = jax.vmap(lambda key: seed(model, key)(k))
keys = jax.random.split(jax.random.PRNGKey(0), 100)
samples = seeded_model(keys)
print(samples)

which did not terminate for a long while (we gave up after 10 minutes or so). The same happens for numpyro.infer.Predictive(model, num_samples=100, parallel=True)(<rngkey>).

The issue appears when using non-constant rate parameters for the Poisson distribution and vmap around the (seeded)
model function. Using a constant rate, a straightforward call to model or even a vmap call that maps over just
one PRNGKey are all fine. This is for any (valid) value of k and happens on CPU and GPU (most of our tests where on
CPU, though, so not sure if GPU has the same behavior in all cases, but we would believe so).

After a long while, we were able to track down the bug and it seems to be as follows:
numpy.distributions.Poisson essentially just refers to numpy.distributions.util._poisson, which in turns calls
numpy.distributions._poisson_one (vmapped over its dimensions).

The last one of these is defined as

def _poisson_one(val):
    return lax.cond(val[1] >= 10, val, _poisson_large, val, _poisson_small)

i.e., performs conditional branching depending on the value of the Poisson rate parameter.
Now, it seems like lax.cond actually executes both of _poisson_large and _poisson_small (and not just for jit compilation)
and that the while loop in _poisson_large does not terminate for (some) small rate values - at least not within an acceptable
time frame. We are not sure why it would be like this, but it is the best explanation for our observations.

A fix for this behavior follows as a pull request (essentially, we just needed to include the branching condition val[1] >= 10 in the condition of the while loop in _poisson_large).

Escalating this issue to the folks at jax might also be a smart choice and I will do so soon.
edit: This is reported as jax issue 2947 to see if they have some better explanation and potentially nice solution for this.

lumip added a commit to lumip/numpyro that referenced this issue Apr 30, 2020
fehiepsi pushed a commit that referenced this issue May 5, 2020
* Tighter conditions for loops in _poisson sampling

fixes issue #582

* fixing linting issues
@lumip
Copy link
Contributor Author

lumip commented May 11, 2020

I guess this can be closed now, unless you want to wait for results on the related jax issue?

@fehiepsi
Copy link
Member

Thanks, @lumip !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants