diff --git a/docs/source/utilities.rst b/docs/source/utilities.rst index b40a04fc2..9a531c8b3 100644 --- a/docs/source/utilities.rst +++ b/docs/source/utilities.rst @@ -104,3 +104,7 @@ Visualization Utilities render_model ------------ .. autofunction:: numpyro.contrib.render.render_model + +Trace inspection +---------------- +.. autofunction:: numpyro.util.format_shapes diff --git a/numpyro/util.py b/numpyro/util.py index f7ad4059d..f26933b93 100644 --- a/numpyro/util.py +++ b/numpyro/util.py @@ -408,3 +408,143 @@ def soft_vmap(fn, xs, batch_ndims=1, chunk_size=None): lambda y: jnp.reshape(y, (-1,) + jnp.shape(y)[map_ndims:])[:batch_size], ys ) return tree_map(lambda y: jnp.reshape(y, batch_shape + jnp.shape(y)[1:]), ys) + + +def format_shapes( + trace, + *, + compute_log_prob=False, + title="Trace Shapes:", + last_site=None, +): + """ + Given the trace of a function, returns a string showing a table of the shapes of + all sites in the trace. + + Use :class:`~numpyro.handlers.trace` handler (or funsor + :class:`~numpyro.contrib.funsor.enum_messenger.trace` handler for enumeration) to + produce the trace. + + :param dict trace: The model trace to format. + :param compute_log_prob: Compute log probabilities and display the shapes in the + table. Accepts True / False or a function which when given a dictionary + containing site-level metadata returns whether the log probability should be + calculated and included in the table. + :param str title: Title for the table of shapes. + :param str last_site: Name of a site in the model. If supplied, subsequent sites + are not displayed in the table. + + Usage:: + + def model(*args, **kwargs): + ... + + trace = numpyro.handlers.trace(model).get_trace(*args, **kwargs) + numpyro.util.format_shapes(trace) + """ + if not trace.keys(): + return title + rows = [[title]] + + rows.append(["Param Sites:"]) + for name, site in trace.items(): + if site["type"] == "param": + rows.append( + [name, None] + + [str(size) for size in getattr(site["value"], "shape", ())] + ) + if name == last_site: + break + + rows.append(["Sample Sites:"]) + for name, site in trace.items(): + if site["type"] == "sample": + # param shape + batch_shape = getattr(site["fn"], "batch_shape", ()) + event_shape = getattr(site["fn"], "event_shape", ()) + rows.append( + [f"{name} dist", None] + + [str(size) for size in batch_shape] + + ["|", None] + + [str(size) for size in event_shape] + ) + + # value shape + event_dim = len(event_shape) + shape = getattr(site["value"], "shape", ()) + batch_shape = shape[: len(shape) - event_dim] + event_shape = shape[len(shape) - event_dim :] + rows.append( + ["value", None] + + [str(size) for size in batch_shape] + + ["|", None] + + [str(size) for size in event_shape] + ) + + # log_prob shape + if (not callable(compute_log_prob) and compute_log_prob) or ( + callable(compute_log_prob) and compute_log_prob(site) + ): + batch_shape = getattr(site["fn"].log_prob(site["value"]), "shape", ()) + rows.append( + ["log_prob", None] + + [str(size) for size in batch_shape] + + ["|", None] + ) + elif site["type"] == "plate": + shape = getattr(site["value"], "shape", ()) + rows.append( + [f"{name} plate", None] + [str(size) for size in shape] + ["|", None] + ) + + if name == last_site: + break + + return _format_table(rows) + + +def _format_table(rows): + """ + Formats a right justified table using None as column separator. + """ + # compute column widths + column_widths = [0, 0, 0] + for row in rows: + widths = [0, 0, 0] + j = 0 + for cell in row: + if cell is None: + j += 1 + else: + widths[j] += 1 + for j in range(3): + column_widths[j] = max(column_widths[j], widths[j]) + + # justify columns + for i, row in enumerate(rows): + cols = [[], [], []] + j = 0 + for cell in row: + if cell is None: + j += 1 + else: + cols[j].append(cell) + cols = [ + [""] * (width - len(col)) + col + if direction == "r" + else col + [""] * (width - len(col)) + for width, col, direction in zip(column_widths, cols, "rrl") + ] + rows[i] = sum(cols, []) + + # compute cell widths + cell_widths = [0] * len(rows[0]) + for row in rows: + for j, cell in enumerate(row): + cell_widths[j] = max(cell_widths[j], len(cell)) + + # justify cells + return "\n".join( + " ".join(cell.rjust(width) for cell, width in zip(row, cell_widths)) + for row in rows + ) diff --git a/test/test_util.py b/test/test_util.py index 7ee8ae9aa..9152aa145 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -9,7 +9,9 @@ from jax.test_util import check_eq from jax.tree_util import tree_flatten, tree_multimap -from numpyro.util import fori_collect, soft_vmap +import numpyro +import numpyro.distributions as dist +from numpyro.util import fori_collect, format_shapes, soft_vmap def test_fori_collect_thinning(): @@ -101,3 +103,66 @@ def f(x): assert set(ys.keys()) == {"a", "b"} assert_allclose(ys["a"], xs["a"][..., None] * jnp.ones(4)) assert_allclose(ys["b"], ~xs["b"]) + + +def test_format_shapes(): + data = jnp.arange(100) + + def model_test(): + mean = numpyro.param("mean", jnp.zeros(len(data))) + scale = numpyro.sample("scale", dist.Normal(0, 1).expand([3]).to_event(1)) + scale = scale.sum() + with numpyro.plate("data", len(data), subsample_size=10) as ind: + batch = data[ind] + mean_batch = mean[ind] + numpyro.sample("x", dist.Normal(mean_batch, scale), obs=batch) + + with numpyro.handlers.seed(rng_seed=0), numpyro.handlers.trace() as t: + model_test() + + assert ( + format_shapes(t) == "Trace Shapes: \n" + " Param Sites: \n" + " mean 100 \n" + "Sample Sites: \n" + " scale dist | 3\n" + " value | 3\n" + " data plate 10 | \n" + " x dist 10 | \n" + " value 10 | " + ) + assert ( + format_shapes(t, compute_log_prob=True) == "Trace Shapes: \n" + " Param Sites: \n" + " mean 100 \n" + "Sample Sites: \n" + " scale dist | 3\n" + " value | 3\n" + " log_prob | \n" + " data plate 10 | \n" + " x dist 10 | \n" + " value 10 | \n" + " log_prob 10 | " + ) + assert ( + format_shapes(t, compute_log_prob=lambda site: site["name"] == "scale") + == "Trace Shapes: \n" + " Param Sites: \n" + " mean 100 \n" + "Sample Sites: \n" + " scale dist | 3\n" + " value | 3\n" + " log_prob | \n" + " data plate 10 | \n" + " x dist 10 | \n" + " value 10 | " + ) + assert ( + format_shapes(t, last_site="data") == "Trace Shapes: \n" + " Param Sites: \n" + " mean 100 \n" + "Sample Sites: \n" + " scale dist | 3\n" + " value | 3\n" + " data plate 10 | " + )