diff --git a/pyro/infer/mcmc/hmc.py b/pyro/infer/mcmc/hmc.py index ff73dedde5..ee9f01b124 100644 --- a/pyro/infer/mcmc/hmc.py +++ b/pyro/infer/mcmc/hmc.py @@ -14,7 +14,7 @@ from pyro.infer.mcmc.adaptation import WarmupAdapter from pyro.infer.mcmc.mcmc_kernel import MCMCKernel from pyro.infer.mcmc.util import initialize_model -from pyro.ops.integrator import potential_grad, velocity_verlet +from pyro.ops.integrator import _EXCEPTION_HANDLERS, potential_grad, velocity_verlet from pyro.util import optional, torch_isnan @@ -173,7 +173,16 @@ def _find_reasonable_step_size(self, z): # We are going to find a step_size which make accept_prob (Metropolis correction) # near the target_accept_prob. If accept_prob:=exp(-delta_energy) is small, # then we have to decrease step_size; otherwise, increase step_size. - potential_energy = self.potential_fn(z) + try: + potential_energy = self.potential_fn(z) + # handle exceptions as defined in the exception registry + except Exception as e: + if any(h(e) for h in _EXCEPTION_HANDLERS.values()): + # skip finding reasonable step size + return step_size + else: + raise e + r, r_unscaled = self._sample_r(name="r_presample_0") energy_current = self._kinetic_energy(r_unscaled) + potential_energy # This is required so as to avoid issues with autograd when model diff --git a/pyro/ops/integrator.py b/pyro/ops/integrator.py index 750090248f..269519eff4 100644 --- a/pyro/ops/integrator.py +++ b/pyro/ops/integrator.py @@ -1,8 +1,15 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +import warnings +from typing import Callable, Dict + from torch.autograd import grad +# Registry for exception handlers that can be used to catch certain failures +# during computation of `potential_fn` within `potential_grad`. +_EXCEPTION_HANDLERS: Dict[str, Callable[[Exception], bool]] = {} + def velocity_verlet( z, r, potential_fn, kinetic_grad, step_size, num_steps=1, z_grads=None @@ -74,15 +81,49 @@ def potential_grad(potential_fn, z): node.requires_grad_(True) try: potential_energy = potential_fn(z) - # deal with singular matrices - except RuntimeError as e: - if "singular" in str(e) or "input is not positive-definite" in str(e): + # handle exceptions as defined in the exception registry + except Exception as e: + if any(h(e) for h in _EXCEPTION_HANDLERS.values()): grads = {k: v.new_zeros(v.shape) for k, v in z.items()} return grads, z_nodes[0].new_tensor(float("nan")) else: raise e - grads = grad(potential_energy, z_nodes) for node in z_nodes: node.requires_grad_(False) return dict(zip(z_keys, grads)), potential_energy.detach() + + +def register_exception_handler( + name: str, handler: Callable[[Exception], bool], warn_on_overwrite: bool = True +) -> None: + """ + Register an exception handler for handling (primarily numerical) errors + when evaluating the potential function. + + :param name: name of the handler (must be unique). + :param handler: A callable mapping an Exception to a boolean. Exceptions + that evaluate to true in any of the handlers are handled in the computation + of the potential energy. + :param warn_on_overwrite: If True, warns when overwriting a handler already + registered under the provided name. + """ + if name in _EXCEPTION_HANDLERS and warn_on_overwrite: + warnings.warn( + f"Overwriting Exception handler already registered under key {name}.", + RuntimeWarning, + ) + _EXCEPTION_HANDLERS[name] = handler + + +def _handle_torch_singular(exception: Exception) -> bool: + """Exception handler for errors thrown on (numerically) singular matrices.""" + # the actual type of the exception thrown is torch._C._LinAlgError + if isinstance(exception, RuntimeError): + msg = str(exception) + return "singular" in msg or "input is not positive-definite" in msg + return False + + +# Register default exception handler +register_exception_handler("torch_singular", _handle_torch_singular)