-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make omnistaging work with validate_args=True #775
Changes from 1 commit
db72bc1
57a76bc
afa505b
ac265df
7cffa27
23c1f6f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,8 @@ | |
from contextlib import contextmanager | ||
import warnings | ||
|
||
import numpy as np | ||
|
||
from jax import lax, tree_util | ||
import jax.numpy as jnp | ||
|
||
|
@@ -138,9 +140,9 @@ def __init__(self, batch_shape=(), event_shape=(), validate_args=None): | |
continue | ||
if is_dependent(constraint): | ||
continue # skip constraints that cannot be checked | ||
is_valid = jnp.all(constraint(getattr(self, param))) | ||
is_valid = constraint(getattr(self, param)) | ||
if not_jax_tracer(is_valid): | ||
if not is_valid: | ||
if not np.all(is_valid): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For my understanding, if we were to always cast a python bool to a numpy or jax bool type (i.e. the default mask would be a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We we need this change because under from jax import device_put, jit
import jax.numpy as jnp
from numpyro.util import not_jax_tracer
def f():
assert not not_jax_tracer(jnp.all(True))
assert not not_jax_tracer(device_put(True))
jit(f)() # pass You are right that if we cast There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for explaining about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, that's true. |
||
raise ValueError("{} distribution got invalid {} parameter.".format( | ||
self.__class__.__name__, param)) | ||
super(Distribution, self).__init__() | ||
|
@@ -243,7 +245,7 @@ def variance(self): | |
def _validate_sample(self, value): | ||
mask = self.support(value) | ||
if not_jax_tracer(mask): | ||
if not jnp.all(mask): | ||
if not np.all(mask): | ||
warnings.warn('Out-of-support values provided to log prob method. ' | ||
'The value argument should be within the support.') | ||
return mask | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -748,6 +748,22 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape): | |
d.log_prob(oob_samples) | ||
|
||
|
||
def test_omnistaging_invalid_param(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have tests that check whether arg constraints are satisfied, but I suppose we do not have tests that check validation for out of support values in distributions? If so, it will be good to add tests for all relevant distributions and do a patch release. WDYT? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test is fine, but my concern is if we have missed something like this for other distributions. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I think we already have some tests that check for both constraints and support. Let me also add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, you raised a very good point. It is just lucky that my test passed for |
||
def f(x): | ||
return dist.LogNormal(x, -np.ones(2), validate_args=True).log_prob(0) | ||
|
||
with pytest.raises(ValueError, match="got invalid"): | ||
jax.jit(f)(0) | ||
|
||
|
||
def test_omnistaging_invalid_sample(): | ||
def f(x): | ||
return dist.LogNormal(x, np.ones(2), validate_args=True).log_prob(-1) | ||
|
||
with pytest.warns(UserWarning, match="Out-of-support"): | ||
jax.jit(f)(0) | ||
|
||
|
||
def test_categorical_log_prob_grad(): | ||
data = jnp.repeat(jnp.arange(3), 10) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as my comment below - is this
np
needed becauseself.total_count
could be a python scalar. In that case does it make sense to usejnp.array
ordevice_put
everywhere? That way, we can freely usejnp
functions instead of a mix ofnp
andjnp
which becomes hard to reason about.