Skip to content

Commit

Permalink
Allow compute_log_prob to be callable, add to docs
Browse files Browse the repository at this point in the history
  • Loading branch information
tcbegley committed Aug 16, 2021
1 parent 4ae0efc commit 9f59232
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 9 deletions.
4 changes: 4 additions & 0 deletions docs/source/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,7 @@ Visualization Utilities
render_model
------------
.. autofunction:: numpyro.contrib.render.render_model

Format trace shapes
-------------------
.. autofunction:: numpyro.util.format_shapes
38 changes: 30 additions & 8 deletions numpyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,17 +410,37 @@ def soft_vmap(fn, xs, batch_ndims=1, chunk_size=None):
return tree_map(lambda y: jnp.reshape(y, batch_shape + jnp.shape(y)[1:]), ys)


def format_shapes(trace, *, title="Trace Shapes:", last_site=None, log_prob=False):
def format_shapes(
trace,
*,
compute_log_prob=False,
title="Trace Shapes:",
last_site=None,
):
"""
Returns a string showing a table of the shapes of all sites in the
trace.
:param dict trace: The model trace to format. Use ``numpyro.handlers.trace`` to
produce the trace.
Given the trace of a function, returns a string showing a table of the shapes of
all sites in the trace.
Use :class:`~numpyro.handlers.trace` handler (or funsor
:class:`~numpyro.contrib.funsor.enum_messenger.trace` handler for enumeration) to
produce the trace.
:param dict trace: The model trace to format.
:param compute_log_prob: Compute log probabilities and display the shapes in the
table. Accepts True / False or a function which when given a dictionary
containing site-level metadata returns whether the log probability should be
calculated and included in the table.
:param str title: Title for the table of shapes.
:param str last_site: Name of a site in the model. If supplied, subsequent sites
are not displayed in the table.
:param bool log_prob: Display shapes of log probabilities for each sample site.
Usage::
def model(*args, **kwargs):
...
trace = numpyro.handlers.trace(model).get_trace(*args, **kwargs)
numpyro.util.format_shapes(trace)
"""
if not trace.keys():
return title
Expand Down Expand Up @@ -462,7 +482,9 @@ def format_shapes(trace, *, title="Trace Shapes:", last_site=None, log_prob=Fals
)

# log_prob shape
if log_prob:
if (not callable(compute_log_prob) and compute_log_prob) or (
callable(compute_log_prob) and compute_log_prob(site)
):
batch_shape = getattr(site["fn"].log_prob(site["value"]), "shape", ())
rows.append(
["log_prob", None]
Expand Down
15 changes: 14 additions & 1 deletion test/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def model_test():
" value 10 | "
)
assert (
format_shapes(t, log_prob=True) == "Trace Shapes: \n"
format_shapes(t, compute_log_prob=True) == "Trace Shapes: \n"
" Param Sites: \n"
" mean 100 \n"
"Sample Sites: \n"
Expand All @@ -144,6 +144,19 @@ def model_test():
" value 10 | \n"
" log_prob 10 | "
)
assert (
format_shapes(t, compute_log_prob=lambda site: site["name"] == "scale")
== "Trace Shapes: \n"
" Param Sites: \n"
" mean 100 \n"
"Sample Sites: \n"
" scale dist | 3\n"
" value | 3\n"
" log_prob | \n"
" data plate 10 | \n"
" x dist 10 | \n"
" value 10 | "
)
assert (
format_shapes(t, last_site="data") == "Trace Shapes: \n"
" Param Sites: \n"
Expand Down

0 comments on commit 9f59232

Please sign in to comment.