Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TraceGraph_ELBO implementation using provenance tracking #1412

Merged
merged 9 commits into from
May 31, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 81 additions & 26 deletions numpyro/infer/elbo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import defaultdict
from functools import partial
from operator import itemgetter
import warnings

Expand All @@ -11,8 +13,9 @@

from numpyro.distributions.kl import kl_divergence
from numpyro.distributions.util import scale_and_mask
from numpyro.handlers import replay, seed, substitute, trace
from numpyro.handlers import Messenger, replay, seed, substitute, trace
from numpyro.infer.util import get_importance_trace, log_density
from numpyro.ops.provenance import eval_provenance, get_provenance
from numpyro.util import _validate_model, check_model_guide_match, find_stack_level


Expand Down Expand Up @@ -526,24 +529,59 @@ def _compute_downstream_costs(model_trace, guide_trace, non_reparam_nodes):
return downstream_costs, downstream_guide_cost_nodes


class track_nonreparam(Messenger):
def postprocess_message(self, msg):
# track non-reparameterizable sample sites
if (
msg["type"] == "sample"
and (not msg["is_observed"])
and (not msg["fn"].has_rsample)
):
new_provenance = frozenset({msg["name"]})
old_provenance = msg["value"].aval.named_shape.get(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment that this is intended to be used in eval_provenance

"_provenance", frozenset()
)
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved
msg["value"].aval.named_shape["_provenance"] = (
old_provenance | new_provenance
)


def get_importance_log_probs(model, guide, args, kwargs, params):
"""
Returns log probabilities at each site for the guide and the model that is run against it.
"""
model_tr, guide_tr = get_importance_trace(model, guide, args, kwargs, params)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed since yesterday: Instead of manually tracing guide and then replaying/tracing model I changed this to just reuse the get_importance_trace function.

model_log_probs = {
name: site["log_prob"]
for name, site in model_tr.items()
if site["type"] == "sample"
}
guide_log_probs = {
name: site["log_prob"]
for name, site in guide_tr.items()
if site["type"] == "sample"
}
return model_log_probs, guide_log_probs


class TraceGraph_ELBO(ELBO):
"""
A TraceGraph implementation of ELBO-based SVI. The gradient estimator
is constructed along the lines of reference [1] specialized to the case
of the ELBO. It supports arbitrary dependency structure for the model
and guide.
Where possible, conditional dependency information as recorded in the
Fine-grained conditional dependency information as recorded in the
trace is used to reduce the variance of the gradient estimator.
In particular two kinds of conditional dependency information are
used to reduce variance:

- the sequential order of samples (z is sampled after y => y does not depend on z)
- :class:`~numpyro.plate` generators
In particular provenance tracking [2] is used to find the ``cost`` terms
that depend on each non-reparameterizable sample site.

References

[1] `Gradient Estimation Using Stochastic Computation Graphs`,
John Schulman, Nicolas Heess, Theophane Weber, Pieter Abbeel

[2] `Nonstandard Interpretations of Probabilistic Programs for Efficient Inference`,
David Wingate, Noah Goodman, Andreas Stuhlmüller, Jeffrey Siskind
"""

can_infer_discrete = True
Expand Down Expand Up @@ -577,34 +615,51 @@ def single_particle_elbo(rng_key):
check_model_guide_match(model_trace, guide_trace)
_validate_model(model_trace, plate_warning="strict")

# XXX: different from Pyro, we don't support baseline_loss here
non_reparam_nodes = {
name
for name, site in guide_trace.items()
if site["type"] == "sample"
and (not site["is_observed"])
and (not site["fn"].has_rsample)
}
if non_reparam_nodes:
downstream_costs, _ = _compute_downstream_costs(
model_trace, guide_trace, non_reparam_nodes
# Find dependencies on non-reparameterizable sample sites for
# each cost term in the model and the guide.
model_deps, guide_deps = get_provenance(
eval_provenance(
partial(
track_nonreparam(get_importance_log_probs),
seeded_model,
seeded_guide,
args,
kwargs,
param_map,
)
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved
)
)

elbo = 0.0
for site in model_trace.values():
# mapping from non-reparameterizable sample sites to cost terms influenced by each of them
downstream_costs = defaultdict(lambda: MultiFrameTensor())
for name, site in model_trace.items():
if site["type"] == "sample":
elbo = elbo + jnp.sum(site["log_prob"])
# add the log_prob to each non-reparam sample site upstream
for key in model_deps[name]:
downstream_costs[key].add(
(site["cond_indep_stack"], site["log_prob"])
)
for name, site in guide_trace.items():
if site["type"] == "sample":
log_prob_sum = jnp.sum(site["log_prob"])
if name in non_reparam_nodes:
surrogate = jnp.sum(
site["log_prob"] * stop_gradient(downstream_costs[name])
)
log_prob_sum = (
stop_gradient(log_prob_sum + surrogate) - surrogate
)
if not site["fn"].has_rsample:
log_prob_sum = stop_gradient(log_prob_sum)
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved
elbo = elbo - log_prob_sum
# add the -log_prob to each non-reparam sample site upstream
for key in guide_deps[name]:
downstream_costs[key].add(
(site["cond_indep_stack"], -site["log_prob"])
)

for node, downstream_cost in downstream_costs.items():
guide_site = guide_trace[node]
downstream_cost = downstream_cost.sum_to(guide_site["cond_indep_stack"])
surrogate = jnp.sum(
guide_site["log_prob"] * stop_gradient(downstream_cost)
)
elbo += surrogate - stop_gradient(surrogate)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: in jax it is clearer to use elbo = elbo + (...)


return elbo

Expand Down
139 changes: 107 additions & 32 deletions test/infer/test_compute_downstream_costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

from collections import defaultdict
from functools import partial
import math

from numpy.testing import assert_allclose
Expand All @@ -17,7 +19,9 @@
_compute_downstream_costs,
_get_plate_stacks,
_identify_dense_edges,
track_nonreparam,
)
from numpyro.ops.provenance import eval_provenance, get_provenance


def _brute_force_compute_downstream_costs(
Expand Down Expand Up @@ -72,6 +76,46 @@ def _brute_force_compute_downstream_costs(
return downstream_costs, downstream_guide_cost_nodes


def _provenance_compute_downstream_costs(model_trace, guide_trace, get_log_probs):

# replicate the logic from TraceGraph_ELBO
model_deps, guide_deps = get_provenance(
eval_provenance(track_nonreparam(get_log_probs))
)

downstream_costs = defaultdict(lambda: MultiFrameTensor())
downstream_guide_cost_nodes = defaultdict(lambda: set())
for name, site in model_trace.items():
if site["type"] == "sample":
# add the log_prob to each non-reparam sample site upstream
for key in model_deps[name]:
downstream_costs[key].add((site["cond_indep_stack"], site["log_prob"]))
downstream_guide_cost_nodes[key] |= {name}
for name, site in guide_trace.items():
if site["type"] == "sample":
# add the -log_prob to each non-reparam sample site upstream
for key in guide_deps[name]:
downstream_costs[key].add((site["cond_indep_stack"], -site["log_prob"]))
downstream_guide_cost_nodes[key] |= {name}

return downstream_costs, downstream_guide_cost_nodes


def _get_log_probs(_get_traces):
model_tr, guide_tr = _get_traces()
model_log_probs = {
name: site["log_prob"]
for name, site in model_tr.items()
if site["type"] == "sample"
}
guide_log_probs = {
name: site["log_prob"]
for name, site in guide_tr.items()
if site["type"] == "sample"
}
return model_log_probs, guide_log_probs


def big_model_guide(
include_obs=True,
include_single=False,
Expand Down Expand Up @@ -154,28 +198,35 @@ def big_model_guide(
def test_compute_downstream_costs_big_model_guide_pair(
include_inner_1, include_single, flip_c23, include_triple, include_z1
):
seeded_guide = handlers.seed(big_model_guide, rng_seed=0)
guide_trace = handlers.trace(seeded_guide).get_trace(
include_obs=False,
include_inner_1=include_inner_1,
include_single=include_single,
flip_c23=flip_c23,
include_triple=include_triple,
include_z1=include_z1,
)
model_trace = handlers.trace(handlers.replay(seeded_guide, guide_trace)).get_trace(
include_obs=True,
include_inner_1=include_inner_1,
include_single=include_single,
flip_c23=flip_c23,
include_triple=include_triple,
include_z1=include_z1,
)
def _get_traces():
seeded_guide = handlers.seed(big_model_guide, rng_seed=0)
guide_trace = handlers.trace(seeded_guide).get_trace(
include_obs=False,
include_inner_1=include_inner_1,
include_single=include_single,
flip_c23=flip_c23,
include_triple=include_triple,
include_z1=include_z1,
)
model_trace = handlers.trace(
handlers.replay(seeded_guide, guide_trace)
).get_trace(
include_obs=True,
include_inner_1=include_inner_1,
include_single=include_single,
flip_c23=flip_c23,
include_triple=include_triple,
include_z1=include_z1,
)

for trace in (model_trace, guide_trace):
for site in trace.values():
if site["type"] == "sample":
site["log_prob"] = site["fn"].log_prob(site["value"])

for trace in (model_trace, guide_trace):
for site in trace.values():
if site["type"] == "sample":
site["log_prob"] = site["fn"].log_prob(site["value"])
return model_trace, guide_trace

model_trace, guide_trace = _get_traces()
non_reparam_nodes = set(
name
for name, site in guide_trace.items()
Expand All @@ -191,8 +242,16 @@ def test_compute_downstream_costs_big_model_guide_pair(
model_trace, guide_trace, non_reparam_nodes
)

dc_provenance, dc_nodes_provenance = _provenance_compute_downstream_costs(
model_trace, guide_trace, partial(_get_log_probs, _get_traces)
)

assert dc_nodes == dc_nodes_brute

for name, nodes in dc_nodes_provenance.items():
assert nodes.issubset(dc_nodes[name])
assert nodes == {name}
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved

expected_nodes_full_model = {
"a1": {"c2", "a1", "d1", "c1", "obs", "b1", "d2", "c3", "b0"},
"d2": {"obs", "d2"},
Expand Down Expand Up @@ -329,18 +388,23 @@ def plate_reuse_model_guide(include_obs=True, dim1=3, dim2=2):
@pytest.mark.parametrize("dim1", [2, 5])
@pytest.mark.parametrize("dim2", [3, 4])
def test_compute_downstream_costs_plate_reuse(dim1, dim2):
seeded_guide = handlers.seed(plate_reuse_model_guide, rng_seed=0)
guide_trace = handlers.trace(seeded_guide).get_trace(
include_obs=False, dim1=dim1, dim2=dim2
)
model_trace = handlers.trace(handlers.replay(seeded_guide, guide_trace)).get_trace(
include_obs=True, dim1=dim1, dim2=dim2
)
def _get_traces():
seeded_guide = handlers.seed(plate_reuse_model_guide, rng_seed=0)
guide_trace = handlers.trace(seeded_guide).get_trace(
include_obs=False, dim1=dim1, dim2=dim2
)
model_trace = handlers.trace(
handlers.replay(seeded_guide, guide_trace)
).get_trace(include_obs=True, dim1=dim1, dim2=dim2)

for trace in (model_trace, guide_trace):
for site in trace.values():
if site["type"] == "sample":
site["log_prob"] = site["fn"].log_prob(site["value"])

for trace in (model_trace, guide_trace):
for site in trace.values():
if site["type"] == "sample":
site["log_prob"] = site["fn"].log_prob(site["value"])
return model_trace, guide_trace

model_trace, guide_trace = _get_traces()
non_reparam_nodes = set(
name
for name, site in guide_trace.items()
Expand All @@ -356,8 +420,19 @@ def test_compute_downstream_costs_plate_reuse(dim1, dim2):
model_trace, guide_trace, non_reparam_nodes
)

dc_provenance, dc_nodes_provenance = _provenance_compute_downstream_costs(
model_trace, guide_trace, partial(_get_log_probs, _get_traces)
)

assert dc_nodes == dc_nodes_brute

for name, nodes in dc_nodes_provenance.items():
assert nodes.issubset(dc_nodes[name])
if name == "c2":
assert nodes == {"c2", "obs"}
else:
assert nodes == {name}
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved

for k in dc:
assert guide_trace[k]["log_prob"].shape == dc[k].shape
assert_allclose(dc[k], dc_brute[k], rtol=1e-6)
Expand Down