From fb580d1326bd01bcdeb2022b2ee75d9a9c1e3d71 Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Wed, 28 Dec 2022 21:44:53 -0800 Subject: [PATCH 1/8] [RFC] Allow registering excpetion handlers for potential function computations In some cases evaluation the potential funciton may result in numerical issues. Currently the code hard-codes the handling of a RuntimeError raised when matrices are (numerically) singular. This PR adds the ability to register custom exception handlers. This allows other code depending on pyro to register custom exception handlers without having to modify core pyro code. There are some other places in which `potential_fn` is called that could benefit from being guarded by these handlers (one is `HMC._find_reasonable_step_size`). I'm not sure what the right thing to do there is when encountering numerical isssues, but happy to add this in as needed. --- pyro/ops/integrator.py | 48 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 44 insertions(+), 4 deletions(-) diff --git a/pyro/ops/integrator.py b/pyro/ops/integrator.py index 750090248f..b0a4ad4d8e 100644 --- a/pyro/ops/integrator.py +++ b/pyro/ops/integrator.py @@ -1,8 +1,14 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +from typing import Callable + from torch.autograd import grad +# Registry for exception handlers that can be used to catch certain failures +# during computation of `potential_fn` within `potential_grad`. +_EXCEPTION_HANDLERS = {} + def velocity_verlet( z, r, potential_fn, kinetic_grad, step_size, num_steps=1, z_grads=None @@ -74,15 +80,49 @@ def potential_grad(potential_fn, z): node.requires_grad_(True) try: potential_energy = potential_fn(z) - # deal with singular matrices - except RuntimeError as e: - if "singular" in str(e) or "input is not positive-definite" in str(e): + # handle exceptions as defined in the exception registry + except Exception as e: + if any(h(e) for h in _EXCEPTION_HANDLERS.values()): grads = {k: v.new_zeros(v.shape) for k, v in z.items()} return grads, z_nodes[0].new_tensor(float("nan")) else: raise e - grads = grad(potential_energy, z_nodes) for node in z_nodes: node.requires_grad_(False) return dict(zip(z_keys, grads)), potential_energy.detach() + + +def register_exception_handler( + name: str, handler: Callable[[Exception], bool], overwrite: bool = False +) -> None: + """ + Register an exception handler for handling (primarily numerical) errors + when evaluating the potential function. + + :param name: name of the handler (must be unique). + :param handler: A callable mapping an exception to a boolean. Exceptions + that evaluate to true in any of the handlers are handled in teh computation + of the potential energy. + :param overwrite: If True, overwrite handlers already registerd under the + provided name. + """ + if name in _EXCEPTION_HANDLERS and not overwrite: + raise RuntimeError( + f"Exception handler already registered under key {name}. " + "Use `overwrite=True` to force overwriting the handler." + ) + _EXCEPTION_HANDLERS[name] = handler + + +def _handle_torch_singular(exception: Exception) -> bool: + """Exception handler for errors thrown on (numerically) singular matrices.""" + if type(exception) == RuntimeError: + return "singular" in str(exception) or "input is not positive-definite" in str( + exception + ) + return False + + +# Register default exception handler +register_exception_handler("torch_singular", _handle_torch_singular) From 35dabe0335505f2d70fba5b9992ed568c8f0155e Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Thu, 29 Dec 2022 08:00:28 -0800 Subject: [PATCH 2/8] Fix typing lint, typos. --- pyro/ops/integrator.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/pyro/ops/integrator.py b/pyro/ops/integrator.py index b0a4ad4d8e..e595f9537a 100644 --- a/pyro/ops/integrator.py +++ b/pyro/ops/integrator.py @@ -7,7 +7,7 @@ # Registry for exception handlers that can be used to catch certain failures # during computation of `potential_fn` within `potential_grad`. -_EXCEPTION_HANDLERS = {} +_EXCEPTION_HANDLERS: Dict[str, Callable[[Exception], bool]] = {} def velocity_verlet( @@ -102,9 +102,9 @@ def register_exception_handler( :param name: name of the handler (must be unique). :param handler: A callable mapping an exception to a boolean. Exceptions - that evaluate to true in any of the handlers are handled in teh computation + that evaluate to true in any of the handlers are handled in the computation of the potential energy. - :param overwrite: If True, overwrite handlers already registerd under the + :param overwrite: If True, overwrite handlers already registered under the provided name. """ if name in _EXCEPTION_HANDLERS and not overwrite: @@ -118,9 +118,8 @@ def register_exception_handler( def _handle_torch_singular(exception: Exception) -> bool: """Exception handler for errors thrown on (numerically) singular matrices.""" if type(exception) == RuntimeError: - return "singular" in str(exception) or "input is not positive-definite" in str( - exception - ) + msg = str(exception) + return "singular" in msg or "input is not positive-definite" in msg return False From 7158583594153d6e4b22ff8793f8c2e911fb9b44 Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Thu, 29 Dec 2022 10:02:54 -0800 Subject: [PATCH 3/8] Warn instead of raise, fix typing import. --- pyro/ops/integrator.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/pyro/ops/integrator.py b/pyro/ops/integrator.py index e595f9537a..15368eec3a 100644 --- a/pyro/ops/integrator.py +++ b/pyro/ops/integrator.py @@ -1,7 +1,9 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 -from typing import Callable +import warnings + +from typing import Callable, Dict from torch.autograd import grad @@ -94,23 +96,23 @@ def potential_grad(potential_fn, z): def register_exception_handler( - name: str, handler: Callable[[Exception], bool], overwrite: bool = False + name: str, handler: Callable[[Exception], bool], warn_on_overwrite: bool = True ) -> None: """ Register an exception handler for handling (primarily numerical) errors when evaluating the potential function. :param name: name of the handler (must be unique). - :param handler: A callable mapping an exception to a boolean. Exceptions + :param handler: A callable mapping an Exception to a boolean. Exceptions that evaluate to true in any of the handlers are handled in the computation of the potential energy. - :param overwrite: If True, overwrite handlers already registered under the - provided name. + :param warn_on_overwrite: If True, warns when overwriting a handler already + registered under the provided name. """ - if name in _EXCEPTION_HANDLERS and not overwrite: - raise RuntimeError( - f"Exception handler already registered under key {name}. " - "Use `overwrite=True` to force overwriting the handler." + if name in _EXCEPTION_HANDLERS and warn_on_overwrite: + warnings.warn( + f"Overwriting Exception handler already registered under key {name}.", + RuntimeWarning, ) _EXCEPTION_HANDLERS[name] = handler From 4c0beb51677daa820bc4c4e95189e25f88c1b5c8 Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Thu, 29 Dec 2022 11:55:16 -0800 Subject: [PATCH 4/8] Make isort happy (hopefully) --- pyro/ops/integrator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pyro/ops/integrator.py b/pyro/ops/integrator.py index 15368eec3a..14d31ebf9f 100644 --- a/pyro/ops/integrator.py +++ b/pyro/ops/integrator.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 import warnings - from typing import Callable, Dict from torch.autograd import grad From 9b1b4294f547cec8ba5d6aaf98d339087ff38196 Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Thu, 29 Dec 2022 14:26:44 -0800 Subject: [PATCH 5/8] Check for instance rather than type equality in _handle_torch_singular --- pyro/ops/integrator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyro/ops/integrator.py b/pyro/ops/integrator.py index 14d31ebf9f..269519eff4 100644 --- a/pyro/ops/integrator.py +++ b/pyro/ops/integrator.py @@ -118,7 +118,8 @@ def register_exception_handler( def _handle_torch_singular(exception: Exception) -> bool: """Exception handler for errors thrown on (numerically) singular matrices.""" - if type(exception) == RuntimeError: + # the actual type of the exception thrown is torch._C._LinAlgError + if isinstance(exception, RuntimeError): msg = str(exception) return "singular" in msg or "input is not positive-definite" in msg return False From f685eb07c9134b6851a7b6ac509b8585c17272c6 Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Thu, 29 Dec 2022 14:40:00 -0800 Subject: [PATCH 6/8] Handle numerical issues also in HMC._find_reasonable_step_size --- pyro/infer/mcmc/hmc.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/pyro/infer/mcmc/hmc.py b/pyro/infer/mcmc/hmc.py index ff73dedde5..dbda613ba3 100644 --- a/pyro/infer/mcmc/hmc.py +++ b/pyro/infer/mcmc/hmc.py @@ -14,7 +14,7 @@ from pyro.infer.mcmc.adaptation import WarmupAdapter from pyro.infer.mcmc.mcmc_kernel import MCMCKernel from pyro.infer.mcmc.util import initialize_model -from pyro.ops.integrator import potential_grad, velocity_verlet +from pyro.ops.integrator import potential_grad, velocity_verlet, _EXCEPTION_HANDLERS from pyro.util import optional, torch_isnan @@ -173,7 +173,16 @@ def _find_reasonable_step_size(self, z): # We are going to find a step_size which make accept_prob (Metropolis correction) # near the target_accept_prob. If accept_prob:=exp(-delta_energy) is small, # then we have to decrease step_size; otherwise, increase step_size. - potential_energy = self.potential_fn(z) + try: + potential_energy = self.potential_fn(z) + # handle exceptions as defined in the exception registry + except Exception as e: + if any(h(e) for h in _EXCEPTION_HANDLERS.values()): + # skip finding reasonable step size + return step_size + else: + raise e + r, r_unscaled = self._sample_r(name="r_presample_0") energy_current = self._kinetic_energy(r_unscaled) + potential_energy # This is required so as to avoid issues with autograd when model From 28dcfe4fdec7ce2a9b7b8303219212d50dee2f61 Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Thu, 29 Dec 2022 14:51:24 -0800 Subject: [PATCH 7/8] isort once more --- pyro/infer/mcmc/hmc.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyro/infer/mcmc/hmc.py b/pyro/infer/mcmc/hmc.py index dbda613ba3..741bbf0f38 100644 --- a/pyro/infer/mcmc/hmc.py +++ b/pyro/infer/mcmc/hmc.py @@ -14,7 +14,8 @@ from pyro.infer.mcmc.adaptation import WarmupAdapter from pyro.infer.mcmc.mcmc_kernel import MCMCKernel from pyro.infer.mcmc.util import initialize_model -from pyro.ops.integrator import potential_grad, velocity_verlet, _EXCEPTION_HANDLERS +from pyro.ops.integrator import (_EXCEPTION_HANDLERS, potential_grad, + velocity_verlet) from pyro.util import optional, torch_isnan From 0c5452fefed1456a301e9a0ca1329dedea607d55 Mon Sep 17 00:00:00 2001 From: Max Balandat Date: Mon, 2 Jan 2023 16:21:34 -0800 Subject: [PATCH 8/8] Fix black format --- pyro/infer/mcmc/hmc.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyro/infer/mcmc/hmc.py b/pyro/infer/mcmc/hmc.py index 741bbf0f38..ee9f01b124 100644 --- a/pyro/infer/mcmc/hmc.py +++ b/pyro/infer/mcmc/hmc.py @@ -14,8 +14,7 @@ from pyro.infer.mcmc.adaptation import WarmupAdapter from pyro.infer.mcmc.mcmc_kernel import MCMCKernel from pyro.infer.mcmc.util import initialize_model -from pyro.ops.integrator import (_EXCEPTION_HANDLERS, potential_grad, - velocity_verlet) +from pyro.ops.integrator import _EXCEPTION_HANDLERS, potential_grad, velocity_verlet from pyro.util import optional, torch_isnan