Skip to content

Commit

Permalink
Add further entropy implementations. (#1800)
Browse files Browse the repository at this point in the history
* Add entropy for `Independent`, `Expanded`, and `Delta` distributions.

* Fix `__repr__` for `SineSkewed`.

* Add entropy for `InverseGamma` distribution.
  • Loading branch information
tillahoffmann authored May 14, 2024
1 parent b373831 commit f572f2b
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 5 deletions.
8 changes: 8 additions & 0 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,14 @@ def variance(self):
def cdf(self, x):
return 1 - self.base_dist.cdf(1 / x)

def entropy(self):
return (
self.concentration
+ jnp.log(self.rate)
+ gammaln(self.concentration)
- (1 + self.concentration) * digamma(self.concentration)
)


class Gompertz(Distribution):
r"""Gompertz Distribution.
Expand Down
4 changes: 2 additions & 2 deletions numpyro/distributions/directional.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ def __repr__(self):
"{}: {}".format(
p,
getattr(self, p)
if getattr(self, p).numel() == 1
else getattr(self, p).size(),
if getattr(self, p).size == 1
else getattr(self, p).size,
)
for p in self.arg_constraints.keys()
]
Expand Down
10 changes: 10 additions & 0 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,9 @@ def variance(self):
self.base_dist.variance, self.batch_shape + self.event_shape
)

def entropy(self):
return jnp.broadcast_to(self.base_dist.entropy(), self.batch_shape)


class ImproperUniform(Distribution):
"""
Expand Down Expand Up @@ -851,6 +854,10 @@ def expand(self, batch_shape):
self.reinterpreted_batch_ndims
)

def entropy(self):
axes = range(-self.reinterpreted_batch_ndims, 0)
return self.base_dist.entropy().sum(axes)


class MaskedDistribution(Distribution):
"""
Expand Down Expand Up @@ -1168,6 +1175,9 @@ def mean(self):
def variance(self):
return jnp.zeros(self.batch_shape + self.event_shape)

def entropy(self):
return -jnp.broadcast_to(self.log_density, self.batch_shape)


class Unit(Distribution):
"""
Expand Down
15 changes: 12 additions & 3 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,6 +1041,15 @@ def get_sp_dist(jax_dist):
),
]

BASE = [
T(lambda *args: dist.Normal(*args).to_event(2), np.arange(24).reshape(3, 4, 2)),
T(lambda *args: dist.Normal(*args).expand((3, 4, 7)), np.arange(7)),
T(
lambda *args: dist.Normal(*args).to_event(2).expand((7, 3)),
np.arange(24).reshape(3, 4, 2),
),
]


def _is_batched_multivariate(jax_dist):
return len(jax_dist.event_shape) > 0 and len(jax_dist.batch_shape) > 0
Expand Down Expand Up @@ -1494,7 +1503,7 @@ def test_entropy_scipy(jax_dist, sp_dist, params):
try:
actual = jax_dist.entropy()
except NotImplementedError:
pytest.skip(reason="distribution does not implement `entropy`")
pytest.skip(reason=f"distribution {jax_dist} does not implement `entropy`")
if _is_batched_multivariate(jax_dist):
pytest.skip("batching not allowed in multivariate distns.")
if sp_dist is None:
Expand All @@ -1506,15 +1515,15 @@ def test_entropy_scipy(jax_dist, sp_dist, params):


@pytest.mark.parametrize(
"jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL
"jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL + BASE
)
def test_entropy_samples(jax_dist, sp_dist, params):
jax_dist = jax_dist(*params)

try:
actual = jax_dist.entropy()
except NotImplementedError:
pytest.skip(reason="distribution does not implement `entropy`")
pytest.skip(reason=f"distribution {jax_dist} does not implement `entropy`")

samples = jax_dist.sample(jax.random.key(8), (1000,))
neg_log_probs = -jax_dist.log_prob(samples)
Expand Down

0 comments on commit f572f2b

Please sign in to comment.