Skip to content

Commit

Permalink
Do not unflatten unevaluated lazy properties.
Browse files Browse the repository at this point in the history
  • Loading branch information
tillahoffmann committed Apr 11, 2024
1 parent d7159b8 commit 7ccfa3b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
6 changes: 4 additions & 2 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,12 @@ def tree_unflatten(cls, aux_data, params):
d = cls.__new__(cls)

for k, v in pytree_data_fields_dict.items():
setattr(d, k, v)
if v is not None or not isinstance(getattr(cls, k, None), lazy_property):
setattr(d, k, v)

for k, v in pytree_aux_fields_dict.items():
setattr(d, k, v)
if v is not None or not isinstance(getattr(cls, k, None), lazy_property):
setattr(d, k, v)

# disable args validation during `tree_unflatten` it is called by jax with
# placeholder attributes that would make validation fail
Expand Down
11 changes: 11 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2768,6 +2768,17 @@ def f(x):
# Test that parameters do not change after flattening.
expected_dist = f(0)
actual_dist = jax.jit(f)(0)
for name in expected_dist.arg_constraints:
expected_arg = getattr(expected_dist, name)
actual_arg = getattr(actual_dist, name)
assert actual_arg is not None, f"arg {name} is None"
if np.issubdtype(np.asarray(expected_arg).dtype, np.number):
assert_allclose(actual_arg, expected_arg)
else:
assert (
actual_arg.shape == expected_arg.shape
and actual_arg.dtype == expected_arg.dtype
)
expected_sample = expected_dist.sample(random.PRNGKey(0))
actual_sample = actual_dist.sample(random.PRNGKey(0))
expected_log_prob = expected_dist.log_prob(expected_sample)
Expand Down

0 comments on commit 7ccfa3b

Please sign in to comment.