From 59de90f50a12063e744417ae9c458aca4184ca95 Mon Sep 17 00:00:00 2001 From: Deathn0t Date: Wed, 24 Jul 2024 18:12:16 +0200 Subject: [PATCH 1/8] Fixing issue Samples are outside the support for DiscreteUniform distribution #1834 --- numpyro/infer/hmc_gibbs.py | 7 +++++++ numpyro/infer/mixed_hmc.py | 6 ++++++ 2 files changed, 13 insertions(+) diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index f6b95389b..d3cd0e9e4 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -434,6 +434,13 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): and site["fn"].has_enumerate_support and not site["is_observed"] } + self._support_enumerates = { + name: site["fn"].enumerate_support(False) + for name, site in self._prototype_trace.items() + if site["type"] == "sample" + and site["fn"].has_enumerate_support + and not site["is_observed"] + } self._gibbs_sites = [ name for name, site in self._prototype_trace.items() diff --git a/numpyro/infer/mixed_hmc.py b/numpyro/infer/mixed_hmc.py index 3e3d2ae59..6a558293f 100644 --- a/numpyro/infer/mixed_hmc.py +++ b/numpyro/infer/mixed_hmc.py @@ -6,6 +6,7 @@ from jax import grad, jacfwd, lax, random from jax.flatten_util import ravel_pytree +import jax import jax.numpy as jnp from numpyro.infer.hmc import momentum_generator @@ -301,6 +302,11 @@ def body_fn(i, vals): adapt_state=adapt_state, ) + z_discrete = jax.tree.map( + lambda idx, support: support[idx], + z_discrete, + self._support_enumerates, + ) z = {**z_discrete, **hmc_state.z} return MixedHMCState(z, hmc_state, rng_key, accept_prob) From 0cd448f051ad178582c4bb7588ffd45491e4cabd Mon Sep 17 00:00:00 2001 From: Deathn0t Date: Fri, 26 Jul 2024 15:32:03 +0200 Subject: [PATCH 2/8] updated with enumerate support as padded zeros arrays --- numpyro/infer/hmc_gibbs.py | 89 ++++++++++++++++++++++++++++++-------- numpyro/infer/mixed_hmc.py | 6 +-- 2 files changed, 73 insertions(+), 22 deletions(-) diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index d3cd0e9e4..e45e18805 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -192,12 +192,22 @@ def __getstate__(self): def _discrete_gibbs_proposal_body_fn( - z_init_flat, unravel_fn, pe_init, potential_fn, idx, i, val + z_init_flat, + unravel_fn, + pe_init, + potential_fn, + idx, + i, + val, + support_size, + support_enumerate, ): rng_key, z, pe, log_weight_sum = val rng_key, rng_transition = random.split(rng_key) - proposal = jnp.where(i >= z_init_flat[idx], i + 1, i) - z_new_flat = z_init_flat.at[idx].set(proposal) + proposal_index = jnp.where( + support_enumerate[i] == z_init_flat[idx], support_size - 1, i + ) + z_new_flat = z_init_flat.at[idx].set(support_enumerate[proposal_index]) z_new = unravel_fn(z_new_flat) pe_new = potential_fn(z_new) log_weight_new = pe_init - pe_new @@ -216,7 +226,9 @@ def _discrete_gibbs_proposal_body_fn( return rng_key, z, pe, log_weight_sum -def _discrete_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx, support_size): +def _discrete_gibbs_proposal( + rng_key, z_discrete, pe, potential_fn, idx, support_size, support_enumerate +): # idx: current index of `z_discrete_flat` to update # support_size: support size of z_discrete at the index idx @@ -234,6 +246,8 @@ def _discrete_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx, support pe, potential_fn, idx, + support_size=support_size, + support_enumerate=support_enumerate, ) init_val = (rng_key, z_discrete, pe, jnp.array(0.0)) rng_key, z_new, pe_new, _ = fori_loop(0, support_size - 1, body_fn, init_val) @@ -242,7 +256,14 @@ def _discrete_gibbs_proposal(rng_key, z_discrete, pe, potential_fn, idx, support def _discrete_modified_gibbs_proposal( - rng_key, z_discrete, pe, potential_fn, idx, support_size, stay_prob=0.0 + rng_key, + z_discrete, + pe, + potential_fn, + idx, + support_size, + support_enumerate, + stay_prob=0.0, ): assert isinstance(stay_prob, float) and stay_prob >= 0.0 and stay_prob < 1 z_discrete_flat, unravel_fn = ravel_pytree(z_discrete) @@ -253,6 +274,8 @@ def _discrete_modified_gibbs_proposal( pe, potential_fn, idx, + support_size=support_size, + support_enumerate=support_enumerate, ) # like gibbs_step but here, weight of the current value is 0 init_val = (rng_key, z_discrete, pe, jnp.array(-jnp.inf)) @@ -276,12 +299,14 @@ def _discrete_modified_gibbs_proposal( return rng_key, z_new, pe_new, log_accept_ratio -def _discrete_rw_proposal(rng_key, z_discrete, pe, potential_fn, idx, support_size): +def _discrete_rw_proposal( + rng_key, z_discrete, pe, potential_fn, idx, support_size, support_enumerate +): rng_key, rng_proposal = random.split(rng_key, 2) z_discrete_flat, unravel_fn = ravel_pytree(z_discrete) proposal = random.randint(rng_proposal, (), minval=0, maxval=support_size) - z_new_flat = z_discrete_flat.at[idx].set(proposal) + z_new_flat = z_discrete_flat.at[idx].set(support_enumerate[proposal]) z_new = unravel_fn(z_new_flat) pe_new = potential_fn(z_new) log_accept_ratio = pe - pe_new @@ -289,15 +314,26 @@ def _discrete_rw_proposal(rng_key, z_discrete, pe, potential_fn, idx, support_si def _discrete_modified_rw_proposal( - rng_key, z_discrete, pe, potential_fn, idx, support_size, stay_prob=0.0 + rng_key, + z_discrete, + pe, + potential_fn, + idx, + support_size, + support_enumerate, + stay_prob=0.0, ): assert isinstance(stay_prob, float) and stay_prob >= 0.0 and stay_prob < 1 rng_key, rng_proposal, rng_stay = random.split(rng_key, 3) z_discrete_flat, unravel_fn = ravel_pytree(z_discrete) i = random.randint(rng_proposal, (), minval=0, maxval=support_size - 1) - proposal = jnp.where(i >= z_discrete_flat[idx], i + 1, i) - proposal = jnp.where(random.bernoulli(rng_stay, stay_prob), idx, proposal) + proposal_index = jnp.where( + support_enumerate[i] == z_discrete_flat[idx], support_size - 1, i + ) + proposal = jnp.where( + random.bernoulli(rng_stay, stay_prob), idx, support_enumerate[proposal_index] + ) z_new_flat = z_discrete_flat.at[idx].set(proposal) z_new = unravel_fn(z_new_flat) pe_new = potential_fn(z_new) @@ -434,13 +470,32 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): and site["fn"].has_enumerate_support and not site["is_observed"] } - self._support_enumerates = { - name: site["fn"].enumerate_support(False) - for name, site in self._prototype_trace.items() - if site["type"] == "sample" - and site["fn"].has_enumerate_support - and not site["is_observed"] - } + max_length_support_enumerates = max( + ( + site["fn"].enumerate_support(False).shape[0] + for site in self._prototype_trace.values() + if site["type"] == "sample" + and site["fn"].has_enumerate_support + and not site["is_observed"] + ) + ) + # All support_enumerates should have the same length to be used in the loop + # Each support is padded with zeros to have the same length + self._support_enumerates = np.zeros( + (len(self._support_sizes), max_length_support_enumerates), dtype=int + ) + for i, (name, site) in enumerate(self._prototype_trace.items()): + if ( + site["type"] == "sample" + and site["fn"].has_enumerate_support + and not site["is_observed"] + ): + self._support_enumerates[ + i, : site["fn"].enumerate_support(False).shape[0] + ] = site["fn"].enumerate_support(False) + self._support_enumerates = jnp.asarray( + self._support_enumerates, dtype=jnp.int32 + ) self._gibbs_sites = [ name for name, site in self._prototype_trace.items() diff --git a/numpyro/infer/mixed_hmc.py b/numpyro/infer/mixed_hmc.py index 6a558293f..7817a5159 100644 --- a/numpyro/infer/mixed_hmc.py +++ b/numpyro/infer/mixed_hmc.py @@ -139,6 +139,7 @@ def update_discrete( partial(potential_fn, z_hmc=hmc_state.z), idx, self._support_sizes_flat[idx], + self._support_enumerates[idx], ) # Algo 1, line 20: depending on reject or refract, we will update # the discrete variable and its corresponding kinetic energy. In case of @@ -302,11 +303,6 @@ def body_fn(i, vals): adapt_state=adapt_state, ) - z_discrete = jax.tree.map( - lambda idx, support: support[idx], - z_discrete, - self._support_enumerates, - ) z = {**z_discrete, **hmc_state.z} return MixedHMCState(z, hmc_state, rng_key, accept_prob) From e14eea743e74b1a613285599efea247e2f1db518 Mon Sep 17 00:00:00 2001 From: Deathn0t Date: Mon, 29 Jul 2024 10:16:59 +0200 Subject: [PATCH 3/8] updating the logical using ravel to maintain a consistant behaviour --- numpyro/infer/hmc_gibbs.py | 42 +++++++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index e45e18805..823abab87 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -5,6 +5,7 @@ import copy from functools import partial +import jax import numpy as np from jax import device_put, grad, jacfwd, random, value_and_grad @@ -470,6 +471,11 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): and site["fn"].has_enumerate_support and not site["is_observed"] } + + # All support_enumerates should have the same length to be used in the loop + # Each support is padded with zeros to have the same length + # ravel is used to maintain a consistant behaviour with `support_sizes` + max_length_support_enumerates = max( ( site["fn"].enumerate_support(False).shape[0] @@ -479,23 +485,25 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): and not site["is_observed"] ) ) - # All support_enumerates should have the same length to be used in the loop - # Each support is padded with zeros to have the same length - self._support_enumerates = np.zeros( - (len(self._support_sizes), max_length_support_enumerates), dtype=int - ) - for i, (name, site) in enumerate(self._prototype_trace.items()): - if ( - site["type"] == "sample" - and site["fn"].has_enumerate_support - and not site["is_observed"] - ): - self._support_enumerates[ - i, : site["fn"].enumerate_support(False).shape[0] - ] = site["fn"].enumerate_support(False) - self._support_enumerates = jnp.asarray( - self._support_enumerates, dtype=jnp.int32 - ) + + support_enumerates = {} + for name, support_size in self._support_sizes.items(): + site = self._prototype_trace[name] + enumerate_support = site["fn"].enumerate_support(False) + padded_enumerate_support = np.pad( + enumerate_support, + (0, max_length_support_enumerates - enumerate_support.shape[0]), + ) + padded_enumerate_support = np.broadcast_to( + padded_enumerate_support, + support_size.shape + (max_length_support_enumerates,), + ) + support_enumerates[name] = padded_enumerate_support + + self._support_enumerates = jax.vmap( + lambda x: ravel_pytree(x)[0] , in_axes=0, out_axes=1 + )(support_enumerates) + self._gibbs_sites = [ name for name, site in self._prototype_trace.items() From 10548fe61acd3ebb45b73843fc68d3efef0d98ce Mon Sep 17 00:00:00 2001 From: Deathn0t Date: Mon, 29 Jul 2024 17:00:16 +0200 Subject: [PATCH 4/8] iterating of support_sizes --- numpyro/infer/hmc_gibbs.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index 823abab87..f513c48ec 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -476,16 +476,8 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): # Each support is padded with zeros to have the same length # ravel is used to maintain a consistant behaviour with `support_sizes` - max_length_support_enumerates = max( - ( - site["fn"].enumerate_support(False).shape[0] - for site in self._prototype_trace.values() - if site["type"] == "sample" - and site["fn"].has_enumerate_support - and not site["is_observed"] - ) - ) - + max_length_support_enumerates = max(size for size in self._support_sizes.values()) + support_enumerates = {} for name, support_size in self._support_sizes.items(): site = self._prototype_trace[name] From f4b9d99f62b81801e7280077063f33b4c0a2e73f Mon Sep 17 00:00:00 2001 From: Deathn0t Date: Tue, 30 Jul 2024 09:32:04 +0200 Subject: [PATCH 5/8] fixed lint issues --- numpyro/infer/hmc_gibbs.py | 8 +++++--- numpyro/infer/mixed_hmc.py | 1 - 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index f513c48ec..a94f0eac5 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -5,9 +5,9 @@ import copy from functools import partial -import jax import numpy as np +import jax from jax import device_put, grad, jacfwd, random, value_and_grad from jax.flatten_util import ravel_pytree import jax.numpy as jnp @@ -476,7 +476,9 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): # Each support is padded with zeros to have the same length # ravel is used to maintain a consistant behaviour with `support_sizes` - max_length_support_enumerates = max(size for size in self._support_sizes.values()) + max_length_support_enumerates = max( + size for size in self._support_sizes.values() + ) support_enumerates = {} for name, support_size in self._support_sizes.items(): @@ -493,7 +495,7 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): support_enumerates[name] = padded_enumerate_support self._support_enumerates = jax.vmap( - lambda x: ravel_pytree(x)[0] , in_axes=0, out_axes=1 + lambda x: ravel_pytree(x)[0], in_axes=0, out_axes=1 )(support_enumerates) self._gibbs_sites = [ diff --git a/numpyro/infer/mixed_hmc.py b/numpyro/infer/mixed_hmc.py index 7817a5159..e163fb75e 100644 --- a/numpyro/infer/mixed_hmc.py +++ b/numpyro/infer/mixed_hmc.py @@ -6,7 +6,6 @@ from jax import grad, jacfwd, lax, random from jax.flatten_util import ravel_pytree -import jax import jax.numpy as jnp from numpyro.infer.hmc import momentum_generator From fe46ba1e7b833d161590c492cb1ab40dbe763397 Mon Sep 17 00:00:00 2001 From: Deathn0t Date: Tue, 30 Jul 2024 09:46:10 +0200 Subject: [PATCH 6/8] adding test for mixed hmc sampling of distribution discrete uniform --- test/test_distributions.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/test/test_distributions.py b/test/test_distributions.py index e10fd7248..6834194be 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -3423,3 +3423,41 @@ def test_gaussian_random_walk_linear_recursive_equivalence(): x2 = dist2.sample(random.PRNGKey(7)) assert jnp.allclose(x1, x2.squeeze()) assert jnp.allclose(dist1.log_prob(x1), dist2.log_prob(x2)) + + +def test_discrete_uniform_with_mixedhmc(): + import numpyro + import numpyro.distributions as dist + from numpyro.infer import HMC, MCMC, MixedHMC + + def model_1(): + numpyro.sample("x0", dist.DiscreteUniform(10, 12)) + numpyro.sample("x1", dist.Categorical(np.asarray([0.25, 0.25, 0.25, 0.25]))) + + mixed_hmc_kwargs = [ + {"random_walk": False, "modified": False}, + {"random_walk": True, "modified": False}, + {"random_walk": True, "modified": True}, + {"random_walk": False, "modified": True}, + ] + + num_samples = 1000 + + for kwargs in mixed_hmc_kwargs: + kernel = HMC(model_1, trajectory_length=1.2) + kernel = MixedHMC(kernel, num_discrete_updates=20, **kwargs) + mcmc = MCMC(kernel, num_warmup=100, num_samples=num_samples, progress_bar=False) + key = jax.random.PRNGKey(0) + mcmc.run(key) + samples = mcmc.get_samples() + + assert jnp.all( + (samples["x0"] >= 10) & (samples["x0"] <= 12) + ), f"Failed with {kwargs=}" + assert jnp.all( + (samples["x1"] >= 0) & (samples["x1"] <= 3) + ), f"Failed with {kwargs=}" + + +if __name__ == "__main__": + test_discrete_uniform_with_mixedhmc() From 17147bb9e1ca9ac3f26b0b7665b03530948ec871 Mon Sep 17 00:00:00 2001 From: Deathn0t Date: Thu, 5 Sep 2024 11:02:04 +0200 Subject: [PATCH 7/8] hmc_gibbs updated to work with different support sizes and batching, tests are passing when using changes from PR #1859 --- numpyro/infer/hmc_gibbs.py | 22 +++++------ test/test_distributions.py | 75 ++++++++++++++++++++++++++++++++------ 2 files changed, 73 insertions(+), 24 deletions(-) diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index a94f0eac5..5a9c20beb 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -476,26 +476,24 @@ def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs): # Each support is padded with zeros to have the same length # ravel is used to maintain a consistant behaviour with `support_sizes` - max_length_support_enumerates = max( - size for size in self._support_sizes.values() + max_length_support_enumerates = np.max( + [size for size in self._support_sizes.values()] ) support_enumerates = {} for name, support_size in self._support_sizes.items(): site = self._prototype_trace[name] - enumerate_support = site["fn"].enumerate_support(False) - padded_enumerate_support = np.pad( - enumerate_support, - (0, max_length_support_enumerates - enumerate_support.shape[0]), - ) - padded_enumerate_support = np.broadcast_to( - padded_enumerate_support, - support_size.shape + (max_length_support_enumerates,), - ) + enumerate_support = site["fn"].enumerate_support(True).T + # Only the last dimension that corresponds to support size is padded + pad_width = [(0, 0) for _ in range(len(enumerate_support.shape) - 1)] + [ + (0, max_length_support_enumerates - enumerate_support.shape[-1]) + ] + padded_enumerate_support = np.pad(enumerate_support, pad_width) + support_enumerates[name] = padded_enumerate_support self._support_enumerates = jax.vmap( - lambda x: ravel_pytree(x)[0], in_axes=0, out_axes=1 + lambda x: ravel_pytree(x)[0], in_axes=len(support_size.shape), out_axes=1 )(support_enumerates) self._gibbs_sites = [ diff --git a/test/test_distributions.py b/test/test_distributions.py index 6834194be..abf4f6b39 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -3430,10 +3430,16 @@ def test_discrete_uniform_with_mixedhmc(): import numpyro.distributions as dist from numpyro.infer import HMC, MCMC, MixedHMC - def model_1(): - numpyro.sample("x0", dist.DiscreteUniform(10, 12)) - numpyro.sample("x1", dist.Categorical(np.asarray([0.25, 0.25, 0.25, 0.25]))) - + def sample_mixedhmc(model_fn, num_samples, **kwargs): + kernel = HMC(model_fn, trajectory_length=1.2) + kernel = MixedHMC(kernel, num_discrete_updates=20, **kwargs) + mcmc = MCMC(kernel, num_warmup=100, num_samples=num_samples, progress_bar=False) + key = jax.random.PRNGKey(0) + mcmc.run(key) + samples = mcmc.get_samples() + return samples + + num_samples = 1000 mixed_hmc_kwargs = [ {"random_walk": False, "modified": False}, {"random_walk": True, "modified": False}, @@ -3441,15 +3447,13 @@ def model_1(): {"random_walk": False, "modified": True}, ] - num_samples = 1000 - + # Case 1: one discrete uniform with one categorical + def model_1(): + numpyro.sample("x0", dist.DiscreteUniform(10, 12)) + numpyro.sample("x1", dist.Categorical(np.asarray([0.25, 0.25, 0.25, 0.25]))) + for kwargs in mixed_hmc_kwargs: - kernel = HMC(model_1, trajectory_length=1.2) - kernel = MixedHMC(kernel, num_discrete_updates=20, **kwargs) - mcmc = MCMC(kernel, num_warmup=100, num_samples=num_samples, progress_bar=False) - key = jax.random.PRNGKey(0) - mcmc.run(key) - samples = mcmc.get_samples() + samples = sample_mixedhmc(model_1, num_samples, **kwargs) assert jnp.all( (samples["x0"] >= 10) & (samples["x0"] <= 12) @@ -3458,6 +3462,53 @@ def model_1(): (samples["x1"] >= 0) & (samples["x1"] <= 3) ), f"Failed with {kwargs=}" + def model_2(): + numpyro.sample("x0", dist.Categorical(0.25 * jnp.ones((4,)))) + numpyro.sample("x1", dist.Categorical(0.1 * jnp.ones((10,)))) + + # Case 2: 2 categorical with different support lengths + for kwargs in mixed_hmc_kwargs: + samples = sample_mixedhmc(model_2, num_samples, **kwargs) + + assert jnp.all( + (samples["x0"] >= 0) & (samples["x0"] <= 3) + ), f"Failed with {kwargs=}" + assert jnp.all( + (samples["x1"] >= 0) & (samples["x1"] <= 9) + ), f"Failed with {kwargs=}" + + def model_3(): + numpyro.sample("x0", dist.Categorical(0.25 * jnp.ones((3, 4)))) + numpyro.sample("x1", dist.Categorical(0.1 * jnp.ones((3, 10)))) + + # Case 3: 2 categorical with different support lengths and batched by 3 + for kwargs in mixed_hmc_kwargs: + samples = sample_mixedhmc(model_3, num_samples, **kwargs) + + assert jnp.all( + (samples["x0"] >= 0) & (samples["x0"] <= 3) + ), f"Failed with {kwargs=}" + assert jnp.all( + (samples["x1"] >= 0) & (samples["x1"] <= 9) + ), f"Failed with {kwargs=}" + + def model_4(): + dist0 = dist.Categorical(0.25 * jnp.ones((3, 4))) + numpyro.sample("x0", dist0) + dist1 = dist.DiscreteUniform(10 * jnp.ones((3,)), 19 * jnp.ones((3,))) + numpyro.sample("x1", dist1) + + # Case 4: 1 categorical with different support lengths and batched by 3 + for kwargs in mixed_hmc_kwargs: + samples = sample_mixedhmc(model_4, num_samples, **kwargs) + + assert jnp.all( + (samples["x0"] >= 0) & (samples["x0"] <= 3) + ), f"Failed with {kwargs=}" + assert jnp.all( + (samples["x1"] >= 10) & (samples["x1"] <= 20) + ), f"Failed with {kwargs=}" + if __name__ == "__main__": test_discrete_uniform_with_mixedhmc() From b1f36dc71cd3fb0743f5cacb28c73ef2e66b9e94 Mon Sep 17 00:00:00 2001 From: Deathn0t Date: Mon, 9 Sep 2024 16:51:12 +0200 Subject: [PATCH 8/8] applying format and changes similar to PR 1859 --- numpyro/distributions/discrete.py | 6 +++--- test/test_distributions.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/numpyro/distributions/discrete.py b/numpyro/distributions/discrete.py index 7d7358a5d..537f21504 100644 --- a/numpyro/distributions/discrete.py +++ b/numpyro/distributions/discrete.py @@ -469,9 +469,9 @@ def enumerate_support(self, expand=True): raise NotImplementedError( "Inhomogeneous `high` not supported by `enumerate_support`." ) - values = (self.low + jnp.arange(np.amax(self.high - self.low) + 1)).reshape( - (-1,) + (1,) * len(self.batch_shape) - ) + low = jnp.reshape(self.low, -1)[0] + high = jnp.reshape(self.high, -1)[0] + values = jnp.arange(low, high + 1).reshape((-1,) + (1,) * len(self.batch_shape)) if expand: values = jnp.broadcast_to(values, values.shape[:1] + self.batch_shape) return values diff --git a/test/test_distributions.py b/test/test_distributions.py index abf4f6b39..c1f2e5f5b 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -3438,7 +3438,7 @@ def sample_mixedhmc(model_fn, num_samples, **kwargs): mcmc.run(key) samples = mcmc.get_samples() return samples - + num_samples = 1000 mixed_hmc_kwargs = [ {"random_walk": False, "modified": False}, @@ -3451,7 +3451,7 @@ def sample_mixedhmc(model_fn, num_samples, **kwargs): def model_1(): numpyro.sample("x0", dist.DiscreteUniform(10, 12)) numpyro.sample("x1", dist.Categorical(np.asarray([0.25, 0.25, 0.25, 0.25]))) - + for kwargs in mixed_hmc_kwargs: samples = sample_mixedhmc(model_1, num_samples, **kwargs)