From 003424bb3c57e44b433991cc73ddbb557bf31f3c Mon Sep 17 00:00:00 2001 From: Du Phan Date: Sat, 10 Jul 2021 15:35:14 -0400 Subject: [PATCH] Support infer_discrete for Predictive (#1086) * support infer_discrete for Predictive * revise docs * use infer_discrete_temperature * use temperature=1 by default --- examples/annotation.py | 22 ++---------- numpyro/contrib/funsor/discrete.py | 15 +++++--- numpyro/infer/util.py | 55 +++++++++++++++++++++++++----- 3 files changed, 60 insertions(+), 32 deletions(-) diff --git a/examples/annotation.py b/examples/annotation.py index 6b7ada33e..264d14ac7 100644 --- a/examples/annotation.py +++ b/examples/annotation.py @@ -42,10 +42,9 @@ import numpyro from numpyro import handlers -from numpyro.contrib.funsor import config_enumerate, infer_discrete from numpyro.contrib.indexing import Vindex import numpyro.distributions as dist -from numpyro.infer import MCMC, NUTS +from numpyro.infer import MCMC, NUTS, Predictive from numpyro.infer.reparam import LocScaleReparam @@ -313,24 +312,9 @@ def main(args): mcmc.run(random.PRNGKey(0), *data) mcmc.print_summary() - def infer_discrete_model(rng_key, samples): - conditioned_model = handlers.condition(model, data=samples) - infer_discrete_model = infer_discrete( - config_enumerate(conditioned_model), rng_key=rng_key - ) - with handlers.trace() as tr: - infer_discrete_model(*data) - - return { - name: site["value"] - for name, site in tr.items() - if site["type"] == "sample" and site["infer"].get("enumerate") == "parallel" - } - posterior_samples = mcmc.get_samples() - discrete_samples = vmap(infer_discrete_model)( - random.split(random.PRNGKey(1), args.num_samples), posterior_samples - ) + predictive = Predictive(model, posterior_samples, infer_discrete=True) + discrete_samples = predictive(random.PRNGKey(1), *data) item_class = vmap(lambda x: jnp.bincount(x, length=4), in_axes=1)( discrete_samples["c"].squeeze(-1) diff --git a/numpyro/contrib/funsor/discrete.py b/numpyro/contrib/funsor/discrete.py index 72767ff84..36462e28a 100644 --- a/numpyro/contrib/funsor/discrete.py +++ b/numpyro/contrib/funsor/discrete.py @@ -118,8 +118,7 @@ def _sample_posterior( values = [v.reshape((-1,) + prototype_shape[1:]) for v in values] data[root_name] = jnp.concatenate(values) - with substitute(data=data): - return model(*args, **kwargs) + return data def infer_discrete(fn=None, first_available_dim=None, temperature=1, rng_key=None): @@ -169,6 +168,12 @@ def viterbi_decoder(data, hidden_dim=10): temperature=temperature, rng_key=rng_key, ) - return functools.partial( - _sample_posterior, fn, first_available_dim, temperature, rng_key - ) + + def wrap_fn(*args, **kwargs): + samples = _sample_posterior( + fn, first_available_dim, temperature, rng_key, *args, **kwargs + ) + with substitute(data=samples): + return fn(*args, **kwargs) + + return wrap_fn diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py index b4aee6ad5..d3c41187a 100644 --- a/numpyro/infer/util.py +++ b/numpyro/infer/util.py @@ -11,12 +11,13 @@ from jax import device_get, jacfwd, lax, random, value_and_grad from jax.flatten_util import ravel_pytree import jax.numpy as jnp +from jax.tree_util import tree_map import numpyro from numpyro.distributions import constraints from numpyro.distributions.transforms import biject_to from numpyro.distributions.util import is_identically_one, sum_rightmost -from numpyro.handlers import replay, seed, substitute, trace +from numpyro.handlers import condition, replay, seed, substitute, trace from numpyro.infer.initialization import init_to_uniform, init_to_value from numpyro.util import not_jax_tracer, soft_vmap, while_loop @@ -673,17 +674,47 @@ def _predictive( posterior_samples, batch_shape, return_sites=None, + infer_discrete=False, parallel=True, model_args=(), model_kwargs={}, ): - model = numpyro.handlers.mask(model, mask=False) + masked_model = numpyro.handlers.mask(model, mask=False) + if infer_discrete: + # inspect the model to get some structure + rng_key, subkey = random.split(rng_key) + batch_ndim = len(batch_shape) + prototype_sample = tree_map( + lambda x: jnp.reshape(x, (-1,) + jnp.shape(x)[batch_ndim:])[0], + posterior_samples, + ) + prototype_trace = trace( + seed(substitute(masked_model, prototype_sample), subkey) + ).get_trace(*model_args, **model_kwargs) + first_available_dim = -_guess_max_plate_nesting(prototype_trace) - 1 def single_prediction(val): rng_key, samples = val - model_trace = trace(seed(substitute(model, samples), rng_key)).get_trace( - *model_args, **model_kwargs - ) + if infer_discrete: + from numpyro.contrib.funsor import config_enumerate + from numpyro.contrib.funsor.discrete import _sample_posterior + + model_trace = prototype_trace + temperature = 1 + pred_samples = _sample_posterior( + config_enumerate(condition(model, samples)), + first_available_dim, + temperature, + rng_key, + *model_args, + **model_kwargs, + ) + else: + model_trace = trace( + seed(substitute(masked_model, samples), rng_key) + ).get_trace(*model_args, **model_kwargs) + pred_samples = {name: site["value"] for name, site in model_trace.items()} + if return_sites is not None: if return_sites == "": sites = { @@ -698,9 +729,7 @@ def single_prediction(val): if (site["type"] == "sample" and k not in samples) or (site["type"] == "deterministic") } - return { - name: site["value"] for name, site in model_trace.items() if name in sites - } + return {name: value for name, value in pred_samples.items() if name in sites} num_samples = int(np.prod(batch_shape)) if num_samples > 1: @@ -729,6 +758,12 @@ class Predictive(object): :param int num_samples: number of samples :param list return_sites: sites to return; by default only sample sites not present in `posterior_samples` are returned. + :param bool infer_discrete: whether or not to sample discrete sites from the + posterior, conditioned on observations and other latent values in + ``posterior_samples``. Under the hood, those sites will be marked with + ``site["infer"]["enumerate"] = "parallel"``. See how `infer_discrete` works at + the `Pyro enumeration tutorial `_. + Note that this requires ``funsor`` installation. :param bool parallel: whether to predict in parallel using JAX vectorized map :func:`jax.vmap`. Defaults to False. :param batch_ndims: the number of batch dimensions in posterior samples. Some usages: @@ -749,10 +784,12 @@ def __init__( self, model, posterior_samples=None, + *, guide=None, params=None, num_samples=None, return_sites=None, + infer_discrete=False, parallel=False, batch_ndims=1, ): @@ -801,6 +838,7 @@ def __init__( self.num_samples = num_samples self.guide = guide self.params = {} if params is None else params + self.infer_discrete = infer_discrete self.return_sites = return_sites self.parallel = parallel self.batch_ndims = batch_ndims @@ -838,6 +876,7 @@ def __call__(self, rng_key, *args, **kwargs): posterior_samples, self._batch_shape, return_sites=self.return_sites, + infer_discrete=self.infer_discrete, parallel=self.parallel, model_args=args, model_kwargs=kwargs,