diff --git a/pyroapi/tests/__init__.py b/pyroapi/tests/__init__.py index 8cf2462..ba83b00 100644 --- a/pyroapi/tests/__init__.py +++ b/pyroapi/tests/__init__.py @@ -1,4 +1,5 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +from .test_mcmc import * # noqa F401 from .test_svi import * # noqa F401 diff --git a/pyroapi/tests/test_mcmc.py b/pyroapi/tests/test_mcmc.py new file mode 100644 index 0000000..dea4c62 --- /dev/null +++ b/pyroapi/tests/test_mcmc.py @@ -0,0 +1,29 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from pyroapi.dispatch import distributions as dist +from pyroapi.dispatch import infer, pyro + +# Note that the backend arg to these tests must be provided as a +# user-defined fixture that sets the pyro_backend. For demonstration, +# see test/conftest.py. + + +def assert_ok(model, *args, **kwargs): + """ + Assert that inference works without warnings or errors. + """ + pyro.get_param_store().clear() + kernel = infer.NUTS(model) + mcmc = infer.MCMC(kernel, num_samples=2, warmup_steps=2) + mcmc.run(*args, **kwargs) + + +def test_mcmc_run_ok(backend): + if backend not in ["pyro", "numpy"]: + return + + def model(): + pyro.sample("x", dist.Normal(0, 1)) + + assert_ok(model) diff --git a/test/test_tests.py b/test/test_tests.py index 779a5ae..d690b68 100644 --- a/test/test_tests.py +++ b/test/test_tests.py @@ -28,3 +28,17 @@ def backend(request): pytest.importorskip(PACKAGE_NAME[request.param]) with pyro_backend(request.param): yield + + +# TODO(fehiepsi): Remove the following when the test passes in numpyro. +_test_mcmc_run_ok = test_mcmc_run_ok # noqa F405 + + +@pytest.mark.parametrize("backend", [ + "pyro", + pytest.param("numpy", marks=[ + pytest.mark.xfail(reason="numpyro signature for MCMC is not consistent.")])]) +def test_mcmc_run_ok(backend): + pytest.importorskip(PACKAGE_NAME[backend]) + with pyro_backend(backend): + _test_mcmc_run_ok(backend)