-
Notifications
You must be signed in to change notification settings - Fork 248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Predictive distribution fails on model with lax.scan #566
Comments
@neerajprad I think we can't trace this function numpyro.handlers.trace(numpyro.handlers.seed(target, 0)).get_trace() |
Hmm..yes, that makes sense. We can't really collect |
I think that we can follow your direction to rewrite |
I am trying to port the DMM example to NumPyro and experiencing the same issues when using SVI. I really support the idea of making NumPyro primitives work within Jax loops 😄 |
@ahmadsalim - we agree that this will be important to have. I think having effect handlers work generically within jax control flow primitives might be hard, but there may be a way to rewrite some version of these control flow primitives in a way that makes writing such loops and collecting samples easier, like @fehiepsi mentioned above. We will look into it. |
It looks like JAX has a new experimental library |
Thanks for pointing this out, @eb8680! I'll definitely take a look at this as I find time. It will be really cool if effect handlers would work with this. |
Looks interesting! I have been thinking about this issue for a while. I think we can mimic its ideas: rewrite content inside I think I can come up with a solution for a loop with simple content. Let's see how it goes. :) (I'll be pretty excited if it works) |
Refer to the forum discussion
The following snippet:
fails with:
The text was updated successfully, but these errors were encountered: