Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add entropy implementations. #1787

Merged
merged 2 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from jax.scipy.linalg import cho_solve, solve_triangular
from jax.scipy.special import (
betaln,
digamma,
expi,
expit,
gammainc,
Expand Down Expand Up @@ -198,6 +199,15 @@ def cdf(self, value):
def icdf(self, q):
return betaincinv(self.concentration1, self.concentration0, q)

def entropy(self):
total = self.concentration0 + self.concentration1
return (
betaln(self.concentration0, self.concentration1)
- (self.concentration0 - 1) * digamma(self.concentration0)
- (self.concentration1 - 1) * digamma(self.concentration1)
+ (total - 2) * digamma(total)
)


class Cauchy(Distribution):
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
Expand Down Expand Up @@ -239,6 +249,9 @@ def cdf(self, value):
def icdf(self, q):
return self.loc + self.scale * jnp.tan(jnp.pi * (q - 0.5))

def entropy(self):
return jnp.broadcast_to(jnp.log(4 * np.pi * self.scale), self.batch_shape)


class Dirichlet(Distribution):
arg_constraints = {
Expand Down Expand Up @@ -293,6 +306,16 @@ def infer_shapes(concentration):
event_shape = concentration[-1:]
return batch_shape, event_shape

def entropy(self):
(n,) = self.event_shape
total = self.concentration.sum(axis=-1)
return (
gammaln(self.concentration).sum(axis=-1)
- gammaln(total)
+ (total - n) * digamma(total)
- ((self.concentration - 1) * digamma(self.concentration)).sum(axis=-1)
)


class EulerMaruyama(Distribution):
"""
Expand Down Expand Up @@ -458,6 +481,9 @@ def cdf(self, value):
def icdf(self, q):
return -jnp.log1p(-q) / self.rate

def entropy(self):
return 1 - jnp.log(self.rate)


class Gamma(Distribution):
arg_constraints = {
Expand Down Expand Up @@ -504,6 +530,14 @@ def cdf(self, x):
def icdf(self, q):
return gammaincinv(self.concentration, q) / self.rate

def entropy(self):
return (
self.concentration
- jnp.log(self.rate)
+ gammaln(self.concentration)
+ (1 - self.concentration) * digamma(self.concentration)
)


class Chi2(Gamma):
arg_constraints = {"df": constraints.positive}
Expand Down Expand Up @@ -861,6 +895,9 @@ def icdf(self, q):
a = q - 0.5
return self.loc - self.scale * jnp.sign(a) * jnp.log1p(-2 * jnp.abs(a))

def entropy(self):
return jnp.log(2 * self.scale) + 1


class LKJ(TransformedDistribution):
r"""
Expand Down Expand Up @@ -1161,6 +1198,9 @@ def variance(self):
def cdf(self, x):
return self.base_dist.cdf(jnp.log(x))

def entropy(self):
return (1 + jnp.log(2 * jnp.pi)) / 2 + self.loc + jnp.log(self.scale)


class Logistic(Distribution):
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
Expand Down Expand Up @@ -1201,6 +1241,9 @@ def cdf(self, value):
def icdf(self, q):
return self.loc + self.scale * logit(q)

def entropy(self):
return jnp.broadcast_to(jnp.log(self.scale) + 2, self.batch_shape)


class LogUniform(TransformedDistribution):
arg_constraints = {"low": constraints.positive, "high": constraints.positive}
Expand Down Expand Up @@ -1233,6 +1276,11 @@ def variance(self):
def cdf(self, x):
return self.base_dist.cdf(jnp.log(x))

def entropy(self):
log_low = jnp.log(self.low)
log_high = jnp.log(self.high)
return (log_low + log_high) / 2 + jnp.log(log_high - log_low)


def _batch_solve_triangular(A, B):
"""
Expand Down Expand Up @@ -1521,6 +1569,13 @@ def infer_shapes(
event_shape = lax.broadcast_shapes(event_shape, matrix[-1:])
return batch_shape, event_shape

def entropy(self):
(n,) = self.event_shape
half_log_det = jnp.log(jnp.diagonal(self.scale_tril, axis1=-2, axis2=-1)).sum(
-1
)
return n * (jnp.log(2 * np.pi) + 1) / 2 + half_log_det


def _is_sparse(A):
from scipy import sparse
Expand Down Expand Up @@ -2062,6 +2117,11 @@ def mean(self):
def variance(self):
return jnp.broadcast_to(self.scale**2, self.batch_shape)

def entropy(self):
return jnp.broadcast_to(
(jnp.log(2 * np.pi * self.scale**2) + 1) / 2, self.batch_shape
)


class Pareto(TransformedDistribution):
arg_constraints = {"scale": constraints.positive, "alpha": constraints.positive}
Expand Down Expand Up @@ -2103,6 +2163,9 @@ def cdf(self, value):
def icdf(self, q):
return self.scale / jnp.power(1 - q, 1 / self.alpha)

def entropy(self):
return jnp.log(self.scale / self.alpha) + 1 + 1 / self.alpha


class RelaxedBernoulliLogits(TransformedDistribution):
arg_constraints = {"temperature": constraints.positive, "logits": constraints.real}
Expand Down Expand Up @@ -2257,6 +2320,15 @@ def icdf(self, q):
scaled = jnp.sign(q - 0.5) * jnp.sqrt(scaled_squared)
return scaled * self.scale + self.loc

def entropy(self):
return jnp.broadcast_to(
(self.df + 1) / 2 * (digamma((self.df + 1) / 2) - digamma(self.df / 2))
+ jnp.log(self.df) / 2
+ betaln(self.df / 2, 0.5)
+ jnp.log(self.scale),
self.batch_shape,
)


class Uniform(Distribution):
arg_constraints = {"low": constraints.dependent, "high": constraints.dependent}
Expand Down Expand Up @@ -2303,6 +2375,9 @@ def infer_shapes(low=(), high=()):
event_shape = ()
return batch_shape, event_shape

def entropy(self):
return jnp.log(self.high - self.low)


class Weibull(Distribution):
arg_constraints = {
Expand Down Expand Up @@ -2348,6 +2423,13 @@ def variance(self):
- jnp.exp(gammaln(1.0 + 1.0 / self.concentration)) ** 2
)

def entropy(self):
return (
jnp.euler_gamma * (1 - 1 / self.concentration)
+ jnp.log(self.scale / self.concentration)
+ 1
)


class BetaProportion(Beta):
"""
Expand Down
31 changes: 31 additions & 0 deletions numpyro/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ def enumerate_support(self, expand=True):
values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
return values

def entropy(self):
return -self.probs * jnp.log(self.probs) - (1 - self.probs) * jnp.log1p(
-self.probs
)


class BernoulliLogits(Distribution):
arg_constraints = {"logits": constraints.real}
Expand Down Expand Up @@ -149,6 +154,10 @@ def enumerate_support(self, expand=True):
values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
return values

def entropy(self):
nexp = jnp.exp(-self.logits)
return ((1 + nexp) * jnp.log1p(nexp) + nexp * self.logits) / (1 + nexp)


def Bernoulli(probs=None, logits=None, *, validate_args=None):
if probs is not None:
Expand Down Expand Up @@ -341,6 +350,9 @@ def enumerate_support(self, expand=True):
values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
return values

def entropy(self):
return -(self.probs * jnp.log(self.probs)).sum(axis=-1)


class CategoricalLogits(Distribution):
arg_constraints = {"logits": constraints.real_vector}
Expand Down Expand Up @@ -393,6 +405,10 @@ def enumerate_support(self, expand=True):
values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
return values

def entropy(self):
probs = softmax(self.logits, axis=-1)
return -(probs * self.logits).sum(axis=-1) + logsumexp(self.logits, axis=-1)


def Categorical(probs=None, logits=None, *, validate_args=None):
if probs is not None:
Expand Down Expand Up @@ -462,6 +478,9 @@ def enumerate_support(self, expand=True):
values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
return values

def entropy(self):
return jnp.log(self.high - self.low + 1)


class OrderedLogistic(CategoricalProbs):
"""
Expand Down Expand Up @@ -498,6 +517,9 @@ def infer_shapes(predictor, cutpoints):
event_shape = ()
return batch_shape, event_shape

def entropy(self):
raise NotImplementedError


class MultinomialProbs(Distribution):
arg_constraints = {
Expand Down Expand Up @@ -879,6 +901,11 @@ def mean(self):
def variance(self):
return (1.0 / self.probs - 1.0) / self.probs

def entropy(self):
return -(1 - self.probs) * jnp.log1p(-self.probs) / self.probs - jnp.log(
self.probs
)


class GeometricLogits(Distribution):
arg_constraints = {"logits": constraints.real}
Expand Down Expand Up @@ -914,6 +941,10 @@ def mean(self):
def variance(self):
return (1.0 / self.probs - 1.0) / self.probs

def entropy(self):
nexp = jnp.exp(-self.logits)
return nexp * self.logits + jnp.log1p(nexp) * (1 + nexp)


def Geometric(probs=None, logits=None, *, validate_args=None):
if probs is not None:
Expand Down
6 changes: 6 additions & 0 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,12 @@ def enumerate_support(self, expand=True):
"""
raise NotImplementedError

def entropy(self):
"""
Returns the entropy of the distribution.
"""
raise NotImplementedError

def expand(self, batch_shape):
"""
Returns a new :class:`ExpandedDistribution` instance with batch
Expand Down
31 changes: 31 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ def __init__(
dist.Cauchy: lambda loc, scale: osp.cauchy(loc=loc, scale=scale),
dist.Chi2: lambda df: osp.chi2(df),
dist.Dirichlet: lambda conc: osp.dirichlet(conc),
dist.DiscreteUniform: lambda low, high: osp.randint(low, high + 1),
dist.Exponential: lambda rate: osp.expon(scale=jnp.reciprocal(rate)),
dist.Gamma: lambda conc, rate: osp.gamma(conc, scale=1.0 / rate),
dist.GeometricProbs: lambda probs: osp.geom(p=probs, loc=-1),
Expand Down Expand Up @@ -1390,6 +1391,36 @@ def test_log_prob(jax_dist, sp_dist, params, prepend_shape, jit):
assert_allclose(jit_fn(jax_dist.log_prob)(samples), expected, atol=1e-5)


@pytest.mark.parametrize(
"jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL
)
def test_entropy(jax_dist, sp_dist, params):
jax_dist = jax_dist(*params)

if _is_batched_multivariate(jax_dist):
pytest.skip("batching not allowed in multivariate distns.")
if sp_dist is None:
pytest.skip(reason="no corresponding scipy distribution")
try:
actual = jax_dist.entropy()
except NotImplementedError:
pytest.skip(reason="distribution does not implement `entropy`")

sp_dist = sp_dist(*params)
expected = sp_dist.entropy()
assert_allclose(actual, expected, atol=1e-5)


def test_entropy_categorical():
# There is no scipy mapping for categorical distributions, but the multinomial with
# one trial has the same entropy--which we check here.
logits = jax.random.normal(jax.random.key(9), (7,))
probs = _to_probs_multinom(logits)
sp_dist = osp.multinomial(1, probs)
for jax_dist in [dist.CategoricalLogits(logits), dist.CategoricalProbs(probs)]:
assert_allclose(jax_dist.entropy(), sp_dist.entropy())


def test_mixture_log_prob():
gmm = dist.MixtureSameFamily(
dist.Categorical(logits=np.zeros(2)), dist.Normal(0, 1).expand([2])
Expand Down
Loading