Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Display correct error message and cleanup _PYRO_STACK when error happens #818

Merged
merged 3 commits into from
Nov 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions numpyro/contrib/funsor/infer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,4 +175,8 @@ def model(*args, **kwargs):
funsor.ops.logaddexp, funsor.ops.add, log_factors,
eliminate=sum_vars | prod_vars, plates=prod_vars)
result = funsor.optimizer.apply_optimizer(lazy_result)
if len(result.inputs) > 0:
raise ValueError("Expected the joint log density is a scalar, but got {}. "
"There seems to be something wrong at the following sites: {}."
.format(result.data.shape, {k.split("__BOUND")[0] for k in result.inputs}))
return result.data, model_trace
9 changes: 7 additions & 2 deletions numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,18 +278,23 @@ def __enter__(self):
COERCIONS.append(self._coerce)
return super().__enter__()

def __exit__(self, *args, **kwargs):
def __exit__(self, exc_type, exc_value, traceback):
import funsor

_coerce = COERCIONS.pop()
assert _coerce is self._coerce
super().__exit__(*args, **kwargs)
super().__exit__(exc_type, exc_value, traceback)

if exc_type is not None:
return

# Convert delayed statements to pyro.factor()
reduced_vars = []
log_prob_terms = []
plates = frozenset()
for name, site in self.trace.items():
if site["type"] != "sample":
continue
if not site["is_observed"]:
reduced_vars.append(name)
dim_to_name = {f.dim: f.name for f in site["cond_indep_stack"]}
Expand Down
18 changes: 15 additions & 3 deletions numpyro/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,21 @@ def __init__(self, fn=None):
def __enter__(self):
_PYRO_STACK.append(self)

def __exit__(self, *args, **kwargs):
assert _PYRO_STACK[-1] is self
_PYRO_STACK.pop()
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is None:
assert _PYRO_STACK[-1] is self
_PYRO_STACK.pop()
else:
# NB: this mimics Pyro exception handling
# the wrapped function or block raised an exception
# handler exception handling:
# when the callee or enclosed block raises an exception,
# find this handler's position in the stack,
# then remove it and everything below it in the stack.
if self in _PYRO_STACK:
loc = _PYRO_STACK.index(self)
for i in range(loc, len(_PYRO_STACK)):
_PYRO_STACK.pop()

def process_message(self, msg):
pass
Expand Down
8 changes: 7 additions & 1 deletion test/contrib/test_funsor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from numpyro.contrib.indexing import Vindex
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from numpyro.primitives import _PYRO_STACK


def test_gaussian_mixture_model():
Expand Down Expand Up @@ -432,7 +433,7 @@ def transition_fn(name, probs, locs, x, y):
assert_allclose(actual_log_joint, expected_log_joint)


def test_missing_plate():
def test_missing_plate(monkeypatch):
K, N = 3, 1000

def gmm(data):
Expand All @@ -453,3 +454,8 @@ def gmm(data):
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=500)
with pytest.raises(AssertionError, match="Missing plate statement"):
mcmc.run(random.PRNGKey(2), data)

monkeypatch.setattr(numpyro.infer.util, "_validate_model", lambda model_trace: None)
with pytest.raises(Exception):
mcmc.run(random.PRNGKey(2), data)
assert len(_PYRO_STACK) == 0
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eb8680 Before this PR, assert len(_PYRO_STACK) == 0 fails (7 vs 0).