Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jittable transforms #1575

Merged
merged 15 commits into from
May 31, 2023
Merged

Conversation

pierreglaser
Copy link
Contributor

@pierreglaser pierreglaser commented Apr 11, 2023

Standalone PR that will vastly ease the flattening/unflattening logic of TransformedDistributions that will eventually get implemented in #1529 or one of its follow-ups.

cc @fehiepsi.

@pierreglaser pierreglaser changed the title [WIP] jittable transforms Jittable transforms Apr 16, 2023
@fehiepsi fehiepsi self-requested a review April 20, 2023 15:56
fehiepsi
fehiepsi previously approved these changes May 18, 2023
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much @pierreglaser! I just have a small nit. Super nice PR!!!

numpyro/distributions/constraints.py Outdated Show resolved Hide resolved
numpyro/distributions/transforms.py Show resolved Hide resolved
@pierreglaser
Copy link
Contributor Author

Thanks for the review @fehiepsi! I'm currently in the process of extending #1529 to all distributions, I'm around 20% there :-)

@fehiepsi
Copy link
Member

Supporting all distributions is complicated I would say. But with jittable support/transform, things might be easier, hopefully. Thanks for working on it!!

@fehiepsi
Copy link
Member

Hi @pierreglaser, I would like to make a release this week. Do you want to have this PR in?

@pierreglaser
Copy link
Contributor Author

Yes, let me finish the PR today and we can merge this.

@pierreglaser
Copy link
Contributor Author

pierreglaser commented May 25, 2023

Ok @fehiepsi, I addressed #1575 (comment) and added a few more tests to make sure that one could vmap over constraints and transforms. As a side effect, I modified a bit the __eq__ methods of these objects to make them robust to inputs containing tracers. I think this PR is good to go!

@pierreglaser
Copy link
Contributor Author

pierreglaser commented May 25, 2023

I think the linting failure is not coming from this PR....

EDIT: it did.


class _Positive(_GreaterThan, _SingletonConstraint):
def __eq__(self, other):
return isinstance(other, _GreaterThan) & jnp.array_equal(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmh, I think the use of the "bitwise and" operator as a jax-compatible and (that does not try to coerce an abstract boolean to a concrete value) is causing __eq__ to attempt accessing attribute like lower_bound even when other is not of type _GreaterThan. For some reason this happens during the tests, but it could have happened otherwise, so good the tests caught it!

@pierreglaser
Copy link
Contributor Author

OK @fehiepsi, I pushed a fix to the doc errors, that actually revealed some lack of robustness in the updated equality checks for constraints and transform in this PR. I added tests so hopefully that won't be a problem in the future :-)

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woohoo, thanks for the fix!!

@fehiepsi fehiepsi merged commit eab63ed into pyro-ppl:master May 31, 2023
OlaRonning pushed a commit to aleatory-science/numpyro that referenced this pull request Jun 2, 2023
* [WIP] jittable transforms

* add licence to new test file

* turn BijectorConstraint into pytree

* test flattening/unflattening of parametrized constraints

* cosmetic edits

* fix typo

* implement tree_flatten/unflatten for transforms

* attempt to avoid confusing black

* add (un)flattening meths for BijectorTransform

* fixup! implement tree_flatten/unflatten for transforms

* test vmapping over transforms/constraints

* Make constraints `__eq__` checks robust to arbitrary inputs

* make transforms equality check robust to arbitrary inputs

* test constraints and transforms equality checks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants