diff --git a/numpyro/primitives.py b/numpyro/primitives.py index 99cf35902..ac02a8856 100644 --- a/numpyro/primitives.py +++ b/numpyro/primitives.py @@ -622,6 +622,10 @@ def prng_key(): :return: a PRNG key of shape (2,) and dtype unit32. """ if not _PYRO_STACK: + warnings.warn( + "Cannot generate JAX PRNG key outside of `seed` handler.", + stacklevel=find_stack_level(), + ) return initial_msg = { diff --git a/test/test_handlers.py b/test/test_handlers.py index e24e22890..518f856dc 100644 --- a/test/test_handlers.py +++ b/test/test_handlers.py @@ -778,7 +778,8 @@ def guide(): def test_prng_key(): - assert numpyro.prng_key() is None + with pytest.warns(Warning, match="outside of `seed`"): + assert numpyro.prng_key() is None with handlers.seed(rng_seed=0): rng_key = numpyro.prng_key()