Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use vmap in the context of a model? #1684

Closed
justindomke opened this issue Nov 21, 2023 · 10 comments · Fixed by #1686
Closed

How to use vmap in the context of a model? #1684

justindomke opened this issue Nov 21, 2023 · 10 comments · Fixed by #1686
Labels
bug Something isn't working

Comments

@justindomke
Copy link

I was excited to see that numpyro now supports vmap. But I'm a little confused about what exactly can be done with it. When I do this:

import numpyro
import jax
from jax import numpy as jnp
dist = numpyro.distributions

loc = jnp.array([0,1])
scale = jnp.array([0,1])
d = jax.vmap(dist.Normal, in_axes=(0,0))(loc,scale)
key = jax.random.PRNGKey(0)
d.sample(key)

Then it works fine—I get a sample from a standard normal in two dimensions. However, when I try to do this:

import numpyro
import jax
from jax import numpy as jnp
dist = numpyro.distributions

loc = jnp.array([0,1])
scale = jnp.array([0,1])
def model():
    numpyro.sample("x",jax.vmap(dist.Normal, in_axes=(0,0))(loc,scale))

kernel = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(kernel, num_warmup=250, num_samples=750)
key = jax.random.PRNGKey(0)
mcmc.run(key)

Then I get TypeError: '>' not supported between instances of 'object' and 'float' which seem to be an error in _GreaterThan.__call__ in numpyro/distributions/constraints.py.

Am I using vmap incorrectly here? Is there any other way that I can vmap distributions and then build numpyro models that I can do MCMC on, etc.?

@fehiepsi fehiepsi added the bug Something isn't working label Nov 22, 2023
@fehiepsi
Copy link
Member

I think this has something to do with validation.

with dist.distribution.validation_enabled():
    jax.vmap(dist.Normal)(loc, scale)

raises the above error.

@pierreglaser do you want to take a look?

@pierreglaser
Copy link
Contributor

The problem comes from tree_unflatten calling Distribution.__init__ whenDistribution._validate_args is set to True, something which is not tested right now, hence the bug. @fehiepsi, WDYT about turning of validation in tree_unflatten? This could be done by skipping the Distribution.__init__ call, or creating a validation_disabled context manager to override enclosing validation instructions?

@fehiepsi
Copy link
Member

Hi @pierreglaser, setting validate_args=False in that init call makes sense to me. We might want to override _validate_args after that call to respect the global value.

@justindomke
Copy link
Author

Thank you both for the fast response! Happy to wait for an update, but in the meantime is there by any chance some kind of workaround? E.g. might there by some kind of magic function like numpyro_config that would allow me to do this?

def model():
    with numpyro_config(validate_args=False):
        numpyro.sample("x",jax.vmap(dist.Normal, in_axes=(0,0))(loc,scale))

Or something like that?

@pierreglaser
Copy link
Contributor

pierreglaser commented Nov 22, 2023

@justindomke a couple clarifications regarding numpyro and vmap.

First, let me clarify your claim numpyro now supports vmap: the only thing that the recent feature added by #1529 is that it made Distribution compliant pytrees, with all the benefits that comes with it. In particular, this means that Distribution objects can be used as inputs and outputs of functions transformed by jax program transformations, like jit, jvp, and in particular, vmap. However, this feature does not allow you to plug in vmap-ed Distributions object in any place where a standard distribution is expected, something that is out of the scope of the vmap model.

Keeping that in mind, your first example:

import numpyro
import jax
from jax import numpy as jnp
dist = numpyro.distributions

loc = jnp.array([0,1])
scale = jnp.array([0,1])
d = jax.vmap(dist.Normal, in_axes=(0,0))(loc,scale)
key = jax.random.PRNGKey(0)
d.sample(key)

assumes that d, a vmap-ed Distribution object, implements a sample method with the same semantics than the sample method non vmap-ed Distribution. But as explained above, this call is not guaranteed to work as you expect it to work: under the vmap model, A vmap-ed object v_obj of a class Obj defining a method Obj.f(self):... does not support calls of the form v_obj.f() Instead, it supports calls of the form vmap(Obj.f)(v_obj).
Thus, although your first code snippets gives you a result, it is purely coincidental, and not guaranteed in other cases of more elaborate mapping axes, or other Distributions classes.

Instead, a usage of vmap and numpyro with guaranteed support that achieves a similar functionality as your first example would be:

import numpyro
import jax
from jax import numpy as jnp
from numpyro.distributions.batch_util import vmap_over
dist = numpyro.distributions


loc = jnp.array([0,1])
scale = jnp.array([0,1])
v_dist = jax.vmap(dist.Normal, in_axes=(0, 0))(loc, scale)

key = jax.random.PRNGKey(0)
samples = jax.vmap(dist.Normal.sample, in_axes=(vmap_over(v_dist, loc=0, scale=0), 0))(v_dist, jax.random.split(key))

Obviously, this use of numpyro and vmap is quite simple. More complicated settings leveraging that feature include batched MCMC sampling of multiple intractable Distribution instances, as described at length in this thread.

Similarly, your second example

import numpyro
import jax
from jax import numpy as jnp
dist = numpyro.distributions

loc = jnp.array([0,1])
scale = jnp.array([0,1])
def model():
    numpyro.sample("x",jax.vmap(dist.Normal, in_axes=(0,0))(loc,scale))

kernel = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(kernel, num_warmup=250, num_samples=750)
key = jax.random.PRNGKey(0)
mcmc.run(key)

makes a similar substitution-ability assumption between vmap-ed and non-vmap-ed Distribution objects, and is thus likely to error out. A snippet that achieves a similar functionality and that leverage the new Distribution-as-pytree feature recently introduced is:

import numpyro
import jax
from jax import numpy as jnp
dist = numpyro.distributions


def run_mcmc(d, key):
    def model():
        numpyro.sample("x", d)
    kernel = numpyro.infer.NUTS(model)
    mcmc = numpyro.infer.MCMC(kernel, num_warmup=250, num_samples=750)
    mcmc.run(key)

v_dist = jax.vmap(dist.Normal, in_axes=(0, 0))(jnp.zeros((2,)), jnp.zeros((2,)))
key = jax.random.PRNGKey(0)
jax.vmap(run_mcmc, in_axes=(0, 0))(v_dist, jax.random.split(key))

notice how the vmapped distribution is passed as an argument of a transformed version of a function expecting a standard distribution, it is not passed directly to run_mcmc. Without this feature, one could not have had used such a function signature, since Distributions objects were not compliant pytrees, and thus could not be passed to functions to be transformed by jax. Instead, one would have had unpack the Distribution individual arguments:

import numpyro
import jax
from jax import numpy as jnp
dist = numpyro.distributions


loc = jnp.array([0,1])
scale = jnp.array([0,1])
def run_mcmc(key, loc, scale):
    def model():
        numpyro.sample("x",dist.Normal(loc, scale))
    kernel = numpyro.infer.NUTS(model)
    mcmc = numpyro.infer.MCMC(kernel, num_warmup=250, num_samples=750)
    mcmc.run(key)

key = jax.random.PRNGKey(0)
jax.vmap(run_mcmc, in_axes=(0, 0, 0))(jax.random.split(key), loc, scale)

which is much clumsier, which is why this feature is beneficial.

Finally, it seems that the last 2 code snippts (using MCMC with or without #1529) actually error out 😄. We should fix that @fehiepsi. However, the reason seems orthogonal to #1529.

@justindomke
Copy link
Author

justindomke commented Nov 22, 2023

@pierreglaser thanks very much for the clarification! Regarding the second example, I see how your fix works, but I was hoping to be able to address a more general case where there are upstream random variables. A more realistic example would be something like this:

import numpyro
import jax
from jax import numpy as jnp
dist = numpyro.distributions

scale = jnp.array([0,1])
def model():
    z = numpyro.sample("z",dist.Normal(0,1))
    x = numpyro.sample("x",jax.vmap(dist.Normal, in_axes=(None,0))(z,scale))
kernel = numpyro.infer.NUTS(model)
mcmc = numpyro.infer.MCMC(kernel, num_warmup=250, num_samples=750)
key = jax.random.PRNGKey(0)
mcmc.run(key)

Here the two different dimensions of x are tied together through the latent variable z. So doing a vmap over the full run of MCMC wouldn't be helpful here.

I know that plate can be used for things like this, but vmap is more powerful so I was hoping to be able to use it instead. But if this use-case is totally outside of what's intended to work, that's understandable!

@pierreglaser
Copy link
Contributor

scale = jnp.array([0,1])
def model():
    z = numpyro.sample("z",dist.Normal(0,1))
    x = numpyro.sample("x",jax.vmap(dist.Normal, in_axes=(None,0))(z,scale))

platemay work, or even just using dist.Normal(z, scale) and let numpyro's shape broadcasting mechanism do its thing right?

But if this use-case is totally outside of what's intended to work, that's understandable!

So far it is. In general, I don' think arbitrary substitionability between Distributions and vmap-ed Distributions is possible, or at least without a lot of auto-magic heavy lifting that is outside the very scopes where jax advertises the uses of vmap.

@justindomke
Copy link
Author

Of course, this particular example is easy to do with plate or auto broadcasting (or even a for loop) but there are more complex cases that are quite tricky and jax.vmap would be helpful.

If you're interested, the real reason I was asking is because I'm developing another PPL (https://github.com/justindomke/pangolin) that supports multiple inference backends. I already have pretty full support for numpyro as a backend with arbitrary uses of vmap in defining models, but I do it by sort of building my own potential function that vmaps over calls to things like Normal(loc,scale).log_prob(z) and whatnot and then calling numpyro's inference routines on that potential function, rather than using numpyro's abstraction of "models". This works fine, but I do lose some of numpyro's goodies, notably the clever initialization schemes and the automatic "unconstraining" transforms.

In any case, I think I have my answer (I can't do what I want) but perhaps what I'm doing above is only possible because of these enhancements, so thanks for making them! And if this surfaced some bugs in the intended vmap functionality, I'm glad.

@pierreglaser
Copy link
Contributor

Of course, this particular example is easy to do with plate or auto broadcasting (or even a for loop) but there are more complex cases that are quite tricky and jax.vmap would be helpful
This works fine, but I do lose some of numpyro's goodies, notably the clever initialization schemes and the automatic "unconstraining" transforms.

I think that it would be useful if I could see an example where using jax.vmap instead of, at least, batching-based semantics in numpyro increases code simplicity.

I already have pretty full support for numpyro as a backend with arbitrary uses of vmap in defining models

Interesting! I'd be interested to see this feature in action in some code snippet, if you can point me to one.

This works fine, but I do lose some of numpyro's goodies, notably the clever initialization schemes and the automatic "unconstraining" transforms.

@pierreglaser
Copy link
Contributor

By the way, that whole validate_args class attribute handling is not thread-safe @fehiepsi. We might want to change the mechanism to have it use thread-local variables.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants