Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Predictive distribution fails on model with lax.scan #566

Closed
neerajprad opened this issue Apr 7, 2020 · 8 comments · Fixed by #595
Closed

Predictive distribution fails on model with lax.scan #566

neerajprad opened this issue Apr 7, 2020 · 8 comments · Fixed by #595
Labels
bug Something isn't working

Comments

@neerajprad
Copy link
Member

Refer to the forum discussion

The following snippet:

def target(T=10, q=1., r=1., phi=0.5, beta=0.5):

    def transition(state, xs):
        i, key = xs
        key1, key2 = jax.random.split(key)
        x0, mu0 = state
        x1 = numpyro.sample(f'x_{i}', dist.Normal(phi * x0, q), rng_key=key1)
        mu1 = beta * mu0 + x1
        y1 = numpyro.sample(f'y_{i}', dist.Normal(mu1, r), rng_key=key2)
        return (x1, mu1), (x1, y1)

    mu0 = x0 = numpyro.sample('x_0', dist.Normal(0, q))
    y0 = numpyro.sample('y_0', dist.Normal(mu0, r))

    key = numpyro.sample('key', dist.PRNGIdentity())
    _, xy = jax.lax.scan(transition, (x0, mu0), (np.arange(1, T), jax.random.split(key, T-1)))
    x, y = xy
    
    return np.append(x0, x), np.append(y0, y)


prior = Predictive(target, posterior_samples = {}, num_samples = 10)
prior_samples = prior(PRNGKey(2), T=10)

fails with:

UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function.
The functions being transformed should not save traced values to global state.
Details: Can't lift level Traced<ShapedArray(float32[]):JaxprTrace(level=1/0)> to JaxprTrace(level=0/0).
@neerajprad neerajprad added the bug Something isn't working label Apr 7, 2020
@fehiepsi
Copy link
Member

fehiepsi commented Apr 7, 2020

@neerajprad I think we can't trace this function

numpyro.handlers.trace(numpyro.handlers.seed(target, 0)).get_trace()

@neerajprad
Copy link
Member Author

Hmm..yes, that makes sense. We can't really collect x_i, y_i samples unfortunately, so this doesn't work. I would like to keep this issue open to explore better workarounds for such use cases where we may need to use a jax control flow primitive inside a model. I think the only workaround is to make sure that the body of the control flow doesn't contain any pyro primitives, but I wonder if there's a better solution.

@fehiepsi
Copy link
Member

fehiepsi commented Apr 7, 2020

I think that we can follow your direction to rewrite scan such that it takes rng as carry and moves the drawn samples to collection. For log_density, given the collection of samples, we use an index i in carry to select the corresponding samples at each scan step. I think a similar approach can be used for fori_loop.

@ahmadsalim
Copy link
Contributor

I am trying to port the DMM example to NumPyro and experiencing the same issues when using SVI. I really support the idea of making NumPyro primitives work within Jax loops 😄

@neerajprad
Copy link
Member Author

@ahmadsalim - we agree that this will be important to have. I think having effect handlers work generically within jax control flow primitives might be hard, but there may be a way to rewrite some version of these control flow primitives in a way that makes writing such loops and collecting samples easier, like @fehiepsi mentioned above. We will look into it.

@eb8680
Copy link
Member

eb8680 commented May 13, 2020

It looks like JAX has a new experimental library loops that might play nicely with effect handlers: https://github.com/google/jax/blob/master/jax/experimental/loops.py

@neerajprad
Copy link
Member Author

Thanks for pointing this out, @eb8680! I'll definitely take a look at this as I find time. It will be really cool if effect handlers would work with this.

@fehiepsi
Copy link
Member

Looks interesting! I have been thinking about this issue for a while. I think we can mimic its ideas: rewrite content inside loops.Scope() into lax.control_flow primitives.

I think I can come up with a solution for a loop with simple content. Let's see how it goes. :) (I'll be pretty excited if it works)

@fehiepsi fehiepsi mentioned this issue May 15, 2020
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants