Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly handle contraction of guide plates in TraceEnum_ELBO #1537

Merged
merged 3 commits into from
Feb 9, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 33 additions & 43 deletions numpyro/infer/elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

from collections import OrderedDict, defaultdict
from functools import partial, reduce
from functools import partial
from operator import itemgetter
import warnings

Expand Down Expand Up @@ -871,15 +871,8 @@ def __init__(self, num_particles=1, max_plate_nesting=float("inf")):

def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
def single_particle_elbo(rng_key):
from opt_einsum import shared_intermediates

import funsor
from funsor.cnf import _eager_contract_tensors
from numpyro.contrib.funsor import to_data, to_funsor

logsumexp_backend = "funsor.einsum.numpy_log"
with shared_intermediates() as cache: # create a cache
pass
from numpyro.contrib.funsor import to_data

model_seed, guide_seed = random.split(rng_key)

Expand Down Expand Up @@ -935,7 +928,6 @@ def single_particle_elbo(rng_key):
cost = model_trace[name]["log_prob"]
scale = model_trace[name]["scale"]
deps = model_deps[name]
dice_factors = [guide_trace[key]["log_measure"] for key in deps]
else:
# compute contracted cost term
group_factors = tuple(
Expand All @@ -952,13 +944,16 @@ def single_particle_elbo(rng_key):
*(frozenset(f.inputs) & group_plates for f in group_factors)
)
elim_plates = group_plates - outermost_plates
cost = funsor.sum_product.sum_product(
funsor.ops.logaddexp,
funsor.ops.add,
group_factors,
plates=group_plates,
eliminate=group_sum_vars | elim_plates,
)
with funsor.interpretations.normalize:
cost = funsor.sum_product.sum_product(
funsor.ops.logaddexp,
funsor.ops.add,
group_factors,
plates=group_plates,
eliminate=group_sum_vars | elim_plates,
)
# TODO: add memoization
cost = funsor.optimizer.apply_optimizer(cost)
# incorporate the effects of subsampling and handlers.scale through a common scale factor
scales_set = set()
for name in group_names | group_sum_vars:
Expand Down Expand Up @@ -992,43 +987,38 @@ def single_particle_elbo(rng_key):
f"model enumeration sites upstream of guide site '{key}' in plate('{plate}')."
"Try converting some model enumeration sites to guide enumeration sites."
)
# combine dice factors
dice_factors = [
guide_trace[key]["log_measure"].reduce(
funsor.ops.add,
frozenset(guide_trace[key]["log_measure"].inputs)
& elim_plates,
)
for key in deps
]
cost_terms.append((cost, scale, dice_factors))
cost_terms.append((cost, scale, deps))

for name, deps in guide_deps.items():
# -logq cost term
cost = -guide_trace[name]["log_prob"]
scale = guide_trace[name]["scale"]
dice_factors = [guide_trace[key]["log_measure"] for key in deps]
cost_terms.append((cost, scale, dice_factors))
cost_terms.append((cost, scale, deps))

# compute elbo
elbo = 0.0
for cost, scale, dice_factors in cost_terms:
if dice_factors:
reduced_vars = (
frozenset().union(*[f.input_vars for f in dice_factors])
- cost.input_vars
for cost, scale, deps in cost_terms:
if deps:
dice_factors = tuple(
guide_trace[key]["log_measure"] for key in deps
)
if reduced_vars:
# use opt_einsum to reduce vars not present in the cost term
with shared_intermediates(cache):
dice_factor = _eager_contract_tensors(
reduced_vars, dice_factors, backend=logsumexp_backend
)
else:
dice_factor = reduce(lambda a, b: a + b, dice_factors)
dice_factor_vars = frozenset().union(
*[f.inputs for f in dice_factors]
)
cost_vars = frozenset(cost.inputs)
with funsor.interpretations.normalize:
dice_factor = funsor.sum_product.sum_product(
funsor.ops.logaddexp,
funsor.ops.add,
dice_factors,
plates=(dice_factor_vars | cost_vars) - model_vars,
eliminate=dice_factor_vars - cost_vars,
)
# TODO: add memoization
dice_factor = funsor.optimizer.apply_optimizer(dice_factor)
cost = cost * funsor.ops.exp(dice_factor)
if (scale is not None) and (not is_identically_one(scale)):
cost = cost * to_funsor(scale)
cost = cost * scale

elbo = elbo + cost.reduce(funsor.ops.add)

Expand Down
41 changes: 41 additions & 0 deletions test/contrib/test_enum_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2472,3 +2472,44 @@ def actual_loss_fn(params_raw):

assert_equal(actual_loss, expected_loss, prec=1e-3)
assert_equal(actual_grads, expected_grads, prec=1e-5)


def test_guide_plate_contraction():
def model(params):
with pyro.plate("a_axis", size=2):
a = pyro.sample("a", dist.Categorical(jnp.array([0.2, 0.8])))
pyro.sample("b", dist.Normal(jnp.sum(a), 1.0), obs=1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me that this invalidates this restriction: https://pyro.ai/examples/enumeration.html#Restriction-2:-no-downstream-coupling? Could you elaborate why we can enumerate here?

Copy link
Member Author

@ordabayevy ordabayevy Feb 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not being enumerated here but being used as a non-reparameterizable site. I should have used other distribution like Poisson to make it less confusing :)

But the point is that b depends on a non-reparametrizable site a which has a_axis plate. The dice_factor needs to be product-contracted to eliminate the extra a_axis plate before multiplying the cost term for site b. Instead a_axis is being passed to _eager_contract_tensors as reduced_vars and sum-contracted with logsumexp. Hope this clarifies it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. So we are using TraceEnum_ELBO but enumeration is disabled for those cases. Could you add a warning for this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean in general or for this test?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean in general. I think using TraceEnum_ELBO without enumeration is confusing. Maybe raise error if we can't enumerate sites with infer={"enumerate": "parallel"}?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed the distribution to Poisson in the test since it can be any non-reparameterizable distributions. I can open another issue/PR for enumeration configuration since it is a separate issue.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds reasonable to me, thanks!


def guide(params):
probs_a = pyro.param(
"probs_a", params["probs_a"], constraint=constraints.simplex
)
with pyro.plate("a_axis", size=2):
pyro.sample("a", dist.Categorical(probs_a))

params = {
"probs_a": jnp.array([[0.4, 0.6], [0.7, 0.3]]),
}
transform = dist.biject_to(dist.constraints.simplex)
params_raw = jax.tree_util.tree_map(transform.inv, params)

# TraceGraph_ELBO grads averaged over num_particles
elbo = infer.TraceGraph_ELBO(num_particles=50_000)

def graph_loss_fn(params_raw):
params = jax.tree_util.tree_map(transform, params_raw)
return elbo.loss(random.PRNGKey(0), {}, model, guide, params)

graph_loss, graph_grads = jax.value_and_grad(graph_loss_fn)(params_raw)

# TraceEnum_ELBO grads averaged over num_particles (no enumeration)
elbo = infer.TraceEnum_ELBO(num_particles=50_000)

def enum_loss_fn(params_raw):
params = jax.tree_util.tree_map(transform, params_raw)
return elbo.loss(random.PRNGKey(0), {}, model, guide, params)

enum_loss, enum_grads = jax.value_and_grad(enum_loss_fn)(params_raw)

assert_equal(enum_loss, graph_loss, prec=1e-3)
assert_equal(enum_grads, graph_grads, prec=3e-3)