Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support pickling MCMC objects with enumeration #1577

Merged
merged 4 commits into from
Apr 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 33 additions & 22 deletions numpyro/contrib/funsor/infer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@ def plate_to_enum_plate():
numpyro.plate.__new__ = lambda *args, **kwargs: object.__new__(numpyro.plate)


def _config_enumerate_fn(site, default):
"""helper function used internally in config_enumerate"""
if (
site["type"] == "sample"
and (not site["is_observed"])
and site["fn"].has_enumerate_support
):
return {"enumerate": site["infer"].get("enumerate", default)}
return {}


def config_enumerate(fn=None, default="parallel"):
"""
Configures enumeration for all relevant sites in a NumPyro model.
Expand Down Expand Up @@ -69,16 +80,18 @@ def model(*args, **kwargs):
if fn is None: # support use as a decorator
return functools.partial(config_enumerate, default=default)

def config_fn(site):
if (
site["type"] == "sample"
and (not site["is_observed"])
and site["fn"].has_enumerate_support
):
return {"enumerate": site["infer"].get("enumerate", default)}
return {}
return infer_config(fn, functools.partial(_config_enumerate_fn, default=default))

return infer_config(fn, config_fn)

def _config_kl_fn(site, sites):
"""helper function used internally in config_kl"""
if (
site["type"] == "sample"
and (not site["is_observed"])
and (sites is None or site["name"] in sites)
):
return {"kl": site["infer"].get("kl", "analytic")}
return {}


def config_kl(fn=None, sites=None):
Expand Down Expand Up @@ -107,16 +120,7 @@ def model(*args, **kwargs):
if fn is None: # support use as a decorator
return functools.partial(config_kl, sites=sites)

def config_fn(site):
if (
site["type"] == "sample"
and (not site["is_observed"])
and (sites is None or site["name"] in sites)
):
return {"kl": site["infer"].get("kl", "analytic")}
return {}

return infer_config(fn, config_fn)
return infer_config(fn, functools.partial(_config_kl_fn, sites=sites))


def _get_shift(name):
Expand Down Expand Up @@ -225,7 +229,8 @@ def _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op):
if name.startswith("_time"):
time_dim = funsor.Variable(name, funsor.Bint[log_prob.shape[dim]])
history = max(
history, max(_get_shift(s) for s in dim_to_name.values())
history,
max(_get_shift(s) for s in dim_to_name.values()),
)
if history == 0:
log_factors.append(log_prob_factor)
Expand Down Expand Up @@ -282,7 +287,8 @@ def _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op):
raise ValueError(
"Expected the joint log density is a scalar, but got {}. "
"There seems to be something wrong at the following sites: {}.".format(
result.data.shape, {k.split("__BOUND")[0] for k in result.inputs}
result.data.shape,
{k.split("__BOUND")[0] for k in result.inputs},
)
)
return result, model_trace, log_measures
Expand Down Expand Up @@ -310,6 +316,11 @@ def model(*args, **kwargs):
:return: log of joint density and a corresponding model trace
"""
result, model_trace, _ = _enum_log_density(
model, model_args, model_kwargs, params, funsor.ops.logaddexp, funsor.ops.add
model,
model_args,
model_kwargs,
params,
funsor.ops.logaddexp,
funsor.ops.add,
)
return result.data, model_trace
61 changes: 61 additions & 0 deletions test/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from jax.tree_util import tree_all, tree_map

import numpyro
from numpyro.contrib.funsor import config_kl
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.distributions.constraints import (
boolean,
circular,
Expand Down Expand Up @@ -49,6 +51,7 @@
DiscreteHMCGibbs,
MixedHMC,
Predictive,
TraceEnum_ELBO,
)
from numpyro.infer.autoguide import AutoDelta, AutoDiagonalNormal, AutoNormal

Expand All @@ -69,6 +72,19 @@ def logistic_regression():
numpyro.sample("obs", dist.Bernoulli(logits=x), obs=batch)


def gmm(data, K):
mix_proportions = numpyro.sample("phi", dist.Dirichlet(jnp.ones(K)))
with numpyro.plate("num_clusters", K, dim=-1):
cluster_means = numpyro.sample("cluster_means", dist.Normal(jnp.arange(K), 1.0))
with numpyro.plate("data", data.shape[0], dim=-1):
assignments = numpyro.sample(
"assignments",
dist.Categorical(mix_proportions),
infer={"enumerate": "parallel"},
)
numpyro.sample("obs", dist.Normal(cluster_means[assignments], 1.0), obs=data)


@pytest.mark.parametrize("kernel", [BarkerMH, HMC, NUTS, SA])
def test_pickle_hmc(kernel):
mcmc = MCMC(kernel(normal_model), num_warmup=10, num_samples=10)
Expand All @@ -77,6 +93,24 @@ def test_pickle_hmc(kernel):
tree_all(tree_map(assert_allclose, mcmc.get_samples(), pickled_mcmc.get_samples()))


@pytest.mark.parametrize("kernel", [BarkerMH, HMC, NUTS, SA])
def test_pickle_hmc_enumeration(kernel):
K, N = 3, 1000

true_cluster_means = jnp.array([1.0, 5.0, 10.0])
true_mix_proportions = jnp.array([0.1, 0.3, 0.6])
cluster_assignments = dist.Categorical(true_mix_proportions).sample(
random.PRNGKey(0), (N,)
)
data = dist.Normal(true_cluster_means[cluster_assignments], 1.0).sample(
random.PRNGKey(1)
)
mcmc = MCMC(kernel(gmm), num_warmup=10, num_samples=10)
mcmc.run(random.PRNGKey(0), data, K)
pickled_mcmc = pickle.loads(pickle.dumps(mcmc))
tree_all(tree_map(assert_allclose, mcmc.get_samples(), pickled_mcmc.get_samples()))


@pytest.mark.parametrize("kernel", [DiscreteHMCGibbs, MixedHMC])
def test_pickle_discrete_hmc(kernel):
mcmc = MCMC(kernel(HMC(bernoulli_model)), num_warmup=10, num_samples=10)
Expand Down Expand Up @@ -176,3 +210,30 @@ def test_mcmc_pickle_post_warmup():
pickled_mcmc = pickle.loads(pickle.dumps(mcmc))
pickled_mcmc.post_warmup_state = pickled_mcmc.last_state
pickled_mcmc.run(random.PRNGKey(1))


def bernoulli_regression(data):
f = numpyro.sample("beta", dist.Beta(1.0, 1.0))
with numpyro.plate("N", len(data)):
numpyro.sample("obs", dist.Bernoulli(f), obs=data)


def test_beta_bernoulli():
data = jnp.array([1.0] * 8 + [0.0] * 2)

def guide(data):
alpha_q = numpyro.param("alpha_q", 1.0, constraint=constraints.positive)
beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive)
numpyro.sample("beta", dist.Beta(alpha_q, beta_q))

pickled_model = pickle.loads(pickle.dumps(config_kl(bernoulli_regression)))
optim = numpyro.optim.Adam(1e-2)
svi = SVI(config_kl(bernoulli_regression), guide, optim, TraceEnum_ELBO())
svi_result = svi.run(random.PRNGKey(0), 3, data)
params = svi_result.params

svi = SVI(pickled_model, guide, optim, TraceEnum_ELBO())
svi_result = svi.run(random.PRNGKey(0), 3, data)
pickled_params = svi_result.params

tree_all(tree_map(assert_allclose, params, pickled_params))