-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
vmap
-able Distribution
s
#1529
vmap
-able Distribution
s
#1529
Conversation
Normal
w.r.t jax.vmap
operations.Normal
instances w.r.t jax.vmap
operations.
Thanks @pierreglaser! Could you make a version for Independent and MultivariateNormal? Currently, Normal is working so it might be easier to discuss design choices for non-working distributions. Re batch shape: You're right, we can simply initialize |
This is based on your implementation but without populating the def tree_flatten(self):
return (
tuple(getattr(self, param) for param in self.arg_constraints.keys()),
self.event_shape,
)
@classmethod
def tree_unflatten(cls, aux_data, params):
event_shape = aux_data
d = cls.__new__(cls)
batch_shapes = []
for n, p in dict(zip(cls.arg_constraints.keys(), params)).items():
event_dim = cls.arg_constraints[n].event_dim
shape = jnp.shape(p)
batch_shapes.append(shape[:len(shape) - event_dim])
setattr(d, n, p)
batch_shape = jax.lax.broadcast_shapes(*batch_shapes)
Distribution.__init__(d, batch_shape, event_shape)
# setattr(d, "_batch_shape", batch_shape)
# setattr(d, "_event_shape", event_shape)
return d This might serve as a default implementation. Special classes can have their own flatten, unflatten logic (as currently). |
Just to be clear, as of now, def make_normal_dist(mean, std) -> Normal:
# print(mean.ndim)
d = numpyro.distributions.Normal(mean, std)
# print(d.batch_shape)
return d
loc = jnp.ones((2,))
scale = jnp.ones((2,))
d = make_normal_dist(loc, scale)
locs = jnp.ones((3, 2))
scales = jnp.ones((3, 2))
assert loc.shape == d.loc.shape
assert scale.shape == d.scale.shape
print("vmapping normal dist creation over first arg and out first arg")
dist_axes = copy.deepcopy(d)
dist_axes.loc = 0
dist_axes.scale = None
dist_axes._batch_shape = (2,)
batched_d = jax.vmap(make_normal_dist, in_axes=(0, None), out_axes=dist_axes)(locs, scale)
samples_batched_dist = jax.vmap(sample, in_axes=(dist_axes,))(batched_d)
assert locs.shape == batched_d.loc.shape
assert scale.shape == batched_d.scale.shape fails on Moreover, this PR even changed the previous partial compatibility features of |
OK, Importantly when a I also added a |
I see, thanks! From MVN distribution, things are clearer to me (I'm still not sure why Normal is not working in complicated patterns when looking at the code). How about we still pass the correct shapes around? IIRC the errors happen when parameters are objects, which make most operators in the constructor not working. I'm not sure if the solution in #1529 (comment) works. Let me give it a try and let you know |
It seems that this works for both Normal and MVN. So I guess this will work for most non-special distributions. But I don't quite understand the semantics of def tree_flatten(self):
# for MVN, we return (self.loc, self.scale_tril), ("loc", "scale_tril")
return (self.loc, self.scale), ("loc", "scale")
@classmethod
def tree_unflatten(cls, aux_data, params):
params = dict(zip(aux_data, params))
d = cls.__new__(cls)
for name, value in params.items():
setattr(d, name, value)
if any(type(x) in [object, type(None)] for x in params.values()):
batch_shape, event_shape = None, None
else:
batch_shape, event_shape = cls.infer_shapes(
**{name: jnp.shape(value) for name, value in params.items()})
dist.Distribution.__init__(d, batch_shape, event_shape)
return d |
Just tried your solution. I like the fact that is does not interact with However, the current import copy
from typing import cast
from absl.app import pdb
import jax
from jax import Array, random
import jax.numpy as jnp
import numpyro
from numpyro.distributions import Distribution, Independent, Normal
from numpyro.distributions.continuous import MultivariateNormal
def make_multivariate_normal_dist(mean, covariance_matrix) -> MultivariateNormal:
d = numpyro.distributions.MultivariateNormal(
loc=mean, covariance_matrix=covariance_matrix
)
return d
def sample(d: Distribution) -> Array:
return cast(Array, d.sample(random.PRNGKey(0)))
loc = jnp.ones((2,))
# covariance_matrix = jnp.eye(2)
_rot_mat = jnp.array(
[[1 / jnp.sqrt(2), -1 / jnp.sqrt(2)], [1 / jnp.sqrt(2), 1 / jnp.sqrt(2)]]
)
covariance_matrix = jnp.matmul(
_rot_mat, jnp.matmul(jnp.diag(jnp.array([1.0, 2.0])), _rot_mat.T)
)
locs = jnp.ones((3, 2))
# covariance_matrices = jnp.stack([jnp.eye(2), jnp.eye(2), jnp.eye(2)])
covariance_matrices = jnp.stack([covariance_matrix] * 3)
d = make_multivariate_normal_dist(loc, covariance_matrix)
print("double vmapping normal dist creation over first arg and out no second arg")
dist_axes = copy.deepcopy(d)
dist_axes.loc = None
dist_axes.scale_tril = 1
dist_axes.covariance_matrix = None
dist_axes.precision_matrix = None
vmap_once = jax.vmap(
make_multivariate_normal_dist, in_axes=(None, 0), out_axes=dist_axes
)
dist_axes2 = copy.deepcopy(d)
dist_axes2.loc = 1
dist_axes2.scale_tril = None
dist_axes2.covariance_matrix = None
dist_axes2.precision_matrix = None
vmap_twice = jax.vmap(vmap_once, in_axes=(0, None), out_axes=dist_axes2)
locs2 = jnp.ones((6, 2))
double_batched_d = vmap_twice(locs2, covariance_matrices)
samples_double_batched_dist = jax.vmap(
jax.vmap(sample, in_axes=(dist_axes,)), in_axes=(dist_axes2,)
)(double_batched_d)
assert samples_double_batched_dist.shape == (6, 3, 2) The pair of @fehiepsi let me explain |
A quick explanation of While the well-known structure of In particular, quoting this section of the
("tree prefix of the positional arguments tuple" is the compact takeaway to retain here for MultivariateNormal(loc=0, scale_tril=None) or MultivariateNormal(loc=None, scale_tril=1) Which are both valid tree prefixes (in that case, maximal tree prefixes, since the treedef of In these examples, Similarly as in Note that in the case where the
That flexible axis specification model is useful, but makes it very complicated to perform Hope that makes sense! |
(I tested the case discussed in #1529 (comment) in cf9f966) |
Wow, I learned a long way with vmap out_axes. Thanks for the detailed explanation. What if we just simply do
which means to set shapes to None whenever things are not working. This way is equivalent to doing nothing in the constructor, but potentially give us correct batch dimensions if possible (I'm not sure, I still don't understand why we need to play with tricky batch axes in vmap - but I guess you have some good applications for such support). |
No worries, this specification model is actually less known that it should be IMO.
The very exotic choices of In my use-case, my distributions are posteriorm parametrized by neural networks parameters, and typically, I often work with posteriors given different settings of the conditioned variable, but I don't want to duplicate the neural networks parameters, which are the same for each observation - that would lead to memory errors. Thus, I would want to set
About this solution: I like it more than the prior option since now, I am still personally quite strongly in favor of not doing any business logic apart from passthrough attribute restoration in There are multiple advantages to this:
I can concede that my |
So if I understand correctly, by skipping updating batch shape, vmapped distribution is no longer a distribution in the usual sense, but it needs to be used in subsequence vmap operators? (E.g., vmapped_dist.sample will not work because its implementation assumes the original batch shape). That seems like a limitation to me. If you want to use the "fix_vmap" pattern, how about storing something like _original_batch_shape in the unflattened dist, then fix_vmap method will fix batch shape using such information. I'm not sure how to implement fix_vmap with single (or nested) out axes though. infer_shapes is guaranteed to give us correct batch shape, event shape given parameter shapes. It is not clear to me how things look under special out axes. If things turn out to be wrong, we can resort to fix_vmap method I guess. |
This seems like an important application. Do you have some sketch in psuedo code for this example? |
Sure, here it is: https://github.com/pierreglaser/unle/blob/main/sbi_ebm/sbi_ebm/likelihood_trainer.py#L59-L159 A few important remarks:
|
Exactly.
Not in my opinion: vmapping functions change the nature of their return values. If a function Sometimes, it is possible to convert a vmapped object into some other version of that object. In the case of |
Now I see why those structured out axes are needed. Yeah, agree that we shouldn't broadcast things unless necessary. I can see why this saves memory in your case. I like the example about two views of 2D arrays. I agree both views make sense, but treating them as different objects might be unfortunate. If needed, we can introduce VmapDistribution to distinguish two interpretations. But let's see what we can do if we consider them to be the same object first. I guess we now agree that we can avoid populating the constructor by avoiding using it in the unflatten method. But we disagree on whether to compute batch shape during unflatten or just use the auxiliary data from flatten dist. Let's see pros and cons of each one (please correct me if I'm wrong)
What do you think? We might introduce VmapDistribution if inferred batch shape is different from the flatten/original batch shape. For both options, we can provide fix vmap utility if needed (I don't know how to implement it - seems tricky to me). |
That for sure, I believe this is a pure implementation detail question, and I'm fine with whatever solution you find more maintainable.
This would be very confusing from a user point of view, because the type of the output of the
The "no infer cons" and the "infer batch shape pros" should more be understood a pro/con of
I'm happy with that. I don't think it will be too tricky as long as we have access to the |
I think what i'm the most concerned about with changing
Just to be clear - I'm fine with enabling people to do post-hoc conversion of vmapped distribution so that they can treat it the way they want. I'm less fine with having that done automatically, since IMO it violates idiomatic |
Another possibility (not a big fan of it but maybe it'll help us to find a common ground), is to introduce a |
I was just trying to see if we can come up with a solution given different view points. Keeping original batch shape under vmap is surprised to me because to me, batch_shape of a distribution is like shape of a numpy array (you might disagree), so under vmap, the output would have new shape. Returning a new type of distribution under vmap is also surprised to me. I have been wanted to use
Yeah, I'm trying to see when things can go wrong. Does your code break if we update batch shape during flatten?
This is great! I guess that pattern will also work for other distributions. I hope we can come up with a solution that works for most distributions. If things are tricky, we can make a singledispatch function to have separate implementations for various distributions. It would be nice to put the implementation in a separate file because the current |
Just to clarify, you mean that Another nice thing about doing automatic |
Yes, the utility vmap will incorporate fixes if needed (the fixes would be fairly complicated to implement to me, given fn with arbitrary pytree input, output axes, but you would a clearer vision how it looks like). Regarding requirements, as long as it works for your application, it is good enough. A complicated system will be hard to maintain. Do you want to make an end-2-end PR or we split out the tasks? I can work on flatten/unflatten logic but don't have enough background to work on vmap fixes. Rather than adding new stuffs to Distribution, I think we can move batch-shape-fix method to a standable utility and dispatch the implementation for each distribution. Maybe we can put everything in |
Normal
instances w.r.t jax.vmap
operations.vmap
-able Distribution
s
OK @fehiepsi, A new unified pair of (un)flattening method for all distributions
Manual Batch shape resolving for
|
Brilliant! This solution resolves a lot of issues with attributes like Regarding |
Mmh, the issue is that
For that reason, we cannot just have the Does that make sense? |
I just meant that by default, we set it to None. In case it is None, we use |
I find it a bit confusing to have |
OK @fehiepsi, I implemented the changes you suggested (using |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes are super nice!!! I'll take a final pass at the expand logic tomorrow. The rest looks awesome to me.
🥳 |
@pierreglaser i've only been watching from afar but thanks for the heroic effort! not sure if it's directly related to this PR but it looks like we now have a flake8 error in mixtures.py
edit: this should be fixed in #1622 |
Does what the title say, for context, I am following up on my #1317 (comment), and figured it might be easier to discuss this issue over an implementation, so here it is.
In this PR,
Normal
overrides thetree_flatten
/tree_unflatten
method of theDistribution
base class by following the design of flax.struct.PyTreeNode, which implementstree_flatten
/tree_unflatten
fordataclass
-style classes:tree_flatten
builds the relevanttree_flatten
tuple by assigning each dataclass field either adata
or anaux
status.tree_unflatten
just callscls.__init__
, which, for dataclasses, is autogenerated, and only sets the attribute of thevmap
-ed object as given indata
andaux
, with zero (0) additional business logic.This design is important if we want full compatibility with the
vmap
axes specification model, in whichin/out_axes
can be any tree prefix (where the leaves areint
s orNone
) of the input/outputtreedef
, see this section of thejax
docs for instance.Under that model, the attributes of
vmap
-edDistribution
s will have a structure that may counternumpyro
's expectations: for instance, in the case ofNormal
instances, theloc
andscale
attributes may have non-broadastable shapes. Because of this, we cannot rely onnumpyro
sbatch_shape
inference mechanism (which assumes broadcastable attribute shapes) to auto-generate distribution'sbatch_shape
attribute, and passingbatch_shape
as auxiliary data duringtree_flatten
becomes necessary if we want to retain information about the non-vmappedNormal
batch_shape
information. Re-establishingbatch_shape
that actually matches the structure of thevmapped
distribution, were it to be understood as, for instance, a batched version of the original distribution, could be done optionally in ainfer_batch_shape
method ofDistribution
class:The implementation is rough: I used
in_vmap
instead ofvalidate_args
(whose role is to tell__init__
to not perform any attribute checking) to not spuriously override thevalidate_args
attribute of thevmapped
distribution toFalse
aftervmap
ing, but obviously,in_vmap
should never be part of the__init__
API. at this point, this PR is just a proof-of-concept/showcase to exemplify what I had in mind. But I don't see any reason as of why a polished version of this model could not be rolled out for allDistributions
andTransform
subclasses, which would be nice (at least for me!) to have.cc @fehiepsi