Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow registering of custom exception handlers for potential_fn computations #3168

Merged
merged 8 commits into from
Jan 3, 2023
13 changes: 11 additions & 2 deletions pyro/infer/mcmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from pyro.infer.mcmc.adaptation import WarmupAdapter
from pyro.infer.mcmc.mcmc_kernel import MCMCKernel
from pyro.infer.mcmc.util import initialize_model
from pyro.ops.integrator import potential_grad, velocity_verlet
from pyro.ops.integrator import _EXCEPTION_HANDLERS, potential_grad, velocity_verlet
from pyro.util import optional, torch_isnan


Expand Down Expand Up @@ -173,7 +173,16 @@ def _find_reasonable_step_size(self, z):
# We are going to find a step_size which make accept_prob (Metropolis correction)
# near the target_accept_prob. If accept_prob:=exp(-delta_energy) is small,
# then we have to decrease step_size; otherwise, increase step_size.
potential_energy = self.potential_fn(z)
try:
potential_energy = self.potential_fn(z)
# handle exceptions as defined in the exception registry
except Exception as e:
if any(h(e) for h in _EXCEPTION_HANDLERS.values()):
# skip finding reasonable step size
return step_size
else:
raise e

r, r_unscaled = self._sample_r(name="r_presample_0")
energy_current = self._kinetic_energy(r_unscaled) + potential_energy
# This is required so as to avoid issues with autograd when model
Expand Down
49 changes: 45 additions & 4 deletions pyro/ops/integrator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import warnings
from typing import Callable, Dict

from torch.autograd import grad

# Registry for exception handlers that can be used to catch certain failures
# during computation of `potential_fn` within `potential_grad`.
_EXCEPTION_HANDLERS: Dict[str, Callable[[Exception], bool]] = {}


def velocity_verlet(
z, r, potential_fn, kinetic_grad, step_size, num_steps=1, z_grads=None
Expand Down Expand Up @@ -74,15 +81,49 @@ def potential_grad(potential_fn, z):
node.requires_grad_(True)
try:
potential_energy = potential_fn(z)
# deal with singular matrices
except RuntimeError as e:
if "singular" in str(e) or "input is not positive-definite" in str(e):
# handle exceptions as defined in the exception registry
except Exception as e:
if any(h(e) for h in _EXCEPTION_HANDLERS.values()):
grads = {k: v.new_zeros(v.shape) for k, v in z.items()}
return grads, z_nodes[0].new_tensor(float("nan"))
else:
raise e

grads = grad(potential_energy, z_nodes)
for node in z_nodes:
node.requires_grad_(False)
return dict(zip(z_keys, grads)), potential_energy.detach()


def register_exception_handler(
name: str, handler: Callable[[Exception], bool], warn_on_overwrite: bool = True
) -> None:
"""
Register an exception handler for handling (primarily numerical) errors
when evaluating the potential function.

:param name: name of the handler (must be unique).
:param handler: A callable mapping an Exception to a boolean. Exceptions
that evaluate to true in any of the handlers are handled in the computation
of the potential energy.
:param warn_on_overwrite: If True, warns when overwriting a handler already
registered under the provided name.
"""
if name in _EXCEPTION_HANDLERS and warn_on_overwrite:
warnings.warn(
f"Overwriting Exception handler already registered under key {name}.",
RuntimeWarning,
)
_EXCEPTION_HANDLERS[name] = handler


def _handle_torch_singular(exception: Exception) -> bool:
"""Exception handler for errors thrown on (numerically) singular matrices."""
# the actual type of the exception thrown is torch._C._LinAlgError
if isinstance(exception, RuntimeError):
msg = str(exception)
return "singular" in msg or "input is not positive-definite" in msg
return False


# Register default exception handler
register_exception_handler("torch_singular", _handle_torch_singular)