diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index ade4e9910..0f7b8136a 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -583,7 +583,7 @@ def _broadcast_shape(existing_shape, new_shape): ) return ( tuple(reversed(reversed_shape)), - OrderedDict(expanded_sizes), + OrderedDict(reversed(expanded_sizes)), OrderedDict(interstitial_sizes), ) @@ -601,6 +601,8 @@ def _sample(self, sample_fn, key, sample_shape=()): batch_shape = expanded_sizes + interstitial_sizes # shape = sample_shape + expanded_sizes + interstitial_sizes + base_dist.shape() samples, intermediates = sample_fn(key, sample_shape=sample_shape + batch_shape) + if not interstitial_sizes: + return samples, intermediates interstitial_dims = tuple(self._interstitial_sizes.keys()) event_dim = len(self.event_shape)