Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KeyError: 'infer' when using handlers.lift #891

Closed
dominikstrb opened this issue Jan 25, 2021 · 3 comments
Closed

KeyError: 'infer' when using handlers.lift #891

dominikstrb opened this issue Jan 25, 2021 · 3 comments

Comments

@dominikstrb
Copy link
Contributor

Hi all,

I have been using numpyro.handlers.lift to replace parameters in my model with priors. Here is a minimal working (at least with 0.4.1) example:

import jax.numpy as jnp
from jax import random
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS


def model(data):
    c = numpyro.param("c", jnp.array(1.), constraint=dist.constraints.positive)

    x = numpyro.sample("x", dist.LogNormal(c, 1.), obs=data)

    return x


with numpyro.handlers.seed(rng_seed=1):
    x = model(None)

nuts_kernel = NUTS(numpyro.handlers.lift(model, prior={"c": dist.Gamma(0.01, 0.01)}))
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
mcmc.run(random.PRNGKey(1), x)

Since upgrading to 0.5.0, I am getting the following error message

  File "test.py", line 21, in <module>
    mcmc.run(random.PRNGKey(1), x)
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 499, in run
    states_flat, last_state = partial_map_fn(map_args)
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 327, in _single_chain_mcmc
    init_state = self.sampler.init(rng_key, self.num_warmup, init_params,
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 468, in init
    init_params = self._init_state(rng_key_init_model, model_args, model_kwargs, init_params)
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 423, in _init_state
    init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/util.py", line 462, in initialize_model
    (init_params, pe, grad), is_valid = find_valid_initial_params(
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/util.py", line 270, in find_valid_initial_params
    (init_params, pe, z_grad), is_valid = _find_valid_params(rng_key, exit_early=True)
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/util.py", line 258, in _find_valid_params
    _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/util.py", line 248, in body_fn
    pe, z_grad = value_and_grad(potential_fn)(params)
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/util.py", line 165, in potential_energy
    log_joint, model_trace = log_density_(substituted_model, model_args, model_kwargs, {})
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/util.py", line 49, in log_density
    model_trace = trace(model).get_trace(*model_args, **model_kwargs)
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/handlers.py", line 161, in get_trace
    self(*args, **kwargs)
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/primitives.py", line 81, in __call__
    return self.fn(*args, **kwargs)
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/primitives.py", line 81, in __call__
    return self.fn(*args, **kwargs)
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/primitives.py", line 81, in __call__
    return self.fn(*args, **kwargs)
  [Previous line repeated 2 more times]
  File "test.py", line 9, in model
    c = numpyro.param("c", jnp.array(1.), constraint=dist.constraints.positive)
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/primitives.py", line 181, in param
    msg = apply_stack(initial_msg)
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/primitives.py", line 25, in apply_stack
    handler.process_message(msg)
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/handlers.py", line 713, in process_message
    value = self.substitute_fn(msg)
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/util.py", line 123, in _unconstrain_reparam
    i = site['infer'].get('_scan_current_index', None)
jax._src.traceback_util.FilteredStackTrace: KeyError: 'infer'

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "test.py", line 21, in <module>
    mcmc.run(random.PRNGKey(1), x)
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 499, in run
    states_flat, last_state = partial_map_fn(map_args)
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 327, in _single_chain_mcmc
    init_state = self.sampler.init(rng_key, self.num_warmup, init_params,
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 468, in init
    init_params = self._init_state(rng_key_init_model, model_args, model_kwargs, init_params)
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 423, in _init_state
    init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/util.py", line 462, in initialize_model
    (init_params, pe, grad), is_valid = find_valid_initial_params(
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/util.py", line 270, in find_valid_initial_params
    (init_params, pe, z_grad), is_valid = _find_valid_params(rng_key, exit_early=True)
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/util.py", line 258, in _find_valid_params
    _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/util.py", line 248, in body_fn
    pe, z_grad = value_and_grad(potential_fn)(params)
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/jax/api.py", line 805, in value_and_grad_f
    ans, vjp_py = _vjp(f_partial, *dyn_args)
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/jax/api.py", line 1874, in _vjp
    out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/jax/interpreters/ad.py", line 114, in vjp
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/jax/interpreters/ad.py", line 101, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 506, in trace_to_jaxpr
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/jax/linear_util.py", line 160, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/util.py", line 165, in potential_energy
    log_joint, model_trace = log_density_(substituted_model, model_args, model_kwargs, {})
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/util.py", line 49, in log_density
    model_trace = trace(model).get_trace(*model_args, **model_kwargs)
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/handlers.py", line 161, in get_trace
    self(*args, **kwargs)
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/primitives.py", line 81, in __call__
    return self.fn(*args, **kwargs)
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/primitives.py", line 81, in __call__
    return self.fn(*args, **kwargs)
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/primitives.py", line 81, in __call__
    return self.fn(*args, **kwargs)
  [Previous line repeated 2 more times]
  File "test.py", line 9, in model
    c = numpyro.param("c", jnp.array(1.), constraint=dist.constraints.positive)
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/primitives.py", line 181, in param
    msg = apply_stack(initial_msg)
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/primitives.py", line 25, in apply_stack
    handler.process_message(msg)
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/handlers.py", line 713, in process_message
    value = self.substitute_fn(msg)
  File "/home/dominik/venvs/default/lib/python3.8/site-packages/numpyro/infer/util.py", line 123, in _unconstrain_reparam
    i = site['infer'].get('_scan_current_index', None)
KeyError: 'infer'

If I replace the param with a sample statement, the example works without problems. Has anything about the usage of handlers.lift changed in 0.5.0? Or is this a bug?

@fehiepsi
Copy link
Member

Thanks, @dominikstrb! It is a bug of missing infer key when converting a param primitive to a sample primitive here. Do you want to contribute a fix? We just need to add msg["infer"] = msg.get("infer", {}) there I believe.

@dominikstrb
Copy link
Contributor Author

Hey @fehiepsi, thanks for the quick reply! Your suggested fix works. I submitted PR #892

@fehiepsi
Copy link
Member

Solved in #892. Thanks @dominikstrb for addressing this issue!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants