-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DOC Add Predictive examples #1084
Changes from 3 commits
65e858b
865027c
c10e4f3
a5e87f4
995796a
4efe6fa
984073e
1512f67
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -64,7 +64,9 @@ class MCMCKernel(ABC): | |
>>> kernel = MetropolisHastings(f) | ||
>>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=1000) | ||
>>> mcmc.run(random.PRNGKey(0), init_params=jnp.array([1., 2.])) | ||
>>> samples = mcmc.get_samples() | ||
>>> posterior_samples = mcmc.get_samples() | ||
>>> predictive = Predictive(model, posterior_samples=posterior_samples) | ||
>>> samples = predictive(rng_key1, *model_args, **model_kwargs) | ||
>>> mcmc.print_summary() # doctest: +SKIP | ||
""" | ||
|
||
|
@@ -509,6 +511,8 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs): | |
""" | ||
Run the MCMC samplers and collect samples. | ||
|
||
See :class:`~numpyro.infer.util.Predictive` for how to use them to collect posterior predictive samples. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think |
||
|
||
:param random.PRNGKey rng_key: Random number generator key to be used for the sampling. | ||
For multi-chains, a batch of `num_chains` keys can be supplied. If `rng_key` | ||
does not have batch_size, it will be split in to a batch of `num_chains` keys. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -83,7 +83,7 @@ class SVI(object): | |
>>> import numpyro | ||
>>> import numpyro.distributions as dist | ||
>>> from numpyro.distributions import constraints | ||
>>> from numpyro.infer import SVI, Trace_ELBO | ||
>>> from numpyro.infer import Predictive, SVI, Trace_ELBO | ||
|
||
>>> def model(data): | ||
... f = numpyro.sample("latent_fairness", dist.Beta(10, 10)) | ||
|
@@ -101,6 +101,8 @@ class SVI(object): | |
>>> svi = SVI(model, guide, optimizer, loss=Trace_ELBO()) | ||
>>> svi_result = svi.run(random.PRNGKey(0), 2000, data) | ||
>>> params = svi_result.params | ||
>>> predictive = Predictive(model, guide=guide, params=params, num_samples=1000) | ||
>>> samples = predictive(random.PRNGKey(0), data) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This code does not return any useful information unless you set >>> # get posterior samples
>>> predictive = Predictive(guide, params=params, num_samples=1000)
>>> samples = predictive(random.PRNGKey(1), data) |
||
>>> inferred_mean = params["alpha_q"] / (params["alpha_q"] + params["beta_q"]) | ||
|
||
:param model: Python callable with Pyro primitives for the model. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -743,6 +743,21 @@ class Predictive(object): | |
argument is not None, its value should be equal to `num_chains x N`. | ||
|
||
:return: dict of samples from the predictive distribution. | ||
|
||
**Example:** | ||
|
||
Given a model, you can sample from the prior predictive: | ||
|
||
>>> predictive = Predictive(model, num_samples=num_samples) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit |
||
>>> samples = predictive(rng_key1, *model_args, **model_kwargs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: rng_key1 -> rng_key To make it clearer that we need to set
|
||
|
||
If you also have posterior samples, you can sample from the posterior predictive: | ||
|
||
>>> predictive = Predictive(model, posterior_samples=posterior_samples) | ||
>>> samples = predictive(rng_key1, *model_args, **model_kwargs) | ||
|
||
See docstrings for :class:`~numpyro.infer.svi.SVI` and :class:`~numpyro.infer.mcmc.MCMCKernel` | ||
to see example code of this in context. | ||
""" | ||
|
||
def __init__( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we have a model here. I think you can add a small code snippet to
get_samples
method ofMCMC
.