From ca8fb39da530c9148d02838134507cf05ffea2bb Mon Sep 17 00:00:00 2001 From: Du Phan Date: Mon, 16 Sep 2024 19:52:37 -0400 Subject: [PATCH] Fix DiscreteUniform.enumerate_support with non-trivial batch shape (#1859) * fix DiscreteUniform enumerate_support * make sure that low and high are concrete values --- numpyro/distributions/discrete.py | 6 +++--- test/test_distributions.py | 19 ++++++++++--------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/numpyro/distributions/discrete.py b/numpyro/distributions/discrete.py index 6fdff43a2..a8576a063 100644 --- a/numpyro/distributions/discrete.py +++ b/numpyro/distributions/discrete.py @@ -469,9 +469,9 @@ def enumerate_support(self, expand=True): raise NotImplementedError( "Inhomogeneous `high` not supported by `enumerate_support`." ) - values = (self.low + jnp.arange(np.amax(self.high - self.low) + 1)).reshape( - (-1,) + (1,) * len(self.batch_shape) - ) + low = np.reshape(self.low, -1)[0] + high = np.reshape(self.high, -1)[0] + values = jnp.arange(low, high + 1).reshape((-1,) + (1,) * len(self.batch_shape)) if expand: values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape) return values diff --git a/test/test_distributions.py b/test/test_distributions.py index ac89f5982..f074eae0d 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -2742,13 +2742,14 @@ def test_generated_sample_distribution( @pytest.mark.parametrize( "jax_dist, params, support", [ - (dist.BernoulliLogits, (5.0,), jnp.arange(2)), - (dist.BernoulliProbs, (0.5,), jnp.arange(2)), - (dist.BinomialLogits, (4.5, 10), jnp.arange(11)), - (dist.BinomialProbs, (0.5, 11), jnp.arange(12)), - (dist.BetaBinomial, (2.0, 0.5, 12), jnp.arange(13)), - (dist.CategoricalLogits, (np.array([3.0, 4.0, 5.0]),), jnp.arange(3)), - (dist.CategoricalProbs, (np.array([0.1, 0.5, 0.4]),), jnp.arange(3)), + (dist.BernoulliLogits, (5.0,), np.arange(2)), + (dist.BernoulliProbs, (0.5,), np.arange(2)), + (dist.BinomialLogits, (4.5, 10), np.arange(11)), + (dist.BinomialProbs, (0.5, 11), np.arange(12)), + (dist.BetaBinomial, (2.0, 0.5, 12), np.arange(13)), + (dist.CategoricalLogits, (np.array([3.0, 4.0, 5.0]),), np.arange(3)), + (dist.CategoricalProbs, (np.array([0.1, 0.5, 0.4]),), np.arange(3)), + (dist.DiscreteUniform, (2, 4), np.arange(2, 5)), ], ) @pytest.mark.parametrize("batch_shape", [(5,), ()]) @@ -3333,8 +3334,8 @@ def test_normal_log_cdf(): "value", [ -15.0, - jnp.array([[-15.0], [-10.0], [-5.0]]), - jnp.array([[[-15.0], [-10.0], [-5.0]], [[-14.0], [-9.0], [-4.0]]]), + np.array([[-15.0], [-10.0], [-5.0]]), + np.array([[[-15.0], [-10.0], [-5.0]], [[-14.0], [-9.0], [-4.0]]]), ], ) def test_truncated_normal_log_prob_in_tail(value):