Skip to content

Commit

Permalink
Fix faulty interaction between jax.vmap and validate_args=True (p…
Browse files Browse the repository at this point in the history
…yro-ppl#1686)

* add initial bug reproducer

* disable arg validation during `tree_unflatten`
  • Loading branch information
pierreglaser authored and OlaRonning committed May 6, 2024
1 parent 8ddfeab commit 521a1e7
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
8 changes: 6 additions & 2 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,10 @@ def gather_pytree_data_fields(cls):
return all_pytree_data_fields

@classmethod
def gather_pytree_aux_fields(cls):
def gather_pytree_aux_fields(cls) -> tuple:
bases = inspect.getmro(cls)

all_pytree_aux_fields = ()
all_pytree_aux_fields = ("_validate_args",)
for base in bases:
if issubclass(base, Distribution):
all_pytree_aux_fields += base.__dict__.get("pytree_aux_fields", ())
Expand Down Expand Up @@ -203,11 +203,15 @@ def tree_unflatten(cls, aux_data, params):
for k, v in pytree_aux_fields_dict.items():
setattr(d, k, v)

# disable args validation during `tree_unflatten` it is called by jax with
# placeholder attributes that would make validation fail
d._validate_args = False
Distribution.__init__(
d,
pytree_aux_fields_dict["_batch_shape"],
pytree_aux_fields_dict["_event_shape"],
)
d._validate_args = pytree_aux_fields_dict["_validate_args"]
return d

@staticmethod
Expand Down
15 changes: 15 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3092,6 +3092,21 @@ def sample(d: dist.Distribution):
assert samples_batched_dist.shape == (1, *samples_dist.shape)


def test_vmap_validate_args():
# Test for #1684: vmapping distributions whould work when `validate_args=True`
v_dist = jax.vmap(
lambda loc, scale: dist.Normal(loc=loc, scale=scale, validate_args=True),
in_axes=(0, 0),
)(jnp.zeros((2,)), jnp.zeros((2,)))

# non-regression test
v_dist = jax.vmap(
lambda loc, scale: dist.Normal(loc=loc, scale=scale, validate_args=False),
in_axes=(0, 0),
)(jnp.zeros((2,)), jnp.zeros((2,)))
assert not v_dist._validate_args


def test_multinomial_abstract_total_count():
probs = jnp.array([0.2, 0.5, 0.3])
key = random.PRNGKey(0)
Expand Down

0 comments on commit 521a1e7

Please sign in to comment.