diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 220c23b92..7d1ee97a0 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -60,6 +60,7 @@ SigmoidTransform, ) from numpyro.distributions.util import ( + add_diag, betainc, betaincinv, cholesky_of_inverse, @@ -1081,10 +1082,7 @@ def _onion(self, key, size): # correct the diagonal # NB: beta_sample = sum(w ** 2) because norm 2 of u is 1. diag = jnp.ones(cholesky.shape[:-1]).at[..., 1:].set(jnp.sqrt(1 - beta_sample)) - cholesky = cholesky + jnp.expand_dims(diag, axis=-1) * jnp.identity( - self.dimension - ) - return cholesky + return add_diag(cholesky, diag) def sample(self, key, sample_shape=()): assert is_prng_key(key) @@ -1860,7 +1858,7 @@ def _batch_capacitance_tril(W, D): Wt_Dinv = jnp.swapaxes(W, -1, -2) / jnp.expand_dims(D, -2) K = jnp.matmul(Wt_Dinv, W) # could be inefficient - return jnp.linalg.cholesky(jnp.add(K, jnp.identity(K.shape[-1]))) + return jnp.linalg.cholesky(add_diag(K, 1)) def _batch_lowrank_logdet(W, D, capacitance_tril): @@ -1957,17 +1955,15 @@ def scale_tril(self): cov_diag_sqrt_unsqueeze = jnp.expand_dims(jnp.sqrt(self.cov_diag), axis=-1) Dinvsqrt_W = self.cov_factor / cov_diag_sqrt_unsqueeze K = jnp.matmul(Dinvsqrt_W, jnp.swapaxes(Dinvsqrt_W, -1, -2)) - K = jnp.add(K, jnp.identity(K.shape[-1])) + K = add_diag(K, 1) scale_tril = cov_diag_sqrt_unsqueeze * jnp.linalg.cholesky(K) return scale_tril @lazy_property def covariance_matrix(self): - # TODO: find a better solution to create a diagonal matrix - new_diag = self.cov_diag[..., jnp.newaxis] * jnp.identity(self.loc.shape[-1]) - covariance_matrix = new_diag + jnp.matmul( + covariance_matrix = add_diag(jnp.matmul( self.cov_factor, jnp.swapaxes(self.cov_factor, -1, -2) - ) + ), self.cov_diag) return covariance_matrix @lazy_property @@ -1979,12 +1975,8 @@ def precision_matrix(self): self.cov_diag, axis=-2 ) A = solve_triangular(Wt_Dinv, self._capacitance_tril, lower=True) - # TODO: find a better solution to create a diagonal matrix inverse_cov_diag = jnp.reciprocal(self.cov_diag) - diag_embed = inverse_cov_diag[..., jnp.newaxis] * jnp.identity( - self.loc.shape[-1] - ) - return diag_embed - jnp.matmul(jnp.swapaxes(A, -1, -2), A) + return add_diag(- jnp.matmul(jnp.swapaxes(A, -1, -2), A), inverse_cov_diag) def sample(self, key, sample_shape=()): assert is_prng_key(key) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index b7c7605c4..a908a7eff 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -17,6 +17,7 @@ from numpyro.distributions import constraints from numpyro.distributions.util import ( + add_diag, matrix_to_tril_vec, signed_stick_breaking_tril, sum_rightmost, @@ -753,7 +754,7 @@ def __call__(self, x): n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) z = vec_to_tril_matrix(x[..., :-n], diagonal=-1) diag = jnp.exp(x[..., -n:]) - return z + jnp.expand_dims(diag, axis=-1) * jnp.identity(n) + return add_diag(z, diag) def _inverse(self, y): z = matrix_to_tril_vec(y, diagonal=-1) @@ -792,7 +793,7 @@ def __call__(self, x): n = round((math.sqrt(1 + 8 * x.shape[-1]) - 1) / 2) z = vec_to_tril_matrix(x[..., :-n], diagonal=-1) diag = softplus(x[..., -n:]) - return (z + jnp.identity(n)) * diag[..., None] + return add_diag(z, 1) * diag[..., None] def _inverse(self, y): diag = jnp.diagonal(y, axis1=-2, axis2=-1) diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index f59e336db..c2e7aafc2 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -400,7 +400,7 @@ def signed_stick_breaking_tril(t): z1m_cumprod_sqrt_shifted = jnp.pad( z1m_cumprod_sqrt[..., :-1], pad_width, mode="constant", constant_values=1.0 ) - y = (r + jnp.identity(r.shape[-1])) * z1m_cumprod_sqrt_shifted + y = add_diag(r, 1) * z1m_cumprod_sqrt_shifted return y @@ -680,3 +680,11 @@ def wrapper(self, *args, **kwargs): return log_prob return wrapper + + +def add_diag(matrix: jnp.ndarray, diag: jnp.ndarray) -> jnp.ndarray: + """ + Add `diag` to the trailing diagonal of `matrix`. + """ + idx = jnp.arange(matrix.shape[-1]) + return matrix.at[..., idx, idx].add(diag) diff --git a/test/test_distributions.py b/test/test_distributions.py index f80f6bcf6..462a461a0 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -7,6 +7,7 @@ from itertools import product import math import os +from typing import Callable import numpy as np from numpy.testing import assert_allclose, assert_array_equal @@ -3070,9 +3071,11 @@ def sample(d: dist.Distribution): for in_axes, out_axes in in_out_axes_cases: batched_params = [ - jax.tree_map(lambda x: jnp.expand_dims(x, ax), arg) - if isinstance(ax, int) - else arg + ( + jax.tree_map(lambda x: jnp.expand_dims(x, ax), arg) + if isinstance(ax, int) + else arg + ) for arg, ax in zip(params, in_axes) ] # Recreate the jax_dist to avoid side effects coming from `d.sample` @@ -3169,3 +3172,52 @@ def test_sample_truncated_normal_in_tail(): def test_jax_custom_prng(): samples = dist.Normal(0, 5).sample(random.PRNGKey(0), sample_shape=(1000,)) assert ~jnp.isinf(samples).any() + + +def _assert_not_jax_issue_19885( + capfd: pytest.CaptureFixture, func: Callable, *args, **kwargs +) -> None: + # jit-ing identity plus matrix multiplication leads to performance degradation as + # discussed in https://github.com/google/jax/issues/19885. This assertion verifies + # that the issue does not affect perforance in numpyro. + for jit in [True, False]: + result = jax.jit(func)(*args, **kwargs) + block_until_ready = getattr(result, "block_until_ready", None) + if block_until_ready: + result = block_until_ready() + _, err = capfd.readouterr() + assert ( + "MatMul reference implementation being executed" not in err + ), f"jit: {jit}" + return result + + +@pytest.mark.xfail +def test_jax_issue_19885(capfd: pytest.CaptureFixture) -> None: + def func_with_warning(y) -> jnp.ndarray: + return jnp.identity(y.shape[-1]) + jnp.matmul(y, y) + + _assert_not_jax_issue_19885(capfd, func_with_warning, jnp.ones((20, 100, 100))) + + +def test_lowrank_mvn_19885(capfd: pytest.CaptureFixture) -> None: + # Create parameters. + batch_size = 100 + event_size = 200 + sample_size = 40 + rank = 40 + loc, cov_diag = random.normal(random.key(0), (2, batch_size, event_size)) + cov_diag = jnp.exp(cov_diag) + cov_factor = random.normal(random.key(1), (batch_size, event_size, rank)) + + distribution = _assert_not_jax_issue_19885( + capfd, dist.LowRankMultivariateNormal, loc, cov_factor, cov_diag + ) + x = _assert_not_jax_issue_19885( + capfd, + lambda x: distribution.sample(random.key(0), x.shape), + jnp.empty(sample_size), + ) + assert x.shape == (sample_size, batch_size, event_size) + log_prob = _assert_not_jax_issue_19885(capfd, distribution.log_prob, x) + assert log_prob.shape == (sample_size, batch_size) diff --git a/test/test_distributions_util.py b/test/test_distributions_util.py index d97d1d5ac..ab15966f8 100644 --- a/test/test_distributions_util.py +++ b/test/test_distributions_util.py @@ -13,6 +13,7 @@ from jax.scipy.special import expit, xlog1py, xlogy from numpyro.distributions.util import ( + add_diag, binary_cross_entropy_with_logits, binomial, categorical, @@ -164,3 +165,20 @@ def test_safe_normalize(dim): data = jnp.zeros((10, dim)) x = safe_normalize(data) assert_allclose((x * x).sum(-1), jnp.ones(x.shape[:-1]), rtol=1e-6) + + +@pytest.mark.parametrize( + "matrix_shape, diag_shape", + [ + ((5, 5), ()), + ((7, 7), (7,)), + ((10, 3, 3), (10, 3)), + ((7, 5, 9, 9), (5, 1)), + ], +) +def test_add_diag(matrix_shape: tuple, diag_shape: tuple) -> None: + matrix = random.normal(random.key(0), matrix_shape) + diag = random.normal(random.key(1), diag_shape) + expected = matrix + diag[..., None] * jnp.eye(matrix.shape[-1]) + actual = add_diag(matrix, diag) + np.testing.assert_allclose(actual, expected)