From fc96664d33625784a57a787a677f4c5e92fd3070 Mon Sep 17 00:00:00 2001 From: Pierre Glaser Date: Tue, 31 Jan 2023 00:04:22 +0000 Subject: [PATCH] Make `MultivariateNormal` compatible with `vmap` --- numpyro/distributions/continuous.py | 112 ++++++++++++++---- test/test_distributions.py | 170 ++++++++++++++++++++++++++++ 2 files changed, 259 insertions(+), 23 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 0111dcdfb..b732c0bd6 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -1403,29 +1403,40 @@ def __init__( precision_matrix=None, scale_tril=None, validate_args=None, + batch_shape=None, + event_shape=None, + in_vmap=False, ): - if jnp.ndim(loc) == 0: - (loc,) = promote_shapes(loc, shape=(1,)) - # temporary append a new axis to loc - loc = loc[..., jnp.newaxis] - if covariance_matrix is not None: - loc, self.covariance_matrix = promote_shapes(loc, covariance_matrix) - self.scale_tril = jnp.linalg.cholesky(self.covariance_matrix) - elif precision_matrix is not None: - loc, self.precision_matrix = promote_shapes(loc, precision_matrix) - self.scale_tril = cholesky_of_inverse(self.precision_matrix) - elif scale_tril is not None: - loc, self.scale_tril = promote_shapes(loc, scale_tril) - else: - raise ValueError( - "One of `covariance_matrix`, `precision_matrix`, `scale_tril`" - " must be specified." + if not in_vmap: + if jnp.ndim(loc) == 0: + (loc,) = promote_shapes(loc, shape=(1,)) + # temporary append a new axis to loc + loc = loc[..., jnp.newaxis] + if covariance_matrix is not None: + loc, self.covariance_matrix = promote_shapes(loc, covariance_matrix) + self.scale_tril = jnp.linalg.cholesky(self.covariance_matrix) + elif precision_matrix is not None: + loc, self.precision_matrix = promote_shapes(loc, precision_matrix) + self.scale_tril = cholesky_of_inverse(self.precision_matrix) + elif scale_tril is not None: + loc, self.scale_tril = promote_shapes(loc, scale_tril) + else: + raise ValueError( + "One of `covariance_matrix`, `precision_matrix`, `scale_tril`" + " must be specified." + ) + batch_shape = lax.broadcast_shapes( + jnp.shape(loc)[:-2], jnp.shape(self.scale_tril)[:-2] ) - batch_shape = lax.broadcast_shapes( - jnp.shape(loc)[:-2], jnp.shape(self.scale_tril)[:-2] - ) - event_shape = jnp.shape(self.scale_tril)[-1:] - self.loc = loc[..., 0] + event_shape = jnp.shape(self.scale_tril)[-1:] + self.loc = loc[..., 0] + else: + assert batch_shape is not None + assert event_shape is not None + # these arguments are always resolved before `vmap`ping. + self.scale_tril = scale_tril + self.loc = loc + super(MultivariateNormal, self).__init__( batch_shape=batch_shape, event_shape=event_shape, @@ -1473,13 +1484,68 @@ def variance(self): jnp.sum(self.scale_tril**2, axis=-1), self.batch_shape + self.event_shape ) + def infer_post_vmap_shapes(self, vmap_axes): + """ + Transform a `vmap`-ed `MultivariateNormal` mapped according to + :param:`vmap_axes` into a batched `MultivariateNormal`. + + + .. note:: The vmapped axis turned into a batch axis is placed + at the leftmost position of the batch shape. + """ + # TODO: take into account the case of muliple vmap transformations + # in a row by allowing `vmap_axes` to be a tuple of tree prefixes. + assert isinstance(vmap_axes, MultivariateNormal) + # handle loc + scale_tril_shape = self.scale_tril.shape + loc_vmap_axis = vmap_axes.loc + scale_tril_vmap_axis = vmap_axes.scale_tril + pre_vmap_batch_shape = self._batch_shape + pre_vmap_event_shape = self._event_shape + + if loc_vmap_axis is None: + new_loc = self.loc + else: + assert isinstance(loc_vmap_axis, int) + new_loc = jnp.moveaxis( + self.loc, + source=tuple(range(len(self.loc.shape))), + destination=(loc_vmap_axis,) + + tuple(t for t in range(len(self.loc.shape)) if t != loc_vmap_axis), + ) + + if scale_tril_vmap_axis is None: + new_scale_tril = self.scale_tril + else: + assert isinstance(scale_tril_vmap_axis, int) + new_scale_tril = jnp.moveaxis( + self.scale_tril, + source=tuple(range(len(scale_tril_shape))), + destination=(scale_tril_vmap_axis,) + + tuple( + t for t in range(len(scale_tril_shape)) if t != scale_tril_vmap_axis + ), + ) + + return MultivariateNormal(loc=new_loc, scale_tril=new_scale_tril) + def tree_flatten(self): - return (self.loc, self.scale_tril), None + return ( + self.loc, + self.scale_tril, + ), (self.batch_shape, self.event_shape) @classmethod def tree_unflatten(cls, aux_data, params): loc, scale_tril = params - return cls(loc, scale_tril=scale_tril) + batch_shape, event_shape = aux_data + return cls( + loc=loc, + scale_tril=scale_tril, + batch_shape=batch_shape, + event_shape=event_shape, + in_vmap=True, + ) @staticmethod def infer_shapes( diff --git a/test/test_distributions.py b/test/test_distributions.py index 6d4c95a79..195fdd585 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -2832,3 +2832,173 @@ def sample(d: dist.Distribution): samples_batched_dist = jax.vmap(sample, in_axes=(dist_axes,))(batched_d) assert samples_batched_dist.shape == (3, 2) + + +def test_vmap_multivariate_normal_dist(): + def make_multivariate_normal_dist( + mean, covariance_matrix + ) -> dist.MultivariateNormal: + d = dist.MultivariateNormal(loc=mean, covariance_matrix=covariance_matrix) + return d + + def sample(d: dist.Distribution): + return d.sample(random.PRNGKey(0)) + + loc = jnp.ones((2,)) + # covariance_matrix = jnp.eye(2) + _rot_mat = jnp.array( + [[1 / jnp.sqrt(2), -1 / jnp.sqrt(2)], [1 / jnp.sqrt(2), 1 / jnp.sqrt(2)]] + ) + covariance_matrix = jnp.matmul( + _rot_mat, jnp.matmul(jnp.diag(jnp.array([1.0, 2.0])), _rot_mat.T) + ) + + d = make_multivariate_normal_dist(loc, covariance_matrix) + + locs = jnp.ones((3, 2)) + # covariance_matrices = jnp.stack([jnp.eye(2), jnp.eye(2), jnp.eye(2)]) + covariance_matrices = jnp.stack([covariance_matrix] * 3) + + assert loc.shape == d.loc.shape + assert covariance_matrix.shape == d.covariance_matrix.shape + + print("vmapping normal dist creation over both args") + batched_d = jax.vmap(make_multivariate_normal_dist, in_axes=(0, 0))( + locs, covariance_matrices + ) + + assert locs.shape == batched_d.loc.shape + assert covariance_matrices.shape == batched_d.covariance_matrix.shape + assert covariance_matrices.shape == batched_d.scale_tril.shape + assert covariance_matrices.shape == batched_d.precision_matrix.shape + + assert batched_d.batch_shape == d.batch_shape + assert batched_d.event_shape == d.event_shape + + samples_batched_dist = jax.vmap(sample, in_axes=(0,))(batched_d) + assert samples_batched_dist.shape == (3, 2) + + print("vmapping normal dist creation over first arg") + batched_d = jax.vmap(make_multivariate_normal_dist, in_axes=(0, None))( + locs, covariance_matrix + ) + samples_batched_dist = jax.vmap(sample, in_axes=(0,))(batched_d) + + assert locs.shape == batched_d.loc.shape + assert covariance_matrices.shape == batched_d.covariance_matrix.shape + assert covariance_matrices.shape == batched_d.scale_tril.shape + assert covariance_matrices.shape == batched_d.precision_matrix.shape + + assert batched_d.batch_shape == d.batch_shape + assert batched_d.event_shape == d.event_shape + + assert samples_batched_dist.shape == (3, 2) + + print("vmapping normal dist creation over second arg") + batched_d = jax.vmap(make_multivariate_normal_dist, in_axes=(None, 0))( + loc, covariance_matrices + ) + samples_batched_dist = jax.vmap(sample, in_axes=(0,))(batched_d) + + assert locs.shape == batched_d.loc.shape + assert covariance_matrices.shape == batched_d.covariance_matrix.shape + assert covariance_matrices.shape == batched_d.scale_tril.shape + assert covariance_matrices.shape == batched_d.precision_matrix.shape + + assert batched_d.batch_shape == d.batch_shape + assert batched_d.event_shape == d.event_shape + + assert samples_batched_dist.shape == (3, 2) + + print("vmapping normal dist creation over first arg and out first arg") + dist_axes = copy.deepcopy(d) + dist_axes.loc = 0 + dist_axes.scale_tril = None + dist_axes.covariance_matrix = None + dist_axes.precision_matrix = None + + batched_d = jax.vmap( + make_multivariate_normal_dist, in_axes=(0, None), out_axes=dist_axes + )(locs, covariance_matrix) + samples_batched_dist = jax.vmap(sample, in_axes=(dist_axes,))(batched_d) + + assert locs.shape == batched_d.loc.shape + assert covariance_matrix.shape == batched_d.covariance_matrix.shape + assert covariance_matrix.shape == batched_d.scale_tril.shape + assert covariance_matrix.shape == batched_d.precision_matrix.shape + + assert batched_d.batch_shape == d.batch_shape + assert batched_d.event_shape == d.event_shape + + assert samples_batched_dist.shape == (3, 2) + + print( + "vmapping normal dist creation over second arg and out second arg with " + "out_axes=1" + ) + dist_axes = copy.deepcopy(d) + dist_axes.loc = None + dist_axes.scale_tril = 1 + dist_axes.covariance_matrix = None + dist_axes.precision_matrix = None + + batched_d = jax.vmap( + make_multivariate_normal_dist, in_axes=(None, 0), out_axes=dist_axes + )(loc, covariance_matrices) + + samples_batched_dist = jax.vmap(sample, in_axes=(dist_axes,))(batched_d) + assert samples_batched_dist.shape == (3, 2) + + assert batched_d.batch_shape == d.batch_shape + assert batched_d.event_shape == d.event_shape + + assert loc.shape == batched_d.loc.shape + assert covariance_matrices.swapaxes(0, 1).shape == batched_d.scale_tril.shape + + # accessing property-based arguments should work fine when wrapped inside a `vmap` + # transformation + assert ( + covariance_matrices.swapaxes(0, 1).shape + == jax.vmap(lambda bd: bd.covariance_matrix, in_axes=(dist_axes,), out_axes=1)( + batched_d + ).shape + ) + assert ( + covariance_matrices.swapaxes(0, 1).shape + == jax.vmap(lambda bd: bd.precision_matrix, in_axes=(dist_axes,), out_axes=1)( + batched_d + ).shape + ) + + # However, non-wrapped property-based attribute acess may fail to evaluate outside + # of the `vmapped` context + try: + assert ( + covariance_matrices.swapaxes(0, 1).shape + == batched_d.covariance_matrix.shape + ) + assert ( + covariance_matrices.swapaxes(0, 1).shape == batched_d.precision_matrix.shape + ) + except Exception: + pass + + print("testing transformation of `vmap`-ed distibution into a batched disribution") + # turn the `vmap`-ed distribution into a batched distribution + batched_d_infered = batched_d.infer_post_vmap_shapes(dist_axes) + + # the transformation mechanism behind `infer_post_vmap_shapes` broacasts + # loc/scale_tril to a common shape loc was not vmapped: account for + # post-inference implicit broadcasting by adding a leading dimension of size 1 + assert (1,) + loc.shape == batched_d_infered.loc.shape + + # after inference, both property and non-property based attribute access should + # work outside of the `vmapped` context + assert covariance_matrices.shape == batched_d_infered.covariance_matrix.shape + assert covariance_matrices.shape == batched_d_infered.scale_tril.shape + assert covariance_matrices.shape == batched_d_infered.precision_matrix.shape + + assert batched_d_infered.batch_shape == (3,) + assert batched_d_infered.event_shape == (2,) + + # TODO: test application of multiple `vmap` transformations.