-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to use vmap in the context of a model? #1684
Comments
I think this has something to do with validation.
raises the above error. @pierreglaser do you want to take a look? |
The problem comes from |
Hi @pierreglaser, setting |
Thank you both for the fast response! Happy to wait for an update, but in the meantime is there by any chance some kind of workaround? E.g. might there by some kind of magic function like
Or something like that? |
@justindomke a couple clarifications regarding First, let me clarify your claim Keeping that in mind, your first example:
assumes that d, a Instead, a usage of import numpyro
import jax
from jax import numpy as jnp
from numpyro.distributions.batch_util import vmap_over
dist = numpyro.distributions
loc = jnp.array([0,1])
scale = jnp.array([0,1])
v_dist = jax.vmap(dist.Normal, in_axes=(0, 0))(loc, scale)
key = jax.random.PRNGKey(0)
samples = jax.vmap(dist.Normal.sample, in_axes=(vmap_over(v_dist, loc=0, scale=0), 0))(v_dist, jax.random.split(key)) Obviously, this use of Similarly, your second example import numpyro
import jax
from jax import numpy as jnp
dist = numpyro.distributions
loc = jnp.array([0,1])
scale = jnp.array([0,1])
def model():
numpyro.sample("x",jax.vmap(dist.Normal, in_axes=(0,0))(loc,scale))
kernel = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(kernel, num_warmup=250, num_samples=750)
key = jax.random.PRNGKey(0)
mcmc.run(key) makes a similar substitution-ability assumption between import numpyro
import jax
from jax import numpy as jnp
dist = numpyro.distributions
def run_mcmc(d, key):
def model():
numpyro.sample("x", d)
kernel = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(kernel, num_warmup=250, num_samples=750)
mcmc.run(key)
v_dist = jax.vmap(dist.Normal, in_axes=(0, 0))(jnp.zeros((2,)), jnp.zeros((2,)))
key = jax.random.PRNGKey(0)
jax.vmap(run_mcmc, in_axes=(0, 0))(v_dist, jax.random.split(key)) notice how the vmapped distribution is passed as an argument of a transformed version of a function expecting a standard distribution, it is not passed directly to import numpyro
import jax
from jax import numpy as jnp
dist = numpyro.distributions
loc = jnp.array([0,1])
scale = jnp.array([0,1])
def run_mcmc(key, loc, scale):
def model():
numpyro.sample("x",dist.Normal(loc, scale))
kernel = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(kernel, num_warmup=250, num_samples=750)
mcmc.run(key)
key = jax.random.PRNGKey(0)
jax.vmap(run_mcmc, in_axes=(0, 0, 0))(jax.random.split(key), loc, scale) which is much clumsier, which is why this feature is beneficial. Finally, it seems that the last 2 code snippts (using |
@pierreglaser thanks very much for the clarification! Regarding the second example, I see how your fix works, but I was hoping to be able to address a more general case where there are upstream random variables. A more realistic example would be something like this: import numpyro scale = jnp.array([0,1])
def model():
z = numpyro.sample("z",dist.Normal(0,1))
x = numpyro.sample("x",jax.vmap(dist.Normal, in_axes=(None,0))(z,scale)) kernel = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(kernel, num_warmup=250, num_samples=750)
key = jax.random.PRNGKey(0)
mcmc.run(key) Here the two different dimensions of I know that |
scale = jnp.array([0,1])
def model():
z = numpyro.sample("z",dist.Normal(0,1))
x = numpyro.sample("x",jax.vmap(dist.Normal, in_axes=(None,0))(z,scale))
So far it is. In general, I don' think arbitrary substitionability between |
Of course, this particular example is easy to do with plate or auto broadcasting (or even a for loop) but there are more complex cases that are quite tricky and If you're interested, the real reason I was asking is because I'm developing another PPL (https://github.com/justindomke/pangolin) that supports multiple inference backends. I already have pretty full support for numpyro as a backend with arbitrary uses of vmap in defining models, but I do it by sort of building my own potential function that vmaps over calls to things like In any case, I think I have my answer (I can't do what I want) but perhaps what I'm doing above is only possible because of these enhancements, so thanks for making them! And if this surfaced some bugs in the intended vmap functionality, I'm glad. |
I think that it would be useful if I could see an example where using
Interesting! I'd be interested to see this feature in action in some code snippet, if you can point me to one.
|
By the way, that whole |
I was excited to see that numpyro now supports vmap. But I'm a little confused about what exactly can be done with it. When I do this:
Then it works fine—I get a sample from a standard normal in two dimensions. However, when I try to do this:
Then I get
TypeError: '>' not supported between instances of 'object' and 'float'
which seem to be an error in_GreaterThan.__call__
innumpyro/distributions/constraints.py
.Am I using
vmap
incorrectly here? Is there any other way that I canvmap
distributions and then build numpyro models that I can do MCMC on, etc.?The text was updated successfully, but these errors were encountered: