Skip to content

Commit

Permalink
Post-process when no sample sites present.
Browse files Browse the repository at this point in the history
Current post-processing behaviour skips models with only deterministic variables. Applying this change will return consistent samples regardless of whether `sample` sites are present.
  • Loading branch information
hessammehr committed Nov 21, 2024
1 parent 4f2c9b2 commit 8e40be9
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 4 deletions.
20 changes: 16 additions & 4 deletions numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,7 @@ def collect_and_postprocess(x):
if collect_fields:
fields = nested_attrgetter(*collect_fields)(x[0])
fields = [fields] if len(collect_fields) == 1 else list(fields)
site_values = jax.tree.flatten(fields[0])[0]
if len(site_values) > 0:
fields[0] = postprocess_fn(fields[0], *x[1:])
fields[0] = postprocess_fn(fields[0], *x[1:])

if remove_sites != ():
assert isinstance(fields[0], dict)
Expand Down Expand Up @@ -400,13 +398,27 @@ def _get_cached_fns(self):
fns, key = None, None
if fns is None:

def ensure_vmap(fn, batch_size=None):
def wrapper(x):
x_arrays = jax.tree.flatten(x)[0]
if len(x_arrays) > 0:
return vmap(fn)(x)
else:
assert batch_size is not None
return jax.tree.map(
lambda x: jnp.broadcast_to(x, (batch_size,) + jnp.shape(x)),
fn(x),
)

return wrapper

def _postprocess_fn(state, args, kwargs):
if self.postprocess_fn is None:
body_fn = self.sampler.postprocess_fn(args, kwargs)
else:
body_fn = self.postprocess_fn
if self.chain_method == "vectorized" and self.num_chains > 1:
body_fn = vmap(body_fn)
body_fn = ensure_vmap(body_fn, batch_size=self.num_chains)

return body_fn(state)

Expand Down
26 changes: 26 additions & 0 deletions test/infer/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,3 +1208,29 @@ def model():
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
mcmc.run(random.PRNGKey(0), extra_fields=("z.x",))
assert_allclose(mcmc.get_samples()["x"], jnp.exp(mcmc.get_extra_fields()["z.x"]))


def test_all_deterministic():
def model1():
numpyro.deterministic("x", 1.0)

def model2():
numpyro.deterministic("x", jnp.array([1.0, 2.0]))

num_samples = 10
shapes = {model1: (), model2: (2,)}

for model, shape in shapes.items():
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=num_samples)
mcmc.run(random.PRNGKey(0))
assert mcmc.get_samples()["x"].shape == (num_samples,) + shape


def test_empty_summary():
def model():
pass

mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
mcmc.run(random.PRNGKey(0))

mcmc.print_summary()

0 comments on commit 8e40be9

Please sign in to comment.