Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix provenance logic for jax 0.2.28 #1320

Merged
merged 1 commit into from
Feb 4, 2022
Merged

Conversation

fehiepsi
Copy link
Member

@fehiepsi fehiepsi commented Feb 4, 2022

Fixes #1318

tldr; I have tested the change and it works for the old jax versions.

The recent jax release made provenance logic failing because the cached abstract value might be a different object w.r.t. the one stored in the output of process_primitive, while we only updated provenance for the output previously.

I think we'll need to make a patch release after this fix.

Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, that's subtle!

Do you think it's worth adding a regression test, e.g. the model in #1318?

@fehiepsi
Copy link
Member Author

fehiepsi commented Feb 4, 2022

Hi Fritz, all the tests involving provenance failed so we have to pin the jax version to make CI pass. In this pr, I just removed the restriction because the tests pass again.

@fehiepsi
Copy link
Member Author

fehiepsi commented Feb 4, 2022

Thanks for reviewing! :)

@fehiepsi fehiepsi merged commit 9719466 into pyro-ppl:master Feb 4, 2022
@fehiepsi fehiepsi mentioned this pull request Feb 27, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

numpyro.render_model behaviour changed in 0.9.0
2 participants