-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TraceGraph_ELBO implementation using provenance tracking #1412
Merged
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
9568317
TraceGraph_ELBO using provenance tracking
2d50c94
update docstring
a1d0c5a
simplify
24c7b4a
add tests back
e70bc9a
misc
532e779
test costs
3c32150
simplify
21d578c
address comments
b751569
update docstring
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
# Copyright Contributors to the Pyro project. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from collections import defaultdict | ||
from functools import partial | ||
from operator import itemgetter | ||
import warnings | ||
|
||
|
@@ -11,8 +13,9 @@ | |
|
||
from numpyro.distributions.kl import kl_divergence | ||
from numpyro.distributions.util import scale_and_mask | ||
from numpyro.handlers import replay, seed, substitute, trace | ||
from numpyro.handlers import Messenger, replay, seed, substitute, trace | ||
from numpyro.infer.util import get_importance_trace, log_density | ||
from numpyro.ops.provenance import eval_provenance, get_provenance | ||
from numpyro.util import _validate_model, check_model_guide_match, find_stack_level | ||
|
||
|
||
|
@@ -526,24 +529,99 @@ def _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes): | |
return downstream_costs, downstream_guide_cost_nodes | ||
|
||
|
||
class track_nonreparam(Messenger): | ||
""" | ||
Track non-reparameterizable sample sites. Intended to be used with ``eval_provenance``. | ||
|
||
**References:** | ||
|
||
1. *Nonstandard Interpretations of Probabilistic Programs for Efficient Inference*, | ||
David Wingate, Noah Goodman, Andreas Stuhlmüller, Jeffrey Siskind | ||
|
||
**Example:** | ||
|
||
.. doctest:: | ||
|
||
>>> import jax.numpy as jnp | ||
>>> import numpyro | ||
>>> import numpyro.distributions as dist | ||
>>> from numpyro.infer.elbo import track_nonreparam | ||
>>> from numpyro.ops.provenance import eval_provenance, get_provenance | ||
>>> from numpyro.handlers import seed, trace | ||
|
||
>>> def model(): | ||
... probs_a = jnp.array([0.3, 0.7]) | ||
... probs_b = jnp.array([[0.1, 0.9], [0.8, 0.2]]) | ||
... probs_c = jnp.array([[0.5, 0.5], [0.6, 0.4]]) | ||
... a = numpyro.sample("a", dist.Categorical(probs_a)) | ||
... b = numpyro.sample("b", dist.Categorical(probs_b[a])) | ||
... numpyro.sample("c", dist.Categorical(probs_c[b]), obs=jnp.array(0)) | ||
|
||
>>> def get_log_probs(): | ||
... seeded_model = seed(model, rng_seed=0) | ||
... model_tr = trace(seeded_model).get_trace() | ||
... return { | ||
... name: site["fn"].log_prob(site["value"]) | ||
... for name, site in model_tr.items() | ||
... if site["type"] == "sample" | ||
... } | ||
|
||
>>> model_deps = get_provenance(eval_provenance(track_nonreparam(get_log_probs))) | ||
>>> print(model_deps) # doctest: +SKIP | ||
{'a': frozenset({'a'}), 'b': frozenset({'a', 'b'}), 'c': frozenset({'a', 'b'})} | ||
""" | ||
|
||
def postprocess_message(self, msg): | ||
if ( | ||
msg["type"] == "sample" | ||
and (not msg["is_observed"]) | ||
and (not msg["fn"].has_rsample) | ||
): | ||
new_provenance = frozenset({msg["name"]}) | ||
old_provenance = msg["value"].aval.named_shape.get( | ||
"_provenance", frozenset() | ||
) | ||
fehiepsi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
msg["value"].aval.named_shape["_provenance"] = ( | ||
old_provenance | new_provenance | ||
) | ||
|
||
|
||
def get_importance_log_probs(model, guide, args, kwargs, params): | ||
""" | ||
Returns log probabilities at each site for the guide and the model that is run against it. | ||
""" | ||
model_tr, guide_tr = get_importance_trace(model, guide, args, kwargs, params) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed since yesterday: Instead of manually tracing guide and then replaying/tracing model I changed this to just reuse the |
||
model_log_probs = { | ||
name: site["log_prob"] | ||
for name, site in model_tr.items() | ||
if site["type"] == "sample" | ||
} | ||
guide_log_probs = { | ||
name: site["log_prob"] | ||
for name, site in guide_tr.items() | ||
if site["type"] == "sample" | ||
} | ||
return model_log_probs, guide_log_probs | ||
|
||
|
||
class TraceGraph_ELBO(ELBO): | ||
""" | ||
A TraceGraph implementation of ELBO-based SVI. The gradient estimator | ||
is constructed along the lines of reference [1] specialized to the case | ||
of the ELBO. It supports arbitrary dependency structure for the model | ||
and guide. | ||
Where possible, conditional dependency information as recorded in the | ||
Fine-grained conditional dependency information as recorded in the | ||
trace is used to reduce the variance of the gradient estimator. | ||
In particular two kinds of conditional dependency information are | ||
used to reduce variance: | ||
|
||
- the sequential order of samples (z is sampled after y => y does not depend on z) | ||
- :class:`~numpyro.plate` generators | ||
In particular provenance tracking [2] is used to find the ``cost`` terms | ||
that depend on each non-reparameterizable sample site. | ||
|
||
References | ||
|
||
[1] `Gradient Estimation Using Stochastic Computation Graphs`, | ||
John Schulman, Nicolas Heess, Theophane Weber, Pieter Abbeel | ||
|
||
[2] `Nonstandard Interpretations of Probabilistic Programs for Efficient Inference`, | ||
David Wingate, Noah Goodman, Andreas Stuhlmüller, Jeffrey Siskind | ||
""" | ||
|
||
can_infer_discrete = True | ||
|
@@ -577,34 +655,51 @@ def single_particle_elbo(rng_key): | |
check_model_guide_match(model_trace, guide_trace) | ||
_validate_model(model_trace, plate_warning="strict") | ||
|
||
# XXX: different from Pyro, we don't support baseline_loss here | ||
non_reparam_nodes = { | ||
name | ||
for name, site in guide_trace.items() | ||
if site["type"] == "sample" | ||
and (not site["is_observed"]) | ||
and (not site["fn"].has_rsample) | ||
} | ||
if non_reparam_nodes: | ||
downstream_costs, _ = _compute_downstream_costs( | ||
model_trace, guide_trace, non_reparam_nodes | ||
# Find dependencies on non-reparameterizable sample sites for | ||
# each cost term in the model and the guide. | ||
model_deps, guide_deps = get_provenance( | ||
eval_provenance( | ||
partial( | ||
track_nonreparam(get_importance_log_probs), | ||
seeded_model, | ||
seeded_guide, | ||
args, | ||
kwargs, | ||
param_map, | ||
) | ||
fehiepsi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
) | ||
|
||
elbo = 0.0 | ||
for site in model_trace.values(): | ||
# mapping from non-reparameterizable sample sites to cost terms influenced by each of them | ||
downstream_costs = defaultdict(lambda: MultiFrameTensor()) | ||
for name, site in model_trace.items(): | ||
if site["type"] == "sample": | ||
elbo = elbo + jnp.sum(site["log_prob"]) | ||
# add the log_prob to each non-reparam sample site upstream | ||
for key in model_deps[name]: | ||
downstream_costs[key].add( | ||
(site["cond_indep_stack"], site["log_prob"]) | ||
) | ||
for name, site in guide_trace.items(): | ||
if site["type"] == "sample": | ||
log_prob_sum = jnp.sum(site["log_prob"]) | ||
if name in non_reparam_nodes: | ||
surrogate = jnp.sum( | ||
site["log_prob"] * stop_gradient(downstream_costs[name]) | ||
) | ||
log_prob_sum = ( | ||
stop_gradient(log_prob_sum + surrogate) - surrogate | ||
) | ||
if not site["fn"].has_rsample: | ||
log_prob_sum = stop_gradient(log_prob_sum) | ||
fehiepsi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
elbo = elbo - log_prob_sum | ||
# add the -log_prob to each non-reparam sample site upstream | ||
for key in guide_deps[name]: | ||
downstream_costs[key].add( | ||
(site["cond_indep_stack"], -site["log_prob"]) | ||
) | ||
|
||
for node, downstream_cost in downstream_costs.items(): | ||
guide_site = guide_trace[node] | ||
downstream_cost = downstream_cost.sum_to(guide_site["cond_indep_stack"]) | ||
surrogate = jnp.sum( | ||
guide_site["log_prob"] * stop_gradient(downstream_cost) | ||
) | ||
elbo = elbo + surrogate - stop_gradient(surrogate) | ||
|
||
return elbo | ||
|
||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a comment that this is intended to be used in
eval_provenance