Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DOC Add Predictive examples #1084

Merged
merged 8 commits into from
Jul 11, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ class MCMCKernel(ABC):
>>> kernel = MetropolisHastings(f)
>>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000)
>>> mcmc.run(random.PRNGKey(0), init_params=jnp.array([1., 2.]))
>>> samples = mcmc.get_samples()
>>> posterior_samples = mcmc.get_samples()
>>> predictive = Predictive(model, posterior_samples=posterior_samples)
>>> samples = predictive(rng_key1, *model_args, **model_kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we have a model here. I think you can add a small code snippet to get_samples method of MCMC.

>>> mcmc.print_summary() # doctest: +SKIP
"""

Expand Down Expand Up @@ -509,6 +511,8 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs):
"""
Run the MCMC samplers and collect samples.

See :class:`~numpyro.infer.util.Predictive` for how to use them to collect posterior predictive samples.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think get_samples is a better place for this. Also maybe changing them to those samples?


:param random.PRNGKey rng_key: Random number generator key to be used for the sampling.
For multi-chains, a batch of `num_chains` keys can be supplied. If `rng_key`
does not have batch_size, it will be split in to a batch of `num_chains` keys.
Expand Down
4 changes: 3 additions & 1 deletion numpyro/infer/svi.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class SVI(object):
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.distributions import constraints
>>> from numpyro.infer import SVI, Trace_ELBO
>>> from numpyro.infer import Predictive, SVI, Trace_ELBO

>>> def model(data):
... f = numpyro.sample("latent_fairness", dist.Beta(10, 10))
Expand All @@ -101,6 +101,8 @@ class SVI(object):
>>> svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
>>> svi_result = svi.run(random.PRNGKey(0), 2000, data)
>>> params = svi_result.params
>>> predictive = Predictive(model, guide=guide, params=params, num_samples=1000)
>>> samples = predictive(random.PRNGKey(0), data)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code does not return any useful information unless you set data=None (which will give you posterior predictive samples). If you want to get posterior samples, then you can move those codes to the end and add a comment

>>> # get posterior samples
>>> predictive = Predictive(guide, params=params, num_samples=1000)
>>> samples = predictive(random.PRNGKey(1), data)

>>> inferred_mean = params["alpha_q"] / (params["alpha_q"] + params["beta_q"])

:param model: Python callable with Pyro primitives for the model.
Expand Down
15 changes: 15 additions & 0 deletions numpyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,21 @@ class Predictive(object):
argument is not None, its value should be equal to `num_chains x N`.

:return: dict of samples from the predictive distribution.

**Example:**

Given a model, you can sample from the prior predictive:

>>> predictive = Predictive(model, num_samples=num_samples)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit num_samples=1000

>>> samples = predictive(rng_key1, *model_args, **model_kwargs)
Copy link
Member

@fehiepsi fehiepsi Jul 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: rng_key1 -> rng_key

To make it clearer that we need to set data=None, I think you can provide a signature for model:

Given a model

    def model(X, y=None):
        ...
        return numpyro.sample("obs", likelihood, obs=y)

you can sample from the prior predictive:

    predictive = Predictive(model, num_samples=1000)
    y_pred = predictive(rng_key, X)["obs"]

If you also have posterior samples, you can sample from the posterior predictive:

    predictive = Predictive(model, posterior_samples=posterior_samples)
    y_pred = predictive(rng_key, X)["obs"]


If you also have posterior samples, you can sample from the posterior predictive:

>>> predictive = Predictive(model, posterior_samples=posterior_samples)
>>> samples = predictive(rng_key1, *model_args, **model_kwargs)

See docstrings for :class:`~numpyro.infer.svi.SVI` and :class:`~numpyro.infer.mcmc.MCMCKernel`
to see example code of this in context.
"""

def __init__(
Expand Down