Skip to content

Commit

Permalink
Display correct error message and cleanup _PYRO_STACK when error happ…
Browse files Browse the repository at this point in the history
…ens (#818)

* remove stacks when error happens

* remove rng_key raise

* add regression test
  • Loading branch information
fehiepsi authored Nov 19, 2020
1 parent 490e5eb commit 970dc48
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 6 deletions.
4 changes: 4 additions & 0 deletions numpyro/contrib/funsor/infer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,4 +175,8 @@ def model(*args, **kwargs):
funsor.ops.logaddexp, funsor.ops.add, log_factors,
eliminate=sum_vars | prod_vars, plates=prod_vars)
result = funsor.optimizer.apply_optimizer(lazy_result)
if len(result.inputs) > 0:
raise ValueError("Expected the joint log density is a scalar, but got {}. "
"There seems to be something wrong at the following sites: {}."
.format(result.data.shape, {k.split("__BOUND")[0] for k in result.inputs}))
return result.data, model_trace
9 changes: 7 additions & 2 deletions numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,18 +278,23 @@ def __enter__(self):
COERCIONS.append(self._coerce)
return super().__enter__()

def __exit__(self, *args, **kwargs):
def __exit__(self, exc_type, exc_value, traceback):
import funsor

_coerce = COERCIONS.pop()
assert _coerce is self._coerce
super().__exit__(*args, **kwargs)
super().__exit__(exc_type, exc_value, traceback)

if exc_type is not None:
return

# Convert delayed statements to pyro.factor()
reduced_vars = []
log_prob_terms = []
plates = frozenset()
for name, site in self.trace.items():
if site["type"] != "sample":
continue
if not site["is_observed"]:
reduced_vars.append(name)
dim_to_name = {f.dim: f.name for f in site["cond_indep_stack"]}
Expand Down
18 changes: 15 additions & 3 deletions numpyro/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,21 @@ def __init__(self, fn=None):
def __enter__(self):
_PYRO_STACK.append(self)

def __exit__(self, *args, **kwargs):
assert _PYRO_STACK[-1] is self
_PYRO_STACK.pop()
def __exit__(self, exc_type, exc_value, traceback):
if exc_type is None:
assert _PYRO_STACK[-1] is self
_PYRO_STACK.pop()
else:
# NB: this mimics Pyro exception handling
# the wrapped function or block raised an exception
# handler exception handling:
# when the callee or enclosed block raises an exception,
# find this handler's position in the stack,
# then remove it and everything below it in the stack.
if self in _PYRO_STACK:
loc = _PYRO_STACK.index(self)
for i in range(loc, len(_PYRO_STACK)):
_PYRO_STACK.pop()

def process_message(self, msg):
pass
Expand Down
8 changes: 7 additions & 1 deletion test/contrib/test_funsor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from numpyro.contrib.indexing import Vindex
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
from numpyro.primitives import _PYRO_STACK


def test_gaussian_mixture_model():
Expand Down Expand Up @@ -432,7 +433,7 @@ def transition_fn(name, probs, locs, x, y):
assert_allclose(actual_log_joint, expected_log_joint)


def test_missing_plate():
def test_missing_plate(monkeypatch):
K, N = 3, 1000

def gmm(data):
Expand All @@ -453,3 +454,8 @@ def gmm(data):
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=500)
with pytest.raises(AssertionError, match="Missing plate statement"):
mcmc.run(random.PRNGKey(2), data)

monkeypatch.setattr(numpyro.infer.util, "_validate_model", lambda model_trace: None)
with pytest.raises(Exception):
mcmc.run(random.PRNGKey(2), data)
assert len(_PYRO_STACK) == 0

0 comments on commit 970dc48

Please sign in to comment.