diff --git a/numpyro/contrib/funsor/infer_util.py b/numpyro/contrib/funsor/infer_util.py index ad2e40963..5ee0af278 100644 --- a/numpyro/contrib/funsor/infer_util.py +++ b/numpyro/contrib/funsor/infer_util.py @@ -42,6 +42,17 @@ def plate_to_enum_plate(): numpyro.plate.__new__ = lambda *args, **kwargs: object.__new__(numpyro.plate) +def _config_enumerate_fn(site, default): + """helper function used internally in config_enumerate""" + if ( + site["type"] == "sample" + and (not site["is_observed"]) + and site["fn"].has_enumerate_support + ): + return {"enumerate": site["infer"].get("enumerate", default)} + return {} + + def config_enumerate(fn=None, default="parallel"): """ Configures enumeration for all relevant sites in a NumPyro model. @@ -69,16 +80,18 @@ def model(*args, **kwargs): if fn is None: # support use as a decorator return functools.partial(config_enumerate, default=default) - def config_fn(site): - if ( - site["type"] == "sample" - and (not site["is_observed"]) - and site["fn"].has_enumerate_support - ): - return {"enumerate": site["infer"].get("enumerate", default)} - return {} + return infer_config(fn, functools.partial(_config_enumerate_fn, default=default)) - return infer_config(fn, config_fn) + +def _config_kl_fn(site, sites): + """helper function used internally in config_kl""" + if ( + site["type"] == "sample" + and (not site["is_observed"]) + and (sites is None or site["name"] in sites) + ): + return {"kl": site["infer"].get("kl", "analytic")} + return {} def config_kl(fn=None, sites=None): @@ -107,16 +120,7 @@ def model(*args, **kwargs): if fn is None: # support use as a decorator return functools.partial(config_kl, sites=sites) - def config_fn(site): - if ( - site["type"] == "sample" - and (not site["is_observed"]) - and (sites is None or site["name"] in sites) - ): - return {"kl": site["infer"].get("kl", "analytic")} - return {} - - return infer_config(fn, config_fn) + return infer_config(fn, functools.partial(_config_kl_fn, sites=sites)) def _get_shift(name): @@ -225,7 +229,8 @@ def _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op): if name.startswith("_time"): time_dim = funsor.Variable(name, funsor.Bint[log_prob.shape[dim]]) history = max( - history, max(_get_shift(s) for s in dim_to_name.values()) + history, + max(_get_shift(s) for s in dim_to_name.values()), ) if history == 0: log_factors.append(log_prob_factor) @@ -282,7 +287,8 @@ def _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op): raise ValueError( "Expected the joint log density is a scalar, but got {}. " "There seems to be something wrong at the following sites: {}.".format( - result.data.shape, {k.split("__BOUND")[0] for k in result.inputs} + result.data.shape, + {k.split("__BOUND")[0] for k in result.inputs}, ) ) return result, model_trace, log_measures @@ -310,6 +316,11 @@ def model(*args, **kwargs): :return: log of joint density and a corresponding model trace """ result, model_trace, _ = _enum_log_density( - model, model_args, model_kwargs, params, funsor.ops.logaddexp, funsor.ops.add + model, + model_args, + model_kwargs, + params, + funsor.ops.logaddexp, + funsor.ops.add, ) return result.data, model_trace diff --git a/test/test_pickle.py b/test/test_pickle.py index af908f79f..a54479be2 100644 --- a/test/test_pickle.py +++ b/test/test_pickle.py @@ -12,7 +12,9 @@ from jax.tree_util import tree_all, tree_map import numpyro +from numpyro.contrib.funsor import config_kl import numpyro.distributions as dist +from numpyro.distributions import constraints from numpyro.distributions.constraints import ( boolean, circular, @@ -49,6 +51,7 @@ DiscreteHMCGibbs, MixedHMC, Predictive, + TraceEnum_ELBO, ) from numpyro.infer.autoguide import AutoDelta, AutoDiagonalNormal, AutoNormal @@ -69,6 +72,19 @@ def logistic_regression(): numpyro.sample("obs", dist.Bernoulli(logits=x), obs=batch) +def gmm(data, K): + mix_proportions = numpyro.sample("phi", dist.Dirichlet(jnp.ones(K))) + with numpyro.plate("num_clusters", K, dim=-1): + cluster_means = numpyro.sample("cluster_means", dist.Normal(jnp.arange(K), 1.0)) + with numpyro.plate("data", data.shape[0], dim=-1): + assignments = numpyro.sample( + "assignments", + dist.Categorical(mix_proportions), + infer={"enumerate": "parallel"}, + ) + numpyro.sample("obs", dist.Normal(cluster_means[assignments], 1.0), obs=data) + + @pytest.mark.parametrize("kernel", [BarkerMH, HMC, NUTS, SA]) def test_pickle_hmc(kernel): mcmc = MCMC(kernel(normal_model), num_warmup=10, num_samples=10) @@ -77,6 +93,24 @@ def test_pickle_hmc(kernel): tree_all(tree_map(assert_allclose, mcmc.get_samples(), pickled_mcmc.get_samples())) +@pytest.mark.parametrize("kernel", [BarkerMH, HMC, NUTS, SA]) +def test_pickle_hmc_enumeration(kernel): + K, N = 3, 1000 + + true_cluster_means = jnp.array([1.0, 5.0, 10.0]) + true_mix_proportions = jnp.array([0.1, 0.3, 0.6]) + cluster_assignments = dist.Categorical(true_mix_proportions).sample( + random.PRNGKey(0), (N,) + ) + data = dist.Normal(true_cluster_means[cluster_assignments], 1.0).sample( + random.PRNGKey(1) + ) + mcmc = MCMC(kernel(gmm), num_warmup=10, num_samples=10) + mcmc.run(random.PRNGKey(0), data, K) + pickled_mcmc = pickle.loads(pickle.dumps(mcmc)) + tree_all(tree_map(assert_allclose, mcmc.get_samples(), pickled_mcmc.get_samples())) + + @pytest.mark.parametrize("kernel", [DiscreteHMCGibbs, MixedHMC]) def test_pickle_discrete_hmc(kernel): mcmc = MCMC(kernel(HMC(bernoulli_model)), num_warmup=10, num_samples=10) @@ -176,3 +210,30 @@ def test_mcmc_pickle_post_warmup(): pickled_mcmc = pickle.loads(pickle.dumps(mcmc)) pickled_mcmc.post_warmup_state = pickled_mcmc.last_state pickled_mcmc.run(random.PRNGKey(1)) + + +def bernoulli_regression(data): + f = numpyro.sample("beta", dist.Beta(1.0, 1.0)) + with numpyro.plate("N", len(data)): + numpyro.sample("obs", dist.Bernoulli(f), obs=data) + + +def test_beta_bernoulli(): + data = jnp.array([1.0] * 8 + [0.0] * 2) + + def guide(data): + alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive) + beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive) + numpyro.sample("beta", dist.Beta(alpha_q, beta_q)) + + pickled_model = pickle.loads(pickle.dumps(config_kl(bernoulli_regression))) + optim = numpyro.optim.Adam(1e-2) + svi = SVI(config_kl(bernoulli_regression), guide, optim, TraceEnum_ELBO()) + svi_result = svi.run(random.PRNGKey(0), 3, data) + params = svi_result.params + + svi = SVI(pickled_model, guide, optim, TraceEnum_ELBO()) + svi_result = svi.run(random.PRNGKey(0), 3, data) + pickled_params = svi_result.params + + tree_all(tree_map(assert_allclose, params, pickled_params))