Skip to content

Commit

Permalink
Add tri_logabsdet utility function.
Browse files Browse the repository at this point in the history
  • Loading branch information
tillahoffmann committed May 13, 2024
1 parent 4a471d1 commit 34aa3e6
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 20 deletions.
30 changes: 10 additions & 20 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
multidigamma,
promote_shapes,
signed_stick_breaking_tril,
tri_logabsdet,
validate_sample,
vec_to_tril_matrix,
)
Expand Down Expand Up @@ -1400,12 +1401,8 @@ def sample(self, key, sample_shape=()):
def log_prob(self, values):
n, p = self.event_shape

row_log_det = jnp.log(
jnp.diagonal(self.scale_tril_row, axis1=-2, axis2=-1)
).sum(-1)
col_log_det = jnp.log(
jnp.diagonal(self.scale_tril_column, axis1=-2, axis2=-1)
).sum(-1)
row_log_det = tri_logabsdet(self.scale_tril_row)
col_log_det = tri_logabsdet(self.scale_tril_column)
log_det_term = (
p * row_log_det + n * col_log_det + 0.5 * n * p * jnp.log(2 * jnp.pi)
)
Expand Down Expand Up @@ -1532,9 +1529,7 @@ def sample(self, key, sample_shape=()):
@validate_sample
def log_prob(self, value):
M = _batch_mahalanobis(self.scale_tril, value - self.loc)
half_log_det = jnp.log(jnp.diagonal(self.scale_tril, axis1=-2, axis2=-1)).sum(
-1
)
half_log_det = tri_logabsdet(self.scale_tril)
normalize_term = half_log_det + 0.5 * self.scale_tril.shape[-1] * jnp.log(
2 * jnp.pi
)
Expand Down Expand Up @@ -1579,9 +1574,7 @@ def infer_shapes(

def entropy(self):
(n,) = self.event_shape
half_log_det = jnp.log(jnp.diagonal(self.scale_tril, axis1=-2, axis2=-1)).sum(
-1
)
half_log_det = tri_logabsdet(self.scale_tril)
return n * (jnp.log(2 * np.pi) + 1) / 2 + half_log_det


Expand Down Expand Up @@ -1857,7 +1850,7 @@ def sample(self, key, sample_shape=()):
def log_prob(self, value):
n = self.scale_tril.shape[-1]
Z = (
jnp.log(jnp.diagonal(self.scale_tril, axis1=-2, axis2=-1)).sum(-1)
tri_logabsdet(self.scale_tril)
+ 0.5 * n * jnp.log(self.df)
+ 0.5 * n * jnp.log(jnp.pi)
+ gammaln(0.5 * self.df)
Expand Down Expand Up @@ -1932,9 +1925,7 @@ def _batch_lowrank_logdet(W, D, capacitance_tril):
where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute
the log determinant.
"""
return 2 * jnp.sum(
jnp.log(jnp.diagonal(capacitance_tril, axis1=-2, axis2=-1)), axis=-1
) + jnp.log(D).sum(-1)
return 2 * tri_logabsdet(capacitance_tril) + jnp.log(D).sum(-1)


def _batch_lowrank_mahalanobis(W, D, x, capacitance_tril):
Expand Down Expand Up @@ -2710,8 +2701,7 @@ def infer_shapes(
def entropy(self):
p = self.event_shape[-1]
return (
(p + 1)
* jnp.log(jnp.diagonal(self.scale_tril, axis1=-1, axis2=-2)).sum(axis=-1)
(p + 1) * tri_logabsdet(self.scale_tril)
+ p * (p + 1) / 2 * jnp.log(2)
+ multigammaln(self.concentration / 2, p)
- (self.concentration - p - 1) / 2 * multidigamma(self.concentration / 2, p)
Expand Down Expand Up @@ -2799,11 +2789,11 @@ def log_prob(self, value):
trace = jnp.square(x).sum(axis=(-1, -2))
p = value.shape[-1]
return (
(self.concentration - p - 1) * jnp.linalg.slogdet(value).logabsdet
(self.concentration - p - 1) * tri_logabsdet(value)
- trace / 2
+ p * (1 - self.concentration / 2) * jnp.log(2)
- multigammaln(self.concentration / 2, p)
- self.concentration * jnp.linalg.slogdet(self.scale_tril).logabsdet
- self.concentration * tri_logabsdet(self.scale_tril)
# Part of the Jacobian of the Cholesky transformation.
+ jnp.sum(
jnp.arange(p, 0, -1) * jnp.log(jnp.diagonal(value, axis1=-2, axis2=-1)),
Expand Down
7 changes: 7 additions & 0 deletions numpyro/distributions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,13 @@ def multidigamma(a: jnp.ndarray, d: jnp.ndarray) -> jnp.ndarray:
return digamma(a[..., None] - 0.5 * jnp.arange(d)).sum(axis=-1)


def tri_logabsdet(a: jnp.ndarray) -> jnp.ndarray:
"""
Evaluate the `logabsdet` of a triangular positive-definite matrix.
"""
return jnp.log(jnp.diagonal(a, axis1=-1, axis2=-2)).sum(axis=-1)


# The is sourced from: torch.distributions.util.py
#
# Copyright (c) 2016- Facebook, Inc (Adam Paszke)
Expand Down

0 comments on commit 34aa3e6

Please sign in to comment.