Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initialize_model() fails with batch size of 1 for model with discrete variables #1448

Closed
rafaol opened this issue Jul 11, 2022 · 5 comments · Fixed by #1792
Closed

initialize_model() fails with batch size of 1 for model with discrete variables #1448

rafaol opened this issue Jul 11, 2022 · 5 comments · Fixed by #1792
Labels
bug Something isn't working

Comments

@rafaol
Copy link

rafaol commented Jul 11, 2022

Hi,

I'm having issues using NumPyro's latest version of initialize_model() on a model with discrete variables. The model samples random variables within a plate context with a data-dependent size. When the given data is of size $N > 1$, everything works fine. However, if I pass a dataset with a single data point ($N = 1$), the execution of initialize_model fails. Apparently, the funsor API expected the computed log-probability of a sampled variable to be a scalar, when in fact it's a single-element 1D array for this case.

Minimal working example:

import jax
import jax.numpy as jnp
import numpyro
from numpyro.distributions import Bernoulli, MultivariateNormal
from numpyro.infer.util import initialize_model


def model(X: jnp.ndarray, y: jnp.ndarray, z: jnp.ndarray):
    phi = numpyro.sample('phi', MultivariateNormal(jnp.zeros(2), jnp.eye(2)))

    n_data = X.shape[-2]
    with numpyro.plate('individual', n_data):
        weights = jnp.tensordot(X, phi, axes=1)
        numpyro.sample('labels', Bernoulli(probs=(1/(1 + jnp.exp(-weights)))), obs=z, infer={'enumerate': 'parallel'})
        coefficients = numpyro.sample('coefficients', MultivariateNormal(jnp.zeros(2), jnp.eye(2)))
        numpyro.sample('responses', MultivariateNormal(X * coefficients, jnp.eye(2)), obs=y)


if __name__ == '__main__':
    rng_key = jax.random.PRNGKey(0)

    # This line runs without problems:
    res = initialize_model(rng_key, model, model_kwargs={'X': jnp.zeros((10, 2)), 'z': None, 'y': None})

    # This line fails
    res1 = initialize_model(rng_key, model, model_kwargs={'X': jnp.zeros((1, 2)), 'z': None, 'y': None})

    print('Done')

Execution output:

Traceback (most recent call last):
  File "/home/rafael/Projects/trajectory-clustering/funsors_issue.py", line 26, in <module>
    res1 = initialize_model(rng_key, model, model_kwargs={'X': jnp.zeros((1, 2)), 'z': None, 'y': None})
  File "/home/rafael/Projects/numpyro/numpyro/infer/util.py", line 654, in initialize_model
    (init_params, pe, grad), is_valid = find_valid_initial_params(
  File "/home/rafael/Projects/numpyro/numpyro/infer/util.py", line 395, in find_valid_initial_params
    (init_params, pe, z_grad), is_valid = _find_valid_params(
  File "/home/rafael/Projects/numpyro/numpyro/infer/util.py", line 381, in _find_valid_params
    _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
  File "/home/rafael/Projects/numpyro/numpyro/infer/util.py", line 366, in body_fn
    pe, z_grad = value_and_grad(potential_fn)(params)
  File "/home/rafael/opt/anaconda3/envs/numpyro_dev/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/rafael/opt/anaconda3/envs/numpyro_dev/lib/python3.9/site-packages/jax/_src/api.py", line 1063, in value_and_grad_f
    ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
  File "/home/rafael/opt/anaconda3/envs/numpyro_dev/lib/python3.9/site-packages/jax/_src/api.py", line 2558, in _vjp
    out_primal, out_vjp = ad.vjp(
  File "/home/rafael/opt/anaconda3/envs/numpyro_dev/lib/python3.9/site-packages/jax/interpreters/ad.py", line 133, in vjp
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  File "/home/rafael/opt/anaconda3/envs/numpyro_dev/lib/python3.9/site-packages/jax/interpreters/ad.py", line 122, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
  File "/home/rafael/opt/anaconda3/envs/numpyro_dev/lib/python3.9/site-packages/jax/_src/profiler.py", line 312, in wrapper
    return func(*args, **kwargs)
  File "/home/rafael/opt/anaconda3/envs/numpyro_dev/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 621, in trace_to_jaxpr_nounits
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/home/rafael/opt/anaconda3/envs/numpyro_dev/lib/python3.9/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/rafael/Projects/numpyro/numpyro/infer/util.py", line 248, in potential_energy
    log_joint, model_trace = log_density_(
  File "/home/rafael/Projects/numpyro/numpyro/contrib/funsor/infer_util.py", line 270, in log_density
    result, model_trace, _ = _enum_log_density(
  File "/home/rafael/Projects/numpyro/numpyro/contrib/funsor/infer_util.py", line 181, in _enum_log_density
    log_prob_factor = funsor.to_funsor(
  File "/home/rafael/opt/anaconda3/envs/numpyro_dev/lib/python3.9/functools.py", line 888, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/home/rafael/opt/anaconda3/envs/numpyro_dev/lib/python3.9/site-packages/funsor/tensor.py", line 491, in tensor_to_funsor
    raise ValueError(
jax._src.traceback_util.UnfilteredStackTrace: ValueError: Invalid shape: expected (), actual (1,)

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/rafael/Projects/numpyro/numpyro/infer/util.py", line 248, in potential_energy
    log_joint, model_trace = log_density_(
  File "/home/rafael/Projects/numpyro/numpyro/contrib/funsor/infer_util.py", line 270, in log_density
    result, model_trace, _ = _enum_log_density(
  File "/home/rafael/Projects/numpyro/numpyro/contrib/funsor/infer_util.py", line 181, in _enum_log_density
    log_prob_factor = funsor.to_funsor(
  File "/home/rafael/opt/anaconda3/envs/numpyro_dev/lib/python3.9/functools.py", line 888, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/home/rafael/opt/anaconda3/envs/numpyro_dev/lib/python3.9/site-packages/funsor/tensor.py", line 491, in tensor_to_funsor
    raise ValueError(
ValueError: Invalid shape: expected (), actual (1,)
@fehiepsi fehiepsi added the bug Something isn't working label Jul 17, 2022
@fehiepsi
Copy link
Member

Thanks @rafaol! I guess we mixed up 1-size plate with singleton dimensions created by promoting and broadcasting somewhere. Let us take a closer look into this.

@amifalk
Copy link
Contributor

amifalk commented May 3, 2024

@fehiepsi I think I uncovered the source of the bug.

In this line of _enum_log_density, we try to cast the log_prob_factor to a funsor, defaulting to a funsor of shape () if there is no dim_to_name argument

dim_to_name = site["infer"]["dim_to_name"]
log_prob_factor = funsor.to_funsor(
log_prob, output=funsor.Real, dim_to_name=dim_to_name
)

Over in NamedMessenger, we don't make a name request if the batch shape is 1 and if there isn't an existing name, which is exactly the case for plates of size 1.

# interpret all names/dims as requests since we only run this function once
for dim in range(-batch_dim, 0):
name = dim_to_name.get(dim, None)
# the time dimension on the left sometimes necessitates empty dimensions appearing
# before they have been assigned a name
if batch_shape[dim] == 1 and name is None:
continue
dim_to_name[dim] = (
name if isinstance(name, NameRequest) else NameRequest(name, dim_type)
)

As a result enum_log_density tries to cast the log_prob associated with the variable (of shape (1,)) in the plate of size 1 into the Real funsor with shape () and fails.

Here's a minimal example showing the difference in dim_to_name allocation depending on plate size:

import numpyro
import numpyro.distributions as dist
from numpyro.contrib.funsor.infer_util import plate_to_enum_plate
from numpyro.contrib.funsor.enum_messenger import enum
from numpyro.contrib.funsor.enum_messenger import (
    infer_config,
    plate as enum_plate,
    trace as packed_trace,
)

def working_model():
    with numpyro.plate('n', 2):
        numpyro.sample('a', dist.Normal())

def failing_model():
    with numpyro.plate('n', 1):
        numpyro.sample('a', dist.Normal())

with enum(first_available_dim=-2):
    with plate_to_enum_plate():
        good_trace = packed_trace(numpyro.handlers.seed(working_model, rng_seed=0)).get_trace()
        
with enum(first_available_dim=-2):
    with plate_to_enum_plate():
        bad_trace = packed_trace(numpyro.handlers.seed(failing_model, rng_seed=0)).get_trace()

print(good_trace['a']['infer']['dim_to_name'], good_trace['a']['infer']['name_to_dim'])
print(bad_trace['a']['infer']['dim_to_name'], bad_trace['a']['infer']['name_to_dim'])
A quick fix is to insert this code right before trying to cast to funsor.
if log_prob.shape == (1,) and dim_to_name == OrderedDict():
   log_prob = log_prob.squeeze()

@fehiepsi
Copy link
Member

fehiepsi commented May 3, 2024

Thanks, @amifalk! Do you know what happens if log_prob shape has non-trivial singleton dimension, e.g., (2, 1, 3), (2, 1), (1, 2), (1, 1)?

@amifalk
Copy link
Contributor

amifalk commented May 4, 2024

It only fails when every single plate dimension for the variable is 1. i.e. (2, 1, 3), (2, 1), (1, 2) work, but (1, 1, ..., 1) fails. So I guess we just need to check that all(dim == 1 for dim in log_prob.shape)

@fehiepsi
Copy link
Member

fehiepsi commented May 4, 2024

Woohoo, that's a good news. Thanks for digging into this @amifalk! Could you send a PR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants