Skip to content

Commit

Permalink
Make MultivariateNormal compatible with vmap
Browse files Browse the repository at this point in the history
  • Loading branch information
pierreglaser committed Jan 31, 2023
1 parent 22cd9e3 commit fc96664
Show file tree
Hide file tree
Showing 2 changed files with 259 additions and 23 deletions.
112 changes: 89 additions & 23 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -1403,29 +1403,40 @@ def __init__(
precision_matrix=None,
scale_tril=None,
validate_args=None,
batch_shape=None,
event_shape=None,
in_vmap=False,
):
if jnp.ndim(loc) == 0:
(loc,) = promote_shapes(loc, shape=(1,))
# temporary append a new axis to loc
loc = loc[..., jnp.newaxis]
if covariance_matrix is not None:
loc, self.covariance_matrix = promote_shapes(loc, covariance_matrix)
self.scale_tril = jnp.linalg.cholesky(self.covariance_matrix)
elif precision_matrix is not None:
loc, self.precision_matrix = promote_shapes(loc, precision_matrix)
self.scale_tril = cholesky_of_inverse(self.precision_matrix)
elif scale_tril is not None:
loc, self.scale_tril = promote_shapes(loc, scale_tril)
else:
raise ValueError(
"One of `covariance_matrix`, `precision_matrix`, `scale_tril`"
" must be specified."
if not in_vmap:
if jnp.ndim(loc) == 0:
(loc,) = promote_shapes(loc, shape=(1,))
# temporary append a new axis to loc
loc = loc[..., jnp.newaxis]
if covariance_matrix is not None:
loc, self.covariance_matrix = promote_shapes(loc, covariance_matrix)
self.scale_tril = jnp.linalg.cholesky(self.covariance_matrix)
elif precision_matrix is not None:
loc, self.precision_matrix = promote_shapes(loc, precision_matrix)
self.scale_tril = cholesky_of_inverse(self.precision_matrix)
elif scale_tril is not None:
loc, self.scale_tril = promote_shapes(loc, scale_tril)
else:
raise ValueError(
"One of `covariance_matrix`, `precision_matrix`, `scale_tril`"
" must be specified."
)
batch_shape = lax.broadcast_shapes(
jnp.shape(loc)[:-2], jnp.shape(self.scale_tril)[:-2]
)
batch_shape = lax.broadcast_shapes(
jnp.shape(loc)[:-2], jnp.shape(self.scale_tril)[:-2]
)
event_shape = jnp.shape(self.scale_tril)[-1:]
self.loc = loc[..., 0]
event_shape = jnp.shape(self.scale_tril)[-1:]
self.loc = loc[..., 0]
else:
assert batch_shape is not None
assert event_shape is not None
# these arguments are always resolved before `vmap`ping.
self.scale_tril = scale_tril
self.loc = loc

super(MultivariateNormal, self).__init__(
batch_shape=batch_shape,
event_shape=event_shape,
Expand Down Expand Up @@ -1473,13 +1484,68 @@ def variance(self):
jnp.sum(self.scale_tril**2, axis=-1), self.batch_shape + self.event_shape
)

def infer_post_vmap_shapes(self, vmap_axes):
"""
Transform a `vmap`-ed `MultivariateNormal` mapped according to
:param:`vmap_axes` into a batched `MultivariateNormal`.
.. note:: The vmapped axis turned into a batch axis is placed
at the leftmost position of the batch shape.
"""
# TODO: take into account the case of muliple vmap transformations
# in a row by allowing `vmap_axes` to be a tuple of tree prefixes.
assert isinstance(vmap_axes, MultivariateNormal)
# handle loc
scale_tril_shape = self.scale_tril.shape
loc_vmap_axis = vmap_axes.loc
scale_tril_vmap_axis = vmap_axes.scale_tril
pre_vmap_batch_shape = self._batch_shape
pre_vmap_event_shape = self._event_shape

if loc_vmap_axis is None:
new_loc = self.loc
else:
assert isinstance(loc_vmap_axis, int)
new_loc = jnp.moveaxis(
self.loc,
source=tuple(range(len(self.loc.shape))),
destination=(loc_vmap_axis,)
+ tuple(t for t in range(len(self.loc.shape)) if t != loc_vmap_axis),
)

if scale_tril_vmap_axis is None:
new_scale_tril = self.scale_tril
else:
assert isinstance(scale_tril_vmap_axis, int)
new_scale_tril = jnp.moveaxis(
self.scale_tril,
source=tuple(range(len(scale_tril_shape))),
destination=(scale_tril_vmap_axis,)
+ tuple(
t for t in range(len(scale_tril_shape)) if t != scale_tril_vmap_axis
),
)

return MultivariateNormal(loc=new_loc, scale_tril=new_scale_tril)

def tree_flatten(self):
return (self.loc, self.scale_tril), None
return (
self.loc,
self.scale_tril,
), (self.batch_shape, self.event_shape)

@classmethod
def tree_unflatten(cls, aux_data, params):
loc, scale_tril = params
return cls(loc, scale_tril=scale_tril)
batch_shape, event_shape = aux_data
return cls(
loc=loc,
scale_tril=scale_tril,
batch_shape=batch_shape,
event_shape=event_shape,
in_vmap=True,
)

@staticmethod
def infer_shapes(
Expand Down
170 changes: 170 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2832,3 +2832,173 @@ def sample(d: dist.Distribution):

samples_batched_dist = jax.vmap(sample, in_axes=(dist_axes,))(batched_d)
assert samples_batched_dist.shape == (3, 2)


def test_vmap_multivariate_normal_dist():
def make_multivariate_normal_dist(
mean, covariance_matrix
) -> dist.MultivariateNormal:
d = dist.MultivariateNormal(loc=mean, covariance_matrix=covariance_matrix)
return d

def sample(d: dist.Distribution):
return d.sample(random.PRNGKey(0))

loc = jnp.ones((2,))
# covariance_matrix = jnp.eye(2)
_rot_mat = jnp.array(
[[1 / jnp.sqrt(2), -1 / jnp.sqrt(2)], [1 / jnp.sqrt(2), 1 / jnp.sqrt(2)]]
)
covariance_matrix = jnp.matmul(
_rot_mat, jnp.matmul(jnp.diag(jnp.array([1.0, 2.0])), _rot_mat.T)
)

d = make_multivariate_normal_dist(loc, covariance_matrix)

locs = jnp.ones((3, 2))
# covariance_matrices = jnp.stack([jnp.eye(2), jnp.eye(2), jnp.eye(2)])
covariance_matrices = jnp.stack([covariance_matrix] * 3)

assert loc.shape == d.loc.shape
assert covariance_matrix.shape == d.covariance_matrix.shape

print("vmapping normal dist creation over both args")
batched_d = jax.vmap(make_multivariate_normal_dist, in_axes=(0, 0))(
locs, covariance_matrices
)

assert locs.shape == batched_d.loc.shape
assert covariance_matrices.shape == batched_d.covariance_matrix.shape
assert covariance_matrices.shape == batched_d.scale_tril.shape
assert covariance_matrices.shape == batched_d.precision_matrix.shape

assert batched_d.batch_shape == d.batch_shape
assert batched_d.event_shape == d.event_shape

samples_batched_dist = jax.vmap(sample, in_axes=(0,))(batched_d)
assert samples_batched_dist.shape == (3, 2)

print("vmapping normal dist creation over first arg")
batched_d = jax.vmap(make_multivariate_normal_dist, in_axes=(0, None))(
locs, covariance_matrix
)
samples_batched_dist = jax.vmap(sample, in_axes=(0,))(batched_d)

assert locs.shape == batched_d.loc.shape
assert covariance_matrices.shape == batched_d.covariance_matrix.shape
assert covariance_matrices.shape == batched_d.scale_tril.shape
assert covariance_matrices.shape == batched_d.precision_matrix.shape

assert batched_d.batch_shape == d.batch_shape
assert batched_d.event_shape == d.event_shape

assert samples_batched_dist.shape == (3, 2)

print("vmapping normal dist creation over second arg")
batched_d = jax.vmap(make_multivariate_normal_dist, in_axes=(None, 0))(
loc, covariance_matrices
)
samples_batched_dist = jax.vmap(sample, in_axes=(0,))(batched_d)

assert locs.shape == batched_d.loc.shape
assert covariance_matrices.shape == batched_d.covariance_matrix.shape
assert covariance_matrices.shape == batched_d.scale_tril.shape
assert covariance_matrices.shape == batched_d.precision_matrix.shape

assert batched_d.batch_shape == d.batch_shape
assert batched_d.event_shape == d.event_shape

assert samples_batched_dist.shape == (3, 2)

print("vmapping normal dist creation over first arg and out first arg")
dist_axes = copy.deepcopy(d)
dist_axes.loc = 0
dist_axes.scale_tril = None
dist_axes.covariance_matrix = None
dist_axes.precision_matrix = None

batched_d = jax.vmap(
make_multivariate_normal_dist, in_axes=(0, None), out_axes=dist_axes
)(locs, covariance_matrix)
samples_batched_dist = jax.vmap(sample, in_axes=(dist_axes,))(batched_d)

assert locs.shape == batched_d.loc.shape
assert covariance_matrix.shape == batched_d.covariance_matrix.shape
assert covariance_matrix.shape == batched_d.scale_tril.shape
assert covariance_matrix.shape == batched_d.precision_matrix.shape

assert batched_d.batch_shape == d.batch_shape
assert batched_d.event_shape == d.event_shape

assert samples_batched_dist.shape == (3, 2)

print(
"vmapping normal dist creation over second arg and out second arg with "
"out_axes=1"
)
dist_axes = copy.deepcopy(d)
dist_axes.loc = None
dist_axes.scale_tril = 1
dist_axes.covariance_matrix = None
dist_axes.precision_matrix = None

batched_d = jax.vmap(
make_multivariate_normal_dist, in_axes=(None, 0), out_axes=dist_axes
)(loc, covariance_matrices)

samples_batched_dist = jax.vmap(sample, in_axes=(dist_axes,))(batched_d)
assert samples_batched_dist.shape == (3, 2)

assert batched_d.batch_shape == d.batch_shape
assert batched_d.event_shape == d.event_shape

assert loc.shape == batched_d.loc.shape
assert covariance_matrices.swapaxes(0, 1).shape == batched_d.scale_tril.shape

# accessing property-based arguments should work fine when wrapped inside a `vmap`
# transformation
assert (
covariance_matrices.swapaxes(0, 1).shape
== jax.vmap(lambda bd: bd.covariance_matrix, in_axes=(dist_axes,), out_axes=1)(
batched_d
).shape
)
assert (
covariance_matrices.swapaxes(0, 1).shape
== jax.vmap(lambda bd: bd.precision_matrix, in_axes=(dist_axes,), out_axes=1)(
batched_d
).shape
)

# However, non-wrapped property-based attribute acess may fail to evaluate outside
# of the `vmapped` context
try:
assert (
covariance_matrices.swapaxes(0, 1).shape
== batched_d.covariance_matrix.shape
)
assert (
covariance_matrices.swapaxes(0, 1).shape == batched_d.precision_matrix.shape
)
except Exception:
pass

print("testing transformation of `vmap`-ed distibution into a batched disribution")
# turn the `vmap`-ed distribution into a batched distribution
batched_d_infered = batched_d.infer_post_vmap_shapes(dist_axes)

# the transformation mechanism behind `infer_post_vmap_shapes` broacasts
# loc/scale_tril to a common shape loc was not vmapped: account for
# post-inference implicit broadcasting by adding a leading dimension of size 1
assert (1,) + loc.shape == batched_d_infered.loc.shape

# after inference, both property and non-property based attribute access should
# work outside of the `vmapped` context
assert covariance_matrices.shape == batched_d_infered.covariance_matrix.shape
assert covariance_matrices.shape == batched_d_infered.scale_tril.shape
assert covariance_matrices.shape == batched_d_infered.precision_matrix.shape

assert batched_d_infered.batch_shape == (3,)
assert batched_d_infered.event_shape == (2,)

# TODO: test application of multiple `vmap` transformations.

0 comments on commit fc96664

Please sign in to comment.