From ce1d39cbd75b1fadcd85d3c050a9779debe986a3 Mon Sep 17 00:00:00 2001 From: Ami Falk Date: Sat, 4 May 2024 10:16:14 -0400 Subject: [PATCH] fix singleton plate bug --- numpyro/contrib/funsor/infer_util.py | 6 +++++- test/contrib/test_funsor.py | 15 +++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/numpyro/contrib/funsor/infer_util.py b/numpyro/contrib/funsor/infer_util.py index 5ee0af278..5e97f82d2 100644 --- a/numpyro/contrib/funsor/infer_util.py +++ b/numpyro/contrib/funsor/infer_util.py @@ -1,7 +1,7 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from collections import defaultdict +from collections import OrderedDict, defaultdict from contextlib import contextmanager import functools import re @@ -220,6 +220,10 @@ def _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op): log_prob = scale * log_prob dim_to_name = site["infer"]["dim_to_name"] + + if all(dim == 1 for dim in log_prob.shape) and dim_to_name == OrderedDict(): + log_prob = log_prob.squeeze() + log_prob_factor = funsor.to_funsor( log_prob, output=funsor.Real, dim_to_name=dim_to_name ) diff --git a/test/contrib/test_funsor.py b/test/contrib/test_funsor.py index ad6037f17..0d403c59b 100644 --- a/test/contrib/test_funsor.py +++ b/test/contrib/test_funsor.py @@ -575,3 +575,18 @@ def gmm(data): with pytest.raises(Exception): mcmc.run(random.PRNGKey(2), data) assert len(_PYRO_STACK) == 0 + + +@pytest.mark.parametrize( + "i_size, j_size, k_size", [(1, 1, 1), (1, 2, 1), (2, 1, 1), (1, 1, 2)] +) +def test_singleton_plate_works(i_size, j_size, k_size): + def model(): + with numpyro.plate("i", i_size, dim=-3): + with numpyro.plate("j", j_size, dim=-2): + with numpyro.plate("k", k_size, dim=-1): + numpyro.sample("a", dist.Normal()) + + model = enum(numpyro.handlers.seed(model, rng_seed=0), first_available_dim=-4) + + log_density(model, (), {}, {})