Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for more general chain_method in MCMC #1825

Merged
merged 1 commit into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,8 @@ def model(X, y):
sample values returned from the sampler to constrained values that lie within the support
of the sample sites. Additionally, this is used to return values at deterministic sites in
the model.
:param str chain_method: One of 'parallel' (default), 'sequential', 'vectorized'. The method
:param str chain_method: A callable jax transform like `jax.vmap` or one of
'parallel' (default), 'sequential', 'vectorized'. The method
'parallel' is used to execute the drawing process in parallel on XLA devices (CPUs/GPUs/TPUs),
If there are not enough devices for 'parallel', we fall back to 'sequential' method to draw
chains sequentially. 'vectorized' method is an experimental feature which vectorizes the
Expand Down Expand Up @@ -340,7 +341,11 @@ def __init__(
raise ValueError("thinning must be a positive integer")
self.thinning = thinning
self.postprocess_fn = postprocess_fn
if chain_method not in ["parallel", "vectorized", "sequential"]:
if not callable(chain_method) and chain_method not in [
"parallel",
"vectorized",
"sequential",
]:
raise ValueError(
"Only supporting the following methods to draw chains:"
' "sequential", "parallel", or "vectorized"'
Expand Down Expand Up @@ -471,7 +476,9 @@ def _single_chain_mcmc(self, init, args, kwargs, collect_fields, remove_sites):
collection_size=collection_size,
progbar_desc=partial(_get_progbar_desc_str, lower_idx, phase),
diagnostics_fn=diagnostics,
num_chains=self.num_chains if self.chain_method == "parallel" else 1,
num_chains=self.num_chains
if (callable(self.chain_method) or self.chain_method == "parallel")
else 1,
)
states, last_val = collect_vals
# Get first argument of type `HMCState`
Expand Down Expand Up @@ -679,6 +686,8 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs):
states, last_state = _laxmap(partial_map_fn, map_args)
elif self.chain_method == "parallel":
states, last_state = pmap(partial_map_fn)(map_args)
elif callable(self.chain_method):
states, last_state = self.chain_method(partial_map_fn)(map_args)
else:
assert self.chain_method == "vectorized"
states, last_state = partial_map_fn(map_args)
Expand Down
21 changes: 21 additions & 0 deletions test/infer/test_hmc_gibbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,3 +462,24 @@ def model(data):
kernel = HMCECS(NUTS(model), proxy=proxy_fn)
mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples)
mcmc.run(random.PRNGKey(0), data)


def test_callable_chain_method():
def model():
x = numpyro.sample("x", dist.Normal(0.0, 2.0))
y = numpyro.sample("y", dist.Normal(0.0, 2.0))
numpyro.sample("obs", dist.Normal(x + y, 1.0), obs=jnp.array([1.0]))

def gibbs_fn(rng_key, gibbs_sites, hmc_sites):
y = hmc_sites["y"]
new_x = dist.Normal(0.8 * (1 - y), jnp.sqrt(0.8)).sample(rng_key)
return {"x": new_x}

hmc_kernel = NUTS(model)
kernel = HMCGibbs(hmc_kernel, gibbs_fn=gibbs_fn, gibbs_sites=["x"])
mcmc = MCMC(
kernel, num_warmup=100, num_chains=2, num_samples=100, chain_method=vmap
)
mcmc.run(random.PRNGKey(0))
samples = mcmc.get_samples()
assert set(samples.keys()) == {"x", "y"}
Loading