Skip to content

Commit

Permalink
Fix failing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi committed Jan 29, 2022
1 parent dd31124 commit 558bd02
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 4 deletions.
2 changes: 1 addition & 1 deletion examples/prodlda.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def load_data():
docs = jnp.array(vectorizer.fit_transform(news["data"]).toarray())

vocab = pd.DataFrame(columns=["word", "index"])
vocab["word"] = vectorizer.get_feature_names()
vocab["word"] = vectorizer.get_feature_names_out()
vocab["index"] = vocab.index

return docs, vocab
Expand Down
6 changes: 4 additions & 2 deletions numpyro/contrib/tfp/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,13 @@ def __getattr__(self, name):

@property
def batch_shape(self):
return self.tfp_dist.batch_shape
# TFP shapes are special tuples that can not be used directly
# with lax.broadcast_shapes. So we convert them to tuple.
return tuple(self.tfp_dist.batch_shape)

@property
def event_shape(self):
return self.tfp_dist.event_shape
return tuple(self.tfp_dist.event_shape)

@property
def has_rsample(self):
Expand Down
2 changes: 1 addition & 1 deletion test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,7 +968,7 @@ def log_likelihood(*params):

expected = log_likelihood(*params)
actual = jax.jit(log_likelihood)(*params)
assert_allclose(actual, expected, atol=2e-5)
assert_allclose(actual, expected, atol=2e-5, rtol=2e-5)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 558bd02

Please sign in to comment.