From 17147bb9e1ca9ac3f26b0b7665b03530948ec871 Mon Sep 17 00:00:00 2001 From: Deathn0t Date: Thu, 5 Sep 2024 11:02:04 +0200 Subject: [PATCH] hmc_gibbs updated to work with different support sizes and batching, tests are passing when using changes from PR #1859 --- numpyro/infer/hmc_gibbs.py | 22 +++++------ test/test_distributions.py | 75 ++++++++++++++++++++++++++++++++------ 2 files changed, 73 insertions(+), 24 deletions(-) diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index a94f0eac5..5a9c20beb 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -476,26 +476,24 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): # Each support is padded with zeros to have the same length # ravel is used to maintain a consistant behaviour with `support_sizes` - max_length_support_enumerates = max( - size for size in self._support_sizes.values() + max_length_support_enumerates = np.max( + [size for size in self._support_sizes.values()] ) support_enumerates = {} for name, support_size in self._support_sizes.items(): site = self._prototype_trace[name] - enumerate_support = site["fn"].enumerate_support(False) - padded_enumerate_support = np.pad( - enumerate_support, - (0, max_length_support_enumerates - enumerate_support.shape[0]), - ) - padded_enumerate_support = np.broadcast_to( - padded_enumerate_support, - support_size.shape + (max_length_support_enumerates,), - ) + enumerate_support = site["fn"].enumerate_support(True).T + # Only the last dimension that corresponds to support size is padded + pad_width = [(0, 0) for _ in range(len(enumerate_support.shape) - 1)] + [ + (0, max_length_support_enumerates - enumerate_support.shape[-1]) + ] + padded_enumerate_support = np.pad(enumerate_support, pad_width) + support_enumerates[name] = padded_enumerate_support self._support_enumerates = jax.vmap( - lambda x: ravel_pytree(x)[0], in_axes=0, out_axes=1 + lambda x: ravel_pytree(x)[0], in_axes=len(support_size.shape), out_axes=1 )(support_enumerates) self._gibbs_sites = [ diff --git a/test/test_distributions.py b/test/test_distributions.py index 6834194be..abf4f6b39 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -3430,10 +3430,16 @@ def test_discrete_uniform_with_mixedhmc(): import numpyro.distributions as dist from numpyro.infer import HMC, MCMC, MixedHMC - def model_1(): - numpyro.sample("x0", dist.DiscreteUniform(10, 12)) - numpyro.sample("x1", dist.Categorical(np.asarray([0.25, 0.25, 0.25, 0.25]))) - + def sample_mixedhmc(model_fn, num_samples, **kwargs): + kernel = HMC(model_fn, trajectory_length=1.2) + kernel = MixedHMC(kernel, num_discrete_updates=20, **kwargs) + mcmc = MCMC(kernel, num_warmup=100, num_samples=num_samples, progress_bar=False) + key = jax.random.PRNGKey(0) + mcmc.run(key) + samples = mcmc.get_samples() + return samples + + num_samples = 1000 mixed_hmc_kwargs = [ {"random_walk": False, "modified": False}, {"random_walk": True, "modified": False}, @@ -3441,15 +3447,13 @@ def model_1(): {"random_walk": False, "modified": True}, ] - num_samples = 1000 - + # Case 1: one discrete uniform with one categorical + def model_1(): + numpyro.sample("x0", dist.DiscreteUniform(10, 12)) + numpyro.sample("x1", dist.Categorical(np.asarray([0.25, 0.25, 0.25, 0.25]))) + for kwargs in mixed_hmc_kwargs: - kernel = HMC(model_1, trajectory_length=1.2) - kernel = MixedHMC(kernel, num_discrete_updates=20, **kwargs) - mcmc = MCMC(kernel, num_warmup=100, num_samples=num_samples, progress_bar=False) - key = jax.random.PRNGKey(0) - mcmc.run(key) - samples = mcmc.get_samples() + samples = sample_mixedhmc(model_1, num_samples, **kwargs) assert jnp.all( (samples["x0"] >= 10) & (samples["x0"] <= 12) @@ -3458,6 +3462,53 @@ def model_1(): (samples["x1"] >= 0) & (samples["x1"] <= 3) ), f"Failed with {kwargs=}" + def model_2(): + numpyro.sample("x0", dist.Categorical(0.25 * jnp.ones((4,)))) + numpyro.sample("x1", dist.Categorical(0.1 * jnp.ones((10,)))) + + # Case 2: 2 categorical with different support lengths + for kwargs in mixed_hmc_kwargs: + samples = sample_mixedhmc(model_2, num_samples, **kwargs) + + assert jnp.all( + (samples["x0"] >= 0) & (samples["x0"] <= 3) + ), f"Failed with {kwargs=}" + assert jnp.all( + (samples["x1"] >= 0) & (samples["x1"] <= 9) + ), f"Failed with {kwargs=}" + + def model_3(): + numpyro.sample("x0", dist.Categorical(0.25 * jnp.ones((3, 4)))) + numpyro.sample("x1", dist.Categorical(0.1 * jnp.ones((3, 10)))) + + # Case 3: 2 categorical with different support lengths and batched by 3 + for kwargs in mixed_hmc_kwargs: + samples = sample_mixedhmc(model_3, num_samples, **kwargs) + + assert jnp.all( + (samples["x0"] >= 0) & (samples["x0"] <= 3) + ), f"Failed with {kwargs=}" + assert jnp.all( + (samples["x1"] >= 0) & (samples["x1"] <= 9) + ), f"Failed with {kwargs=}" + + def model_4(): + dist0 = dist.Categorical(0.25 * jnp.ones((3, 4))) + numpyro.sample("x0", dist0) + dist1 = dist.DiscreteUniform(10 * jnp.ones((3,)), 19 * jnp.ones((3,))) + numpyro.sample("x1", dist1) + + # Case 4: 1 categorical with different support lengths and batched by 3 + for kwargs in mixed_hmc_kwargs: + samples = sample_mixedhmc(model_4, num_samples, **kwargs) + + assert jnp.all( + (samples["x0"] >= 0) & (samples["x0"] <= 3) + ), f"Failed with {kwargs=}" + assert jnp.all( + (samples["x1"] >= 10) & (samples["x1"] <= 20) + ), f"Failed with {kwargs=}" + if __name__ == "__main__": test_discrete_uniform_with_mixedhmc()