Skip to content

Commit

Permalink
Support infer_discrete for Predictive (#1086)
Browse files Browse the repository at this point in the history
* support infer_discrete for Predictive

* revise docs

* use infer_discrete_temperature

* use temperature=1 by default
  • Loading branch information
fehiepsi authored Jul 10, 2021
1 parent 1b517b0 commit 003424b
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 32 deletions.
22 changes: 3 additions & 19 deletions examples/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,9 @@

import numpyro
from numpyro import handlers
from numpyro.contrib.funsor import config_enumerate, infer_discrete
from numpyro.contrib.indexing import Vindex
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from numpyro.infer import MCMC, NUTS, Predictive
from numpyro.infer.reparam import LocScaleReparam


Expand Down Expand Up @@ -313,24 +312,9 @@ def main(args):
mcmc.run(random.PRNGKey(0), *data)
mcmc.print_summary()

def infer_discrete_model(rng_key, samples):
conditioned_model = handlers.condition(model, data=samples)
infer_discrete_model = infer_discrete(
config_enumerate(conditioned_model), rng_key=rng_key
)
with handlers.trace() as tr:
infer_discrete_model(*data)

return {
name: site["value"]
for name, site in tr.items()
if site["type"] == "sample" and site["infer"].get("enumerate") == "parallel"
}

posterior_samples = mcmc.get_samples()
discrete_samples = vmap(infer_discrete_model)(
random.split(random.PRNGKey(1), args.num_samples), posterior_samples
)
predictive = Predictive(model, posterior_samples, infer_discrete=True)
discrete_samples = predictive(random.PRNGKey(1), *data)

item_class = vmap(lambda x: jnp.bincount(x, length=4), in_axes=1)(
discrete_samples["c"].squeeze(-1)
Expand Down
15 changes: 10 additions & 5 deletions numpyro/contrib/funsor/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,7 @@ def _sample_posterior(
values = [v.reshape((-1,) + prototype_shape[1:]) for v in values]
data[root_name] = jnp.concatenate(values)

with substitute(data=data):
return model(*args, **kwargs)
return data


def infer_discrete(fn=None, first_available_dim=None, temperature=1, rng_key=None):
Expand Down Expand Up @@ -169,6 +168,12 @@ def viterbi_decoder(data, hidden_dim=10):
temperature=temperature,
rng_key=rng_key,
)
return functools.partial(
_sample_posterior, fn, first_available_dim, temperature, rng_key
)

def wrap_fn(*args, **kwargs):
samples = _sample_posterior(
fn, first_available_dim, temperature, rng_key, *args, **kwargs
)
with substitute(data=samples):
return fn(*args, **kwargs)

return wrap_fn
55 changes: 47 additions & 8 deletions numpyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
from jax import device_get, jacfwd, lax, random, value_and_grad
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
from jax.tree_util import tree_map

import numpyro
from numpyro.distributions import constraints
from numpyro.distributions.transforms import biject_to
from numpyro.distributions.util import is_identically_one, sum_rightmost
from numpyro.handlers import replay, seed, substitute, trace
from numpyro.handlers import condition, replay, seed, substitute, trace
from numpyro.infer.initialization import init_to_uniform, init_to_value
from numpyro.util import not_jax_tracer, soft_vmap, while_loop

Expand Down Expand Up @@ -673,17 +674,47 @@ def _predictive(
posterior_samples,
batch_shape,
return_sites=None,
infer_discrete=False,
parallel=True,
model_args=(),
model_kwargs={},
):
model = numpyro.handlers.mask(model, mask=False)
masked_model = numpyro.handlers.mask(model, mask=False)
if infer_discrete:
# inspect the model to get some structure
rng_key, subkey = random.split(rng_key)
batch_ndim = len(batch_shape)
prototype_sample = tree_map(
lambda x: jnp.reshape(x, (-1,) + jnp.shape(x)[batch_ndim:])[0],
posterior_samples,
)
prototype_trace = trace(
seed(substitute(masked_model, prototype_sample), subkey)
).get_trace(*model_args, **model_kwargs)
first_available_dim = -_guess_max_plate_nesting(prototype_trace) - 1

def single_prediction(val):
rng_key, samples = val
model_trace = trace(seed(substitute(model, samples), rng_key)).get_trace(
*model_args, **model_kwargs
)
if infer_discrete:
from numpyro.contrib.funsor import config_enumerate
from numpyro.contrib.funsor.discrete import _sample_posterior

model_trace = prototype_trace
temperature = 1
pred_samples = _sample_posterior(
config_enumerate(condition(model, samples)),
first_available_dim,
temperature,
rng_key,
*model_args,
**model_kwargs,
)
else:
model_trace = trace(
seed(substitute(masked_model, samples), rng_key)
).get_trace(*model_args, **model_kwargs)
pred_samples = {name: site["value"] for name, site in model_trace.items()}

if return_sites is not None:
if return_sites == "":
sites = {
Expand All @@ -698,9 +729,7 @@ def single_prediction(val):
if (site["type"] == "sample" and k not in samples)
or (site["type"] == "deterministic")
}
return {
name: site["value"] for name, site in model_trace.items() if name in sites
}
return {name: value for name, value in pred_samples.items() if name in sites}

num_samples = int(np.prod(batch_shape))
if num_samples > 1:
Expand Down Expand Up @@ -729,6 +758,12 @@ class Predictive(object):
:param int num_samples: number of samples
:param list return_sites: sites to return; by default only sample sites not present
in `posterior_samples` are returned.
:param bool infer_discrete: whether or not to sample discrete sites from the
posterior, conditioned on observations and other latent values in
``posterior_samples``. Under the hood, those sites will be marked with
``site["infer"]["enumerate"] = "parallel"``. See how `infer_discrete` works at
the `Pyro enumeration tutorial <https://pyro.ai/examples/enumeration.html>`_.
Note that this requires ``funsor`` installation.
:param bool parallel: whether to predict in parallel using JAX vectorized map :func:`jax.vmap`.
Defaults to False.
:param batch_ndims: the number of batch dimensions in posterior samples. Some usages:
Expand All @@ -749,10 +784,12 @@ def __init__(
self,
model,
posterior_samples=None,
*,
guide=None,
params=None,
num_samples=None,
return_sites=None,
infer_discrete=False,
parallel=False,
batch_ndims=1,
):
Expand Down Expand Up @@ -801,6 +838,7 @@ def __init__(
self.num_samples = num_samples
self.guide = guide
self.params = {} if params is None else params
self.infer_discrete = infer_discrete
self.return_sites = return_sites
self.parallel = parallel
self.batch_ndims = batch_ndims
Expand Down Expand Up @@ -838,6 +876,7 @@ def __call__(self, rng_key, *args, **kwargs):
posterior_samples,
self._batch_shape,
return_sites=self.return_sites,
infer_discrete=self.infer_discrete,
parallel=self.parallel,
model_args=args,
model_kwargs=kwargs,
Expand Down

0 comments on commit 003424b

Please sign in to comment.