Skip to content

Commit

Permalink
raise if determinstic site name is duplicated
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Jan 9, 2022
1 parent cbc371e commit 65beb08
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
2 changes: 1 addition & 1 deletion numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def postprocess_message(self, msg):
# which has no name
return
assert not (
msg["type"] == "sample" and msg["name"] in self.trace
msg["type"] in ("sample", "deterministic") and msg["name"] in self.trace
), "all sites must have unique names but got `{}` duplicated".format(
msg["name"]
)
Expand Down
13 changes: 12 additions & 1 deletion test/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from numpyro import handlers
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
from numpyro.infer.util import log_density
import numpyro.optim as optim
from numpyro.util import not_jax_tracer, optional
Expand Down Expand Up @@ -778,3 +778,14 @@ def subsample_fn(rng_key):

# test that values are not duplicated
assert len(set(subsamples[k].copy())) == subsample_size


def test_sites_have_unique_names():
def model():
alpha = numpyro.sample("alpha", dist.Normal())
numpyro.deterministic("alpha", alpha * 2)

mcmc = MCMC(NUTS(model), num_chains=1, num_samples=10, num_warmup=10)
msg = "all sites must have unique names but got `alpha` duplicated"
with pytest.raises(AssertionError, match=msg):
mcmc.run(random.PRNGKey(0))

0 comments on commit 65beb08

Please sign in to comment.