Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify the provenance logic to prepare for the removal of jax named shape #1837

Merged
merged 1 commit into from
Jul 29, 2024

Conversation

fehiepsi
Copy link
Member

@fehiepsi fehiepsi commented Jul 25, 2024

See jax-ml/jax#21069 as a context. The current tests fail with jax's dev branch.

Current tests at test/test_model_rendering and test/ops/test_provenance pass locally.

Copy link
Member

@ordabayevy ordabayevy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Just to clarify for myself: These code changes don't use any new functionality, right? I just wonder why did we use named_shape in the first place.

@fehiepsi
Copy link
Member Author

Yes, this just simplifies the logic. Previously we used named shape because we need some functionality to flatten unflatten the provenance. Turns out that those jax tree utilities also work for frozenset leaves.

@fehiepsi fehiepsi merged commit 3e41320 into pyro-ppl:master Jul 29, 2024
4 checks passed
@fehiepsi fehiepsi mentioned this pull request Jul 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants