Skip to content

Commit

Permalink
Add entropy implementations.
Browse files Browse the repository at this point in the history
  • Loading branch information
tillahoffmann committed Apr 25, 2024
1 parent 2f1bccd commit d7cc8ff
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 0 deletions.
84 changes: 84 additions & 0 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from jax.scipy.linalg import cho_solve, solve_triangular
from jax.scipy.special import (
betaln,
digamma,
expi,
expit,
gammainc,
Expand Down Expand Up @@ -198,6 +199,17 @@ def cdf(self, value):
def icdf(self, q):
return betaincinv(self.concentration1, self.concentration0, q)

def entropy(self):
total = self.concentration0 + self.concentration1
return (
gammaln(self.concentration0)
+ gammaln(self.concentration1)
- gammaln(total)
- (self.concentration0 - 1) * digamma(self.concentration0)
- (self.concentration1 - 1) * digamma(self.concentration1)
+ (total - 2) * digamma(total)
)


class Cauchy(Distribution):
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
Expand Down Expand Up @@ -239,6 +251,9 @@ def cdf(self, value):
def icdf(self, q):
return self.loc + self.scale * jnp.tan(jnp.pi * (q - 0.5))

def entropy(self):
return jnp.broadcast_to(jnp.log(4 * np.pi * self.scale), self.batch_shape)


class Dirichlet(Distribution):
arg_constraints = {
Expand Down Expand Up @@ -293,6 +308,16 @@ def infer_shapes(concentration):
event_shape = concentration[-1:]
return batch_shape, event_shape

def entropy(self):
(n,) = self.event_shape
total = self.concentration.sum(axis=-1)
return (
gammaln(self.concentration).sum(axis=-1)
- gammaln(total)
+ (total - n) * digamma(total)
- ((self.concentration - 1) * digamma(self.concentration)).sum(axis=-1)
)


class EulerMaruyama(Distribution):
"""
Expand Down Expand Up @@ -458,6 +483,9 @@ def cdf(self, value):
def icdf(self, q):
return -jnp.log1p(-q) / self.rate

def entropy(self):
return 1 - jnp.log(self.rate)


class Gamma(Distribution):
arg_constraints = {
Expand Down Expand Up @@ -504,6 +532,14 @@ def cdf(self, x):
def icdf(self, q):
return gammaincinv(self.concentration, q) / self.rate

def entropy(self):
return (
self.concentration
- jnp.log(self.rate)
+ gammaln(self.concentration)
+ (1 - self.concentration) * digamma(self.concentration)
)


class Chi2(Gamma):
arg_constraints = {"df": constraints.positive}
Expand Down Expand Up @@ -861,6 +897,9 @@ def icdf(self, q):
a = q - 0.5
return self.loc - self.scale * jnp.sign(a) * jnp.log1p(-2 * jnp.abs(a))

def entropy(self):
return jnp.log(2 * self.scale) + 1


class LKJ(TransformedDistribution):
r"""
Expand Down Expand Up @@ -1161,6 +1200,9 @@ def variance(self):
def cdf(self, x):
return self.base_dist.cdf(jnp.log(x))

def entropy(self):
return (1 + jnp.log(2 * jnp.pi)) / 2 + self.loc + jnp.log(self.scale)


class Logistic(Distribution):
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
Expand Down Expand Up @@ -1201,6 +1243,9 @@ def cdf(self, value):
def icdf(self, q):
return self.loc + self.scale * logit(q)

def entropy(self):
return jnp.broadcast_to(jnp.log(self.scale) + 2, self.batch_shape)


class LogUniform(TransformedDistribution):
arg_constraints = {"low": constraints.positive, "high": constraints.positive}
Expand Down Expand Up @@ -1233,6 +1278,11 @@ def variance(self):
def cdf(self, x):
return self.base_dist.cdf(jnp.log(x))

def entropy(self):
log_low = jnp.log(self.low)
log_high = jnp.log(self.high)
return (log_low + log_high) / 2 + jnp.log(log_high - log_low)


def _batch_solve_triangular(A, B):
"""
Expand Down Expand Up @@ -1521,6 +1571,13 @@ def infer_shapes(
event_shape = lax.broadcast_shapes(event_shape, matrix[-1:])
return batch_shape, event_shape

def entropy(self):
(n,) = self.event_shape
half_log_det = jnp.log(jnp.diagonal(self.scale_tril, axis1=-2, axis2=-1)).sum(
-1
)
return n * (jnp.log(2 * np.pi) + 1) / 2 + half_log_det


def _is_sparse(A):
from scipy import sparse
Expand Down Expand Up @@ -2062,6 +2119,11 @@ def mean(self):
def variance(self):
return jnp.broadcast_to(self.scale**2, self.batch_shape)

def entropy(self):
return jnp.broadcast_to(
(jnp.log(2 * np.pi * self.scale**2) + 1) / 2, self.batch_shape
)


class Pareto(TransformedDistribution):
arg_constraints = {"scale": constraints.positive, "alpha": constraints.positive}
Expand Down Expand Up @@ -2103,6 +2165,9 @@ def cdf(self, value):
def icdf(self, q):
return self.scale / jnp.power(1 - q, 1 / self.alpha)

def entropy(self):
return jnp.log(self.scale / self.alpha) + 1 + 1 / self.alpha


class RelaxedBernoulliLogits(TransformedDistribution):
arg_constraints = {"temperature": constraints.positive, "logits": constraints.real}
Expand Down Expand Up @@ -2257,6 +2322,15 @@ def icdf(self, q):
scaled = jnp.sign(q - 0.5) * jnp.sqrt(scaled_squared)
return scaled * self.scale + self.loc

def entropy(self):
return jnp.broadcast_to(
(self.df + 1) / 2 * (digamma((self.df + 1) / 2) - digamma(self.df / 2))
+ jnp.log(self.df) / 2
+ betaln(self.df / 2, 0.5)
+ jnp.log(self.scale),
self.batch_shape,
)


class Uniform(Distribution):
arg_constraints = {"low": constraints.dependent, "high": constraints.dependent}
Expand Down Expand Up @@ -2303,6 +2377,9 @@ def infer_shapes(low=(), high=()):
event_shape = ()
return batch_shape, event_shape

def entropy(self):
return jnp.log(self.high - self.low)


class Weibull(Distribution):
arg_constraints = {
Expand Down Expand Up @@ -2348,6 +2425,13 @@ def variance(self):
- jnp.exp(gammaln(1.0 + 1.0 / self.concentration)) ** 2
)

def entropy(self):
return (
jnp.euler_gamma * (1 - 1 / self.concentration)
+ jnp.log(self.scale / self.concentration)
+ 1
)


class BetaProportion(Beta):
"""
Expand Down
31 changes: 31 additions & 0 deletions numpyro/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ def enumerate_support(self, expand=True):
values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
return values

def entropy(self):
return -self.probs * jnp.log(self.probs) - (1 - self.probs) * jnp.log1p(
-self.probs
)


class BernoulliLogits(Distribution):
arg_constraints = {"logits": constraints.real}
Expand Down Expand Up @@ -149,6 +154,10 @@ def enumerate_support(self, expand=True):
values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
return values

def entropy(self):
nexp = jnp.exp(-self.logits)
return ((1 + nexp) * jnp.log1p(nexp) + nexp * self.logits) / (1 + nexp)


def Bernoulli(probs=None, logits=None, *, validate_args=None):
if probs is not None:
Expand Down Expand Up @@ -341,6 +350,9 @@ def enumerate_support(self, expand=True):
values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
return values

def entropy(self):
return -(self.probs * jnp.log(self.probs)).sum(axis=-1)


class CategoricalLogits(Distribution):
arg_constraints = {"logits": constraints.real_vector}
Expand Down Expand Up @@ -393,6 +405,10 @@ def enumerate_support(self, expand=True):
values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
return values

def entropy(self):
probs = softmax(self.logits, axis=-1)
return -(probs * self.logits).sum(axis=-1) + logsumexp(self.logits, axis=-1)


def Categorical(probs=None, logits=None, *, validate_args=None):
if probs is not None:
Expand Down Expand Up @@ -462,6 +478,9 @@ def enumerate_support(self, expand=True):
values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape)
return values

def entropy(self):
return jnp.log(self.high - self.low + 1)


class OrderedLogistic(CategoricalProbs):
"""
Expand Down Expand Up @@ -498,6 +517,9 @@ def infer_shapes(predictor, cutpoints):
event_shape = ()
return batch_shape, event_shape

def entropy(self):
raise NotImplementedError


class MultinomialProbs(Distribution):
arg_constraints = {
Expand Down Expand Up @@ -879,6 +901,11 @@ def mean(self):
def variance(self):
return (1.0 / self.probs - 1.0) / self.probs

def entropy(self):
return -(1 - self.probs) * jnp.log1p(-self.probs) / self.probs - jnp.log(
self.probs
)


class GeometricLogits(Distribution):
arg_constraints = {"logits": constraints.real}
Expand Down Expand Up @@ -914,6 +941,10 @@ def mean(self):
def variance(self):
return (1.0 / self.probs - 1.0) / self.probs

def entropy(self):
nexp = jnp.exp(-self.logits)
return nexp * self.logits + jnp.log1p(nexp) * (1 + nexp)


def Geometric(probs=None, logits=None, *, validate_args=None):
if probs is not None:
Expand Down
6 changes: 6 additions & 0 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,12 @@ def enumerate_support(self, expand=True):
"""
raise NotImplementedError

def entropy(self):
"""
Returns the entropy of the distribution.
"""
raise NotImplementedError

def expand(self, batch_shape):
"""
Returns a new :class:`ExpandedDistribution` instance with batch
Expand Down
31 changes: 31 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ def __init__(
dist.Cauchy: lambda loc, scale: osp.cauchy(loc=loc, scale=scale),
dist.Chi2: lambda df: osp.chi2(df),
dist.Dirichlet: lambda conc: osp.dirichlet(conc),
dist.DiscreteUniform: lambda low, high: osp.randint(low, high + 1),
dist.Exponential: lambda rate: osp.expon(scale=jnp.reciprocal(rate)),
dist.Gamma: lambda conc, rate: osp.gamma(conc, scale=1.0 / rate),
dist.GeometricProbs: lambda probs: osp.geom(p=probs, loc=-1),
Expand Down Expand Up @@ -1390,6 +1391,36 @@ def test_log_prob(jax_dist, sp_dist, params, prepend_shape, jit):
assert_allclose(jit_fn(jax_dist.log_prob)(samples), expected, atol=1e-5)


@pytest.mark.parametrize(
"jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL
)
def test_entropy(jax_dist, sp_dist, params):
jax_dist = jax_dist(*params)

if _is_batched_multivariate(jax_dist):
pytest.skip("batching not allowed in multivariate distns.")
if sp_dist is None:
pytest.skip(reason="no corresponding scipy distribution")
try:
actual = jax_dist.entropy()
except NotImplementedError:
pytest.skip(reason="distribution does not implement `entropy`")

sp_dist = sp_dist(*params)
expected = sp_dist.entropy()
assert_allclose(actual, expected, atol=1e-5)


def test_entropy_categorical():
# There is no scipy mapping for categorical distributions, but the multinomial with
# one trial has the same entropy--which we check here.
logits = jax.random.normal(jax.random.key(9), (7,))
probs = _to_probs_multinom(logits)
sp_dist = osp.multinomial(1, probs)
for jax_dist in [dist.CategoricalLogits(logits), dist.CategoricalProbs(probs)]:
assert_allclose(jax_dist.entropy(), sp_dist.entropy())


def test_mixture_log_prob():
gmm = dist.MixtureSameFamily(
dist.Categorical(logits=np.zeros(2)), dist.Normal(0, 1).expand([2])
Expand Down

0 comments on commit d7cc8ff

Please sign in to comment.