diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 8280482c0..237cb633a 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -703,6 +703,14 @@ def variance(self): def cdf(self, x): return 1 - self.base_dist.cdf(1 / x) + def entropy(self): + return ( + self.concentration + + jnp.log(self.rate) + + gammaln(self.concentration) + - (1 + self.concentration) * digamma(self.concentration) + ) + class Gompertz(Distribution): r"""Gompertz Distribution. diff --git a/numpyro/distributions/directional.py b/numpyro/distributions/directional.py index 1add0fcba..fd5b1596c 100644 --- a/numpyro/distributions/directional.py +++ b/numpyro/distributions/directional.py @@ -235,8 +235,8 @@ def __repr__(self): "{}: {}".format( p, getattr(self, p) - if getattr(self, p).numel() == 1 - else getattr(self, p).size(), + if getattr(self, p).size == 1 + else getattr(self, p).size, ) for p in self.arg_constraints.keys() ] diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index 6eec9f4d3..cba2994b8 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -686,6 +686,9 @@ def variance(self): self.base_dist.variance, self.batch_shape + self.event_shape ) + def entropy(self): + return jnp.broadcast_to(self.base_dist.entropy(), self.batch_shape) + class ImproperUniform(Distribution): """ @@ -851,6 +854,10 @@ def expand(self, batch_shape): self.reinterpreted_batch_ndims ) + def entropy(self): + axes = range(-self.reinterpreted_batch_ndims, 0) + return self.base_dist.entropy().sum(axes) + class MaskedDistribution(Distribution): """ @@ -1168,6 +1175,9 @@ def mean(self): def variance(self): return jnp.zeros(self.batch_shape + self.event_shape) + def entropy(self): + return -jnp.broadcast_to(self.log_density, self.batch_shape) + class Unit(Distribution): """ diff --git a/test/test_distributions.py b/test/test_distributions.py index 3487ab745..a4dc3528b 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -1041,6 +1041,15 @@ def get_sp_dist(jax_dist): ), ] +BASE = [ + T(lambda *args: dist.Normal(*args).to_event(2), np.arange(24).reshape(3, 4, 2)), + T(lambda *args: dist.Normal(*args).expand((3, 4, 7)), np.arange(7)), + T( + lambda *args: dist.Normal(*args).to_event(2).expand((7, 3)), + np.arange(24).reshape(3, 4, 2), + ), +] + def _is_batched_multivariate(jax_dist): return len(jax_dist.event_shape) > 0 and len(jax_dist.batch_shape) > 0 @@ -1494,7 +1503,7 @@ def test_entropy_scipy(jax_dist, sp_dist, params): try: actual = jax_dist.entropy() except NotImplementedError: - pytest.skip(reason="distribution does not implement `entropy`") + pytest.skip(reason=f"distribution {jax_dist} does not implement `entropy`") if _is_batched_multivariate(jax_dist): pytest.skip("batching not allowed in multivariate distns.") if sp_dist is None: @@ -1506,7 +1515,7 @@ def test_entropy_scipy(jax_dist, sp_dist, params): @pytest.mark.parametrize( - "jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL + "jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL + BASE ) def test_entropy_samples(jax_dist, sp_dist, params): jax_dist = jax_dist(*params) @@ -1514,7 +1523,7 @@ def test_entropy_samples(jax_dist, sp_dist, params): try: actual = jax_dist.entropy() except NotImplementedError: - pytest.skip(reason="distribution does not implement `entropy`") + pytest.skip(reason=f"distribution {jax_dist} does not implement `entropy`") samples = jax_dist.sample(jax.random.key(8), (1000,)) neg_log_probs = -jax_dist.log_prob(samples)