Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add format_shapes utility #1116

Merged
merged 8 commits into from
Aug 16, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add tests
  • Loading branch information
tcbegley committed Aug 8, 2021
commit de9c3ec8f2295933c7c3ab4af80eeb51fc2f01a8
36 changes: 35 additions & 1 deletion test/test_util.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,9 @@
from jax.test_util import check_eq
from jax.tree_util import tree_flatten, tree_multimap

from numpyro.util import fori_collect, soft_vmap
import numpyro
import numpyro.distributions as dist
from numpyro.util import fori_collect, format_shapes, soft_vmap


def test_fori_collect_thinning():
@@ -101,3 +103,35 @@ def f(x):
assert set(ys.keys()) == {"a", "b"}
assert_allclose(ys["a"], xs["a"][..., None] * jnp.ones(4))
assert_allclose(ys["b"], ~xs["b"])


def test_format_shapes():
data = jnp.arange(100)

def model_test():
mean = numpyro.param("mean", jnp.zeros(len(data)))
with numpyro.plate("data", len(data), subsample_size=10) as ind:
batch = data[ind]
mean_batch = mean[ind]
numpyro.sample("x", dist.Normal(mean_batch, 1), obs=batch)

with numpyro.handlers.seed(rng_seed=0), numpyro.handlers.trace() as t:
model_test()

assert (
format_shapes(t)
== "Trace Shapes: \n Param Sites: \n mean 100\n"
"Sample Sites: \n data plate 10 |\n x dist 10 |\n "
tcbegley marked this conversation as resolved.
Show resolved Hide resolved
"value 10 |"
)
assert (
format_shapes(t, log_prob=True)
== "Trace Shapes: \n Param Sites: \n mean 100\n"
"Sample Sites: \n data plate 10 |\n x dist 10 |\n "
"value 10 |\n log_prob 10 |"
)
assert (
format_shapes(t, last_site="data")
== "Trace Shapes: \n Param Sites: \n mean 100\n"
"Sample Sites: \n data plate 10 |"
)