-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make Distribution pytree work under jax.vmap #1317
Comments
Hi @nicolasemmenegger, thanks for the clear description! I'm not sure how to make distributions robust w.r.t. jax transforms. Currently we convert distributions to pytree to be able to scan over them. For vmap, we need to redesign the distribution construction, which is a bit involved (see jax-ml/jax#3265). How about using tensorflow probability distributions? I think vmap should work for tfd.Independent(tfd.Normal). |
Hi @fehiepsi and thanks for the quick answer. Unfortunately, tensorflow probability does not support this either: "Because TFP distributions are registered as JAX pytree nodes, we can write functions with distributions as inputs or outputs and transform them using jit, but they are not yet supported as arguments to vmap-ed functions". It turns out there is a very similar open issue tensorflow/probability#1271 there as well. Interestingly, for plain I'd be happy to help with making numpyro distributions robust to vmap transforms, but I'd need some guidance. |
I don't have much ideas right now. I think one way is to follow one approach in tfp thread: return your distribution.flatten() rather than the distribution itself, then calling Independent.unflatten(the returned values of vmap). I guess it will work for your case and many other distributions; but might be there are some edge cases. You can write a "vmap" helper to do such job for you, kind of when such vmap taking a function, it will modify the behavior of the function to return flattened values (together with tree definition - using jax.tree_util.tree_flatten), apply vmap over that modified function, and finally unflatten the output). It it works, it would be great to have that helper in numpyro.util. :) |
Thanks for the suggestion, this is cleaner than my own workaround! So I implemented a version that could potentially be generalized more. The main issue I was running into is that the
Some of this could be done in more generality, for instance by supporting outputs that are pytrees possibly containing The main question is whether vmapping over the flattened parameters of the distribution but keeping the auxiliary data is always the correct semantics for vmap. In my case it does what I want, since it's equivalent to adding a batch dimension, but I wouldn't know what happens in other cases. Let me know what you think! |
I don't think it will work for all distributions but it is nice to cover many of them. I think we just need a simple
Your implementation looks great to me. The usage of DIST_ID is a bit annoying but I can't think of a better solution. |
Hi, sorry for the delayed response.
One use-case I have is that I vmap over functions returned by These semantics make sense in my case. What do you think? The main issue I see so far is what you mention, namely that one could specify an out_axes argument that goes "inside" a Distribution leaf since the vmap is applied to a tree that has more leaves. It's not clear to me what happens then. I am not sure if and how I should raise an error in that case since I can only check whether this inside the vmapped function, which I think is supposed to not have any side effects. So for now it is undefined behavior. |
Also, what would the expected behaviour be when vmapping a Distribution with out_axes not equal to 0? |
I guess
we can do
i.e. turns distribution type into a dict with value None (vmap over None is still None).
Sounds reasonable to me. Regarding Ideally, I think a nice api would be
What do you think? |
Hi! I'm currently working on a project where a lot of workarounds are needed to overcome the lack of interoperability between print(simple_dist.batch_shape, "&", simple_dist.event_shape) # prints: () & (8,1)
print(batched_dist.batch_shape, "&", batched_dist.event_shape) # Should print: (3,) & (8,1) this implies that What I propose is a more minimal interoperability spec that ensures that Long term that'd be nice to have automagic batch/event shape adaptation, but IMO that's way above than the minimum that WDYT @fehiepsi? |
As a workaround for Independent, I guess we can do
but I'm not sure how to make vmap work for MultivariateNormal
Any idea? |
I think we can follow @nicolasemmenegger solution at #1317 (comment) to create a utility from collections import namedtuple
import numpyro.distributions as dist
import jax
import jax.numpy as jnp
DistFlatten = namedtuple("DistFlatten", ["value"])
def maybe_flatten_dist(d):
if not isinstance(d, dist.Distribution):
return d
d_flat, d_treedef = jax.tree_util.tree_flatten(d)
return DistFlatten({d_treedef: d_flat})
def maybe_unflatten_dist(d):
if not isinstance(d, DistFlatten):
return d
treedef, flat = list(d.value.items())[0]
return jax.tree_util.tree_unflatten(treedef, flat)
def vmap(fn):
def wrapped(*args, **kwargs):
def fn_flat(*args, **kwargs):
return jax.tree_util.tree_map(maybe_flatten_dist, fn(*args, **kwargs), is_leaf=lambda x: isinstance(x, dist.Distribution))
out = jax.vmap(fn_flat)(*args, **kwargs)
return jax.tree_util.tree_map(maybe_flatten_dist, fn(*args, **kwargs), is_leaf=lambda x: isinstance(x, DistFlatten))
return wrapped
def f(x):
return dist.MultivariateNormal(x, jnp.eye(3))
vmap(f)(jnp.ones((4, 3))).batch_shape |
I'll give both a try but personally I'd root for a solution that just makes |
This can be closed now! |
Thanks, @pierreglaser! |
I have a function which returns a posterior distribution of type
Independent
withbase_dist
beingNormal
. I want tovmap
over the function in order to get the posteriors for a number of different models. However, the assertion in the constructor ofIndependent
hinders the usage ofjax.vmap
. Here is a toy example that illustrates this behaviour. To simplify things, the following functionget_dist
takes a mean, and returns a standard Gaussian distribution withevent_shape = mean.shape
. I'd expect that vmapping over this yields a function which returns a distribution with the sameevent_shape
, but with one more dimension in thebatch_shape
. Unfortunately the following code does not work:Now commenting out the lines
in the constructor of the
Independent
class makes the code output the desired results. I assume raising an error in this scenario isn't desired behaviour, sincetree_unflatten
andtree_flatten
are implemented andvmap
ing should therefore be possible.Thank you in advance.
The text was updated successfully, but these errors were encountered: