diff --git a/numpyro/compat/infer.py b/numpyro/compat/infer.py index 5e3442cec..1b179b80c 100644 --- a/numpyro/compat/infer.py +++ b/numpyro/compat/infer.py @@ -132,7 +132,7 @@ def step(self, *args, rng_key=None, **kwargs): except TypeError as e: if 'not a valid JAX type' in str(e): raise TypeError('NumPyro backend requires args, kwargs to be arrays or tuples, ' - 'dicts of arrays.') + 'dicts of arrays.') from e else: raise e params = jit(super(SVI, self).get_params)(self.svi_state) diff --git a/numpyro/contrib/funsor/__init__.py b/numpyro/contrib/funsor/__init__.py index 647017717..e40844ad7 100644 --- a/numpyro/contrib/funsor/__init__.py +++ b/numpyro/contrib/funsor/__init__.py @@ -3,11 +3,11 @@ try: import funsor -except ImportError: +except ImportError as e: raise ImportError("Looking like you want to do inference for models with " "discrete latent variables. This is an experimental feature. " "You need to install `funsor` to be able to use this feature. " - "It can be installed with `pip install funsor`.") + "It can be installed with `pip install funsor`.") from e from numpyro.contrib.funsor.enum_messenger import enum, infer_config, markov, plate, to_data, to_funsor, trace from numpyro.contrib.funsor.infer_util import config_enumerate, log_density, plate_to_enum_plate diff --git a/numpyro/contrib/module.py b/numpyro/contrib/module.py index e0b795718..9cb7c4258 100644 --- a/numpyro/contrib/module.py +++ b/numpyro/contrib/module.py @@ -35,11 +35,11 @@ def flax_module(name, nn_module, *, input_shape=None): """ try: import flax # noqa: F401 - except ImportError: + except ImportError as e: raise ImportError("Looking like you want to use flax to declare " "nn modules. This is an experimental feature. " "You need to install `flax` to be able to use this feature. " - "It can be installed with `pip install flax`.") + "It can be installed with `pip install flax`.") from e module_key = name + '$params' nn_params = numpyro.param(module_key) if nn_params is None: @@ -71,11 +71,11 @@ def haiku_module(name, nn_module, *, input_shape=None): """ try: import haiku # noqa: F401 - except ImportError: + except ImportError as e: raise ImportError("Looking like you want to use haiku to declare " "nn modules. This is an experimental feature. " "You need to install `haiku` to be able to use this feature. " - "It can be installed with `pip install dm-haiku`.") + "It can be installed with `pip install dm-haiku`.") from e module_key = name + '$params' nn_params = numpyro.param(module_key) diff --git a/numpyro/contrib/tfp/__init__.py b/numpyro/contrib/tfp/__init__.py index ef67892d9..87e269656 100644 --- a/numpyro/contrib/tfp/__init__.py +++ b/numpyro/contrib/tfp/__init__.py @@ -3,7 +3,7 @@ try: import tensorflow_probability.substrates.jax as tfp # noqa: F401 -except ImportError: +except ImportError as e: raise ImportError("Looking like your installed tensorflow_probability does not" " support JAX backend. You might try to install the nightly" - " version with: `pip install tfp-nightly`") + " version with: `pip install tfp-nightly`") from e diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 4541cca9a..4e7be1010 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -555,8 +555,8 @@ def register(self, constraint, factory=None): def __call__(self, constraint): try: factory = self._registry[type(constraint)] - except KeyError: - raise NotImplementedError + except KeyError as e: + raise NotImplementedError from e return factory(constraint) diff --git a/numpyro/infer/reparam.py b/numpyro/infer/reparam.py index 4d94562dd..c17d80289 100644 --- a/numpyro/infer/reparam.py +++ b/numpyro/infer/reparam.py @@ -182,10 +182,10 @@ def __init__(self, guide, params): self.params = params try: self.transform = self.guide.get_transform(params) - except (NotImplementedError, TypeError): + except (NotImplementedError, TypeError) as e: raise ValueError("NeuTraReparam only supports guides that implement " "`get_transform` method that does not depend on the " - "model's `*args, **kwargs`") + "model's `*args, **kwargs`") from e self._x_unconstrained = {} def _reparam_config(self, site):