-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Leak when running examples with JAX_CHECK_TRACER_LEAKS=1 #1467
Labels
bug
Something isn't working
Comments
Could you try to uninstall numpyro first then install the dev version (probably also with latest jax/jaxlib)? You can use |
Hi sorry for the late response, was away. So I have reinstalled everything and used the context manager
* no issues when running on my laptop mac (jax 0.3.16, jaxlib 0.3.15, numpyro 0.10.0)
* But still have the same issue on my Linux machine (jax 0.3.16 , jaxlib 0.3.15+cuda11.cudnn82, numpyro 0.10.0)
best
…________________________________
From: Du Phan ***@***.***>
Sent: Sunday, August 14, 2022 14:03
To: pyro-ppl/numpyro ***@***.***>
Cc: Francesco Franco ***@***.***>; Author ***@***.***>
Subject: Re: [pyro-ppl/numpyro] Leak when running examples with JAX_CHECK_TRACER_LEAKS=1 (Issue #1467)
Could you try to uninstall numpyro first then install the dev version (probably also with latest jax/jaxlib)? You can use jax.check_tracer_leaks() context manager to check the leak (instead of setting env variable)
—
Reply to this email directly, view it on GitHub<https://nam12.safelinks.protection.outlook.com/?url=https%3A%2F%2Fgithub.com%2Fpyro-ppl%2Fnumpyro%2Fissues%2F1467%23issuecomment-1214384614&data=05%7C01%7C%7C56852a5483f641d4671008da7dfdc631%7C84df9e7fe9f640afb435aaaaaaaaaaaa%7C1%7C0%7C637960826143409169%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C3000%7C%7C%7C&sdata=BohIdikzw64UGTeCHMnkToi%2F0rT7KhVlrTmyu1W2SVM%3D&reserved=0>, or unsubscribe<https://nam12.safelinks.protection.outlook.com/?url=https%3A%2F%2Fgithub.com%2Fnotifications%2Funsubscribe-auth%2FAJKIEHVREB6F4GJDDFXNI6DVZD4DJANCNFSM56K6CWXA&data=05%7C01%7C%7C56852a5483f641d4671008da7dfdc631%7C84df9e7fe9f640afb435aaaaaaaaaaaa%7C1%7C0%7C637960826143409169%7CUnknown%7CTWFpbGZsb3d8eyJWIjoiMC4wLjAwMDAiLCJQIjoiV2luMzIiLCJBTiI6Ik1haWwiLCJXVCI6Mn0%3D%7C3000%7C%7C%7C&sdata=WWsOZlyI%2FThuXD1jP2Ra54y7FEsSiV3hU0c9dSND5B8%3D&reserved=0>.
You are receiving this because you authored the thread.Message ID: ***@***.***>
|
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi
when I run ar2.py from yours examples (or bnn.py did not try with others) with the environment variable JAX_CHECK_TRACER_LEAKS=1 they fail. (I had to use it to try to find an issue with a function I had written)
the exception raised is
Exception: Leaked level MainTrace(1,DynamicJaxprTrace). Leaked tracer(s): [Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>].
here the complete log:
Traceback (most recent call last):
File "/home/ffranco/Downloads/ar2.py", line 138, in
main(args)
File "/home/ffranco/Downloads/ar2.py", line 117, in main
run_inference(model, args, rng_key, y)
File "/home/ffranco/Downloads/ar2.py", line 96, in run_inference
mcmc.run(rng_key, y=y)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/infer/mcmc.py", line 593, in run
states_flat, last_state = partial_map_fn(map_args)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/infer/mcmc.py", line 386, in _single_chain_mcmc
model_kwargs=kwargs,
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/infer/hmc.py", line 707, in init
rng_key_init_model, model_args, model_kwargs, init_params
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/infer/hmc.py", line 659, in _init_state
forward_mode_differentiation=self._forward_mode_differentiation,
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/infer/util.py", line 606, in initialize_model
) = _get_model_transforms(substituted_model, model_args, model_kwargs)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/infer/util.py", line 404, in _get_model_transforms
model_trace = trace(model).get_trace(*model_args, **model_kwargs)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/handlers.py", line 171, in get_trace
self(*args, **kwargs)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/primitives.py", line 105, in call
return self.fn(*args, **kwargs)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/primitives.py", line 105, in call
return self.fn(*args, **kwargs)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/primitives.py", line 105, in call
return self.fn(*args, **kwargs)
File "/home/ffranco/Downloads/ar2.py", line 67, in ar2_scan
scan(transition, init, timesteps)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/contrib/control_flow/scan.py", line 438, in scan
msg = apply_stack(initial_msg)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/primitives.py", line 53, in apply_stack
default_process_message(msg)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/primitives.py", line 28, in default_process_message
msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/contrib/control_flow/scan.py", line 306, in scan_wrapper
body_fn, wrapped_carry, xs, length=length, reverse=reverse
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/jax/_src/lax/control_flow.py", line 1345, in scan
init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/jax/_src/lax/control_flow.py", line 1332, in _create_jaxpr
f, in_tree, carry_avals + x_avals, "scan")
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/jax/_src/util.py", line 185, in wrapper
return f(*args, **kwargs)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/jax/_src/lax/control_flow.py", line 78, in _initial_style_jaxpr
fun, in_tree, in_avals, primitive_name)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/jax/_src/util.py", line 185, in wrapper
return f(*args, **kwargs)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/jax/_src/lax/control_flow.py", line 71, in _initial_style_open_jaxpr
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/jax/interpreters/partial_eval.py", line 1511, in trace_to_jaxpr_dynamic
del main, fun
File "/home/ffranco/anaconda3/lib/python3.7/contextlib.py", line 119, in exit
next(self.gen)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/jax/core.py", line 810, in new_main
raise Exception(f'Leaked level {t()}. Leaked tracer(s): {leaked_tracers}.')
jax._src.traceback_util.UnfilteredStackTrace: Exception: Leaked level MainTrace(1,DynamicJaxprTrace). Leaked tracer(s): [Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>].
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/ffranco/Downloads/ar2.py", line 138, in
main(args)
File "/home/ffranco/Downloads/ar2.py", line 117, in main
run_inference(model, args, rng_key, y)
File "/home/ffranco/Downloads/ar2.py", line 96, in run_inference
mcmc.run(rng_key, y=y)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/infer/mcmc.py", line 593, in run
states_flat, last_state = partial_map_fn(map_args)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/infer/mcmc.py", line 386, in _single_chain_mcmc
model_kwargs=kwargs,
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/infer/hmc.py", line 707, in init
rng_key_init_model, model_args, model_kwargs, init_params
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/infer/hmc.py", line 659, in _init_state
forward_mode_differentiation=self._forward_mode_differentiation,
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/infer/util.py", line 606, in initialize_model
) = _get_model_transforms(substituted_model, model_args, model_kwargs)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/infer/util.py", line 404, in _get_model_transforms
model_trace = trace(model).get_trace(*model_args, **model_kwargs)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/handlers.py", line 171, in get_trace
self(*args, **kwargs)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/primitives.py", line 105, in call
return self.fn(*args, **kwargs)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/primitives.py", line 105, in call
return self.fn(*args, **kwargs)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/primitives.py", line 105, in call
return self.fn(*args, **kwargs)
File "/home/ffranco/Downloads/ar2.py", line 67, in ar2_scan
scan(transition, init, timesteps)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/contrib/control_flow/scan.py", line 438, in scan
msg = apply_stack(initial_msg)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/primitives.py", line 53, in apply_stack
default_process_message(msg)
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/primitives.py", line 28, in default_process_message
msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])
File "/home/ffranco/numpyro08/lib/python3.7/site-packages/numpyro/contrib/control_flow/scan.py", line 306, in scan_wrapper
body_fn, wrapped_carry, xs, length=length, reverse=reverse
File "/home/ffranco/anaconda3/lib/python3.7/contextlib.py", line 119, in exit
next(self.gen)
Exception: Leaked level MainTrace(1,DynamicJaxprTrace). Leaked tracer(s): [Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>].
The text was updated successfully, but these errors were encountered: