Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Predictive: better docs for setting obs to None #1845

Closed
tomwallis opened this issue Aug 7, 2024 · 4 comments
Closed

Predictive: better docs for setting obs to None #1845

tomwallis opened this issue Aug 7, 2024 · 4 comments

Comments

@tomwallis
Copy link
Contributor

Hello,

It wasn't clear to me until I dug into some strange behaviour that the Predictive method requires setting obs to None. I see now that the examples do it, but it's not made explicit that this is required for the class to work as intended. Could this be made more explicit in the docs? Thanks in advance!

@fehiepsi
Copy link
Member

fehiepsi commented Aug 7, 2024

You are right. It is not trivial to realize that api. Do you want to make a PR to enhance the docs?

@tomwallis
Copy link
Contributor Author

I can try. To do this, should I build the docs locally, or are the changes to the source files sufficient?

@tillahoffmann
Copy link
Contributor

Expanding the documentation in the source file linked below should do the trick. You don't need to build the documentation locally, but it can be helpful for checking that the output looks like what you intended.

class Predictive(object):
"""
This class is used to construct predictive distribution. The predictive distribution is obtained
by running model conditioned on latent samples from `posterior_samples`.
.. warning::
The interface for the `Predictive` class is experimental, and
might change in the future.
:param model: Python callable containing Pyro primitives.
:param dict posterior_samples: dictionary of samples from the posterior.
:param callable guide: optional guide to get posterior samples of sites not present
in `posterior_samples`.
:param dict params: dictionary of values for param sites of model/guide.
:param int num_samples: number of samples
:param list return_sites: sites to return; by default only sample sites not present
in `posterior_samples` are returned.
:param bool infer_discrete: whether or not to sample discrete sites from the
posterior, conditioned on observations and other latent values in
``posterior_samples``. Under the hood, those sites will be marked with
``site["infer"]["enumerate"] = "parallel"``. See how `infer_discrete` works at
the `Pyro enumeration tutorial <https://pyro.ai/examples/enumeration.html>`_.
Note that this requires ``funsor`` installation.
:param bool parallel: whether to predict in parallel using JAX vectorized map :func:`jax.vmap`.
Defaults to False.
:param batch_ndims: the number of batch dimensions in posterior samples or parameters. If `None` defaults
to 0 if guide is set (i.e. not `None`) and 1 otherwise. Usages for batched posterior samples:
+ set `batch_ndims=0` to get prediction for 1 single sample
+ set `batch_ndims=1` to get prediction for `posterior_samples`
with shapes `(num_samples x ...)` (same as`batch_ndims=None` with `guide=None`)
+ set `batch_ndims=2` to get prediction for `posterior_samples`
with shapes `(num_chains x N x ...)`. Note that if `num_samples`
argument is not None, its value should be equal to `num_chains x N`.
Usages for batched parameters:
+ set `batch_ndims=0` to get 1 sample from the guide and parameters (same as `batch_ndims=None` with guide)
+ set `batch_ndims=1` to get predictions from a one dimensional batch of the guide and parameters
with shapes `(num_samples x batch_size x ...)`
:param exclude_deterministic: indicates whether to ignore deterministic sites from the posterior samples.
:return: dict of samples from the predictive distribution.
**Example:**
Given a model::
def model(X, y=None):
...
return numpyro.sample("obs", likelihood, obs=y)
you can sample from the prior predictive::
predictive = Predictive(model, num_samples=1000)
y_pred = predictive(rng_key, X)["obs"]
If you also have posterior samples, you can sample from the posterior predictive::
predictive = Predictive(model, posterior_samples=posterior_samples)
y_pred = predictive(rng_key, X)["obs"]
See docstrings for :class:`~numpyro.infer.svi.SVI` and :class:`~numpyro.infer.mcmc.MCMCKernel`
to see example code of this in context.
"""

@tomwallis
Copy link
Contributor Author

Thanks @tillahoffmann. I have created a PR that hopefully does the trick. I was not able to build the docs locally without more time investment, so I have not checked the intended output.

@fehiepsi fehiepsi closed this as completed Dec 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants