Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vmap-able Distributions #1529

Merged
merged 51 commits into from
Jul 30, 2023
Merged

Conversation

pierreglaser
Copy link
Contributor

@pierreglaser pierreglaser commented Jan 30, 2023

Does what the title say, for context, I am following up on my #1317 (comment), and figured it might be easier to discuss this issue over an implementation, so here it is.

In this PR, Normal overrides the tree_flatten/tree_unflatten method of the Distribution base class by following the design of flax.struct.PyTreeNode, which implements tree_flatten/tree_unflatten for dataclass-style classes:

  • tree_flatten builds the relevant tree_flatten tuple by assigning each dataclass field either a data or an aux status.
  • tree_unflatten just calls cls.__init__, which, for dataclasses, is autogenerated, and only sets the attribute of the vmap-ed object as given in data and aux, with zero (0) additional business logic.

This design is important if we want full compatibility with the vmap axes specification model, in which in/out_axes can be any tree prefix (where the leaves are ints or None) of the input/output treedef, see this section of the jax docs for instance.
Under that model, the attributes of vmap-ed Distributions will have a structure that may counter numpyro's expectations: for instance, in the case of Normal instances, the loc and scale attributes may have non-broadastable shapes. Because of this, we cannot rely on numpyros batch_shape inference mechanism (which assumes broadcastable attribute shapes) to auto-generate distribution's batch_shape attribute, and passing batch_shape as auxiliary data during tree_flatten becomes necessary if we want to retain information about the non-vmapped Normal batch_shape information. Re-establishing batch_shape that actually matches the structure of the vmapped distribution, were it to be understood as, for instance, a batched version of the original distribution, could be done optionally in a infer_batch_shape method of Distribution class:

def infer_batch_shape(self, tree_prefix):
   """
   Given a `vmap`-ed `Distribution` instance :param:`self`, and a :param:`tree_prefix`
   reflecting the axes over which :param:`self` was `vmap`-ed, make a reasonable guess
   as of were the `vmap`-ed `Distribution` treated as a batched `Distribution`, what would
   its `batch_shape` be.
   """
   ...

The implementation is rough: I used in_vmap instead of validate_args (whose role is to tell __init__ to not perform any attribute checking) to not spuriously override the validate_args attribute of the vmapped distribution to False after vmaping, but obviously, in_vmap should never be part of the __init__ API. at this point, this PR is just a proof-of-concept/showcase to exemplify what I had in mind. But I don't see any reason as of why a polished version of this model could not be rolled out for all Distributions and Transform subclasses, which would be nice (at least for me!) to have.

cc @fehiepsi

@pierreglaser pierreglaser changed the title Provide full compatibility of Normal w.r.t jax.vmap operations. Provide full compatibility for Normal instances w.r.t jax.vmap operations. Jan 30, 2023
@fehiepsi
Copy link
Member

Thanks @pierreglaser! Could you make a version for Independent and MultivariateNormal? Currently, Normal is working so it might be easier to discuss design choices for non-working distributions.

Re batch shape: You're right, we can simply initialize _batch_shape to None and perform the infer job in the batch_shape property. Currently, passing it in tree_unflatten would cause broadcasting issues when we use vmap(..., out_axes=1).

@fehiepsi
Copy link
Member

fehiepsi commented Jan 30, 2023

This is based on your implementation but without populating the __init__ method

    def tree_flatten(self):
        return (
            tuple(getattr(self, param) for param in self.arg_constraints.keys()),
            self.event_shape,
        )

    @classmethod
    def tree_unflatten(cls, aux_data, params):
        event_shape = aux_data
        d = cls.__new__(cls)
        batch_shapes = []
        for n, p in dict(zip(cls.arg_constraints.keys(), params)).items():
            event_dim = cls.arg_constraints[n].event_dim
            shape = jnp.shape(p)
            batch_shapes.append(shape[:len(shape) - event_dim])
            setattr(d, n, p)
        batch_shape = jax.lax.broadcast_shapes(*batch_shapes)
        Distribution.__init__(d, batch_shape, event_shape)
        # setattr(d, "_batch_shape", batch_shape)
        # setattr(d, "_event_shape", event_shape)
        return d

This might serve as a default implementation. Special classes can have their own flatten, unflatten logic (as currently).

@pierreglaser
Copy link
Contributor Author

pierreglaser commented Jan 30, 2023

Normal is working so it might be easier to discuss design choices for non-working distributions

Just to be clear, as of now, Normal in master is not compliant with the axes specification model of jax.vmap (aside from very basic in_axes/out_axes specifications (like 0or None)) due to the broadcasting operation happening in __init__. For instance, the following snippet:

def make_normal_dist(mean, std) -> Normal:
    # print(mean.ndim)
    d = numpyro.distributions.Normal(mean, std)
    # print(d.batch_shape)
    return d


loc = jnp.ones((2,))
scale = jnp.ones((2,))

d = make_normal_dist(loc, scale)

locs = jnp.ones((3, 2))
scales = jnp.ones((3, 2))


assert loc.shape == d.loc.shape
assert scale.shape == d.scale.shape


print("vmapping normal dist creation over first arg and out first arg")
dist_axes = copy.deepcopy(d)
dist_axes.loc = 0
dist_axes.scale = None
dist_axes._batch_shape = (2,)

batched_d = jax.vmap(make_normal_dist, in_axes=(0, None), out_axes=dist_axes)(locs, scale)
samples_batched_dist = jax.vmap(sample, in_axes=(dist_axes,))(batched_d)

assert locs.shape == batched_d.loc.shape
assert scale.shape == batched_d.scale.shape

fails on master.

Moreover, this PR even changed the previous partial compatibility features of Normal w.r.t vmap, since in this PR, vmap-ped Normal distributions do not perform batch_shape/event_shape inference automatically. I think it may be what is causing the failures in the tests.

@pierreglaser
Copy link
Contributor Author

pierreglaser commented Jan 31, 2023

OK, vmap compatibility for MultivariateNormal instances is implemented in fc96664.

Importantly when a MultivariateNormal distribution is vmap-ed, some property-based attributes such as covariance_matrix may not work properly (unless they're accessed within a vmap-ed wrapper function), as the current logic within this properties only work when batch dimensions (e.g, vmap dimensions in our case) are placed at the leftmost part of the scale_tril attribute - however, this is not guaranteed to hold in vmap-ed distributions since out_axes can be set to rightmost values, like 1/2/3.... Non-property based attributes like loc and scale_tril on the other hand should be set to the values/shapes expected given
in/out_axes specification given as input to vmap, which I check in the tests.

I also added a infer_post_vmap_shapes helper (surely there is a better name) that turns a vmap-ed distribution into a batched distribution, which I also test in the tests.

@fehiepsi
Copy link
Member

I see, thanks! From MVN distribution, things are clearer to me (I'm still not sure why Normal is not working in complicated patterns when looking at the code). How about we still pass the correct shapes around? IIRC the errors happen when parameters are objects, which make most operators in the constructor not working. I'm not sure if the solution in #1529 (comment) works. Let me give it a try and let you know

@fehiepsi
Copy link
Member

It seems that this works for both Normal and MVN. So I guess this will work for most non-special distributions. But I don't quite understand the semantics of dist_axes in your test so I might miss something.

    def tree_flatten(self):
        # for MVN, we return (self.loc, self.scale_tril), ("loc", "scale_tril")
        return (self.loc, self.scale), ("loc", "scale")

    @classmethod
    def tree_unflatten(cls, aux_data, params):
        params = dict(zip(aux_data, params))
        d = cls.__new__(cls)
        for name, value in params.items():
            setattr(d, name, value)
        if any(type(x) in [object, type(None)] for x in params.values()):
            batch_shape, event_shape = None, None
        else:
            batch_shape, event_shape = cls.infer_shapes(
                **{name: jnp.shape(value) for name, value in params.items()})
        dist.Distribution.__init__(d, batch_shape, event_shape)
        return d

@pierreglaser
Copy link
Contributor Author

pierreglaser commented Feb 1, 2023

It seems that this works for both Normal and MVN.

Just tried your solution. I like the fact that is does not interact with __init__.

However, the current batch/event_shape inference mechanism (relying on cls.infer_shapes) errors for complicated in_axes/out_axes specifications:
Take this example, where I vmap make_multivariate_gaussian_dist twice with in_axes = (None,0) then in_axes=(0, None) and out_axes corresponding to putting the new axis of the scale_tril in dimension 1 first, and the new axis of loc in dimension 1 then:

import copy
from typing import cast

from absl.app import pdb

import jax
from jax import Array, random
import jax.numpy as jnp

import numpyro
from numpyro.distributions import Distribution, Independent, Normal
from numpyro.distributions.continuous import MultivariateNormal


def make_multivariate_normal_dist(mean, covariance_matrix) -> MultivariateNormal:
    d = numpyro.distributions.MultivariateNormal(
        loc=mean, covariance_matrix=covariance_matrix
    )
    return d


def sample(d: Distribution) -> Array:
    return cast(Array, d.sample(random.PRNGKey(0)))


loc = jnp.ones((2,))
# covariance_matrix = jnp.eye(2)
_rot_mat = jnp.array(
    [[1 / jnp.sqrt(2), -1 / jnp.sqrt(2)], [1 / jnp.sqrt(2), 1 / jnp.sqrt(2)]]
)
covariance_matrix = jnp.matmul(
    _rot_mat, jnp.matmul(jnp.diag(jnp.array([1.0, 2.0])), _rot_mat.T)
)

locs = jnp.ones((3, 2))
# covariance_matrices = jnp.stack([jnp.eye(2), jnp.eye(2), jnp.eye(2)])
covariance_matrices = jnp.stack([covariance_matrix] * 3)


d = make_multivariate_normal_dist(loc, covariance_matrix)

print("double vmapping normal dist creation over first arg and out no second arg")
dist_axes = copy.deepcopy(d)
dist_axes.loc = None
dist_axes.scale_tril = 1
dist_axes.covariance_matrix = None
dist_axes.precision_matrix = None

vmap_once = jax.vmap(
    make_multivariate_normal_dist, in_axes=(None, 0), out_axes=dist_axes
)

dist_axes2 = copy.deepcopy(d)
dist_axes2.loc = 1
dist_axes2.scale_tril = None
dist_axes2.covariance_matrix = None
dist_axes2.precision_matrix = None

vmap_twice = jax.vmap(vmap_once, in_axes=(0, None), out_axes=dist_axes2)

locs2 = jnp.ones((6, 2))

double_batched_d = vmap_twice(locs2, covariance_matrices)

samples_double_batched_dist = jax.vmap(
    jax.vmap(sample, in_axes=(dist_axes,)), in_axes=(dist_axes2,)
)(double_batched_d)
assert samples_double_batched_dist.shape == (6, 3, 2)

The pair of tree_flatten/tree_unflatten proposed in #1529 (comment) fails. What I implemented in this PR currently works, since I do not try to infer batch/event_shape info without having access to the 2 out_axesprefixes used to vmap make_multivariate_normal_dist twice. I believe that using such out_axes prefixes greater simplifies batch_shape retrieval (or is just flat-out necessary for it).

@fehiepsi let me explain dist_axes in the next comment :-)

@pierreglaser
Copy link
Contributor Author

pierreglaser commented Feb 1, 2023

A quick explanation of dist_axes.

While the well-known structure of in_axes (resp. out_axes) consists of a tuple of Union[None, int] with arity equal to the number of arguments (resp. entries in the return tuple if one returns more than 1 element) of the vmap-ed function, in_axes and out_axes are actually allowed other values, as discussed briefly in this section of the jax website.

In particular, quoting this section of the vmap docstrings:

If the positional arguments to fun are container (pytree) types, the
corresponding element of in_axes can itself be a matching container,
so that distinct array axes can be mapped for different container
elements. in_axes must be a container tree prefix of the positional
argument tuple passed to fun.

("tree prefix of the positional arguments tuple" is the compact takeaway to retain here for in_axes - would be return value instead of positional arguments tuple for out_axes). That means that when vmapping make_multivariate_normal_dist, out_axes is allowed to be 0, None (only in the dummy case where the distribution is not affected at all by the vmapping operation,), but also things like:

MultivariateNormal(loc=0, scale_tril=None)

or

MultivariateNormal(loc=None, scale_tril=1) 

Which are both valid tree prefixes (in that case, maximal tree prefixes, since the treedef of out_axes actually matches the treedef of the return value of make_multivariate_normal_dist) of the output of make_multivariate_normal_dist. (Note that creating these two values won't actually work currently without starting from a valid MultivariateNormal distribution and mutating the arguments, as I do with dist_axes in the tests).

In these examples, loc/scale_tril can be any combination of intor None as long as the in_axes, out_axes and behavior of function to be vmap-ed are all compatible.

Similarly as in in_axes, the out_axes argument tells jax where to put the vmapped axis in the new pytree leaves of the output of func (which are loc and scale_tril in our case).

Note that in the case where the make_multivariate_normal_dist creates a MultivariateNormal distribution of batch_shape == () and event_shape == (2,), meaning that without any additional vmap transformation, the output d of make_multivariate_normal_dist is such that d.loc.shape == (2,) and d.scale_tril.shape == (2, 2). In that case:

  • since d.loc.ndim == 1, the only values that make sense for the loc parameter of the MultivariateNormal container passed to out_axes are 0 and 1,
  • while since d.scale_tril.ndim == 2, the values that makes sense for scale_tril are 0, 1 and 2.

That flexible axis specification model is useful, but makes it very complicated to perform batch_shape inference by only introspecting the vmapped scale_tril and loc shapes. (while it is easy to infer such values if one additionally has access to the out_axes given to construct such vmapped attributes, as I discussed in my previous #1529 (comment).

Hope that makes sense!

@pierreglaser
Copy link
Contributor Author

(I tested the case discussed in #1529 (comment) in cf9f966)

@fehiepsi
Copy link
Member

fehiepsi commented Feb 1, 2023

Wow, I learned a long way with vmap out_axes. Thanks for the detailed explanation.

What if we just simply do

    @classmethod
    def tree_unflatten(cls, aux_data, params):
        params = dict(zip(aux_data, params))
        d = cls.__new__(cls)
        for name, value in params.items():
            setattr(d, name, value)            
        try:
            batch_shape, event_shape = cls.infer_shapes(
                **{name: jnp.shape(value) for name, value in params.items()})
        except:
            batch_shape, event_shape = None, None
        dist.Distribution.__init__(d, batch_shape, event_shape)
        return d

which means to set shapes to None whenever things are not working. This way is equivalent to doing nothing in the constructor, but potentially give us correct batch dimensions if possible (I'm not sure, I still don't understand why we need to play with tricky batch axes in vmap - but I guess you have some good applications for such support).

@pierreglaser
Copy link
Contributor Author

pierreglaser commented Feb 1, 2023

Wow, I learned a long way with vmap out_axes. Thanks for the detailed explanation.

No worries, this specification model is actually less known that it should be IMO.

I still don't understand why we need to play with tricky batch axes in vmap - but I guess you have some good applications for such support

The very exotic choices of vmap dimensions in out_axes I used were more to showcase the situations where things might break automatic shape inference, they don't refer to an exact example of mine. In practice, I'll most likely stick to combinations of 0 and None . But I definitely resort to the tree prefix semantics a lot, and not only use a single 0 or None for out_axes.

In my use-case, my distributions are posteriorm parametrized by neural networks parameters, and typically, I often work with posteriors given different settings of the conditioned variable, but I don't want to duplicate the neural networks parameters, which are the same for each observation - that would lead to memory errors. Thus, I would want to set None for the out_axes part pointing to my network parameter, but 0 for the one pointing to my conditioned variable :-). On top of this I often do MCMC inference for such posteriors with multiple chains, which incur additional vmap operations on top of such posteriors. Again, I need to be careful in my choices of in/out_axes in order to not blow up my working memory.

What if we just simply do ... which means to set shapes to None whenever things are not working. This way is equivalent to doing nothing in the constructor, but potentially give us correct batch dimensions if possible

About this solution: I like it more than the prior option since now, vmaping will not error out - however, now the batch_shape/event_shape of a vmapped distribution becomes hard to predict, which is a bit problematic if one specializes code (as possible for static attributes such as batch_shape/event_shape) based on such attributes.

I am still personally quite strongly in favor of not doing any business logic apart from passthrough attribute restoration in tree_unflatten, and exposing a Distribution.turn_vmapped_distribution_into_batched_distribution(self, vmap_axes)

There are multiple advantages to this:

  • Distribution.turn_vmapped_distribution_into_batched_distribution(self, vmap_axes) is guaranteed to work all the time (since we rely on the vmap_axes, we have all necessary information at our disposal), as opposed to using infer_shapes on tree_unflatten.
  • By not touching the batch_shape/event_shape unless asked to, any numpyro operation on the vmapped distributions, but wrapped into subsequent vmap calls are guaranteed to work, since the batch shape will reflect the un-vmapped distribution that manipulated within the vmap-wrapped operations. This is what I've tried to showcase in my tests, with sample for instance (see here), or with attribute access (see here)
  • In a sense, the approach without cls.infer_shapes is more conservative: we don't promise any auto-magic additional features coming from numpyro when vmap-ing - we just promise compatibility between numpyro distributions and vmap transformations.

I can concede that my vmaping habits are not necessarily representative of what numpyro users want (which I don't know) - although I would hope that other users would agree with me, my opinion is definitely biased :D

@fehiepsi
Copy link
Member

fehiepsi commented Feb 1, 2023

So if I understand correctly, by skipping updating batch shape, vmapped distribution is no longer a distribution in the usual sense, but it needs to be used in subsequence vmap operators? (E.g., vmapped_dist.sample will not work because its implementation assumes the original batch shape). That seems like a limitation to me. If you want to use the "fix_vmap" pattern, how about storing something like _original_batch_shape in the unflattened dist, then fix_vmap method will fix batch shape using such information. I'm not sure how to implement fix_vmap with single (or nested) out axes though.

infer_shapes is guaranteed to give us correct batch shape, event shape given parameter shapes. It is not clear to me how things look under special out axes. If things turn out to be wrong, we can resort to fix_vmap method I guess.

@fehiepsi
Copy link
Member

fehiepsi commented Feb 1, 2023

In my use-case, my distributions are posteriorm parametrized by neural networks parameters, and typically, I often work with posteriors given different settings of the conditioned variable, but I don't want to duplicate the neural networks parameters, which are the same for each observation - that would lead to memory errors. Thus, I would want to set None for the out_axes part pointing to my network parameter, but 0 for the one pointing to my conditioned variable :-). On top of this I often do MCMC inference for such posteriors with multiple chains, which incur additional vmap operations on top of such posteriors. Again, I need to be careful in my choices of in/out_axes in order to not blow up my working memory.

This seems like an important application. Do you have some sketch in psuedo code for this example?

@pierreglaser
Copy link
Contributor Author

pierreglaser commented Feb 1, 2023

Sure, here it is: https://github.com/pierreglaser/unle/blob/main/sbi_ebm/sbi_ebm/likelihood_trainer.py#L59-L159

A few important remarks:

  • due to the fact that to this day, Distributions are not vmappable, I've been using a custom, vmappableLogDensity ersatz class that basically wraps around a log_prob. This also sort of makes sense in my case, because most distributions I work with are intractable, so they don't really have sample semantics (I was actually wondering whether you'd be OK with introducing an IntractableDistribution base class that does not expose .sample in numpyro, and from which all other distributions class could derive, while remaining compliant with subtype-polymorphism, but that's for another time). I'm partly looking forward to this PR so that I can just use Distribution subclasses instead of such jerry-rigged LogDensity.
  • Notice the heavy vmap gymnastics :-)
  • I've written my own MCMC machinery that is 100% pytree-based (à la equinox/flax.struct), you can vmap/jit over absolutely anything you want in a completely transparent manner, a model (as you've probably notice) that really appeals to me. This MCMC machinery can wrap around numpyro kernels like NUTS, as well as blackjax ones, and can also use both of the libs warmup methods. One thing I'd like to do at some point is to have both numpyro and blackjax settle on a unified API for MCMC kernels and warmup methods. Both libs are cool, since numpyro is still much more user friendly with its MCMC class, blackjax is under heavy developments and new features are being added frequently.
  • Although I've used "posterior", the proper term would probably have been "conditionned densities": in that case, I perform vmapped MCMC on (conditional) likelihoods $p(x|\theta)$ with different settings for the conditionned variable $\theta$.

@pierreglaser
Copy link
Contributor Author

pierreglaser commented Feb 1, 2023

So if I understand correctly, by skipping updating batch shape, vmapped distribution is no longer a distribution in the usual sense, but it needs to be used in subsequence vmap operators?

Exactly.

That seems like a limitation to me

Not in my opinion: vmapping functions change the nature of their return values. If a function f returns a $d$-dimensional vector v, vmap(f, out_axes=0) does not return a $d$-dimensional vector anymore, it returns a $n \times d$ array (or matrix), let's call it vmapped_v. The semantics of the objects are different, and we cannot treat this matrix as we would treat a vector. However, we can treat this vmapped_v as a batch of vectors. Thus, any vector-based operations (like jnp.dot, jnp.linalg.norm, len etc) can operate on vmapped_v as long as is is wrapped by a vmap transformation (like vmap(lambda v: jnp.linalg.norm(v), in_axes=0)(vmapped_v)).

Sometimes, it is possible to convert a vmapped object into some other version of that object. In the case of numpyro, we can convert a vmapped Distribution into a single Distribution with an augmented batch_shape, thanks to cls.infer_post_vmap_shapes. That's nice. But IMO, doing this automatically might be surprising for people that have this mental model of "vmapped object != object", like me, and thus, I would not advocate for trying to do this automatically. But very happy to hear your points of course.

@fehiepsi
Copy link
Member

fehiepsi commented Feb 1, 2023

Now I see why those structured out axes are needed. Yeah, agree that we shouldn't broadcast things unless necessary. I can see why this saves memory in your case.

I like the example about two views of 2D arrays. I agree both views make sense, but treating them as different objects might be unfortunate. If needed, we can introduce VmapDistribution to distinguish two interpretations. But let's see what we can do if we consider them to be the same object first.

I guess we now agree that we can avoid populating the constructor by avoiding using it in the unflatten method. But we disagree on whether to compute batch shape during unflatten or just use the auxiliary data from flatten dist. Let's see pros and cons of each one (please correct me if I'm wrong)

  • infer batch shape pros: the output of vmap over a distribution is still a distribution
  • infer batch shape cons: it might be surprised - is it good to be surprised?, will this cause issues for subsequent computations?
  • no infer pros: the output of vmap is a special distribution that works for subsequence vmap computations
  • no infer cons: the vmapped distribution might not work for non-vmap computations.

What do you think? We might introduce VmapDistribution if inferred batch shape is different from the flatten/original batch shape. For both options, we can provide fix vmap utility if needed (I don't know how to implement it - seems tricky to me).

@pierreglaser
Copy link
Contributor Author

pierreglaser commented Feb 1, 2023

I guess we now agree that we can avoid populating the constructor by avoiding using it in the unflatten method

That for sure, I believe this is a pure implementation detail question, and I'm fine with whatever solution you find more maintainable.

We might introduce VmapDistribution if inferred batch shape is different from the flatten/original batch shape

This would be very confusing from a user point of view, because the type of the output of the vmap-ed function output would now depend on whether numpyro's batch_shape inference logic works or not, which is going to be very difficult to predict from a user perspective. I'd rather return the same return type all the time.

I guess we now agree that we can avoid populating the constructor by avoiding using it in the unflatten method. But we disagree on whether to compute batch shape during unflatten or just use the auxiliary data from flatten dist. Let's see pros and cons of each one (please correct me if I'm wrong)
- infer batch shape pros: the output of vmap over a distribution is still a distribution
- infer batch shape cons: it might be surprised - is it good to be surprised?, will this cause issues for subsequent computations?
- no infer pros: the output of vmap is a special distribution that works for subsequence vmap computations
- no infer cons: the vmapped distribution might not work for non-vmap computations.

The "no infer cons" and the "infer batch shape pros" should more be understood a pro/con of vmap. The purpose of my PR is simply to bring vmap compatibility to numpyro, not to make vmap more user-friendly (And FWIW, again, I would not qualify them as pros and cons, since different objects and ought to must be treated differently).
However, I do agree with you that we can be nice and add vmap->batched dist conversion utilities (which I've implemented). But I would not force that conversion upon all users.

For both options, we can provide fix vmap utility if needed (I don't know how to implement it - seems tricky to me).

I'm happy with that. I don't think it will be too tricky as long as we have access to the in/out_axes information used to create the VmappedDistribution. I implemented such a function here, and it could be made even much more compact.

@pierreglaser
Copy link
Contributor Author

pierreglaser commented Feb 1, 2023

infer batch shape cons: it might be surprised - is it good to be surprised?, will this cause issues for subsequent computations?

I think what i'm the most concerned about with changing batch_shape automagically, is to break vmap-wrapped numpyro operations, which I believe we need to guarantee if we want to advertise vmap compatibility. I'm not familiar enough with numpyro's internal to know whether mutating the batch_shape of a distributions has consequences for subsequent numpyro operations like sample/log_prob/log_det_jacobian...

but treating them as different objects might be unfortunate.

Just to be clear - I'm fine with enabling people to do post-hoc conversion of vmapped distribution so that they can treat it the way they want. I'm less fine with having that done automatically, since IMO it violates idiomatic vmap semantics.

@pierreglaser
Copy link
Contributor Author

Another possibility (not a big fan of it but maybe it'll help us to find a common ground), is to introduce a numpyro.vmap utility that does automatic shape inference, while jax.vmap would not do automatic shape inference. That way we remain non-surprising when using standard vmap utilities, but provide user that would want such automatic shape inference a user-friendly way to do so.

@fehiepsi
Copy link
Member

fehiepsi commented Feb 1, 2023

I was just trying to see if we can come up with a solution given different view points. Keeping original batch shape under vmap is surprised to me because to me, batch_shape of a distribution is like shape of a numpy array (you might disagree), so under vmap, the output would have new shape. Returning a new type of distribution under vmap is also surprised to me. I have been wanted to use vmap(trace(model))(...) natively to get a batch of traces but currently, I have to resort to returning flatten traces. Having incorrect batch shape will lead to incorrect sample call because we always return samples with shape sample_shape + batch_shape + event_shape. I think numpyro.util.vmap can be used for this purpose, where we return original-batch-shape vmapped dist and then fix the shape inside it, rather than the other way around (using infer_post_vmap_shapes is surprised to me, so I hope to avoid calling it if possible). I hope that clarifies my view point.

most concerned about with changing batch_shape automagically, is to break vmap-wrapped numpyro operations

Yeah, I'm trying to see when things can go wrong. Does your code break if we update batch shape during flatten?

implemented such a function here, and it could be made even much more compact.

This is great! I guess that pattern will also work for other distributions. I hope we can come up with a solution that works for most distributions. If things are tricky, we can make a singledispatch function to have separate implementations for various distributions. It would be nice to put the implementation in a separate file because the current flatten, unflatten logic is already complicated for users to define a new distribution.

@pierreglaser
Copy link
Contributor Author

pierreglaser commented Feb 1, 2023

I think numpyro.util.vmap can be used for this purpose

Just to clarify, you mean that numpyro.util.vmap would do automatic batch_shape inference, while vmap wouldn't (which is what i'm in favor of)?

Another nice thing about doing automatic batch_shape inference within a separate numpyro.util.vmap function is that in that context we can intercept the in/out_axes and use the logic of infer_post_vmap_shapes to return a valid batched distribution with a 100% success rate with no intervention from the user. But since this inference routine can modify the shape of the distribution's attribute (since it will have to turn it into a batched distribution), that would mean that anything different from 0 in out_axes would be defeated by the inference part if that makes sense (can extend if needed)

@fehiepsi
Copy link
Member

fehiepsi commented Feb 1, 2023

Yes, the utility vmap will incorporate fixes if needed (the fixes would be fairly complicated to implement to me, given fn with arbitrary pytree input, output axes, but you would a clearer vision how it looks like). Regarding requirements, as long as it works for your application, it is good enough. A complicated system will be hard to maintain.

Do you want to make an end-2-end PR or we split out the tasks? I can work on flatten/unflatten logic but don't have enough background to work on vmap fixes.

Rather than adding new stuffs to Distribution, I think we can move batch-shape-fix method to a standable utility and dispatch the implementation for each distribution. Maybe we can put everything in numpyro.contrib.vmap_fix first, then incorporate into the main library later, when things are more robust. What do you think?

@pierreglaser pierreglaser changed the title Provide full compatibility for Normal instances w.r.t jax.vmap operations. vmap-able Distributions Jul 9, 2023
@pierreglaser
Copy link
Contributor Author

pierreglaser commented Jul 10, 2023

OK @fehiepsi, vmap-support for all distributions is now implemented and tested. Here are some details regarding the implementation I've chosen:

A new unified pair of (un)flattening method for all distributions

  • I implemented a very general pair of tree_flatten/tree_unflatten methods for the Distribution base class. These methods do not need to be overridden for the overwhelming majority of distributions. This model has allowed me to remove a lot of other tree_flatten/tree_unflatten methods in subclasses that are now rendered obsolete by their parent's analogue.
  • The only thing these methods do is look recursively in the mro of a distribution class for all attributes of a distribution that are either pytree nodes, or pytree auxiliary data - tree_flatten will extract the corresponding attributes and put them in the appropriate position in the return tuple, while tree_unflatten setattrs all auxiliary and pytree node attributes on the newly created distribution subclass instance.
  • The way that tree_flatten figures out which attributes are pytree node and which are auxiliary data is through two static class attributes: pytree_data_fields and pytree_aux_fields, which contain a list of attribute names to be treated as either pytree nodes, or pytree auxiliary data. As I said, for most distribution, these two attributes are the only things necessary to make a Distribution class flatten/unflattenable. This declarative model is very similar to the one of flax.struct.

Manual Batch shape resolving for scan/enum-ed Distributions

  • For a motivation and explanation for what this even means, see vmap-able Distributions #1529 (comment) and vmap-able Distributions #1529 (comment).
  • I ended up taking a similar general implementation at the base class level + static declarative customisation at the subclass level: a promote_batch_shape is implemented in the Distribution, and relies on what is the shape of a distribution's attribute for a batch_shape of () (this knowledge is often relied-on implicitly in distributions subclasses's __init__ method as discussed here vmap-able Distributions #1529 (comment).
  • Again, what tells Distribution.promote_batch_shape about the attribute's shape is a subclass specific class-attribute, called attr_atomic_ndim, and takes the form of a dict whose keys are the attributes, and values are the rank of the attribute for a corresponding batch shape of ().
  • Currently, I only defined this attributes for cases that are tested in the test suite. Not sure if the other, non tested cases were expected to work out of the box before that PR? If so, I should define this attribute for all distribution, and try to come up with a systematic testing solution to test this.

The vmap_over utility

The vmap_over function allows to easily create a vmap-style axes tree prefix to vmap a function over only specific attributes of a Distribution. This vmap_over is currently defined on a distribution-by-distribution basis (we could fully genericize its implementation if the primitive parameter proto-specification model (arg_constraints, reparametrized_args, pytree_data_fields, ...) was changed to handle parameter dependencies, but this is not necessary for now). This utility is used to gracefully test that all distributions are indeed, vmappable, across arbitrary in/out axes specifications.

Finally, I made a couple minor change to some tested distributions (turned function-based constructors into classes).

I believe that in addition to increasing the compliance of numpyro with jax's function transformations, this PR actually simplifies the parts of the code it modifies. The large diff is mostly due to the per-class implementations of vmap_over, and changes in test files.

@fehiepsi, would be keen to get you input on what should I do w.r.t promote_batch_shape. Once this is done, I think we can get the final review process started!

@fehiepsi
Copy link
Member

The way that tree_flatten figures out which attributes are pytree node and which are auxiliary data is through two static class attributes: pytree_data_fields and pytree_aux_fields,

Brilliant! This solution resolves a lot of issues with attributes like base_dist, support, etc. It might be even simpler if by default, pytree_data_fields are arg_constraints keys.

Regarding promote_batch_shape, I think you can rename vmap_util to batch_util and perform dispatch like the vmap_over function. In addition, you can infer attr_atomic_ndim of a parameter foo by looking at arg_constraints[foo].event_dim.

@pierreglaser
Copy link
Contributor Author

are arg_constraints keys

Mmh, the issue is that arg_contraints currently only contain fields that are part of the public interface of the distribution's subclass __init__, but this set is different from the set of attributes of the distribution (they might actually be mutually disjoint):

For that reason, we cannot just have the pytree_data_fields piggy-back on arg_constraints: we need a more modular solution, like an AttributeConfig object, which stores, for each argument/attribute, whether it is part of the public interface, whether it is an attribute, their event dimension, etc.

Does that make sense?

@fehiepsi
Copy link
Member

I just meant that by default, we set it to None. In case it is None, we use arg_constraints.keys() in the methods that use it. This way, we don't need to set pytree_data_fields for many distributions.

@pierreglaser
Copy link
Contributor Author

pierreglaser commented Jul 12, 2023

I find it a bit confusing to have pytree_data_fields default to the keys of arg_constraints, since these two things are just distincts features of attributes and IMO should not be hierarchically arranged. But I also see your point that there might be cases in which these two iterables reference the same attributes, and that arg_constraints is already defined for most distributions. I'm fine with doing this the way you suggest for this PR, and we can make a more substantial clean up of that part of the code later on :-)

@pierreglaser
Copy link
Contributor Author

OK @fehiepsi, I implemented the changes you suggested (using event_dim for batch shape promotion was a great idea!). Wanna review? :D

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes are super nice!!! I'll take a final pass at the expand logic tomorrow. The rest looks awesome to me.

@fehiepsi fehiepsi merged commit 902623c into pyro-ppl:master Jul 30, 2023
@pierreglaser
Copy link
Contributor Author

🥳

@martinjankowiak
Copy link
Collaborator

martinjankowiak commented Jul 30, 2023

@pierreglaser i've only been watching from afar but thanks for the heroic effort!

not sure if it's directly related to this PR but it looks like we now have a flake8 error in mixtures.py

flake8
./numpyro/distributions/mixtures.py:323:16: E721 do not compare types, for exact checks use is / is not, for instance checks use isinstance()

edit: this should be fixed in #1622

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants