Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow to use Delta on numpy arrays without moving them to jax devices #1777

Merged
merged 3 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,6 +1138,8 @@ def support(self):
return constraints.independent(constraints.real, self.event_dim)

def sample(self, key, sample_shape=()):
if not sample_shape:
return self.v
shape = sample_shape + self.batch_shape + self.event_shape
return jnp.broadcast_to(self.v, shape)

Expand Down
11 changes: 6 additions & 5 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,15 +1102,16 @@ def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)):


@pytest.mark.parametrize(
"jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL
"jax_dist_cls, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL
)
@pytest.mark.parametrize("prepend_shape", [(), (2,), (2, 3)])
def test_dist_shape(jax_dist, sp_dist, params, prepend_shape):
jax_dist = jax_dist(*params)
def test_dist_shape(jax_dist_cls, sp_dist, params, prepend_shape):
jax_dist = jax_dist_cls(*params)
rng_key = random.PRNGKey(0)
expected_shape = prepend_shape + jax_dist.batch_shape + jax_dist.event_shape
samples = jax_dist.sample(key=rng_key, sample_shape=prepend_shape)
assert isinstance(samples, jnp.ndarray)
if jax_dist_cls is not dist.Delta:
assert isinstance(samples, jnp.ndarray)
assert jnp.shape(samples) == expected_shape
if (
sp_dist
Expand Down Expand Up @@ -2620,7 +2621,7 @@ def test_expand(jax_dist, sp_dist, params, prepend_shape, sample_shape):
rng_key = random.PRNGKey(0)
samples = expanded_dist.sample(rng_key, sample_shape)
assert expanded_dist.batch_shape == new_batch_shape
assert samples.shape == sample_shape + new_batch_shape + jax_dist.event_shape
assert jnp.shape(samples) == sample_shape + new_batch_shape + jax_dist.event_shape
assert expanded_dist.log_prob(samples).shape == sample_shape + new_batch_shape
# test expand of expand
assert (
Expand Down
Loading