Skip to content

Commit

Permalink
Add sampled entropy test for distribution without scipy equivalent.
Browse files Browse the repository at this point in the history
  • Loading branch information
tillahoffmann committed May 2, 2024
1 parent bd069fc commit 3f9c197
Showing 1 changed file with 26 additions and 4 deletions.
30 changes: 26 additions & 4 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1488,21 +1488,43 @@ def test_log_prob(jax_dist, sp_dist, params, prepend_shape, jit):
@pytest.mark.parametrize(
"jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL
)
def test_entropy(jax_dist, sp_dist, params):
def test_entropy_scipy(jax_dist, sp_dist, params):
jax_dist = jax_dist(*params)

try:
actual = jax_dist.entropy()
except NotImplementedError:
pytest.skip(reason="distribution does not implement `entropy`")
if _is_batched_multivariate(jax_dist):
pytest.skip("batching not allowed in multivariate distns.")
if sp_dist is None:
pytest.skip(reason="no corresponding scipy distribution")

sp_dist = sp_dist(*params)
expected = sp_dist.entropy()
assert_allclose(actual, expected, atol=1e-5)


@pytest.mark.parametrize(
"jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL
)
def test_entropy_samples(jax_dist, sp_dist, params):
jax_dist = jax_dist(*params)

try:
actual = jax_dist.entropy()
except NotImplementedError:
pytest.skip(reason="distribution does not implement `entropy`")

sp_dist = sp_dist(*params)
expected = sp_dist.entropy()
assert_allclose(actual, expected, atol=1e-5)
samples = jax_dist.sample(jax.random.key(8), (1000,))
neg_log_probs = -jax_dist.log_prob(samples)
mean = neg_log_probs.mean(axis=0)
stderr = neg_log_probs.std(axis=0) / jnp.sqrt(neg_log_probs.shape[-1] - 1)
z = (actual - mean) / stderr

# Check the z-score is small or that all values are close. This happens, for
# example, for uniform distributions with constant log prob and hence zero stderr.
assert (jnp.abs(z) < 5).all() or jnp.allclose(actual, neg_log_probs, atol=1e-5)


def test_entropy_categorical():
Expand Down

0 comments on commit 3f9c197

Please sign in to comment.