diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index daabb309b..262295b3e 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -2710,7 +2710,8 @@ def infer_shapes( def entropy(self): p = self.event_shape[-1] return ( - (p + 1) * jnp.linalg.slogdet(self.scale_tril).logabsdet + (p + 1) + * jnp.log(jnp.diagonal(self.scale_tril, axis1=-1, axis2=-2)).sum(axis=-1) + p * (p + 1) / 2 * jnp.log(2) + multigammaln(self.concentration / 2, p) - (self.concentration - p - 1) / 2 * multidigamma(self.concentration / 2, p)