Skip to content

Commit

Permalink
fix singleton plate bug (#1792)
Browse files Browse the repository at this point in the history
  • Loading branch information
amifalk authored May 5, 2024
1 parent 7c3ec50 commit b500936
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
6 changes: 5 additions & 1 deletion numpyro/contrib/funsor/infer_util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import defaultdict
from collections import OrderedDict, defaultdict
from contextlib import contextmanager
import functools
import re
Expand Down Expand Up @@ -220,6 +220,10 @@ def _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op):
log_prob = scale * log_prob

dim_to_name = site["infer"]["dim_to_name"]

if all(dim == 1 for dim in log_prob.shape) and dim_to_name == OrderedDict():
log_prob = log_prob.squeeze()

log_prob_factor = funsor.to_funsor(
log_prob, output=funsor.Real, dim_to_name=dim_to_name
)
Expand Down
15 changes: 15 additions & 0 deletions test/contrib/test_funsor.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,3 +575,18 @@ def gmm(data):
with pytest.raises(Exception):
mcmc.run(random.PRNGKey(2), data)
assert len(_PYRO_STACK) == 0


@pytest.mark.parametrize(
"i_size, j_size, k_size", [(1, 1, 1), (1, 2, 1), (2, 1, 1), (1, 1, 2)]
)
def test_singleton_plate_works(i_size, j_size, k_size):
def model():
with numpyro.plate("i", i_size, dim=-3):
with numpyro.plate("j", j_size, dim=-2):
with numpyro.plate("k", k_size, dim=-1):
numpyro.sample("a", dist.Normal())

model = enum(numpyro.handlers.seed(model, rng_seed=0), first_available_dim=-4)

log_density(model, (), {}, {})

0 comments on commit b500936

Please sign in to comment.