-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make omnistaging work with validate_args=True #775
Conversation
@@ -748,6 +748,22 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape): | |||
d.log_prob(oob_samples) | |||
|
|||
|
|||
def test_omnistaging_invalid_param(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have tests that check whether arg constraints are satisfied, but I suppose we do not have tests that check validation for out of support values in distributions? If so, it will be good to add tests for all relevant distributions and do a patch release. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test is fine, but my concern is if we have missed something like this for other distributions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think we already have some tests that check for both constraints and support. Let me also add jit
check for them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, you raised a very good point. It is just lucky that my test passed for positive
constraint. For many other constraints, no warning is raised! :( The issue is if the constraint uses jax.numpy
to check, the output will be a traced array (under jit), regardless the input is a DeviceArray or not.
@neerajprad Omnistaging is a bit annoying for validation code. Currently, I found no simply way to have some "check" under |
@fehiepsi can you explain more about the difficulties and need for dispatch mechanism? I'm just curious because I have been thinking about what additional metadata might be useful to add to EDIT do you mean lines like the following? jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy |
Yes, I meant so. With the new behavior of JAX, outside
Assume In this PR, I used If the code is run outside of Re what additional metadata might be useful: Currently, we have
In addition, I think having a method |
if not_jax_tracer(is_valid): | ||
if not is_valid: | ||
if not np.all(is_valid): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my understanding, if we were to always cast a python bool to a numpy or jax bool type (i.e. the default mask would be a jnp.array(True)
), we would not have needed any change here. Is that correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We we need this change because under jit
. jnp.all(True)
, device_put(True)
will create a tracer. So validation code won't activate.
from jax import device_put, jit
import jax.numpy as jnp
from numpyro.util import not_jax_tracer
def f():
assert not not_jax_tracer(jnp.all(True))
assert not not_jax_tracer(device_put(True))
jit(f)() # pass
You are right that if we cast mask
in _validate_sample
method to a device array, then the error will be fixed. In that case, no validation for parameters/sample are activated under jit
. It is fine for single chain mcmc because we have some code to trace the model outside of jit
context, so if there are some wrong specification, the users will get the warning/error. But for multi chain mcmc, no validation will be activated because we run everything under pmap
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for explaining about pmap
, I missed that. In that case, is it correct to say that for distributions that have all parameters as jax devices (no python scalars), there won't be any validation performed under pmap
because not_jax_tracer(param)
will be true?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's true.
total_count = jnp.amax(self.total_count) | ||
if not_jax_tracer(total_count): | ||
if not_jax_tracer(self.total_count): | ||
total_count = np.amax(self.total_count) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as my comment below - is this np
needed because self.total_count
could be a python scalar. In that case does it make sense to use jnp.array
or device_put
everywhere? That way, we can freely use jnp
functions instead of a mix of np
and jnp
which becomes hard to reason about.
If we can avoid this by using |
Thanks for explaining your changes, @fehiepsi. |
Fixes #771. The issue is with omnistaging, under jit, operators such as
jnp.all(concrete_array)
will return a TracedArray, which cannot be converted to a Python bool object. I think "omnistaging" is a bit annoying but given that it is the default behavior in JAX 0.2, we have to live with it. :(