Skip to content

Commit

Permalink
avoids setting jax tracer as lazy property attribute (#1843)
Browse files Browse the repository at this point in the history
* remove tracer as attribute of truncated dist

* streamline test

* fix CI test run failure

* Update test name

Co-authored-by: Dylan H. Morris <[email protected]>

* Move test from test_distributions.py to test_distributions_util.py

* Direct tests of tracer leaks

* pre-commit changes

---------

Co-authored-by: Dylan H. Morris <[email protected]>
Co-authored-by: Dylan H. Morris <[email protected]>
  • Loading branch information
3 people authored Aug 9, 2024
1 parent b6e4629 commit d61f15c
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 1 deletion.
5 changes: 4 additions & 1 deletion numpyro/distributions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from jax.scipy.linalg import solve_triangular
from jax.scipy.special import digamma

from numpyro.util import not_jax_tracer

# Parameters for Transformed Rejection with Squeeze (TRS) algorithm - page 3.
_tr_params = namedtuple(
"tr_params", ["c", "b", "a", "alpha", "u_r", "v_r", "m", "log_p", "log1_p", "log_h"]
Expand Down Expand Up @@ -692,7 +694,8 @@ def __get__(self, instance, obj_type=None):
if instance is None:
return self
value = self.wrapped(instance)
setattr(instance, self.wrapped.__name__, value)
if not_jax_tracer(value):
setattr(instance, self.wrapped.__name__, value)
return value


Expand Down
44 changes: 44 additions & 0 deletions test/test_distributions_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
import pytest
import scipy

import jax
from jax import lax, random, vmap
import jax.numpy as jnp
from jax.scipy.special import expit, xlog1py, xlogy

import numpyro.distributions as dist
from numpyro.distributions.util import (
add_diag,
binary_cross_entropy_with_logits,
Expand Down Expand Up @@ -182,3 +184,45 @@ def test_add_diag(matrix_shape: tuple, diag_shape: tuple) -> None:
expected = matrix + diag[..., None] * jnp.eye(matrix.shape[-1])
actual = add_diag(matrix, diag)
np.testing.assert_allclose(actual, expected)


@pytest.mark.parametrize(
"my_dist",
[
dist.TruncatedNormal(low=-1.0, high=2.0),
dist.TruncatedCauchy(low=-5, high=10),
dist.TruncatedDistribution(dist.StudentT(3), low=1.5),
],
)
def test_no_tracer_leak_at_lazy_property_log_prob(my_dist):
"""
Tests that truncated distributions, which use @lazy_property
values in their log_prob() methods, do not
have tracer leakage when log_prob() is called.
Reference: https://github.com/pyro-ppl/numpyro/issues/1836, and
https://github.com/CDCgov/multisignal-epi-inference/issues/282
"""
jit_lp = jax.jit(my_dist.log_prob)
with jax.check_tracer_leaks():
jit_lp(1.0)


@pytest.mark.parametrize(
"my_dist",
[
dist.TruncatedNormal(low=-1.0, high=2.0),
dist.TruncatedCauchy(low=-5, high=10),
dist.TruncatedDistribution(dist.StudentT(3), low=1.5),
],
)
def test_no_tracer_leak_at_lazy_property_sample(my_dist):
"""
Tests that truncated distributions, which use @lazy_property
values in their sample() methods, do not
have tracer leakage when sample() is called.
Reference: https://github.com/pyro-ppl/numpyro/issues/1836, and
https://github.com/CDCgov/multisignal-epi-inference/issues/282
"""
jit_sample = jax.jit(my_dist.sample)
with jax.check_tracer_leaks():
jit_sample(jax.random.key(5))

0 comments on commit d61f15c

Please sign in to comment.