From e2c3f136072adaafb6ffd9812fa16e0834e2703e Mon Sep 17 00:00:00 2001 From: Du Phan Date: Thu, 5 Sep 2024 10:44:55 -0400 Subject: [PATCH] make sure that low and high are concrete values --- numpyro/distributions/discrete.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/numpyro/distributions/discrete.py b/numpyro/distributions/discrete.py index d2a3b5a0c..a8576a063 100644 --- a/numpyro/distributions/discrete.py +++ b/numpyro/distributions/discrete.py @@ -469,8 +469,8 @@ def enumerate_support(self, expand=True): raise NotImplementedError( "Inhomogeneous `high` not supported by `enumerate_support`." ) - low = jnp.reshape(self.low, -1)[0] - high = jnp.reshape(self.high, -1)[0] + low = np.reshape(self.low, -1)[0] + high = np.reshape(self.high, -1)[0] values = jnp.arange(low, high + 1).reshape((-1,) + (1,) * len(self.batch_shape)) if expand: values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)