Skip to content

Commit

Permalink
add stacklevel to warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Jan 1, 2022
1 parent 3347583 commit 0b316d4
Show file tree
Hide file tree
Showing 17 changed files with 99 additions and 30 deletions.
2 changes: 2 additions & 0 deletions numpyro/compat/pyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import warnings

from numpyro.compat.util import UnsupportedAPIWarning
from numpyro.util import find_stack_level

from numpyro.primitives import module, plate, sample # noqa: F401 isort:skip
from numpyro.primitives import param as _param # noqa: F401 isort:skip
Expand All @@ -16,6 +17,7 @@ def get_param_store():
"A limited parameter store is provided for compatibility with Pyro. "
"Value of SVI parameters should be obtained via SVI.get_params() method.",
category=UnsupportedAPIWarning,
stacklevel=find_stack_level(),
)
# Return an empty dict for compatibility
return _PARAM_STORE
Expand Down
2 changes: 2 additions & 0 deletions numpyro/contrib/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import warnings

from numpyro.ops import Vindex, vindex # noqa: F401
from numpyro.util import find_stack_level

warnings.warn(
"`indexing` module has been moved from `numpyro.contrib` to `numpyro.ops`."
" Please import Vindex or vindex functions from `numpyro.ops.indexing`.",
FutureWarning,
stacklevel=find_stack_level(),
)
3 changes: 2 additions & 1 deletion numpyro/contrib/tfp/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpyro.distributions as numpyro_dist
from numpyro.distributions import Distribution as NumPyroDistribution, constraints
from numpyro.distributions.transforms import Transform, biject_to
from numpyro.util import not_jax_tracer
from numpyro.util import find_stack_level, not_jax_tracer


def _get_codomain(bijector):
Expand Down Expand Up @@ -129,6 +129,7 @@ def init(self, *args, **kwargs):
"deprecated. You should import distributions directly from "
"tensorflow_probability.substrates.jax.distributions instead.",
FutureWarning,
stacklevel=find_stack_level(),
)
self.tfp_dist = tfd_class(*args, **kwargs)

Expand Down
3 changes: 2 additions & 1 deletion numpyro/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
promote_shapes,
validate_sample,
)
from numpyro.util import not_jax_tracer
from numpyro.util import find_stack_level, not_jax_tracer


def _to_probs_bernoulli(logits):
Expand Down Expand Up @@ -512,6 +512,7 @@ def __init__(self):
"PRNGIdentity distribution is deprecated. To get a random "
"PRNG key, you can use `numpyro.prng_key()` instead.",
FutureWarning,
stacklevel=find_stack_level(),
)
super(PRNGIdentity, self).__init__(event_shape=(2,))

Expand Down
5 changes: 3 additions & 2 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
sum_rightmost,
validate_sample,
)
from numpyro.util import not_jax_tracer
from numpyro.util import find_stack_level, not_jax_tracer

from . import constraints

Expand Down Expand Up @@ -291,7 +291,8 @@ def _validate_sample(self, value):
if not np.all(mask):
warnings.warn(
"Out-of-support values provided to log prob method. "
"The value argument should be within the support."
"The value argument should be within the support.",
stacklevel=find_stack_level(),
)
return mask

Expand Down
7 changes: 6 additions & 1 deletion numpyro/distributions/gof.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def test_my_distribution():

import jax

from numpyro.util import find_stack_level

HISTOGRAM_WIDTH = 60


Expand Down Expand Up @@ -117,7 +119,10 @@ def multinomial_goodness_of_fit(probs, counts, *, total_count=None, plot=False):
chi_squared += (c - mean) ** 2 / variance
dof += 1
else:
warnings.warn("Zero probability in goodness-of-fit test")
warnings.warn(
"Zero probability in goodness-of-fit test",
stacklevel=find_stack_level(),
)
if c > 0:
return math.inf

Expand Down
7 changes: 5 additions & 2 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
sum_rightmost,
vec_to_tril_matrix,
)
from numpyro.util import not_jax_tracer
from numpyro.util import find_stack_level, not_jax_tracer

__all__ = [
"biject_to",
Expand Down Expand Up @@ -67,6 +67,7 @@ def event_dim(self):
"transform.event_dim is deprecated. Please use Transform.domain.event_dim to "
"get input event dim or Transform.codomain.event_dim to get output event dim.",
FutureWarning,
stacklevel=find_stack_level(),
)
return self.domain.event_dim

Expand Down Expand Up @@ -560,6 +561,7 @@ def __init__(self, domain=constraints.lower_cholesky):
"InvCholeskyTransform is deprecated. Please use CholeskyTransform"
" or CorrMatrixCholeskyTransform instead.",
FutureWarning,
stacklevel=find_stack_level(),
)
assert domain in [constraints.lower_cholesky, constraints.corr_cholesky]
self.domain = domain
Expand Down Expand Up @@ -1021,7 +1023,8 @@ def _inverse(self, y):
if not_scalar and all(d == d0 for d in leading_dims[1:]):
warnings.warn(
"UnpackTransform.inv might lead to an unexpected behavior because it"
" cannot transform a batch of unpacked arrays."
" cannot transform a batch of unpacked arrays.",
stacklevel=find_stack_level(),
)
return ravel_pytree(y)[0]

Expand Down
7 changes: 6 additions & 1 deletion numpyro/examples/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from jax import lax

from numpyro.util import find_stack_level

if "CI" in os.environ:
DATA_DIR = os.path.expanduser("~/.data")
else:
Expand Down Expand Up @@ -258,7 +260,10 @@ def _load_jsb_chorales():


def _load_higgs(num_datapoints):
warnings.warn("Higgs is a 2.6 GB dataset")
warnings.warn(
"Higgs is a 2.6 GB dataset",
stacklevel=find_stack_level(),
)
_download(HIGGS)

file_path = os.path.join(DATA_DIR, "HIGGS.csv.gz")
Expand Down
4 changes: 3 additions & 1 deletion numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
apply_stack,
plate,
)
from numpyro.util import not_jax_tracer
from numpyro.util import find_stack_level, not_jax_tracer

__all__ = [
"block",
Expand Down Expand Up @@ -207,6 +207,7 @@ def __init__(self, fn=None, trace=None, guide_trace=None):
warnings.warn(
"`guide_trace` argument is deprecated. Please replace it by `trace`.",
FutureWarning,
stacklevel=find_stack_level(),
)
if guide_trace is not None:
trace = guide_trace
Expand Down Expand Up @@ -845,6 +846,7 @@ def process_message(self, msg):
"Attempting to intervene on variable {} multiple times,"
"this is almost certainly incorrect behavior".format(msg["name"]),
RuntimeWarning,
stacklevel=find_stack_level(),
)
msg["_intervener_id"] = self._intervener_id

Expand Down
10 changes: 8 additions & 2 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import jax
from jax import grad, hessian, lax, random, tree_map

from numpyro.util import _versiontuple
from numpyro.util import _versiontuple, find_stack_level

if _versiontuple(jax.__version__) >= (0, 2, 25):
from jax.example_libraries import stax
Expand Down Expand Up @@ -890,6 +890,7 @@ def __init__(
"`init_strategy` argument has been deprecated in favor of `init_loc_fn`"
" argument.",
FutureWarning,
stacklevel=find_stack_level(),
)
if init_scale <= 0:
raise ValueError("Expected init_scale > 0. but got {}".format(init_scale))
Expand Down Expand Up @@ -959,6 +960,7 @@ def __init__(
"`init_strategy` argument has been deprecated in favor of `init_loc_fn`"
" argument.",
FutureWarning,
stacklevel=find_stack_level(),
)
if init_scale <= 0:
raise ValueError("Expected init_scale > 0. but got {}".format(init_scale))
Expand Down Expand Up @@ -1032,6 +1034,7 @@ def __init__(
"`init_strategy` argument has been deprecated in favor of `init_loc_fn`"
" argument.",
FutureWarning,
stacklevel=find_stack_level(),
)
if init_scale <= 0:
raise ValueError("Expected init_scale > 0. but got {}".format(init_scale))
Expand Down Expand Up @@ -1161,7 +1164,8 @@ def loss_fn(z):
warnings.warn(
"Hessian of log posterior at the MAP point is singular. Posterior"
" samples from AutoLaplaceApproxmiation will be constant (equal to"
" the MAP point). Please consider using an AutoNormal guide."
" the MAP point). Please consider using an AutoNormal guide.",
stacklevel=find_stack_level(),
)
scale_tril = jnp.where(jnp.isnan(scale_tril), 0.0, scale_tril)
return LowerCholeskyAffine(loc, scale_tril)
Expand Down Expand Up @@ -1233,6 +1237,7 @@ def __init__(
"`init_strategy` argument has been deprecated in favor of `init_loc_fn`"
" argument.",
FutureWarning,
stacklevel=find_stack_level(),
)
self.num_flows = num_flows
# 2-layer, stax.Elu, skip_connections=False by default following the experiments in
Expand Down Expand Up @@ -1319,6 +1324,7 @@ def __init__(
"`init_strategy` argument has been deprecated in favor of `init_loc_fn`"
" argument.",
FutureWarning,
stacklevel=find_stack_level(),
)
self.num_flows = num_flows
self._hidden_factors = hidden_factors
Expand Down
7 changes: 4 additions & 3 deletions numpyro/infer/elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from numpyro.distributions.util import scale_and_mask
from numpyro.handlers import replay, seed, substitute, trace
from numpyro.infer.util import get_importance_trace, log_density
from numpyro.util import check_model_guide_match
from numpyro.util import check_model_guide_match, find_stack_level


class ELBO:
Expand Down Expand Up @@ -190,8 +190,9 @@ def _check_mean_field_requirement(model_trace, guide_trace):
+ "Model sites:\n "
+ "\n ".join(model_sites)
+ "Guide sites:\n "
+ "\n ".join(guide_sites)
)
+ "\n ".join(guide_sites),
stacklevel=find_stack_level(),
),


class TraceMeanField_ELBO(ELBO):
Expand Down
7 changes: 5 additions & 2 deletions numpyro/infer/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpyro.distributions as dist
from numpyro.distributions import biject_to
from numpyro.util import find_stack_level


def init_to_median(site=None, num_samples=15):
Expand All @@ -28,7 +29,8 @@ def init_to_median(site=None, num_samples=15):
if site["value"] is not None:
warnings.warn(
f"init_to_median() skipping initialization of site '{site['name']}'"
" which already stores a value."
" which already stores a value.",
stacklevel=find_stack_level(),
)
return site["value"]

Expand Down Expand Up @@ -68,7 +70,8 @@ def init_to_uniform(site=None, radius=2):
if site["value"] is not None:
warnings.warn(
f"init_to_uniform() skipping initialization of site '{site['name']}'"
" which already stores a value."
" which already stores a value.",
stacklevel=find_stack_level(),
)
return site["value"]

Expand Down
5 changes: 3 additions & 2 deletions numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from jax.tree_util import tree_flatten, tree_map, tree_multimap

from numpyro.diagnostics import print_summary
from numpyro.util import cached_by, fori_collect, identity
from numpyro.util import cached_by, find_stack_level, fori_collect, identity

__all__ = [
"MCMCKernel",
Expand Down Expand Up @@ -304,7 +304,8 @@ def __init__(
" of your program. You can double-check how many devices are available in"
" your system using `jax.local_device_count()`.".format(
self.num_chains, local_device_count(), self.num_chains
)
),
stacklevel=find_stack_level(),
)
self.chain_method = chain_method
self.progress_bar = progress_bar
Expand Down
5 changes: 3 additions & 2 deletions numpyro/infer/svi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import jax

from numpyro.util import _versiontuple
from numpyro.util import _versiontuple, find_stack_level

if _versiontuple(jax.__version__) >= (0, 2, 25):
from jax.example_libraries import optimizers
Expand Down Expand Up @@ -204,7 +204,8 @@ def init(self, rng_key, *args, **kwargs):
):
s_name = type(self.loss).__name__
warnings.warn(
f"Currently, SVI with {s_name} loss does not support models with discrete latent variables"
f"Currently, SVI with {s_name} loss does not support models with discrete latent variables",
stacklevel=find_stack_level(),
)

if not mutable_state:
Expand Down
6 changes: 4 additions & 2 deletions numpyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from numpyro.distributions.util import is_identically_one, sum_rightmost
from numpyro.handlers import condition, replay, seed, substitute, trace
from numpyro.infer.initialization import init_to_uniform, init_to_value
from numpyro.util import not_jax_tracer, soft_vmap, while_loop
from numpyro.util import find_stack_level, not_jax_tracer, soft_vmap, while_loop

__all__ = [
"find_valid_initial_params",
Expand Down Expand Up @@ -426,6 +426,7 @@ def _get_model_transforms(model, model_args=(), model_kwargs=None):
" enumerated sites need to be marked with"
" `infer={'enumerate': 'parallel'}`.",
FutureWarning,
stacklevel=find_stack_level(),
)
else:
support = v["fn"].support
Expand Down Expand Up @@ -886,6 +887,7 @@ def __init__(
batch_size, num_samples, batch_size
),
UserWarning,
stacklevel=find_stack_level(),
)
num_samples = batch_size

Expand Down Expand Up @@ -1036,7 +1038,7 @@ def helpful_support_errors(site, raise_warnings=False):
+ "a reparameterizer, e.g. "
+ f"numpyro.handlers.reparam(config={{'{name}': CircularReparam()}})."
)
warnings.warn(msg, UserWarning)
warnings.warn(msg, UserWarning, stacklevel=find_stack_level())

# Exceptions
try:
Expand Down
5 changes: 3 additions & 2 deletions numpyro/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import jax.numpy as jnp

import numpyro
from numpyro.util import identity
from numpyro.util import find_stack_level, identity

_PYRO_STACK = []

Expand Down Expand Up @@ -463,7 +463,8 @@ def _subsample(name, size, subsample_size, dim):
"subsample_size does not match len(subsample), {} vs {}.".format(
subsample_size, len(subsample)
)
+ " Did you accidentally use different subsample_size in the model and guide?"
+ " Did you accidentally use different subsample_size in the model and guide?",
stacklevel=find_stack_level(),
)
cond_indep_stack = msg["cond_indep_stack"]
occupied_dims = {f.dim for f in cond_indep_stack}
Expand Down
Loading

0 comments on commit 0b316d4

Please sign in to comment.