Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add stacklevel to warnings #1271

Merged
merged 1 commit into from
Jan 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions numpyro/compat/pyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import warnings

from numpyro.compat.util import UnsupportedAPIWarning
from numpyro.util import find_stack_level

from numpyro.primitives import module, plate, sample # noqa: F401 isort:skip
from numpyro.primitives import param as _param # noqa: F401 isort:skip
Expand All @@ -16,6 +17,7 @@ def get_param_store():
"A limited parameter store is provided for compatibility with Pyro. "
"Value of SVI parameters should be obtained via SVI.get_params() method.",
category=UnsupportedAPIWarning,
stacklevel=find_stack_level(),
)
# Return an empty dict for compatibility
return _PARAM_STORE
Expand Down
2 changes: 2 additions & 0 deletions numpyro/contrib/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import warnings

from numpyro.ops import Vindex, vindex # noqa: F401
from numpyro.util import find_stack_level

warnings.warn(
"`indexing` module has been moved from `numpyro.contrib` to `numpyro.ops`."
" Please import Vindex or vindex functions from `numpyro.ops.indexing`.",
FutureWarning,
stacklevel=find_stack_level(),
)
3 changes: 2 additions & 1 deletion numpyro/contrib/tfp/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpyro.distributions as numpyro_dist
from numpyro.distributions import Distribution as NumPyroDistribution, constraints
from numpyro.distributions.transforms import Transform, biject_to
from numpyro.util import not_jax_tracer
from numpyro.util import find_stack_level, not_jax_tracer


def _get_codomain(bijector):
Expand Down Expand Up @@ -129,6 +129,7 @@ def init(self, *args, **kwargs):
"deprecated. You should import distributions directly from "
"tensorflow_probability.substrates.jax.distributions instead.",
FutureWarning,
stacklevel=find_stack_level(),
)
self.tfp_dist = tfd_class(*args, **kwargs)

Expand Down
3 changes: 2 additions & 1 deletion numpyro/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
promote_shapes,
validate_sample,
)
from numpyro.util import not_jax_tracer
from numpyro.util import find_stack_level, not_jax_tracer


def _to_probs_bernoulli(logits):
Expand Down Expand Up @@ -512,6 +512,7 @@ def __init__(self):
"PRNGIdentity distribution is deprecated. To get a random "
"PRNG key, you can use `numpyro.prng_key()` instead.",
FutureWarning,
stacklevel=find_stack_level(),
)
super(PRNGIdentity, self).__init__(event_shape=(2,))

Expand Down
5 changes: 3 additions & 2 deletions numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
sum_rightmost,
validate_sample,
)
from numpyro.util import not_jax_tracer
from numpyro.util import find_stack_level, not_jax_tracer

from . import constraints

Expand Down Expand Up @@ -291,7 +291,8 @@ def _validate_sample(self, value):
if not np.all(mask):
warnings.warn(
"Out-of-support values provided to log prob method. "
"The value argument should be within the support."
"The value argument should be within the support.",
stacklevel=find_stack_level(),
)
return mask

Expand Down
7 changes: 6 additions & 1 deletion numpyro/distributions/gof.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def test_my_distribution():

import jax

from numpyro.util import find_stack_level

HISTOGRAM_WIDTH = 60


Expand Down Expand Up @@ -117,7 +119,10 @@ def multinomial_goodness_of_fit(probs, counts, *, total_count=None, plot=False):
chi_squared += (c - mean) ** 2 / variance
dof += 1
else:
warnings.warn("Zero probability in goodness-of-fit test")
warnings.warn(
"Zero probability in goodness-of-fit test",
stacklevel=find_stack_level(),
)
if c > 0:
return math.inf

Expand Down
7 changes: 5 additions & 2 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
sum_rightmost,
vec_to_tril_matrix,
)
from numpyro.util import not_jax_tracer
from numpyro.util import find_stack_level, not_jax_tracer

__all__ = [
"biject_to",
Expand Down Expand Up @@ -67,6 +67,7 @@ def event_dim(self):
"transform.event_dim is deprecated. Please use Transform.domain.event_dim to "
"get input event dim or Transform.codomain.event_dim to get output event dim.",
FutureWarning,
stacklevel=find_stack_level(),
)
return self.domain.event_dim

Expand Down Expand Up @@ -560,6 +561,7 @@ def __init__(self, domain=constraints.lower_cholesky):
"InvCholeskyTransform is deprecated. Please use CholeskyTransform"
" or CorrMatrixCholeskyTransform instead.",
FutureWarning,
stacklevel=find_stack_level(),
)
assert domain in [constraints.lower_cholesky, constraints.corr_cholesky]
self.domain = domain
Expand Down Expand Up @@ -1021,7 +1023,8 @@ def _inverse(self, y):
if not_scalar and all(d == d0 for d in leading_dims[1:]):
warnings.warn(
"UnpackTransform.inv might lead to an unexpected behavior because it"
" cannot transform a batch of unpacked arrays."
" cannot transform a batch of unpacked arrays.",
stacklevel=find_stack_level(),
)
return ravel_pytree(y)[0]

Expand Down
7 changes: 6 additions & 1 deletion numpyro/examples/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from jax import lax

from numpyro.util import find_stack_level

if "CI" in os.environ:
DATA_DIR = os.path.expanduser("~/.data")
else:
Expand Down Expand Up @@ -258,7 +260,10 @@ def _load_jsb_chorales():


def _load_higgs(num_datapoints):
warnings.warn("Higgs is a 2.6 GB dataset")
warnings.warn(
"Higgs is a 2.6 GB dataset",
stacklevel=find_stack_level(),
)
_download(HIGGS)

file_path = os.path.join(DATA_DIR, "HIGGS.csv.gz")
Expand Down
4 changes: 3 additions & 1 deletion numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@
apply_stack,
plate,
)
from numpyro.util import not_jax_tracer
from numpyro.util import find_stack_level, not_jax_tracer

__all__ = [
"block",
Expand Down Expand Up @@ -207,6 +207,7 @@ def __init__(self, fn=None, trace=None, guide_trace=None):
warnings.warn(
"`guide_trace` argument is deprecated. Please replace it by `trace`.",
FutureWarning,
stacklevel=find_stack_level(),
)
if guide_trace is not None:
trace = guide_trace
Expand Down Expand Up @@ -845,6 +846,7 @@ def process_message(self, msg):
"Attempting to intervene on variable {} multiple times,"
"this is almost certainly incorrect behavior".format(msg["name"]),
RuntimeWarning,
stacklevel=find_stack_level(),
)
msg["_intervener_id"] = self._intervener_id

Expand Down
10 changes: 8 additions & 2 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import jax
from jax import grad, hessian, lax, random, tree_map

from numpyro.util import _versiontuple
from numpyro.util import _versiontuple, find_stack_level

if _versiontuple(jax.__version__) >= (0, 2, 25):
from jax.example_libraries import stax
Expand Down Expand Up @@ -890,6 +890,7 @@ def __init__(
"`init_strategy` argument has been deprecated in favor of `init_loc_fn`"
" argument.",
FutureWarning,
stacklevel=find_stack_level(),
)
if init_scale <= 0:
raise ValueError("Expected init_scale > 0. but got {}".format(init_scale))
Expand Down Expand Up @@ -959,6 +960,7 @@ def __init__(
"`init_strategy` argument has been deprecated in favor of `init_loc_fn`"
" argument.",
FutureWarning,
stacklevel=find_stack_level(),
)
if init_scale <= 0:
raise ValueError("Expected init_scale > 0. but got {}".format(init_scale))
Expand Down Expand Up @@ -1032,6 +1034,7 @@ def __init__(
"`init_strategy` argument has been deprecated in favor of `init_loc_fn`"
" argument.",
FutureWarning,
stacklevel=find_stack_level(),
)
if init_scale <= 0:
raise ValueError("Expected init_scale > 0. but got {}".format(init_scale))
Expand Down Expand Up @@ -1161,7 +1164,8 @@ def loss_fn(z):
warnings.warn(
"Hessian of log posterior at the MAP point is singular. Posterior"
" samples from AutoLaplaceApproxmiation will be constant (equal to"
" the MAP point). Please consider using an AutoNormal guide."
" the MAP point). Please consider using an AutoNormal guide.",
stacklevel=find_stack_level(),
)
scale_tril = jnp.where(jnp.isnan(scale_tril), 0.0, scale_tril)
return LowerCholeskyAffine(loc, scale_tril)
Expand Down Expand Up @@ -1233,6 +1237,7 @@ def __init__(
"`init_strategy` argument has been deprecated in favor of `init_loc_fn`"
" argument.",
FutureWarning,
stacklevel=find_stack_level(),
)
self.num_flows = num_flows
# 2-layer, stax.Elu, skip_connections=False by default following the experiments in
Expand Down Expand Up @@ -1319,6 +1324,7 @@ def __init__(
"`init_strategy` argument has been deprecated in favor of `init_loc_fn`"
" argument.",
FutureWarning,
stacklevel=find_stack_level(),
)
self.num_flows = num_flows
self._hidden_factors = hidden_factors
Expand Down
7 changes: 4 additions & 3 deletions numpyro/infer/elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from numpyro.distributions.util import scale_and_mask
from numpyro.handlers import replay, seed, substitute, trace
from numpyro.infer.util import get_importance_trace, log_density
from numpyro.util import check_model_guide_match
from numpyro.util import check_model_guide_match, find_stack_level


class ELBO:
Expand Down Expand Up @@ -190,8 +190,9 @@ def _check_mean_field_requirement(model_trace, guide_trace):
+ "Model sites:\n "
+ "\n ".join(model_sites)
+ "Guide sites:\n "
+ "\n ".join(guide_sites)
)
+ "\n ".join(guide_sites),
stacklevel=find_stack_level(),
),


class TraceMeanField_ELBO(ELBO):
Expand Down
7 changes: 5 additions & 2 deletions numpyro/infer/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpyro.distributions as dist
from numpyro.distributions import biject_to
from numpyro.util import find_stack_level


def init_to_median(site=None, num_samples=15):
Expand All @@ -28,7 +29,8 @@ def init_to_median(site=None, num_samples=15):
if site["value"] is not None:
warnings.warn(
f"init_to_median() skipping initialization of site '{site['name']}'"
" which already stores a value."
" which already stores a value.",
stacklevel=find_stack_level(),
)
return site["value"]

Expand Down Expand Up @@ -68,7 +70,8 @@ def init_to_uniform(site=None, radius=2):
if site["value"] is not None:
warnings.warn(
f"init_to_uniform() skipping initialization of site '{site['name']}'"
" which already stores a value."
" which already stores a value.",
stacklevel=find_stack_level(),
)
return site["value"]

Expand Down
5 changes: 3 additions & 2 deletions numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from jax.tree_util import tree_flatten, tree_map, tree_multimap

from numpyro.diagnostics import print_summary
from numpyro.util import cached_by, fori_collect, identity
from numpyro.util import cached_by, find_stack_level, fori_collect, identity

__all__ = [
"MCMCKernel",
Expand Down Expand Up @@ -304,7 +304,8 @@ def __init__(
" of your program. You can double-check how many devices are available in"
" your system using `jax.local_device_count()`.".format(
self.num_chains, local_device_count(), self.num_chains
)
),
stacklevel=find_stack_level(),
)
self.chain_method = chain_method
self.progress_bar = progress_bar
Expand Down
5 changes: 3 additions & 2 deletions numpyro/infer/svi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import jax

from numpyro.util import _versiontuple
from numpyro.util import _versiontuple, find_stack_level

if _versiontuple(jax.__version__) >= (0, 2, 25):
from jax.example_libraries import optimizers
Expand Down Expand Up @@ -204,7 +204,8 @@ def init(self, rng_key, *args, **kwargs):
):
s_name = type(self.loss).__name__
warnings.warn(
f"Currently, SVI with {s_name} loss does not support models with discrete latent variables"
f"Currently, SVI with {s_name} loss does not support models with discrete latent variables",
stacklevel=find_stack_level(),
)

if not mutable_state:
Expand Down
6 changes: 4 additions & 2 deletions numpyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from numpyro.distributions.util import is_identically_one, sum_rightmost
from numpyro.handlers import condition, replay, seed, substitute, trace
from numpyro.infer.initialization import init_to_uniform, init_to_value
from numpyro.util import not_jax_tracer, soft_vmap, while_loop
from numpyro.util import find_stack_level, not_jax_tracer, soft_vmap, while_loop

__all__ = [
"find_valid_initial_params",
Expand Down Expand Up @@ -426,6 +426,7 @@ def _get_model_transforms(model, model_args=(), model_kwargs=None):
" enumerated sites need to be marked with"
" `infer={'enumerate': 'parallel'}`.",
FutureWarning,
stacklevel=find_stack_level(),
)
else:
support = v["fn"].support
Expand Down Expand Up @@ -886,6 +887,7 @@ def __init__(
batch_size, num_samples, batch_size
),
UserWarning,
stacklevel=find_stack_level(),
)
num_samples = batch_size

Expand Down Expand Up @@ -1036,7 +1038,7 @@ def helpful_support_errors(site, raise_warnings=False):
+ "a reparameterizer, e.g. "
+ f"numpyro.handlers.reparam(config={{'{name}': CircularReparam()}})."
)
warnings.warn(msg, UserWarning)
warnings.warn(msg, UserWarning, stacklevel=find_stack_level())

# Exceptions
try:
Expand Down
5 changes: 3 additions & 2 deletions numpyro/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import jax.numpy as jnp

import numpyro
from numpyro.util import identity
from numpyro.util import find_stack_level, identity

_PYRO_STACK = []

Expand Down Expand Up @@ -463,7 +463,8 @@ def _subsample(name, size, subsample_size, dim):
"subsample_size does not match len(subsample), {} vs {}.".format(
subsample_size, len(subsample)
)
+ " Did you accidentally use different subsample_size in the model and guide?"
+ " Did you accidentally use different subsample_size in the model and guide?",
stacklevel=find_stack_level(),
)
cond_indep_stack = msg["cond_indep_stack"]
occupied_dims = {f.dim for f in cond_indep_stack}
Expand Down
Loading