Skip to content

Commit

Permalink
make sure that low and high are concrete values
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi committed Sep 5, 2024
1 parent 016ba85 commit e2c3f13
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions numpyro/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,8 +469,8 @@ def enumerate_support(self, expand=True):
raise NotImplementedError(
"Inhomogeneous `high` not supported by `enumerate_support`."
)
low = jnp.reshape(self.low, -1)[0]
high = jnp.reshape(self.high, -1)[0]
low = np.reshape(self.low, -1)[0]
high = np.reshape(self.high, -1)[0]
values = jnp.arange(low, high + 1).reshape((-1,) + (1,) * len(self.batch_shape))
if expand:
values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
Expand Down

0 comments on commit e2c3f13

Please sign in to comment.