diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index ade4e9910..8bda77b18 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -1138,6 +1138,8 @@ def support(self): return constraints.independent(constraints.real, self.event_dim) def sample(self, key, sample_shape=()): + if not sample_shape: + return self.v shape = sample_shape + self.batch_shape + self.event_shape return jnp.broadcast_to(self.v, shape) diff --git a/test/test_distributions.py b/test/test_distributions.py index 6ebb990f2..2f7170d42 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -1102,15 +1102,16 @@ def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)): @pytest.mark.parametrize( - "jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL + "jax_dist_cls, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL ) @pytest.mark.parametrize("prepend_shape", [(), (2,), (2, 3)]) -def test_dist_shape(jax_dist, sp_dist, params, prepend_shape): - jax_dist = jax_dist(*params) +def test_dist_shape(jax_dist_cls, sp_dist, params, prepend_shape): + jax_dist = jax_dist_cls(*params) rng_key = random.PRNGKey(0) expected_shape = prepend_shape + jax_dist.batch_shape + jax_dist.event_shape samples = jax_dist.sample(key=rng_key, sample_shape=prepend_shape) - assert isinstance(samples, jnp.ndarray) + if jax_dist_cls is not dist.Delta: + assert isinstance(samples, jnp.ndarray) assert jnp.shape(samples) == expected_shape if ( sp_dist @@ -2620,7 +2621,7 @@ def test_expand(jax_dist, sp_dist, params, prepend_shape, sample_shape): rng_key = random.PRNGKey(0) samples = expanded_dist.sample(rng_key, sample_shape) assert expanded_dist.batch_shape == new_batch_shape - assert samples.shape == sample_shape + new_batch_shape + jax_dist.event_shape + assert jnp.shape(samples) == sample_shape + new_batch_shape + jax_dist.event_shape assert expanded_dist.log_prob(samples).shape == sample_shape + new_batch_shape # test expand of expand assert (