From d21a41c80ab22846f954a681bfead5c29cacff8e Mon Sep 17 00:00:00 2001 From: Du Phan Date: Sun, 30 Jun 2024 23:43:41 -0400 Subject: [PATCH] allow for more general chain method --- numpyro/infer/mcmc.py | 15 ++++++++++++--- test/infer/test_hmc_gibbs.py | 21 +++++++++++++++++++++ 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index ad016825b..e18b35bc0 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -277,7 +277,8 @@ def model(X, y): sample values returned from the sampler to constrained values that lie within the support of the sample sites. Additionally, this is used to return values at deterministic sites in the model. - :param str chain_method: One of 'parallel' (default), 'sequential', 'vectorized'. The method + :param str chain_method: A callable jax transform like `jax.vmap` or one of + 'parallel' (default), 'sequential', 'vectorized'. The method 'parallel' is used to execute the drawing process in parallel on XLA devices (CPUs/GPUs/TPUs), If there are not enough devices for 'parallel', we fall back to 'sequential' method to draw chains sequentially. 'vectorized' method is an experimental feature which vectorizes the @@ -340,7 +341,11 @@ def __init__( raise ValueError("thinning must be a positive integer") self.thinning = thinning self.postprocess_fn = postprocess_fn - if chain_method not in ["parallel", "vectorized", "sequential"]: + if not callable(chain_method) and chain_method not in [ + "parallel", + "vectorized", + "sequential", + ]: raise ValueError( "Only supporting the following methods to draw chains:" ' "sequential", "parallel", or "vectorized"' @@ -471,7 +476,9 @@ def _single_chain_mcmc(self, init, args, kwargs, collect_fields, remove_sites): collection_size=collection_size, progbar_desc=partial(_get_progbar_desc_str, lower_idx, phase), diagnostics_fn=diagnostics, - num_chains=self.num_chains if self.chain_method == "parallel" else 1, + num_chains=self.num_chains + if (callable(self.chain_method) or self.chain_method == "parallel") + else 1, ) states, last_val = collect_vals # Get first argument of type `HMCState` @@ -679,6 +686,8 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs): states, last_state = _laxmap(partial_map_fn, map_args) elif self.chain_method == "parallel": states, last_state = pmap(partial_map_fn)(map_args) + elif callable(self.chain_method): + states, last_state = self.chain_method(partial_map_fn)(map_args) else: assert self.chain_method == "vectorized" states, last_state = partial_map_fn(map_args) diff --git a/test/infer/test_hmc_gibbs.py b/test/infer/test_hmc_gibbs.py index 3cb7f02a3..5caef6f38 100644 --- a/test/infer/test_hmc_gibbs.py +++ b/test/infer/test_hmc_gibbs.py @@ -462,3 +462,24 @@ def model(data): kernel = HMCECS(NUTS(model), proxy=proxy_fn) mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples) mcmc.run(random.PRNGKey(0), data) + + +def test_callable_chain_method(): + def model(): + x = numpyro.sample("x", dist.Normal(0.0, 2.0)) + y = numpyro.sample("y", dist.Normal(0.0, 2.0)) + numpyro.sample("obs", dist.Normal(x + y, 1.0), obs=jnp.array([1.0])) + + def gibbs_fn(rng_key, gibbs_sites, hmc_sites): + y = hmc_sites["y"] + new_x = dist.Normal(0.8 * (1 - y), jnp.sqrt(0.8)).sample(rng_key) + return {"x": new_x} + + hmc_kernel = NUTS(model) + kernel = HMCGibbs(hmc_kernel, gibbs_fn=gibbs_fn, gibbs_sites=["x"]) + mcmc = MCMC( + kernel, num_warmup=100, num_chains=2, num_samples=100, chain_method=vmap + ) + mcmc.run(random.PRNGKey(0)) + samples = mcmc.get_samples() + assert set(samples.keys()) == {"x", "y"}