Skip to content

Commit

Permalink
aadd function assigment test
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Nov 10, 2024
1 parent aac66b7 commit 7489d8a
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions test/infer/test_svi.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
TraceGraph_ELBO,
TraceMeanField_ELBO,
)
from numpyro.infer.elbo import _apply_vmap
from numpyro.primitives import mutable as numpyro_mutable
from numpyro.util import fori_loop

Expand Down Expand Up @@ -163,6 +164,14 @@ def get_renyi(n=N, k=K, fix_indices=True):
assert_allclose(atol, 0.0, atol=1e-5)


def test_assign_vectorize_particles_fn():
elbo = Trace_ELBO()
assert elbo._assign_vectorize_particles_fn(True) == _apply_vmap
assert elbo._assign_vectorize_particles_fn(False) == jax.lax.map
assert elbo._assign_vectorize_particles_fn(jax.pmap) == jax.pmap
assert callable(elbo._assign_vectorize_particles_fn(lambda x: x))


@pytest.mark.parametrize(
argnames="vectorize_particles",
argvalues=[True, False, jax.pmap, lambda x: x],
Expand Down

0 comments on commit 7489d8a

Please sign in to comment.