Skip to content

Commit

Permalink
Fix a regression bug for ExpandedDistribution (#972)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi authored Mar 29, 2021
1 parent 7493e35 commit b6acb19
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 8 deletions.
27 changes: 19 additions & 8 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,17 +501,28 @@ def _sample(self, sample_fn, key, sample_shape=()):
interstitial_sizes = tuple(self._interstitial_sizes.values())
expanded_sizes = tuple(self._expanded_sizes.values())
batch_shape = expanded_sizes + interstitial_sizes
# shape = sample_shape + expanded_sizes + interstitial_sizes + base_dist.shape()
samples, intermediates = sample_fn(key, sample_shape=sample_shape + batch_shape)

interstitial_dims = tuple(self._interstitial_sizes.keys())
event_dim = len(self.event_shape)
batch_ndims = jnp.ndim(samples) - event_dim
interstitial_dims = tuple(batch_ndims + i for i in interstitial_dims)
interstitial_idx = len(sample_shape) + len(expanded_sizes)
interstitial_sample_dims = range(interstitial_idx, interstitial_idx + len(interstitial_dims))
permutation = list(range(batch_ndims))
for dim1, dim2 in zip(interstitial_dims, interstitial_sample_dims):
permutation[dim1], permutation[dim2] = permutation[dim2], permutation[dim1]

def reshape_sample(x):
""" Reshapes samples and intermediates to ensure that the output
shape is correct: This implicitly replaces the interstitial dims
of size 1 in the original batch_shape of base_dist with those
in the expanded dims. While it somewhat 'shuffles' over batch
dimensions, we don't care because they are considered independent."""
subshape = x.shape[len(sample_shape) + len(batch_shape):]
# subshape == base_dist.batch_shape + event_shape of x (latter unknown for intermediates)
event_shape = subshape[len(self.base_dist.batch_shape):]
"""
Reshapes samples and intermediates to ensure that the output
shape is correct: This implicitly replaces the interstitial dims
of size 1 in the original batch_shape of base_dist with those
in the expanded dims.
"""
x = jnp.transpose(x, permutation + list(range(batch_ndims, jnp.ndim(x))))
event_shape = jnp.shape(x)[batch_ndims:]
return x.reshape(sample_shape + self.batch_shape + event_shape)

intermediates = tree_util.tree_map(reshape_sample, intermediates)
Expand Down
14 changes: 14 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1429,6 +1429,20 @@ def test_expand(jax_dist, sp_dist, params, prepend_shape, sample_shape):
assert expanded_dist.expand((3,) + jax_dist.batch_shape)


@pytest.mark.parametrize('base_shape', [(2, 1, 5), (3, 1), (2, 1, 1), (1, 1, 5)])
@pytest.mark.parametrize('event_dim', [0, 1, 2, 3])
@pytest.mark.parametrize('sample_shape', [(1000,), (1000, 7, 1), (1000, 1, 7)])
def test_expand_shuffle_regression(base_shape, event_dim, sample_shape):
expand_shape = (2, 3, 5)
event_dim = min(event_dim, len(base_shape))
loc = random.normal(random.PRNGKey(0), base_shape) * 10
base_dist = dist.Normal(loc, 0.1).to_event(event_dim)
expanded_dist = base_dist.expand(expand_shape[:len(expand_shape) - event_dim])
samples = expanded_dist.sample(random.PRNGKey(1), sample_shape)
expected_mean = jnp.broadcast_to(loc, sample_shape[1:] + expanded_dist.shape())
assert_allclose(samples.mean(0), expected_mean, atol=0.1)


@pytest.mark.parametrize('batch_shape', [
(),
(4,),
Expand Down

0 comments on commit b6acb19

Please sign in to comment.