-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
HMCGibbs with chain_method=”vectorized” #1725
Comments
Could you change this line to jax.vmap(...) with the default |
Choosing the "parallel" option and changing to jax.vmap did not work for me. It seems like it still processes the chains in sequential order when I do that. |
Did you set host device to the number of chains: https://num.pyro.ai/en/stable/utilities.html#set-host-device-count? |
I think the When I was writing a custom Gibbs sampler (that does an HMC step for each conditional rather than drawing from a known distribution), I was able to get I imagine it would look a bit like: def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
model_kwargs = {} if model_kwargs is None else model_kwargs.copy()
def init_fn(init_parms, rng_key):
...
return HMCGibbsState(z, hmc_stat, rng_key)
if is_prng_key(rng_key):
init_state = init_fn(init_params, rng_key)
self._sample_fn = self._sample_one_chain
else:
init_state = vmap(init_fn)(init_params, rng_key)
self._sample_fn = vmap(self._sample_one_chian, in_axis=(0, None, None))
return device_put(init_state) and rename the current Might need a bit of extra logic around to work as expected but I think it is what the solution would look like. |
I have a similar issue with Minimal example: import jax
import numpyro
def model():
x = numpyro.sample('x', numpyro.distributions.Normal())
n = numpyro.sample('n', numpyro.distributions.DiscreteUniform())
nuts = numpyro.infer.NUTS(model)
gibbs = numpyro.infer.DiscreteHMCGibbs(nuts)
mcmc = numpyro.infer.MCMC(
gibbs,
num_warmup = 1_000,
num_samples = 1_000,
num_chains = 2,
chain_method = jax.vmap,
progress_bar = True,
)
mcmc.run(jax.random.key(0)) The output looks like: Running chain 0: 100%|█████████████████████████████| 2000/2000 [00:05<00:00, 378.40it/s]
Compiling.. : 0%| | 0/2000 [00:05<?, ?it/s] Same thing happens when running in Jupyter notebook/lab except the progress bar is the |
I am trying to use HMCGibbs sampling with more than one chain using chain_method=“vectorized”, but there appears to be some problem with splitting the random keys.
Consider this toy example that I copied from the numpyro documentation, where I only changed the chain_method and the number of chains:
I find that I get the following Error when running the above Code:
TypeError: split accepts a single key, but was given a key array of shape (2,) != (). Use jax.vmap for batching.
Is there a way to make the vectorize option available for HMCGibbs sampling?
The text was updated successfully, but these errors were encountered: