diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 0e895c827..493afa500 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -36,6 +36,7 @@ from jax.scipy.linalg import cho_solve, solve_triangular from jax.scipy.special import ( betaln, + digamma, expi, expit, gammainc, @@ -198,6 +199,17 @@ def cdf(self, value): def icdf(self, q): return betaincinv(self.concentration1, self.concentration0, q) + def entropy(self): + total = self.concentration0 + self.concentration1 + return ( + gammaln(self.concentration0) + + gammaln(self.concentration1) + - gammaln(total) + - (self.concentration0 - 1) * digamma(self.concentration0) + - (self.concentration1 - 1) * digamma(self.concentration1) + + (total - 2) * digamma(total) + ) + class Cauchy(Distribution): arg_constraints = {"loc": constraints.real, "scale": constraints.positive} @@ -239,6 +251,9 @@ def cdf(self, value): def icdf(self, q): return self.loc + self.scale * jnp.tan(jnp.pi * (q - 0.5)) + def entropy(self): + return jnp.broadcast_to(jnp.log(4 * np.pi * self.scale), self.batch_shape) + class Dirichlet(Distribution): arg_constraints = { @@ -293,6 +308,16 @@ def infer_shapes(concentration): event_shape = concentration[-1:] return batch_shape, event_shape + def entropy(self): + (n,) = self.event_shape + total = self.concentration.sum(axis=-1) + return ( + gammaln(self.concentration).sum(axis=-1) + - gammaln(total) + + (total - n) * digamma(total) + - ((self.concentration - 1) * digamma(self.concentration)).sum(axis=-1) + ) + class EulerMaruyama(Distribution): """ @@ -458,6 +483,9 @@ def cdf(self, value): def icdf(self, q): return -jnp.log1p(-q) / self.rate + def entropy(self): + return 1 - jnp.log(self.rate) + class Gamma(Distribution): arg_constraints = { @@ -504,6 +532,14 @@ def cdf(self, x): def icdf(self, q): return gammaincinv(self.concentration, q) / self.rate + def entropy(self): + return ( + self.concentration + - jnp.log(self.rate) + + gammaln(self.concentration) + + (1 - self.concentration) * digamma(self.concentration) + ) + class Chi2(Gamma): arg_constraints = {"df": constraints.positive} @@ -861,6 +897,9 @@ def icdf(self, q): a = q - 0.5 return self.loc - self.scale * jnp.sign(a) * jnp.log1p(-2 * jnp.abs(a)) + def entropy(self): + return jnp.log(2 * self.scale) + 1 + class LKJ(TransformedDistribution): r""" @@ -1161,6 +1200,9 @@ def variance(self): def cdf(self, x): return self.base_dist.cdf(jnp.log(x)) + def entropy(self): + return (1 + jnp.log(2 * jnp.pi)) / 2 + self.loc + jnp.log(self.scale) + class Logistic(Distribution): arg_constraints = {"loc": constraints.real, "scale": constraints.positive} @@ -1201,6 +1243,9 @@ def cdf(self, value): def icdf(self, q): return self.loc + self.scale * logit(q) + def entropy(self): + return jnp.broadcast_to(jnp.log(self.scale) + 2, self.batch_shape) + class LogUniform(TransformedDistribution): arg_constraints = {"low": constraints.positive, "high": constraints.positive} @@ -1233,6 +1278,11 @@ def variance(self): def cdf(self, x): return self.base_dist.cdf(jnp.log(x)) + def entropy(self): + log_low = jnp.log(self.low) + log_high = jnp.log(self.high) + return (log_low + log_high) / 2 + jnp.log(log_high - log_low) + def _batch_solve_triangular(A, B): """ @@ -1521,6 +1571,13 @@ def infer_shapes( event_shape = lax.broadcast_shapes(event_shape, matrix[-1:]) return batch_shape, event_shape + def entropy(self): + (n,) = self.event_shape + half_log_det = jnp.log(jnp.diagonal(self.scale_tril, axis1=-2, axis2=-1)).sum( + -1 + ) + return n * (jnp.log(2 * np.pi) + 1) / 2 + half_log_det + def _is_sparse(A): from scipy import sparse @@ -2062,6 +2119,11 @@ def mean(self): def variance(self): return jnp.broadcast_to(self.scale**2, self.batch_shape) + def entropy(self): + return jnp.broadcast_to( + (jnp.log(2 * np.pi * self.scale**2) + 1) / 2, self.batch_shape + ) + class Pareto(TransformedDistribution): arg_constraints = {"scale": constraints.positive, "alpha": constraints.positive} @@ -2103,6 +2165,9 @@ def cdf(self, value): def icdf(self, q): return self.scale / jnp.power(1 - q, 1 / self.alpha) + def entropy(self): + return jnp.log(self.scale / self.alpha) + 1 + 1 / self.alpha + class RelaxedBernoulliLogits(TransformedDistribution): arg_constraints = {"temperature": constraints.positive, "logits": constraints.real} @@ -2257,6 +2322,15 @@ def icdf(self, q): scaled = jnp.sign(q - 0.5) * jnp.sqrt(scaled_squared) return scaled * self.scale + self.loc + def entropy(self): + return jnp.broadcast_to( + (self.df + 1) / 2 * (digamma((self.df + 1) / 2) - digamma(self.df / 2)) + + jnp.log(self.df) / 2 + + betaln(self.df / 2, 0.5) + + jnp.log(self.scale), + self.batch_shape, + ) + class Uniform(Distribution): arg_constraints = {"low": constraints.dependent, "high": constraints.dependent} @@ -2303,6 +2377,9 @@ def infer_shapes(low=(), high=()): event_shape = () return batch_shape, event_shape + def entropy(self): + return jnp.log(self.high - self.low) + class Weibull(Distribution): arg_constraints = { @@ -2348,6 +2425,13 @@ def variance(self): - jnp.exp(gammaln(1.0 + 1.0 / self.concentration)) ** 2 ) + def entropy(self): + return ( + jnp.euler_gamma * (1 - 1 / self.concentration) + + jnp.log(self.scale / self.concentration) + + 1 + ) + class BetaProportion(Beta): """ diff --git a/numpyro/distributions/discrete.py b/numpyro/distributions/discrete.py index a5fd12536..fc0c81dd9 100644 --- a/numpyro/distributions/discrete.py +++ b/numpyro/distributions/discrete.py @@ -108,6 +108,11 @@ def enumerate_support(self, expand=True): values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape) return values + def entropy(self): + return -self.probs * jnp.log(self.probs) - (1 - self.probs) * jnp.log1p( + -self.probs + ) + class BernoulliLogits(Distribution): arg_constraints = {"logits": constraints.real} @@ -149,6 +154,10 @@ def enumerate_support(self, expand=True): values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape) return values + def entropy(self): + nexp = jnp.exp(-self.logits) + return ((1 + nexp) * jnp.log1p(nexp) + nexp * self.logits) / (1 + nexp) + def Bernoulli(probs=None, logits=None, *, validate_args=None): if probs is not None: @@ -341,6 +350,9 @@ def enumerate_support(self, expand=True): values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape) return values + def entropy(self): + return -(self.probs * jnp.log(self.probs)).sum(axis=-1) + class CategoricalLogits(Distribution): arg_constraints = {"logits": constraints.real_vector} @@ -393,6 +405,10 @@ def enumerate_support(self, expand=True): values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape) return values + def entropy(self): + probs = softmax(self.logits, axis=-1) + return -(probs * self.logits).sum(axis=-1) + logsumexp(self.logits, axis=-1) + def Categorical(probs=None, logits=None, *, validate_args=None): if probs is not None: @@ -462,6 +478,9 @@ def enumerate_support(self, expand=True): values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape) return values + def entropy(self): + return jnp.log(self.high - self.low + 1) + class OrderedLogistic(CategoricalProbs): """ @@ -498,6 +517,9 @@ def infer_shapes(predictor, cutpoints): event_shape = () return batch_shape, event_shape + def entropy(self): + raise NotImplementedError + class MultinomialProbs(Distribution): arg_constraints = { @@ -879,6 +901,11 @@ def mean(self): def variance(self): return (1.0 / self.probs - 1.0) / self.probs + def entropy(self): + return -(1 - self.probs) * jnp.log1p(-self.probs) / self.probs - jnp.log( + self.probs + ) + class GeometricLogits(Distribution): arg_constraints = {"logits": constraints.real} @@ -914,6 +941,10 @@ def mean(self): def variance(self): return (1.0 / self.probs - 1.0) / self.probs + def entropy(self): + nexp = jnp.exp(-self.logits) + return nexp * self.logits + jnp.log1p(nexp) * (1 + nexp) + def Geometric(probs=None, logits=None, *, validate_args=None): if probs is not None: diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index 8463c10c2..fa9e4613c 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -391,6 +391,12 @@ def enumerate_support(self, expand=True): """ raise NotImplementedError + def entropy(self): + """ + Returns the entropy of the distribution. + """ + raise NotImplementedError + def expand(self, batch_shape): """ Returns a new :class:`ExpandedDistribution` instance with batch diff --git a/test/test_distributions.py b/test/test_distributions.py index 961ab449f..43360b74b 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -408,6 +408,7 @@ def __init__( dist.Cauchy: lambda loc, scale: osp.cauchy(loc=loc, scale=scale), dist.Chi2: lambda df: osp.chi2(df), dist.Dirichlet: lambda conc: osp.dirichlet(conc), + dist.DiscreteUniform: lambda low, high: osp.randint(low, high + 1), dist.Exponential: lambda rate: osp.expon(scale=jnp.reciprocal(rate)), dist.Gamma: lambda conc, rate: osp.gamma(conc, scale=1.0 / rate), dist.GeometricProbs: lambda probs: osp.geom(p=probs, loc=-1), @@ -1390,6 +1391,36 @@ def test_log_prob(jax_dist, sp_dist, params, prepend_shape, jit): assert_allclose(jit_fn(jax_dist.log_prob)(samples), expected, atol=1e-5) +@pytest.mark.parametrize( + "jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL +) +def test_entropy(jax_dist, sp_dist, params): + jax_dist = jax_dist(*params) + + if _is_batched_multivariate(jax_dist): + pytest.skip("batching not allowed in multivariate distns.") + if sp_dist is None: + pytest.skip(reason="no corresponding scipy distribution") + try: + actual = jax_dist.entropy() + except NotImplementedError: + pytest.skip(reason="distribution does not implement `entropy`") + + sp_dist = sp_dist(*params) + expected = sp_dist.entropy() + assert_allclose(actual, expected, atol=1e-5) + + +def test_entropy_categorical(): + # There is no scipy mapping for categorical distributions, but the multinomial with + # one trial has the same entropy--which we check here. + logits = jax.random.normal(jax.random.key(9), (7,)) + probs = _to_probs_multinom(logits) + sp_dist = osp.multinomial(1, probs) + for jax_dist in [dist.CategoricalLogits(logits), dist.CategoricalProbs(probs)]: + assert_allclose(jax_dist.entropy(), sp_dist.entropy()) + + def test_mixture_log_prob(): gmm = dist.MixtureSameFamily( dist.Categorical(logits=np.zeros(2)), dist.Normal(0, 1).expand([2])