From 558bd02d687d2b3fe4163827dbc00926d0ae0a04 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Sat, 29 Jan 2022 02:04:18 -0500 Subject: [PATCH] Fix failing tests --- examples/prodlda.py | 2 +- numpyro/contrib/tfp/distributions.py | 6 ++++-- test/test_distributions.py | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/prodlda.py b/examples/prodlda.py index bad0c9ea1..f232e74a2 100644 --- a/examples/prodlda.py +++ b/examples/prodlda.py @@ -235,7 +235,7 @@ def load_data(): docs = jnp.array(vectorizer.fit_transform(news["data"]).toarray()) vocab = pd.DataFrame(columns=["word", "index"]) - vocab["word"] = vectorizer.get_feature_names() + vocab["word"] = vectorizer.get_feature_names_out() vocab["index"] = vocab.index return docs, vocab diff --git a/numpyro/contrib/tfp/distributions.py b/numpyro/contrib/tfp/distributions.py index 06660dce4..c1245d72f 100644 --- a/numpyro/contrib/tfp/distributions.py +++ b/numpyro/contrib/tfp/distributions.py @@ -221,11 +221,13 @@ def __getattr__(self, name): @property def batch_shape(self): - return self.tfp_dist.batch_shape + # TFP shapes are special tuples that can not be used directly + # with lax.broadcast_shapes. So we convert them to tuple. + return tuple(self.tfp_dist.batch_shape) @property def event_shape(self): - return self.tfp_dist.event_shape + return tuple(self.tfp_dist.event_shape) @property def has_rsample(self): diff --git a/test/test_distributions.py b/test/test_distributions.py index 5d39e4dfd..05b289f7a 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -968,7 +968,7 @@ def log_likelihood(*params): expected = log_likelihood(*params) actual = jax.jit(log_likelihood)(*params) - assert_allclose(actual, expected, atol=2e-5) + assert_allclose(actual, expected, atol=2e-5, rtol=2e-5) @pytest.mark.parametrize(