Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make omnistaging work with validate_args=True #775

Merged
merged 6 commits into from
Oct 14, 2020

Conversation

fehiepsi
Copy link
Member

@fehiepsi fehiepsi commented Oct 6, 2020

Fixes #771. The issue is with omnistaging, under jit, operators such as jnp.all(concrete_array) will return a TracedArray, which cannot be converted to a Python bool object. I think "omnistaging" is a bit annoying but given that it is the default behavior in JAX 0.2, we have to live with it. :(

@@ -748,6 +748,22 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape):
d.log_prob(oob_samples)


def test_omnistaging_invalid_param():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have tests that check whether arg constraints are satisfied, but I suppose we do not have tests that check validation for out of support values in distributions? If so, it will be good to add tests for all relevant distributions and do a patch release. WDYT?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is fine, but my concern is if we have missed something like this for other distributions.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think we already have some tests that check for both constraints and support. Let me also add jit check for them.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, you raised a very good point. It is just lucky that my test passed for positive constraint. For many other constraints, no warning is raised! :( The issue is if the constraint uses jax.numpy to check, the output will be a traced array (under jit), regardless the input is a DeviceArray or not.

@fehiepsi
Copy link
Member Author

@neerajprad Omnistaging is a bit annoying for validation code. Currently, I found no simply way to have some "check" under jit without adding some dispatching mechanism to some of our constraint implementations. For arg validation, there are a few places that we use jnp.broadcast_to, which will create Tracer output. To fix it, I tried to use promote_shapes for all of them. :(

@fritzo
Copy link
Member

fritzo commented Oct 12, 2020

@fehiepsi can you explain more about the difficulties and need for dispatch mechanism? I'm just curious because I have been thinking about what additional metadata might be useful to add to Constraint objects, e.g. maybe an .event_dim attribute or richer shape attributes like (.batch_shape, .event_shape).

EDIT do you mean lines like the following?

jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy

@fehiepsi
Copy link
Member Author

fehiepsi commented Oct 12, 2020

Yes, I meant so. With the new behavior of JAX, outside jit, we have numpy.ndarray and DeviceArray while under jit, they become numpy.ndarray and ShapedArray (i.e. tracer). For validation, currently, we use the code

mask = support.check(value)
if not_tracer(mask):
    if not jax.numpy.all(mask):
        raise ...

Assume mask is a numpy ndarray (or python True/False), not_tracer(mask) will be True. Then we need to check if all values in mask is True. If we use jax.numpy.all, jax.numpy.all(mask) will be a tracer under jit. So if not jax.numpy.all(mask) will raise a JAX error because we cannot get the actual value of jax.numpy.all(mask) here.

In this PR, I used numpy.all(mask) instead. In addition, to be able to raise the validation warning/error, it is required that mask is not a tracer (i.e. an abstract shaped array). If we use jax.numpy operators in support.check implementation, the output will always be a tracer so I added the above dispatching mechanism. With that, if value is a numpy ndarray, we will get a numpy ndarray mask. Of course, the validation code still does not work if value is a device array. However, I think it is still useful to detect issues such as: using 0 value for scale parameter or data contains nan,...

If the code is run outside of jit context, then the validation will still work because numpy.all(...) can apply for a DeviceArray.

Re what additional metadata might be useful: Currently, we have event_dim but I think having something like TFP forward_event_shape method would be useful to:

  • determine changes in event_dim of domain and codomain (e.g. LKJ)
  • determine changes in shapes of domain and codomain (e.g. Simplex)
    Currently, we can't give a correct event_shape of a TransformedDistribution due to the missing of such method.

In addition, I think having a method forward_domain will also be useful, in case we want to use AffineTransform with positive domain. Currently, users need to define a correct support in e.g. Weibull because we don't have such logic for the composed of PowerTransform and AffineTransform.

if not_jax_tracer(is_valid):
if not is_valid:
if not np.all(is_valid):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my understanding, if we were to always cast a python bool to a numpy or jax bool type (i.e. the default mask would be a jnp.array(True)), we would not have needed any change here. Is that correct?

Copy link
Member Author

@fehiepsi fehiepsi Oct 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We we need this change because under jit. jnp.all(True), device_put(True) will create a tracer. So validation code won't activate.

from jax import device_put, jit
import jax.numpy as jnp
from numpyro.util import not_jax_tracer

def f():
    assert not not_jax_tracer(jnp.all(True))
    assert not not_jax_tracer(device_put(True))

jit(f)()  # pass

You are right that if we cast mask in _validate_sample method to a device array, then the error will be fixed. In that case, no validation for parameters/sample are activated under jit. It is fine for single chain mcmc because we have some code to trace the model outside of jit context, so if there are some wrong specification, the users will get the warning/error. But for multi chain mcmc, no validation will be activated because we run everything under pmap.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining about pmap, I missed that. In that case, is it correct to say that for distributions that have all parameters as jax devices (no python scalars), there won't be any validation performed under pmap because not_jax_tracer(param) will be true?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's true.

total_count = jnp.amax(self.total_count)
if not_jax_tracer(total_count):
if not_jax_tracer(self.total_count):
total_count = np.amax(self.total_count)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as my comment below - is this np needed because self.total_count could be a python scalar. In that case does it make sense to use jnp.array or device_put everywhere? That way, we can freely use jnp functions instead of a mix of np and jnp which becomes hard to reason about.

@neerajprad
Copy link
Member

For arg validation, there are a few places that we use jnp.broadcast_to, which will create Tracer output. To fix it, I tried to use promote_shapes for all of them. :(

If we can avoid this by using device_put (basically not having any python scalar types in distributions and masking etc.), then that would seem preferable and will help avoid future regressions. That said, my understanding of the actual issues could be quite off, so correct me if I am mistaken.

@neerajprad
Copy link
Member

Thanks for explaining your changes, @fehiepsi.

@neerajprad neerajprad merged commit 494bcbb into pyro-ppl:master Oct 14, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Omnistaging does not work when enabling distributions' validation
3 participants