Skip to content

Commit

Permalink
Add entropy for Wishart distribution.
Browse files Browse the repository at this point in the history
  • Loading branch information
tillahoffmann committed May 2, 2024
1 parent 90cc7b6 commit bd069fc
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
11 changes: 11 additions & 0 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
gammaincinv,
lazy_property,
matrix_to_tril_vec,
multidigamma,
promote_shapes,
signed_stick_breaking_tril,
validate_sample,
Expand Down Expand Up @@ -2706,6 +2707,16 @@ def infer_shapes(
concentration, scale_matrix, rate_matrix, scale_tril
)

def entropy(self):
p = self.event_shape[-1]
return (
(p + 1) * jnp.linalg.slogdet(self.scale_tril).logabsdet
+ p * (p + 1) / 2 * jnp.log(2)
+ multigammaln(self.concentration / 2, p)
- (self.concentration - p - 1) / 2 * multidigamma(self.concentration / 2, p)
+ self.concentration * p / 2
)


class WishartCholesky(Distribution):
"""
Expand Down
8 changes: 8 additions & 0 deletions numpyro/distributions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from jax import jit, lax, random, vmap
import jax.numpy as jnp
from jax.scipy.linalg import solve_triangular
from jax.scipy.special import digamma

# Parameters for Transformed Rejection with Squeeze (TRS) algorithm - page 3.
_tr_params = namedtuple(
Expand Down Expand Up @@ -634,6 +635,13 @@ def assert_one_of(**kwargs):
)


def multidigamma(a: jnp.ndarray, d: jnp.ndarray) -> jnp.ndarray:
"""
Derivative of the log of multivariate gamma.
"""
return digamma(a[..., None] - 0.5 * jnp.arange(d)).sum(axis=-1)


# The is sourced from: torch.distributions.util.py
#
# Copyright (c) 2016- Facebook, Inc (Adam Paszke)
Expand Down

0 comments on commit bd069fc

Please sign in to comment.