Skip to content

Commit

Permalink
hmc_gibbs updated to work with different support sizes and batching, …
Browse files Browse the repository at this point in the history
…tests are passing when using changes from PR pyro-ppl#1859
  • Loading branch information
Deathn0t committed Sep 5, 2024
1 parent fe46ba1 commit 17147bb
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 24 deletions.
22 changes: 10 additions & 12 deletions numpyro/infer/hmc_gibbs.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,26 +476,24 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
# Each support is padded with zeros to have the same length
# ravel is used to maintain a consistant behaviour with `support_sizes`

max_length_support_enumerates = max(
size for size in self._support_sizes.values()
max_length_support_enumerates = np.max(
[size for size in self._support_sizes.values()]
)

support_enumerates = {}
for name, support_size in self._support_sizes.items():
site = self._prototype_trace[name]
enumerate_support = site["fn"].enumerate_support(False)
padded_enumerate_support = np.pad(
enumerate_support,
(0, max_length_support_enumerates - enumerate_support.shape[0]),
)
padded_enumerate_support = np.broadcast_to(
padded_enumerate_support,
support_size.shape + (max_length_support_enumerates,),
)
enumerate_support = site["fn"].enumerate_support(True).T
# Only the last dimension that corresponds to support size is padded
pad_width = [(0, 0) for _ in range(len(enumerate_support.shape) - 1)] + [
(0, max_length_support_enumerates - enumerate_support.shape[-1])
]
padded_enumerate_support = np.pad(enumerate_support, pad_width)

support_enumerates[name] = padded_enumerate_support

self._support_enumerates = jax.vmap(
lambda x: ravel_pytree(x)[0], in_axes=0, out_axes=1
lambda x: ravel_pytree(x)[0], in_axes=len(support_size.shape), out_axes=1
)(support_enumerates)

self._gibbs_sites = [
Expand Down
75 changes: 63 additions & 12 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3430,26 +3430,30 @@ def test_discrete_uniform_with_mixedhmc():
import numpyro.distributions as dist
from numpyro.infer import HMC, MCMC, MixedHMC

def model_1():
numpyro.sample("x0", dist.DiscreteUniform(10, 12))
numpyro.sample("x1", dist.Categorical(np.asarray([0.25, 0.25, 0.25, 0.25])))

def sample_mixedhmc(model_fn, num_samples, **kwargs):
kernel = HMC(model_fn, trajectory_length=1.2)
kernel = MixedHMC(kernel, num_discrete_updates=20, **kwargs)
mcmc = MCMC(kernel, num_warmup=100, num_samples=num_samples, progress_bar=False)
key = jax.random.PRNGKey(0)
mcmc.run(key)
samples = mcmc.get_samples()
return samples

num_samples = 1000
mixed_hmc_kwargs = [
{"random_walk": False, "modified": False},
{"random_walk": True, "modified": False},
{"random_walk": True, "modified": True},
{"random_walk": False, "modified": True},
]

num_samples = 1000

# Case 1: one discrete uniform with one categorical
def model_1():
numpyro.sample("x0", dist.DiscreteUniform(10, 12))
numpyro.sample("x1", dist.Categorical(np.asarray([0.25, 0.25, 0.25, 0.25])))

for kwargs in mixed_hmc_kwargs:
kernel = HMC(model_1, trajectory_length=1.2)
kernel = MixedHMC(kernel, num_discrete_updates=20, **kwargs)
mcmc = MCMC(kernel, num_warmup=100, num_samples=num_samples, progress_bar=False)
key = jax.random.PRNGKey(0)
mcmc.run(key)
samples = mcmc.get_samples()
samples = sample_mixedhmc(model_1, num_samples, **kwargs)

assert jnp.all(
(samples["x0"] >= 10) & (samples["x0"] <= 12)
Expand All @@ -3458,6 +3462,53 @@ def model_1():
(samples["x1"] >= 0) & (samples["x1"] <= 3)
), f"Failed with {kwargs=}"

def model_2():
numpyro.sample("x0", dist.Categorical(0.25 * jnp.ones((4,))))
numpyro.sample("x1", dist.Categorical(0.1 * jnp.ones((10,))))

# Case 2: 2 categorical with different support lengths
for kwargs in mixed_hmc_kwargs:
samples = sample_mixedhmc(model_2, num_samples, **kwargs)

assert jnp.all(
(samples["x0"] >= 0) & (samples["x0"] <= 3)
), f"Failed with {kwargs=}"
assert jnp.all(
(samples["x1"] >= 0) & (samples["x1"] <= 9)
), f"Failed with {kwargs=}"

def model_3():
numpyro.sample("x0", dist.Categorical(0.25 * jnp.ones((3, 4))))
numpyro.sample("x1", dist.Categorical(0.1 * jnp.ones((3, 10))))

# Case 3: 2 categorical with different support lengths and batched by 3
for kwargs in mixed_hmc_kwargs:
samples = sample_mixedhmc(model_3, num_samples, **kwargs)

assert jnp.all(
(samples["x0"] >= 0) & (samples["x0"] <= 3)
), f"Failed with {kwargs=}"
assert jnp.all(
(samples["x1"] >= 0) & (samples["x1"] <= 9)
), f"Failed with {kwargs=}"

def model_4():
dist0 = dist.Categorical(0.25 * jnp.ones((3, 4)))
numpyro.sample("x0", dist0)
dist1 = dist.DiscreteUniform(10 * jnp.ones((3,)), 19 * jnp.ones((3,)))
numpyro.sample("x1", dist1)

# Case 4: 1 categorical with different support lengths and batched by 3
for kwargs in mixed_hmc_kwargs:
samples = sample_mixedhmc(model_4, num_samples, **kwargs)

assert jnp.all(
(samples["x0"] >= 0) & (samples["x0"] <= 3)
), f"Failed with {kwargs=}"
assert jnp.all(
(samples["x1"] >= 10) & (samples["x1"] <= 20)
), f"Failed with {kwargs=}"


if __name__ == "__main__":
test_discrete_uniform_with_mixedhmc()

0 comments on commit 17147bb

Please sign in to comment.