From e46e7f4201b75186b3d30f0339d4ee6cc2801a35 Mon Sep 17 00:00:00 2001 From: Colin Carroll Date: Tue, 30 Jul 2024 08:20:45 -0700 Subject: [PATCH] Add nutpie sampler to bayeux. PiperOrigin-RevId: 657599766 --- bayeux/_src/debug.py | 2 +- bayeux/_src/mcmc/nutpie.py | 135 +++++++++++++++++++++++++++++++++++++ bayeux/mcmc/__init__.py | 6 ++ bayeux/tests/mcmc_test.py | 17 +++++ pyproject.toml | 1 + 5 files changed, 160 insertions(+), 1 deletion(-) create mode 100644 bayeux/_src/mcmc/nutpie.py diff --git a/bayeux/_src/debug.py b/bayeux/_src/debug.py index 32e84cd..81c0ed9 100644 --- a/bayeux/_src/debug.py +++ b/bayeux/_src/debug.py @@ -522,7 +522,7 @@ def debug_no_ildj( def _get_num_chains(default_kwargs): for v in default_kwargs.values(): - for key in ("num_chains", "num_particles", "batch_size"): + for key in ("num_chains", "num_particles", "batch_size", "chains"): if key in v: return v[key] raise KeyError("No `num_chains` in default kwargs!") diff --git a/bayeux/_src/mcmc/nutpie.py b/bayeux/_src/mcmc/nutpie.py new file mode 100644 index 0000000..2543759 --- /dev/null +++ b/bayeux/_src/mcmc/nutpie.py @@ -0,0 +1,135 @@ +# Copyright 2024 The bayeux Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2024 The bayeux Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Nutpie specific code.""" + +import arviz as az +from bayeux._src import shared +import jax +import numpy as np + +import nutpie +from nutpie.compiled_pyfunc import from_pyfunc + + +class _NutpieSampler(shared.Base): + """Base class for nutpie sampler.""" + name: str = "nutpie" + + def _get_aux(self): + flat, unflatten = jax.flatten_util.ravel_pytree(self.test_point) + + def flatten(pytree): + return jax.flatten_util.ravel_pytree(pytree)[0] + + def make_logp_fn(): + constrained_log_density = self.constrained_log_density() + def log_density(x): + return constrained_log_density(unflatten(x)).squeeze() + log_grad = jax.jit(jax.value_and_grad(log_density)) + def wrapper(x): + val, grad = log_grad(x) + return val, np.array(grad, dtype=np.float64) + return wrapper + return make_logp_fn, flatten, unflatten, flat.shape[0] + + def get_kwargs(self, **kwargs): + make_logp_fn, flatten, unflatten, ndim = self._get_aux() + + def make_expand_fn(*args, **kwargs): + del args + del kwargs + return lambda x: {"x": np.asarray(x, dtype="float64")} + + from_pyfunc_kwargs = { + "ndim": ndim, + "make_logp_fn": make_logp_fn, + "make_expand_fn": make_expand_fn, + "expanded_shapes": [(ndim,)], + "expanded_names": ["x"], + "expanded_dtypes": [np.float64], + } + from_pyfunc_kwargs = { + k: kwargs.get(k, v) for k, v in from_pyfunc_kwargs.items()} + + kwargs_with_defaults = { + "draws": 1_000, + "chains": 8, + } | kwargs + sample_kwargs, _ = shared.get_default_signature(nutpie.sample) + sample_kwargs.update({k: kwargs_with_defaults[k] for k in sample_kwargs if + k in kwargs_with_defaults}) + if "cores" not in kwargs: + sample_kwargs["cores"] = sample_kwargs["chains"] + extra_parameters = {"flatten": flatten, + "unflatten": unflatten, + "return_pytree": kwargs.get("return_pytree", False)} + + return {from_pyfunc: from_pyfunc_kwargs, + nutpie.sample: sample_kwargs, + "extra_parameters": extra_parameters} + + def __call__(self, seed, **kwargs): + kwargs = self.get_kwargs(**kwargs) + extra_parameters = kwargs["extra_parameters"] + compiled = from_pyfunc(**kwargs[from_pyfunc]) + idata = nutpie.sample(compiled_model=compiled, + **kwargs[nutpie.sample]) + return _postprocess_idata(idata, + extra_parameters["unflatten"], + self.transform_fn, + extra_parameters["return_pytree"]) + + +def _pytree_to_dict(draws): + if hasattr(draws, "_asdict"): + draws = draws._asdict() + elif not isinstance(draws, dict): + draws = {"var0": draws} + + return draws + + +def _postprocess_idata(idata, unflatten, transform_fn, return_pytree): + """Convert nutpie inference data back to pytree, transform, and put back.""" + unflatten = jax.vmap(jax.vmap(unflatten)) + posterior = transform_fn(unflatten(idata.posterior.x.values)) + + if return_pytree: + return posterior + + posterior = _pytree_to_dict(posterior) + warmup_posterior = _pytree_to_dict( + transform_fn(unflatten(idata.warmup_posterior.x.values))) + new_posterior = az.from_dict(posterior=posterior) + new_warmup_posterior = az.from_dict(posterior=warmup_posterior) + del idata.posterior + del idata.warmup_posterior + idata.add_groups(posterior=new_posterior.posterior, + warmup_posterior=new_warmup_posterior.posterior) + return idata diff --git a/bayeux/mcmc/__init__.py b/bayeux/mcmc/__init__.py index 9c8ed81..80a5315 100644 --- a/bayeux/mcmc/__init__.py +++ b/bayeux/mcmc/__init__.py @@ -52,3 +52,9 @@ from bayeux._src.mcmc.numpyro import NUTS as NUTSnumpyro __all__.extend(["HMCnumpyro", "NUTSnumpyro"]) + +if importlib.util.find_spec("nutpie") is not None: + from bayeux._src.mcmc.nutpie import _NutpieSampler as NutpieSampler + + __all__.extend(["NutpieSampler"]) + diff --git a/bayeux/tests/mcmc_test.py b/bayeux/tests/mcmc_test.py index e1fff71..e04634b 100644 --- a/bayeux/tests/mcmc_test.py +++ b/bayeux/tests/mcmc_test.py @@ -112,6 +112,23 @@ def test_return_pytree_flowmc(): assert pytree["x"]["y"].shape == (4, 10, 2) +@pytest.mark.skipif(importlib.util.find_spec("nutpie") is None, + reason="Test requires nutpie which is not installed") +def test_return_pytree_nutpie(): + model = bx.Model(log_density=lambda pt: -jnp.sum(pt["x"]["y"]**2), + test_point={"x": {"y": jnp.array([1., 1.])}}) + seed = jax.random.PRNGKey(0) + pytree = model.mcmc.nutpie( + seed=seed, + return_pytree=True, + chains=4, + draws=10, + tune=10, + ) + # 10 draws = (1 local + 1 global) * 5 loops + assert pytree["x"]["y"].shape == (4, 10, 2) + + @pytest.mark.parametrize("method", METHODS) def test_samplers(method): # flowMC samplers are broken for 0 or 1 dimensions, so just test diff --git a/pyproject.toml b/pyproject.toml index eaf6989..4b97405 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "numpyro", "jaxopt", "pymc", + "nutpie", ] # `version` is automatically set by flit to use `bayeux.__version__`