Skip to content

Commit

Permalink
address comment: apply transpose at one time
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi committed Mar 29, 2021
1 parent b9d0edb commit 0837ee4
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,9 +509,10 @@ def _sample(self, sample_fn, key, sample_shape=()):
batch_ndims = jnp.ndim(samples) - event_dim
interstitial_dims = tuple(batch_ndims + i for i in interstitial_dims)
interstitial_idx = len(sample_shape) + len(expanded_sizes)
interstitial_sample_dims = tuple(
range(interstitial_idx, interstitial_idx + len(interstitial_sizes))
)
interstitial_sample_dims = range(interstitial_idx, interstitial_idx + len(interstitial_dims))
permutation = list(range(batch_ndims))
for dim1, dim2 in zip(interstitial_dims, interstitial_sample_dims):
permutation[dim1], permutation[dim2] = permutation[dim2], permutation[dim1]

def reshape_sample(x):
"""
Expand All @@ -520,8 +521,7 @@ def reshape_sample(x):
of size 1 in the original batch_shape of base_dist with those
in the expanded dims.
"""
for dim1, dim2 in zip(interstitial_dims, interstitial_sample_dims):
x = jnp.swapaxes(x, dim1, dim2)
x = jnp.transpose(x, permutation + list(range(batch_ndims, jnp.ndim(x))))
event_shape = jnp.shape(x)[batch_ndims:]
return x.reshape(sample_shape + self.batch_shape + event_shape)

Expand Down

0 comments on commit 0837ee4

Please sign in to comment.