Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NumPyro doesn't raise if a determinstic variable has the same name as a random variable #1281

Merged

Conversation

MarcoGorelli
Copy link
Contributor

@MarcoGorelli MarcoGorelli commented Jan 9, 2022

closes #1280

With this change, the traceback from the same example gives:

$ python t.py 
Traceback (most recent call last):
  File "t.py", line 13, in <module>
    mcmc.run(jax.random.PRNGKey(0))
  File "/home/marco/numpyro-dev/numpyro/infer/mcmc.py", line 599, in run
    states_flat, last_state = partial_map_fn(map_args)
  File "/home/marco/numpyro-dev/numpyro/infer/mcmc.py", line 387, in _single_chain_mcmc
    init_state = self.sampler.init(
  File "/home/marco/numpyro-dev/numpyro/infer/hmc.py", line 696, in init
    init_params = self._init_state(
  File "/home/marco/numpyro-dev/numpyro/infer/hmc.py", line 642, in _init_state
    init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
  File "/home/marco/numpyro-dev/numpyro/infer/util.py", line 608, in initialize_model
    ) = _get_model_transforms(substituted_model, model_args, model_kwargs)
  File "/home/marco/numpyro-dev/numpyro/infer/util.py", line 399, in _get_model_transforms
    model_trace = trace(model).get_trace(*model_args, **model_kwargs)
  File "/home/marco/numpyro-dev/numpyro/handlers.py", line 171, in get_trace
    self(*args, **kwargs)
  File "/home/marco/numpyro-dev/numpyro/primitives.py", line 87, in __call__
    return self.fn(*args, **kwargs)
  File "/home/marco/numpyro-dev/numpyro/primitives.py", line 87, in __call__
    return self.fn(*args, **kwargs)
  File "/home/marco/numpyro-dev/numpyro/primitives.py", line 87, in __call__
    return self.fn(*args, **kwargs)
  File "t.py", line 9, in model
    numpyro.deterministic("alpha", alpha * 2)
  File "/home/marco/numpyro-dev/numpyro/primitives.py", line 279, in deterministic
    msg = apply_stack(initial_msg)
  File "/home/marco/numpyro-dev/numpyro/primitives.py", line 41, in apply_stack
    handler.postprocess_message(msg)
  File "/home/marco/numpyro-dev/numpyro/handlers.py", line 156, in postprocess_message
    assert not (
AssertionError: all sites must have unique names but got `alpha` duplicated

@MarcoGorelli MarcoGorelli force-pushed the raise-if-duplicate-in-deterministic branch from 15ba7aa to 65beb08 Compare January 9, 2022 17:16
@MarcoGorelli MarcoGorelli marked this pull request as ready for review January 9, 2022 17:16
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @MarcoGorelli!!!

@fehiepsi fehiepsi merged commit 6a0856b into pyro-ppl:master Jan 9, 2022
@MarcoGorelli MarcoGorelli deleted the raise-if-duplicate-in-deterministic branch January 9, 2022 18:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

NumPyro doesn't raise if a determinstic variable has the same name as a random variable
2 participants