Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Distribution pytree work under jax.vmap #1317

Closed
nicolasemmenegger opened this issue Feb 2, 2022 · 14 comments
Closed

Make Distribution pytree work under jax.vmap #1317

nicolasemmenegger opened this issue Feb 2, 2022 · 14 comments
Labels
help wanted Extra attention is needed

Comments

@nicolasemmenegger
Copy link

nicolasemmenegger commented Feb 2, 2022

I have a function which returns a posterior distribution of type Independent with base_dist being Normal. I want to vmap over the function in order to get the posteriors for a number of different models. However, the assertion in the constructor of Independent hinders the usage of jax.vmap. Here is a toy example that illustrates this behaviour. To simplify things, the following function get_dist takes a mean, and returns a standard Gaussian distribution with event_shape = mean.shape. I'd expect that vmapping over this yields a function which returns a distribution with the same event_shape, but with one more dimension in the batch_shape. Unfortunately the following code does not work:

def get_dist(mean):
    return numpyro.distributions.Normal(mean, jnp.ones_like(mean)).to_event(mean.ndim)

get_dists = jax.vmap(get_dist)

key = jax.random.PRNGKey(42)
mean = jax.random.normal(key, (8, 1))
means = jax.random.normal(key, (3, 8, 1))

simple_dist = get_dist(mean)
batched_dist = get_dists(means)

print(simple_dist.batch_shape, "&", simple_dist.event_shape)  # prints: () & (8,1)
print(batched_dist.batch_shape, "&", batched_dist.event_shape)  # Should print: (3,) & (8,1)

Now commenting out the lines

if reinterpreted_batch_ndims > len(base_dist.batch_shape):
    raise ValueError(
        "Expected reinterpreted_batch_ndims <= len(base_distribution.batch_shape), "
        "actual {} vs {}".format(
            reinterpreted_batch_ndims, len(base_dist.batch_shape)
        )
    )

in the constructor of the Independent class makes the code output the desired results. I assume raising an error in this scenario isn't desired behaviour, since tree_unflatten and tree_flatten are implemented and vmaping should therefore be possible.

Thank you in advance.

@fehiepsi fehiepsi added the help wanted Extra attention is needed label Feb 3, 2022
@fehiepsi
Copy link
Member

fehiepsi commented Feb 3, 2022

Hi @nicolasemmenegger, thanks for the clear description! I'm not sure how to make distributions robust w.r.t. jax transforms. Currently we convert distributions to pytree to be able to scan over them. For vmap, we need to redesign the distribution construction, which is a bit involved (see jax-ml/jax#3265). How about using tensorflow probability distributions? I think vmap should work for tfd.Independent(tfd.Normal).

@fehiepsi fehiepsi changed the title Independent constructor ValueError hinders usage with jax.vmap Make Distribution pytree work under jax.vmap Feb 11, 2022
@nicolasemmenegger
Copy link
Author

Hi @fehiepsi and thanks for the quick answer. Unfortunately, tensorflow probability does not support this either: "Because TFP distributions are registered as JAX pytree nodes, we can write functions with distributions as inputs or outputs and transform them using jit, but they are not yet supported as arguments to vmap-ed functions". It turns out there is a very similar open issue tensorflow/probability#1271 there as well.

Interestingly, for plain numpyro.distributions.Normal (with trivial event_shape), the snipplet above runs just fine. However, for instance for numpyro.distributions.MultivariateNormal, and of course for the initially described example it doesn't. So it really depends on what is happening inside each specific constructor I suppose. I'll have to work around it some other way for now.

I'd be happy to help with making numpyro distributions robust to vmap transforms, but I'd need some guidance.

@fehiepsi
Copy link
Member

fehiepsi commented Feb 14, 2022

I don't have much ideas right now. I think one way is to follow one approach in tfp thread: return your distribution.flatten() rather than the distribution itself, then calling Independent.unflatten(the returned values of vmap). I guess it will work for your case and many other distributions; but might be there are some edge cases. You can write a "vmap" helper to do such job for you, kind of when such vmap taking a function, it will modify the behavior of the function to return flattened values (together with tree definition - using jax.tree_util.tree_flatten), apply vmap over that modified function, and finally unflatten the output). It it works, it would be great to have that helper in numpyro.util. :)

@nicolasemmenegger
Copy link
Author

nicolasemmenegger commented Feb 15, 2022

Thanks for the suggestion, this is cleaner than my own workaround! So I implemented a version that could potentially be generalized more. The main issue I was running into is that the aux data from tree_flatten may contain distribution types, and jax complains about these being in output of vmapped functions (even if you don't actually vmap over the axis where they are present), see e.g. jax-ml/jax#4416. This is what I came up with. Please let me know if there is something cleaner that could be done:

DIST_ID = [Independent, Normal]  # TODO add others, see how they behave


def _auxiliary_to_jax_type(pytree):
    converted = jax.tree_map(lambda leaf: DIST_ID.index(leaf) if leaf in DIST_ID else leaf, pytree)
    blueprint = jax.tree_map(lambda leaf: leaf in DIST_ID, pytree)
    return converted, blueprint


def _restore_auxiliary(converted, blueprint):
    return jax.tree_multimap(
        lambda leaf, do_restore: DIST_ID[leaf] if do_restore else leaf,
        converted,
        blueprint,
    )


def _flatten_dist(dist: Distribution):
    params, aux = dist.tree_flatten()
    return params, *_auxiliary_to_jax_type(aux), DIST_ID.index(dist.__class__)


def vmap_distribution(f: Callable[Any, Distribution], in_axes=0, out_axes=0, axis_name=None):
    """Helper function that vmaps a function that returns a distribution as output."""
    flat_f = lambda *args, **kwargs: _flatten_dist(f(*args, **kwargs)) 
    
    vmapped = jax.vmap(
        flat_f,
        in_axes=in_axes,
        out_axes=(out_axes, None, None, None),
        axis_name=axis_name,
    )

    def unflattened(*args, **kwargs):
        vmapped_output, aux_converted, aux_blueprint, cls_id = vmapped(*args, **kwargs)
        cls = DIST_ID[cls_id]
        aux = _restore_auxiliary(aux_converted, aux_blueprint)
        return cls.tree_unflatten(aux, vmapped_output)

    return unflattened

_flatten_dist is a little helper that returns also the type of the flattened distribution (in form of a number, so we know what tree_unflatten to call) and transforms the aux data into something that jax.vmap accepts (using _auxiliary_to_jax_type). Later on, this gets reversed with _restore_auxiliary.

Some of this could be done in more generality, for instance by supporting outputs that are pytrees possibly containing Distribution and something similar is probably doable for inputs as well. For my project I really need only this use case, and will only have to extend it to some more distribution types.

The main question is whether vmapping over the flattened parameters of the distribution but keeping the auxiliary data is always the correct semantics for vmap. In my case it does what I want, since it's equivalent to adding a batch dimension, but I wouldn't know what happens in other cases.

Let me know what you think!

@fehiepsi
Copy link
Member

whether vmapping over the flattened parameters of the distribution but keeping the auxiliary data is always the correct semantics for vmap

I don't think it will work for all distributions but it is nice to cover many of them. I think we just need a simple vmap over the first dimension. Things are pretty complicated if we allow to specify out axes, which might add batch dimensions to event dimensions of the distributions.

Please let me know if there is something cleaner that could be done:

Your implementation looks great to me. The usage of DIST_ID is a bit annoying but I can't think of a better solution.

@nicolasemmenegger
Copy link
Author

Hi, sorry for the delayed response.

I think we just need a simple vmap over the first dimension. Things are pretty complicated if we allow to specify out axes, which might add batch dimensions to event dimensions of the distributions.

One use-case I have is that I vmap over functions returned by haiku.transform_with_state. These pure functions return the original output, along with some model state. In one of my applications, the former is a distribution, while the latter is a regular nested container/pytree. Thus, it is nice if we can handle nested outputs possibly containing Distribution objects. I now have a version that performs the transform above over an arbitrary pytree output, but treats all appearances of Distribution objects as leaves of said pytree, and handles them as above. out_axes can be specified as in the original vmap, since the shape of the prefix tree stays the same. It works at least for Independent(Normal) and MultivariateNormal, which is all I need for now.

These semantics make sense in my case. What do you think? The main issue I see so far is what you mention, namely that one could specify an out_axes argument that goes "inside" a Distribution leaf since the vmap is applied to a tree that has more leaves. It's not clear to me what happens then. I am not sure if and how I should raise an error in that case since I can only check whether this inside the vmapped function, which I think is supposed to not have any side effects. So for now it is undefined behavior.

@nicolasemmenegger
Copy link
Author

Also, what would the expected behaviour be when vmapping a Distribution with out_axes not equal to 0?

@fehiepsi
Copy link
Member

fehiepsi commented Mar 6, 2022

I guess out_axes=-1 for dist.Normal(0, 1).expand([10, 3]).to_event(1) means that the output will be dist.Normal(0, 1).expand([10, vmap, 3]).to_event(1), i.e. vmap dimension semantic only applies for batch dimensions. Regarding the auxiliary issue, maybe instead of

jax.vmap(lambda x: (x, "123"))(jnp.ones(3))

we can do

jax.vmap(lambda x: (x, {"123": None}))(jnp.ones(3))

i.e. turns distribution type into a dict with value None (vmap over None is still None).

a version that performs the transform above over an arbitrary pytree output, but treats all appearances of Distribution objects as leaves of said pytree, and handles them as above

Sounds reasonable to me. Regarding out_axes I think it will make sense when its "non-negative" dimension is less than or equal to the batch dimensions of the distribution. Applying vmap for event dimensions should raise errors but I guess there's no usage case for that.

Ideally, I think a nice api would be

def f(...):
    # we can update `.flatten` to make sure we can vmap over it
    # I think we just need to update a few distributions with base_dist for this
    return dist_instance.flatten()

out_flatten = jax.vmap(f, ...)(...)
out = unflatten_distribution_util(out_flatten)

What do you think?

@pierreglaser
Copy link
Contributor

pierreglaser commented Jan 28, 2023

Hi! I'm currently working on a project where a lot of workarounds are needed to overcome the lack of interoperability between numpyro.distributions.Distribution objects and jax.vmap.
Reading this thread, I have a suggestions about how to make these interoperable, in a conservative manner.
@nicolasemmenegger, you mention in your example:

print(simple_dist.batch_shape, "&", simple_dist.event_shape)  # prints: () & (8,1)
print(batched_dist.batch_shape, "&", batched_dist.event_shape)  # Should print: (3,) & (8,1)

this implies that Distribution objects should be aware of vmapping operations, and modify their metadata accordingly upon calling a vmapped function that returns a Distribution.

What I propose is a more minimal interoperability spec that ensures that jax.vmap(func) simply does not error out when func returns Distribution objects, while returning the same batch_shape and event_shape as the original distribution. It would the responsibility of the user to correctly wrap any operations on the vmapped distribution pytree, such as .sample or .log_prob using the appropriate in_axes arguments such that these calls operate under the hood under "well-defined" distributions (e.g. distribution whose batch_shape and event_shape reflect the shape of their parameters, unlike vmapped distributions). This would require disabling the checks that currently error out when unflattening, which are not compatible with vmap transformations. A possible way would be to create some cls._init_nocheck that does what cls.__init__ does, but without the checks, and invoke it during unflattening routines.

Long term that'd be nice to have automagic batch/event shape adaptation, but IMO that's way above than the minimum that numpyro has to do to ensure compatibilty with vmap.

WDYT @fehiepsi?

@fehiepsi
Copy link
Member

As a workaround for Independent, I guess we can do

    # in __init__
    if (validate_args != False) & (reinterpreted_batch_ndims > len(base_dist.batch_shape)):
        raise Error...

    @classmethod
    def tree_unflatten(cls, aux_data, params):
        base_cls, base_aux, reinterpreted_batch_ndims = aux_data
        base_dist = base_cls.tree_unflatten(base_aux, params)
        return cls(base_dist, reinterpreted_batch_ndims, validate_args=False)

but I'm not sure how to make vmap work for MultivariateNormal

def f(x):
    return dist.MultivariateNormal(x, jnp.eye(3))

d = jax.vmap(f, out_axes=1)(jnp.ones((2, 4, 3)))

Any idea?

@fehiepsi
Copy link
Member

I think we can follow @nicolasemmenegger solution at #1317 (comment) to create a utility vmap for this. Something like

from collections import namedtuple
import numpyro.distributions as dist
import jax
import jax.numpy as jnp

DistFlatten = namedtuple("DistFlatten", ["value"])

def maybe_flatten_dist(d):
    if not isinstance(d, dist.Distribution):
        return d
    d_flat, d_treedef = jax.tree_util.tree_flatten(d)
    return DistFlatten({d_treedef: d_flat})

def maybe_unflatten_dist(d):
    if not isinstance(d, DistFlatten):
        return d
    treedef, flat = list(d.value.items())[0]
    return jax.tree_util.tree_unflatten(treedef, flat)

def vmap(fn):
    def wrapped(*args, **kwargs):
        def fn_flat(*args, **kwargs):
            return jax.tree_util.tree_map(maybe_flatten_dist, fn(*args, **kwargs), is_leaf=lambda x: isinstance(x, dist.Distribution))

        out = jax.vmap(fn_flat)(*args, **kwargs)
        return jax.tree_util.tree_map(maybe_flatten_dist, fn(*args, **kwargs), is_leaf=lambda x: isinstance(x, DistFlatten))
    return wrapped

def f(x):
    return dist.MultivariateNormal(x, jnp.eye(3))

vmap(f)(jnp.ones((4, 3))).batch_shape

@pierreglaser
Copy link
Contributor

pierreglaser commented Jan 29, 2023

I'll give both a try but personally I'd root for a solution that just makes numpyro objects natively vmap-able, instead of having to rely on some vmap wrapper, which gets messy if other libs start doing the same thing.

@pierreglaser
Copy link
Contributor

This can be closed now!

@fehiepsi
Copy link
Member

Thanks, @pierreglaser!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

3 participants