Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add format_shapes utility #1116

Merged
merged 8 commits into from
Aug 16, 2021
Merged

Add format_shapes utility #1116

merged 8 commits into from
Aug 16, 2021

Conversation

tcbegley
Copy link
Contributor

@tcbegley tcbegley commented Aug 3, 2021

This is a proposed solution to #1104. It's a pretty direct port of .format_shapes from Pyro.

I've created a gist with examples from the Pyro docs.

Would love some feedback on a few things in particular:

  • Function name + location. I just put it in numpyro.util for now.
  • The main difference between this and Pyro that I'm aware of is that in Pyro plates are traced as "sample" sites, which means they show up in the result of format_shapes. I can add some logic to add "plate" sites from NumPyro traces to the result. Should they be mixed in with the sample sites like in Pyro or should they get their own section?
  • Added a log_prob argument to mimic the effect of running trace.compute_log_prob() before trace.format_shapes() in Pyro. Is this a reasonable solution?
  • Couldn't see any tests in the Pyro repo for this to replicate. Is it worth writing tests for?

TODO:

  • Update docstring

numpyro/util.py Outdated Show resolved Hide resolved
@fehiepsi
Copy link
Member

fehiepsi commented Aug 6, 2021

Is it worth writing tests for?

I think you can compare the output with expected strings

Should they be mixed in with the sample sites like in Pyro

I think this is the way to go and I like all other resolutions. :)

@tcbegley tcbegley marked this pull request as ready for review August 8, 2021 08:39
fehiepsi
fehiepsi previously approved these changes Aug 12, 2021
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great to me! Thank you, this utility is really helpful!

test/test_util.py Outdated Show resolved Hide resolved
@tcbegley tcbegley changed the title Add format_trace_shapes utility Add format_shapes utility Aug 12, 2021
test/test_util.py Outdated Show resolved Hide resolved
numpyro/util.py Outdated Show resolved Hide resolved
numpyro/util.py Outdated Show resolved Hide resolved
numpyro/util.py Outdated Show resolved Hide resolved
@fehiepsi fehiepsi merged commit 935d4f5 into pyro-ppl:master Aug 16, 2021
@tcbegley tcbegley deleted the format-trace branch August 16, 2021 12:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants