Skip to content

Commit

Permalink
Fix DiscreteUniform.enumerate_support with non-trivial batch shape (#…
Browse files Browse the repository at this point in the history
…1859)

* fix DiscreteUniform enumerate_support

* make sure that low and high are concrete values
  • Loading branch information
fehiepsi authored Sep 16, 2024
1 parent a7a2f31 commit ca8fb39
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
6 changes: 3 additions & 3 deletions numpyro/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,9 +469,9 @@ def enumerate_support(self, expand=True):
raise NotImplementedError(
"Inhomogeneous `high` not supported by `enumerate_support`."
)
values = (self.low + jnp.arange(np.amax(self.high - self.low) + 1)).reshape(
(-1,) + (1,) * len(self.batch_shape)
)
low = np.reshape(self.low, -1)[0]
high = np.reshape(self.high, -1)[0]
values = jnp.arange(low, high + 1).reshape((-1,) + (1,) * len(self.batch_shape))
if expand:
values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
return values
Expand Down
19 changes: 10 additions & 9 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2742,13 +2742,14 @@ def test_generated_sample_distribution(
@pytest.mark.parametrize(
"jax_dist, params, support",
[
(dist.BernoulliLogits, (5.0,), jnp.arange(2)),
(dist.BernoulliProbs, (0.5,), jnp.arange(2)),
(dist.BinomialLogits, (4.5, 10), jnp.arange(11)),
(dist.BinomialProbs, (0.5, 11), jnp.arange(12)),
(dist.BetaBinomial, (2.0, 0.5, 12), jnp.arange(13)),
(dist.CategoricalLogits, (np.array([3.0, 4.0, 5.0]),), jnp.arange(3)),
(dist.CategoricalProbs, (np.array([0.1, 0.5, 0.4]),), jnp.arange(3)),
(dist.BernoulliLogits, (5.0,), np.arange(2)),
(dist.BernoulliProbs, (0.5,), np.arange(2)),
(dist.BinomialLogits, (4.5, 10), np.arange(11)),
(dist.BinomialProbs, (0.5, 11), np.arange(12)),
(dist.BetaBinomial, (2.0, 0.5, 12), np.arange(13)),
(dist.CategoricalLogits, (np.array([3.0, 4.0, 5.0]),), np.arange(3)),
(dist.CategoricalProbs, (np.array([0.1, 0.5, 0.4]),), np.arange(3)),
(dist.DiscreteUniform, (2, 4), np.arange(2, 5)),
],
)
@pytest.mark.parametrize("batch_shape", [(5,), ()])
Expand Down Expand Up @@ -3333,8 +3334,8 @@ def test_normal_log_cdf():
"value",
[
-15.0,
jnp.array([[-15.0], [-10.0], [-5.0]]),
jnp.array([[[-15.0], [-10.0], [-5.0]], [[-14.0], [-9.0], [-4.0]]]),
np.array([[-15.0], [-10.0], [-5.0]]),
np.array([[[-15.0], [-10.0], [-5.0]], [[-14.0], [-9.0], [-4.0]]]),
],
)
def test_truncated_normal_log_prob_in_tail(value):
Expand Down

0 comments on commit ca8fb39

Please sign in to comment.