Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

get_model_relations and get_dependencies give UnexpectedTracerError on seeded models #1886

Closed
danielward27 opened this issue Oct 10, 2024 · 1 comment · Fixed by #1926
Closed

Comments

@danielward27
Copy link
Contributor

I'm assuming these functions are expected to work with seeded models, but they yield UnexpectedTracerErrors.

import numpyro
import jax
import jax.random as jr
import numpyro.distributions as dist
from numpyro.infer.inspect import get_model_relations
from numpyro import handlers

def model():
    m = numpyro.sample('m', dist.Normal(0, 1))
    numpyro.sample('sd', dist.LogNormal(m, 1))

seeded_model = handlers.seed(model, jr.key(0))

with jax.checking_leaks():
    get_model_relations(seeded_model)  # UnexpectedTracerError, same for get_dependencies

The traceback is

Traceback (most recent call last):
  File "<stdin>", line 2, in <module>
  File "/home/dw16200/miniconda3/envs/softcvi_env/lib/python3.12/site-packages/numpyro/infer/inspect.py", line 323, in get_model_relations
    trace = jax.eval_shape(get_trace).trace
            ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dw16200/miniconda3/envs/softcvi_env/lib/python3.12/contextlib.py", line 144, in __exit__
    next(self.gen)
^^^^^^^^^^^^^^^^^^
Exception: Leaked trace MainTrace(1,DynamicJaxprTrace). Leaked tracer(s):

Traced<ShapedArray(key<fry>[])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line <stdin>:3:4 (model)
<DynamicJaxprTracer 129514377205776> is referred to by <seed 129515008443408>.rng_key
<seed 129515008443408> is referred to by __main__.seeded_model

--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
@fehiepsi
Copy link
Member

You can wrap the seeded model in a function call to avoid the leakage, e.g.

def seeded_model():
    return handlers.seed(model, jr.key(0))()

We can avoid the issue for seed handler, but there are other handlers that have mutable properties like trace handler (which traces things during the run). A good practice is to avoid applying those handlers globally.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants