-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make omnistaging work with validate_args=True #775
Changes from all commits
db72bc1
57a76bc
afa505b
ac265df
7cffa27
23c1f6f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,8 @@ | |
from contextlib import contextmanager | ||
import warnings | ||
|
||
import numpy as np | ||
|
||
from jax import lax, tree_util | ||
import jax.numpy as jnp | ||
|
||
|
@@ -150,9 +152,9 @@ def __init__(self, batch_shape=(), event_shape=(), validate_args=None): | |
continue | ||
if is_dependent(constraint): | ||
continue # skip constraints that cannot be checked | ||
is_valid = jnp.all(constraint(getattr(self, param))) | ||
is_valid = constraint(getattr(self, param)) | ||
if not_jax_tracer(is_valid): | ||
if not is_valid: | ||
if not np.all(is_valid): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For my understanding, if we were to always cast a python bool to a numpy or jax bool type (i.e. the default mask would be a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We we need this change because under from jax import device_put, jit
import jax.numpy as jnp
from numpyro.util import not_jax_tracer
def f():
assert not not_jax_tracer(jnp.all(True))
assert not not_jax_tracer(device_put(True))
jit(f)() # pass You are right that if we cast There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for explaining about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, that's true. |
||
raise ValueError("{} distribution got invalid {} parameter.".format( | ||
self.__class__.__name__, param)) | ||
super(Distribution, self).__init__() | ||
|
@@ -255,7 +257,7 @@ def variance(self): | |
def _validate_sample(self, value): | ||
mask = self.support(value) | ||
if not_jax_tracer(mask): | ||
if not jnp.all(mask): | ||
if not np.all(mask): | ||
warnings.warn('Out-of-support values provided to log prob method. ' | ||
'The value argument should be within the support.') | ||
return mask | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as my comment below - is this
np
needed becauseself.total_count
could be a python scalar. In that case does it make sense to usejnp.array
ordevice_put
everywhere? That way, we can freely usejnp
functions instead of a mix ofnp
andjnp
which becomes hard to reason about.