From f204d53576722264153bf1e5c23619bff585cfc0 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Wed, 10 Apr 2024 15:00:04 -0400 Subject: [PATCH 1/3] allow to use Delta on numpy arrays without moving them to devices --- numpyro/distributions/distribution.py | 2 ++ test/test_distributions.py | 5 +++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index ade4e9910..8bda77b18 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -1138,6 +1138,8 @@ def support(self): return constraints.independent(constraints.real, self.event_dim) def sample(self, key, sample_shape=()): + if not sample_shape: + return self.v shape = sample_shape + self.batch_shape + self.event_shape return jnp.broadcast_to(self.v, shape) diff --git a/test/test_distributions.py b/test/test_distributions.py index 6ebb990f2..446fc436f 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -1110,7 +1110,8 @@ def test_dist_shape(jax_dist, sp_dist, params, prepend_shape): rng_key = random.PRNGKey(0) expected_shape = prepend_shape + jax_dist.batch_shape + jax_dist.event_shape samples = jax_dist.sample(key=rng_key, sample_shape=prepend_shape) - assert isinstance(samples, jnp.ndarray) + if not (jax_dist is dist.Delta): + assert isinstance(samples, jnp.ndarray) assert jnp.shape(samples) == expected_shape if ( sp_dist @@ -2620,7 +2621,7 @@ def test_expand(jax_dist, sp_dist, params, prepend_shape, sample_shape): rng_key = random.PRNGKey(0) samples = expanded_dist.sample(rng_key, sample_shape) assert expanded_dist.batch_shape == new_batch_shape - assert samples.shape == sample_shape + new_batch_shape + jax_dist.event_shape + assert jnp.shape(samples) == sample_shape + new_batch_shape + jax_dist.event_shape assert expanded_dist.log_prob(samples).shape == sample_shape + new_batch_shape # test expand of expand assert ( From a4f118fe8862909e4d7438cfa4be550b43d28d6d Mon Sep 17 00:00:00 2001 From: Du Phan Date: Wed, 10 Apr 2024 15:19:07 -0400 Subject: [PATCH 2/3] fix lint --- test/test_distributions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_distributions.py b/test/test_distributions.py index 446fc436f..dabf02095 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -1110,7 +1110,7 @@ def test_dist_shape(jax_dist, sp_dist, params, prepend_shape): rng_key = random.PRNGKey(0) expected_shape = prepend_shape + jax_dist.batch_shape + jax_dist.event_shape samples = jax_dist.sample(key=rng_key, sample_shape=prepend_shape) - if not (jax_dist is dist.Delta): + if jax_dist is not dist.Delta: assert isinstance(samples, jnp.ndarray) assert jnp.shape(samples) == expected_shape if ( From a58cdc1fa4b070c3d5221629438f4ebf19e1f402 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Wed, 10 Apr 2024 16:38:20 -0400 Subject: [PATCH 3/3] fix test --- test/test_distributions.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_distributions.py b/test/test_distributions.py index dabf02095..2f7170d42 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -1102,15 +1102,15 @@ def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)): @pytest.mark.parametrize( - "jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL + "jax_dist_cls, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL ) @pytest.mark.parametrize("prepend_shape", [(), (2,), (2, 3)]) -def test_dist_shape(jax_dist, sp_dist, params, prepend_shape): - jax_dist = jax_dist(*params) +def test_dist_shape(jax_dist_cls, sp_dist, params, prepend_shape): + jax_dist = jax_dist_cls(*params) rng_key = random.PRNGKey(0) expected_shape = prepend_shape + jax_dist.batch_shape + jax_dist.event_shape samples = jax_dist.sample(key=rng_key, sample_shape=prepend_shape) - if jax_dist is not dist.Delta: + if jax_dist_cls is not dist.Delta: assert isinstance(samples, jnp.ndarray) assert jnp.shape(samples) == expected_shape if (