Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Leak when running examples with JAX_CHECK_TRACER_LEAKS=1 #1467

Closed
hyperfra opened this issue Aug 12, 2022 · 4 comments · Fixed by #1469
Closed

Leak when running examples with JAX_CHECK_TRACER_LEAKS=1 #1467

hyperfra opened this issue Aug 12, 2022 · 4 comments · Fixed by #1469
Labels
bug Something isn't working

Comments

@hyperfra
Copy link

Hi
when I run ar2.py from yours examples (or bnn.py did not try with others) with the environment variable JAX_CHECK_TRACER_LEAKS=1 they fail. (I had to use it to try to find an issue with a function I had written)

the exception raised is

Exception: Leaked level MainTrace(1,DynamicJaxprTrace). Leaked tracer(s): [Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>].

here the complete log:

Traceback (most recent call last):
File "/home/ffranco/Downloads/ar2.py", line 138, in
main(args)
File "/home/ffranco/Downloads/ar2.py", line 117, in main
run_inference(model, args, rng_key, y)
File "/home/ffranco/Downloads/ar2.py", line 96, in run_inference
mcmc.run(rng_key, y=y)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/infer/mcmc.py", line 593, in run
states_flat, last_state = partial_map_fn(map_args)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/infer/mcmc.py", line 386, in _single_chain_mcmc
model_kwargs=kwargs,
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/infer/hmc.py", line 707, in init
rng_key_init_model, model_args, model_kwargs, init_params
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/infer/hmc.py", line 659, in _init_state
forward_mode_differentiation=self._forward_mode_differentiation,
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/infer/util.py", line 606, in initialize_model
) = _get_model_transforms(substituted_model, model_args, model_kwargs)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/infer/util.py", line 404, in _get_model_transforms
model_trace = trace(model).get_trace(*model_args, **model_kwargs)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/handlers.py", line 171, in get_trace
self(*args, **kwargs)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/primitives.py", line 105, in call
return self.fn(*args, **kwargs)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/primitives.py", line 105, in call
return self.fn(*args, **kwargs)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/primitives.py", line 105, in call
return self.fn(*args, **kwargs)
File "/home/ffranco/Downloads/ar2.py", line 67, in ar2_scan
scan(transition, init, timesteps)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/contrib/control_flow/scan.py", line 438, in scan
msg = apply_stack(initial_msg)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/primitives.py", line 53, in apply_stack
default_process_message(msg)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/primitives.py", line 28, in default_process_message
msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/contrib/control_flow/scan.py", line 306, in scan_wrapper
body_fn, wrapped_carry, xs, length=length, reverse=reverse
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/jax/_src/lax/control_flow.py", line 1345, in scan
init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/jax/_src/lax/control_flow.py", line 1332, in _create_jaxpr
f, in_tree, carry_avals + x_avals, "scan")
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/jax/_src/util.py", line 185, in wrapper
return f(*args, **kwargs)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/jax/_src/lax/control_flow.py", line 78, in _initial_style_jaxpr
fun, in_tree, in_avals, primitive_name)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/jax/_src/util.py", line 185, in wrapper
return f(*args, **kwargs)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/jax/_src/lax/control_flow.py", line 71, in _initial_style_open_jaxpr
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 1511, in trace_to_jaxpr_dynamic
del main, fun
File "/home/ffranco/anaconda3/lib/python3.7/contextlib.py", line 119, in exit
next(self.gen)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/jax/core.py", line 810, in new_main
raise Exception(f'Leaked level {t()}. Leaked tracer(s): {leaked_tracers}.')
jax._src.traceback_util.UnfilteredStackTrace: Exception: Leaked level MainTrace(1,DynamicJaxprTrace). Leaked tracer(s): [Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>].

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.


The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/home/ffranco/Downloads/ar2.py", line 138, in
main(args)
File "/home/ffranco/Downloads/ar2.py", line 117, in main
run_inference(model, args, rng_key, y)
File "/home/ffranco/Downloads/ar2.py", line 96, in run_inference
mcmc.run(rng_key, y=y)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/infer/mcmc.py", line 593, in run
states_flat, last_state = partial_map_fn(map_args)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/infer/mcmc.py", line 386, in _single_chain_mcmc
model_kwargs=kwargs,
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/infer/hmc.py", line 707, in init
rng_key_init_model, model_args, model_kwargs, init_params
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/infer/hmc.py", line 659, in _init_state
forward_mode_differentiation=self._forward_mode_differentiation,
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/infer/util.py", line 606, in initialize_model
) = _get_model_transforms(substituted_model, model_args, model_kwargs)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/infer/util.py", line 404, in _get_model_transforms
model_trace = trace(model).get_trace(*model_args, **model_kwargs)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/handlers.py", line 171, in get_trace
self(*args, **kwargs)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/primitives.py", line 105, in call
return self.fn(*args, **kwargs)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/primitives.py", line 105, in call
return self.fn(*args, **kwargs)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/primitives.py", line 105, in call
return self.fn(*args, **kwargs)
File "/home/ffranco/Downloads/ar2.py", line 67, in ar2_scan
scan(transition, init, timesteps)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/contrib/control_flow/scan.py", line 438, in scan
msg = apply_stack(initial_msg)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/primitives.py", line 53, in apply_stack
default_process_message(msg)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/primitives.py", line 28, in default_process_message
msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/contrib/control_flow/scan.py", line 306, in scan_wrapper
body_fn, wrapped_carry, xs, length=length, reverse=reverse
File "/home/ffranco/anaconda3/lib/python3.7/contextlib.py", line 119, in exit
next(self.gen)
Exception: Leaked level MainTrace(1,DynamicJaxprTrace). Leaked tracer(s): [Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>].

@fehiepsi fehiepsi added the bug Something isn't working label Aug 12, 2022
@hyperfra
Copy link
Author

hyperfra commented Aug 14, 2022 via email

@fehiepsi
Copy link
Member

Could you try to uninstall numpyro first then install the dev version (probably also with latest jax/jaxlib)? You can use jax.check_tracer_leaks() context manager to check the leak (instead of setting env variable)

@hyperfra
Copy link
Author

hyperfra commented Aug 21, 2022 via email

@fehiepsi
Copy link
Member

@hyperfra I think the issue Exception: Leaked level MainTrace(1,DynamicJaxprTrace). Leaked tracer(s): [Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>]. is fixed in #1469. Probably you get a different leaked tracer. If so, could you make a new issue for it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants