-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
initialize_model() fails with batch size of 1 for model with discrete variables #1448
Comments
Thanks @rafaol! I guess we mixed up 1-size plate with singleton dimensions created by promoting and broadcasting somewhere. Let us take a closer look into this. |
@fehiepsi I think I uncovered the source of the bug. In this line of numpyro/numpyro/contrib/funsor/infer_util.py Lines 222 to 225 in 7c3ec50
Over in numpyro/numpyro/contrib/funsor/enum_messenger.py Lines 275 to 285 in 7c3ec50
As a result Here's a minimal example showing the difference in dim_to_name allocation depending on plate size: import numpyro
import numpyro.distributions as dist
from numpyro.contrib.funsor.infer_util import plate_to_enum_plate
from numpyro.contrib.funsor.enum_messenger import enum
from numpyro.contrib.funsor.enum_messenger import (
infer_config,
plate as enum_plate,
trace as packed_trace,
)
def working_model():
with numpyro.plate('n', 2):
numpyro.sample('a', dist.Normal())
def failing_model():
with numpyro.plate('n', 1):
numpyro.sample('a', dist.Normal())
with enum(first_available_dim=-2):
with plate_to_enum_plate():
good_trace = packed_trace(numpyro.handlers.seed(working_model, rng_seed=0)).get_trace()
with enum(first_available_dim=-2):
with plate_to_enum_plate():
bad_trace = packed_trace(numpyro.handlers.seed(failing_model, rng_seed=0)).get_trace()
print(good_trace['a']['infer']['dim_to_name'], good_trace['a']['infer']['name_to_dim'])
print(bad_trace['a']['infer']['dim_to_name'], bad_trace['a']['infer']['name_to_dim']) if log_prob.shape == (1,) and dim_to_name == OrderedDict():
log_prob = log_prob.squeeze() |
Thanks, @amifalk! Do you know what happens if log_prob shape has non-trivial singleton dimension, e.g., (2, 1, 3), (2, 1), (1, 2), (1, 1)? |
It only fails when every single plate dimension for the variable is 1. i.e. (2, 1, 3), (2, 1), (1, 2) work, but (1, 1, ..., 1) fails. So I guess we just need to check that |
Woohoo, that's a good news. Thanks for digging into this @amifalk! Could you send a PR? |
Hi,
I'm having issues using NumPyro's latest version of$N > 1$ , everything works fine. However, if I pass a dataset with a single data point ($N = 1$ ), the execution of
initialize_model()
on a model with discrete variables. The model samples random variables within a plate context with a data-dependent size. When the given data is of sizeinitialize_model
fails. Apparently, the funsor API expected the computed log-probability of a sampled variable to be a scalar, when in fact it's a single-element 1D array for this case.Minimal working example:
Execution output:
The text was updated successfully, but these errors were encountered: