Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not unflatten unevaluated lazy properties. #1778

Merged
merged 1 commit into from
Apr 12, 2024

Conversation

tillahoffmann
Copy link
Contributor

The tree_flatten function uses self.__dict__.get(name) to obtain the data field with the given name. Unevaluated lazy properties do not appear in self.__dict__, and they are set to None when the representation is unflattened. Here is an example.

>>> import jax
>>> from jax import numpy as jnp
>>> import numpyro
>>> 
>>> 
>>> @jax.jit
>>> def f1(x):
...     return numpyro.distributions.MultivariateNormal(jnp.zeros(3), jnp.eye(3))
>>> 
>>> 
>>> @jax.jit
>>> def f2(x):
...     dist = numpyro.distributions.MultivariateNormal(jnp.zeros(3), jnp.eye(3))
...     dist.precision_matrix
...     return dist


>>> print(f1(0).precision_matrix)
None
>>> print(f2(0).precision_matrix)
[[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]]

The changes in this PR only set the attribute on the reconstructed instance if the value is not None or if the attribute is not a lazy property. While there is ambiguity between None representing an unevaluated property and None being the value of a lazy property, the implementation remains correct: If the evaluated value is None it is not cached and re-evaluates to None the first time the attribute is accessed on the reconstructed instance.

I've also added a test.

@tillahoffmann tillahoffmann changed the title Exclude lazy_property from data fields. Do not unflatten unevaluated lazy properties. Apr 11, 2024
@fehiepsi
Copy link
Member

Hi @tillahoffmann, could you point me to an example where the value is None but it is not a lazy property?

@tillahoffmann
Copy link
Contributor Author

I'm not actually aware of an example where the value is None. We could also just not set the value if it's None without the extra lazy_property check.

I added it to prevent surprises. E.g., if a user implemented a custom distribution where one of the fields has a None value and we didn't check the lazy_property, they would get an AttributeError if they accessed the attribute after a flatten/unflatten cycle. It's probably an unlikely scenario, however.

@fehiepsi
Copy link
Member

Oh, I think I understand your implementation now. That makes sense to me. Thanks!!

@fehiepsi fehiepsi closed this Apr 12, 2024
@fehiepsi fehiepsi reopened this Apr 12, 2024
@fehiepsi fehiepsi merged commit 7facd8c into pyro-ppl:master Apr 12, 2024
8 checks passed
OlaRonning pushed a commit to aleatory-science/numpyro that referenced this pull request May 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants