You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Since upgrading to 0.5.0, I am getting the following error message
File "test.py", line 21, in<module>
mcmc.run(random.PRNGKey(1), x)
File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 499, in run
states_flat, last_state = partial_map_fn(map_args)
File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 327, in _single_chain_mcmc
init_state = self.sampler.init(rng_key, self.num_warmup, init_params,
File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 468, in init
init_params = self._init_state(rng_key_init_model, model_args, model_kwargs, init_params)
File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 423, in _init_state
init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/util.py", line 462, in initialize_model
(init_params, pe, grad), is_valid = find_valid_initial_params(
File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/util.py", line 270, in find_valid_initial_params
(init_params, pe, z_grad), is_valid = _find_valid_params(rng_key, exit_early=True)
File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/util.py", line 258, in _find_valid_params
_, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/util.py", line 248, in body_fn
pe, z_grad = value_and_grad(potential_fn)(params)
File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/util.py", line 165, in potential_energy
log_joint, model_trace = log_density_(substituted_model, model_args, model_kwargs, {})
File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/util.py", line 49, in log_density
model_trace = trace(model).get_trace(*model_args, **model_kwargs)
File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/handlers.py", line 161, in get_trace
self(*args, **kwargs)
File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/primitives.py", line 81, in __call__
return self.fn(*args, **kwargs)
File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/primitives.py", line 81, in __call__
return self.fn(*args, **kwargs)
File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/primitives.py", line 81, in __call__
return self.fn(*args, **kwargs)
[Previous line repeated 2 more times]
File "test.py", line 9, in model
c = numpyro.param("c", jnp.array(1.), constraint=dist.constraints.positive)
File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/primitives.py", line 181, in param
msg = apply_stack(initial_msg)
File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/primitives.py", line 25, in apply_stack
handler.process_message(msg)
File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/handlers.py", line 713, in process_message
value = self.substitute_fn(msg)
File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/util.py", line 123, in _unconstrain_reparam
i = site['infer'].get('_scan_current_index', None)
jax._src.traceback_util.FilteredStackTrace: KeyError: 'infer'
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "test.py", line 21, in<module>
mcmc.run(random.PRNGKey(1), x)
File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 499, in run
states_flat, last_state = partial_map_fn(map_args)
File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 327, in _single_chain_mcmc
init_state = self.sampler.init(rng_key, self.num_warmup, init_params,
File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 468, in init
init_params = self._init_state(rng_key_init_model, model_args, model_kwargs, init_params)
File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 423, in _init_state
init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/util.py", line 462, in initialize_model
(init_params, pe, grad), is_valid = find_valid_initial_params(
File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/util.py", line 270, in find_valid_initial_params
(init_params, pe, z_grad), is_valid = _find_valid_params(rng_key, exit_early=True)
File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/util.py", line 258, in _find_valid_params
_, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/util.py", line 248, in body_fn
pe, z_grad = value_and_grad(potential_fn)(params)
File "/home/dominik/venvs/default/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/dominik/venvs/default/lib/python3.8/site-packages/jax/api.py", line 805, in value_and_grad_f
ans, vjp_py = _vjp(f_partial, *dyn_args)
File "/home/dominik/venvs/default/lib/python3.8/site-packages/jax/api.py", line 1874, in _vjp
out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
File "/home/dominik/venvs/default/lib/python3.8/site-packages/jax/interpreters/ad.py", line 114, in vjp
out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
File "/home/dominik/venvs/default/lib/python3.8/site-packages/jax/interpreters/ad.py", line 101, in linearize
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
File "/home/dominik/venvs/default/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 506, in trace_to_jaxpr
jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
File "/home/dominik/venvs/default/lib/python3.8/site-packages/jax/linear_util.py", line 160, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/util.py", line 165, in potential_energy
log_joint, model_trace = log_density_(substituted_model, model_args, model_kwargs, {})
File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/util.py", line 49, in log_density
model_trace = trace(model).get_trace(*model_args, **model_kwargs)
File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/handlers.py", line 161, in get_trace
self(*args, **kwargs)
File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/primitives.py", line 81, in __call__
return self.fn(*args, **kwargs)
File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/primitives.py", line 81, in __call__
return self.fn(*args, **kwargs)
File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/primitives.py", line 81, in __call__
return self.fn(*args, **kwargs)
[Previous line repeated 2 more times]
File "test.py", line 9, in model
c = numpyro.param("c", jnp.array(1.), constraint=dist.constraints.positive)
File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/primitives.py", line 181, in param
msg = apply_stack(initial_msg)
File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/primitives.py", line 25, in apply_stack
handler.process_message(msg)
File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/handlers.py", line 713, in process_message
value = self.substitute_fn(msg)
File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/util.py", line 123, in _unconstrain_reparam
i = site['infer'].get('_scan_current_index', None)
KeyError: 'infer'
If I replace the param with a sample statement, the example works without problems. Has anything about the usage of handlers.lift changed in 0.5.0? Or is this a bug?
The text was updated successfully, but these errors were encountered:
Thanks, @dominikstrb! It is a bug of missing infer key when converting a param primitive to a sample primitive here. Do you want to contribute a fix? We just need to add msg["infer"] = msg.get("infer", {}) there I believe.
Hi all,
I have been using
numpyro.handlers.lift
to replace parameters in my model with priors. Here is a minimal working (at least with 0.4.1) example:Since upgrading to 0.5.0, I am getting the following error message
If I replace the
param
with asample
statement, the example works without problems. Has anything about the usage ofhandlers.lift
changed in 0.5.0? Or is this a bug?The text was updated successfully, but these errors were encountered: