diff --git a/numpyro/compat/pyro.py b/numpyro/compat/pyro.py index 07d94df97..18a443748 100644 --- a/numpyro/compat/pyro.py +++ b/numpyro/compat/pyro.py @@ -4,6 +4,7 @@ import warnings from numpyro.compat.util import UnsupportedAPIWarning +from numpyro.util import find_stack_level from numpyro.primitives import module, plate, sample # noqa: F401 isort:skip from numpyro.primitives import param as _param # noqa: F401 isort:skip @@ -16,6 +17,7 @@ def get_param_store(): "A limited parameter store is provided for compatibility with Pyro. " "Value of SVI parameters should be obtained via SVI.get_params() method.", category=UnsupportedAPIWarning, + stacklevel=find_stack_level(), ) # Return an empty dict for compatibility return _PARAM_STORE diff --git a/numpyro/contrib/indexing.py b/numpyro/contrib/indexing.py index 91a4505f0..e36760c30 100644 --- a/numpyro/contrib/indexing.py +++ b/numpyro/contrib/indexing.py @@ -4,9 +4,11 @@ import warnings from numpyro.ops import Vindex, vindex # noqa: F401 +from numpyro.util import find_stack_level warnings.warn( "`indexing` module has been moved from `numpyro.contrib` to `numpyro.ops`." " Please import Vindex or vindex functions from `numpyro.ops.indexing`.", FutureWarning, + stacklevel=find_stack_level(), ) diff --git a/numpyro/contrib/tfp/distributions.py b/numpyro/contrib/tfp/distributions.py index bb7504578..a21543e19 100644 --- a/numpyro/contrib/tfp/distributions.py +++ b/numpyro/contrib/tfp/distributions.py @@ -13,7 +13,7 @@ import numpyro.distributions as numpyro_dist from numpyro.distributions import Distribution as NumPyroDistribution, constraints from numpyro.distributions.transforms import Transform, biject_to -from numpyro.util import not_jax_tracer +from numpyro.util import find_stack_level, not_jax_tracer def _get_codomain(bijector): @@ -129,6 +129,7 @@ def init(self, *args, **kwargs): "deprecated. You should import distributions directly from " "tensorflow_probability.substrates.jax.distributions instead.", FutureWarning, + stacklevel=find_stack_level(), ) self.tfp_dist = tfd_class(*args, **kwargs) diff --git a/numpyro/distributions/discrete.py b/numpyro/distributions/discrete.py index dd70ba590..6e523a53a 100644 --- a/numpyro/distributions/discrete.py +++ b/numpyro/distributions/discrete.py @@ -49,7 +49,7 @@ promote_shapes, validate_sample, ) -from numpyro.util import not_jax_tracer +from numpyro.util import find_stack_level, not_jax_tracer def _to_probs_bernoulli(logits): @@ -512,6 +512,7 @@ def __init__(self): "PRNGIdentity distribution is deprecated. To get a random " "PRNG key, you can use `numpyro.prng_key()` instead.", FutureWarning, + stacklevel=find_stack_level(), ) super(PRNGIdentity, self).__init__(event_shape=(2,)) diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index 97fe10d60..2139981e8 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -44,7 +44,7 @@ sum_rightmost, validate_sample, ) -from numpyro.util import not_jax_tracer +from numpyro.util import find_stack_level, not_jax_tracer from . import constraints @@ -291,7 +291,8 @@ def _validate_sample(self, value): if not np.all(mask): warnings.warn( "Out-of-support values provided to log prob method. " - "The value argument should be within the support." + "The value argument should be within the support.", + stacklevel=find_stack_level(), ) return mask diff --git a/numpyro/distributions/gof.py b/numpyro/distributions/gof.py index d101c56ac..fd186c690 100644 --- a/numpyro/distributions/gof.py +++ b/numpyro/distributions/gof.py @@ -63,6 +63,8 @@ def test_my_distribution(): import jax +from numpyro.util import find_stack_level + HISTOGRAM_WIDTH = 60 @@ -117,7 +119,10 @@ def multinomial_goodness_of_fit(probs, counts, *, total_count=None, plot=False): chi_squared += (c - mean) ** 2 / variance dof += 1 else: - warnings.warn("Zero probability in goodness-of-fit test") + warnings.warn( + "Zero probability in goodness-of-fit test", + stacklevel=find_stack_level(), + ) if c > 0: return math.inf diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index e468b9b8c..ef433fceb 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -22,7 +22,7 @@ sum_rightmost, vec_to_tril_matrix, ) -from numpyro.util import not_jax_tracer +from numpyro.util import find_stack_level, not_jax_tracer __all__ = [ "biject_to", @@ -67,6 +67,7 @@ def event_dim(self): "transform.event_dim is deprecated. Please use Transform.domain.event_dim to " "get input event dim or Transform.codomain.event_dim to get output event dim.", FutureWarning, + stacklevel=find_stack_level(), ) return self.domain.event_dim @@ -560,6 +561,7 @@ def __init__(self, domain=constraints.lower_cholesky): "InvCholeskyTransform is deprecated. Please use CholeskyTransform" " or CorrMatrixCholeskyTransform instead.", FutureWarning, + stacklevel=find_stack_level(), ) assert domain in [constraints.lower_cholesky, constraints.corr_cholesky] self.domain = domain @@ -1021,7 +1023,8 @@ def _inverse(self, y): if not_scalar and all(d == d0 for d in leading_dims[1:]): warnings.warn( "UnpackTransform.inv might lead to an unexpected behavior because it" - " cannot transform a batch of unpacked arrays." + " cannot transform a batch of unpacked arrays.", + stacklevel=find_stack_level(), ) return ravel_pytree(y)[0] diff --git a/numpyro/examples/datasets.py b/numpyro/examples/datasets.py index d8f2e6784..126f37c38 100644 --- a/numpyro/examples/datasets.py +++ b/numpyro/examples/datasets.py @@ -17,6 +17,8 @@ from jax import lax +from numpyro.util import find_stack_level + if "CI" in os.environ: DATA_DIR = os.path.expanduser("~/.data") else: @@ -258,7 +260,10 @@ def _load_jsb_chorales(): def _load_higgs(num_datapoints): - warnings.warn("Higgs is a 2.6 GB dataset") + warnings.warn( + "Higgs is a 2.6 GB dataset", + stacklevel=find_stack_level(), + ) _download(HIGGS) file_path = os.path.join(DATA_DIR, "HIGGS.csv.gz") diff --git a/numpyro/handlers.py b/numpyro/handlers.py index 337eadcab..bea52a623 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -93,7 +93,7 @@ apply_stack, plate, ) -from numpyro.util import not_jax_tracer +from numpyro.util import find_stack_level, not_jax_tracer __all__ = [ "block", @@ -207,6 +207,7 @@ def __init__(self, fn=None, trace=None, guide_trace=None): warnings.warn( "`guide_trace` argument is deprecated. Please replace it by `trace`.", FutureWarning, + stacklevel=find_stack_level(), ) if guide_trace is not None: trace = guide_trace @@ -845,6 +846,7 @@ def process_message(self, msg): "Attempting to intervene on variable {} multiple times," "this is almost certainly incorrect behavior".format(msg["name"]), RuntimeWarning, + stacklevel=find_stack_level(), ) msg["_intervener_id"] = self._intervener_id diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index 57acf824a..7df4f35ec 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -12,7 +12,7 @@ import jax from jax import grad, hessian, lax, random, tree_map -from numpyro.util import _versiontuple +from numpyro.util import _versiontuple, find_stack_level if _versiontuple(jax.__version__) >= (0, 2, 25): from jax.example_libraries import stax @@ -890,6 +890,7 @@ def __init__( "`init_strategy` argument has been deprecated in favor of `init_loc_fn`" " argument.", FutureWarning, + stacklevel=find_stack_level(), ) if init_scale <= 0: raise ValueError("Expected init_scale > 0. but got {}".format(init_scale)) @@ -959,6 +960,7 @@ def __init__( "`init_strategy` argument has been deprecated in favor of `init_loc_fn`" " argument.", FutureWarning, + stacklevel=find_stack_level(), ) if init_scale <= 0: raise ValueError("Expected init_scale > 0. but got {}".format(init_scale)) @@ -1032,6 +1034,7 @@ def __init__( "`init_strategy` argument has been deprecated in favor of `init_loc_fn`" " argument.", FutureWarning, + stacklevel=find_stack_level(), ) if init_scale <= 0: raise ValueError("Expected init_scale > 0. but got {}".format(init_scale)) @@ -1161,7 +1164,8 @@ def loss_fn(z): warnings.warn( "Hessian of log posterior at the MAP point is singular. Posterior" " samples from AutoLaplaceApproxmiation will be constant (equal to" - " the MAP point). Please consider using an AutoNormal guide." + " the MAP point). Please consider using an AutoNormal guide.", + stacklevel=find_stack_level(), ) scale_tril = jnp.where(jnp.isnan(scale_tril), 0.0, scale_tril) return LowerCholeskyAffine(loc, scale_tril) @@ -1233,6 +1237,7 @@ def __init__( "`init_strategy` argument has been deprecated in favor of `init_loc_fn`" " argument.", FutureWarning, + stacklevel=find_stack_level(), ) self.num_flows = num_flows # 2-layer, stax.Elu, skip_connections=False by default following the experiments in @@ -1319,6 +1324,7 @@ def __init__( "`init_strategy` argument has been deprecated in favor of `init_loc_fn`" " argument.", FutureWarning, + stacklevel=find_stack_level(), ) self.num_flows = num_flows self._hidden_factors = hidden_factors diff --git a/numpyro/infer/elbo.py b/numpyro/infer/elbo.py index 6e1e1d0e4..6580ccea0 100644 --- a/numpyro/infer/elbo.py +++ b/numpyro/infer/elbo.py @@ -13,7 +13,7 @@ from numpyro.distributions.util import scale_and_mask from numpyro.handlers import replay, seed, substitute, trace from numpyro.infer.util import get_importance_trace, log_density -from numpyro.util import check_model_guide_match +from numpyro.util import check_model_guide_match, find_stack_level class ELBO: @@ -190,8 +190,9 @@ def _check_mean_field_requirement(model_trace, guide_trace): + "Model sites:\n " + "\n ".join(model_sites) + "Guide sites:\n " - + "\n ".join(guide_sites) - ) + + "\n ".join(guide_sites), + stacklevel=find_stack_level(), + ), class TraceMeanField_ELBO(ELBO): diff --git a/numpyro/infer/initialization.py b/numpyro/infer/initialization.py index b6e00e9ca..78be69363 100644 --- a/numpyro/infer/initialization.py +++ b/numpyro/infer/initialization.py @@ -8,6 +8,7 @@ import numpyro.distributions as dist from numpyro.distributions import biject_to +from numpyro.util import find_stack_level def init_to_median(site=None, num_samples=15): @@ -28,7 +29,8 @@ def init_to_median(site=None, num_samples=15): if site["value"] is not None: warnings.warn( f"init_to_median() skipping initialization of site '{site['name']}'" - " which already stores a value." + " which already stores a value.", + stacklevel=find_stack_level(), ) return site["value"] @@ -68,7 +70,8 @@ def init_to_uniform(site=None, radius=2): if site["value"] is not None: warnings.warn( f"init_to_uniform() skipping initialization of site '{site['name']}'" - " which already stores a value." + " which already stores a value.", + stacklevel=find_stack_level(), ) return site["value"] diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index 7ceb43ab1..649b33951 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -16,7 +16,7 @@ from jax.tree_util import tree_flatten, tree_map, tree_multimap from numpyro.diagnostics import print_summary -from numpyro.util import cached_by, fori_collect, identity +from numpyro.util import cached_by, find_stack_level, fori_collect, identity __all__ = [ "MCMCKernel", @@ -304,7 +304,8 @@ def __init__( " of your program. You can double-check how many devices are available in" " your system using `jax.local_device_count()`.".format( self.num_chains, local_device_count(), self.num_chains - ) + ), + stacklevel=find_stack_level(), ) self.chain_method = chain_method self.progress_bar = progress_bar diff --git a/numpyro/infer/svi.py b/numpyro/infer/svi.py index f73698e4a..2238d549b 100644 --- a/numpyro/infer/svi.py +++ b/numpyro/infer/svi.py @@ -8,7 +8,7 @@ import jax -from numpyro.util import _versiontuple +from numpyro.util import _versiontuple, find_stack_level if _versiontuple(jax.__version__) >= (0, 2, 25): from jax.example_libraries import optimizers @@ -204,7 +204,8 @@ def init(self, rng_key, *args, **kwargs): ): s_name = type(self.loss).__name__ warnings.warn( - f"Currently, SVI with {s_name} loss does not support models with discrete latent variables" + f"Currently, SVI with {s_name} loss does not support models with discrete latent variables", + stacklevel=find_stack_level(), ) if not mutable_state: diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py index 27f238885..615ac8df9 100644 --- a/numpyro/infer/util.py +++ b/numpyro/infer/util.py @@ -22,7 +22,7 @@ from numpyro.distributions.util import is_identically_one, sum_rightmost from numpyro.handlers import condition, replay, seed, substitute, trace from numpyro.infer.initialization import init_to_uniform, init_to_value -from numpyro.util import not_jax_tracer, soft_vmap, while_loop +from numpyro.util import find_stack_level, not_jax_tracer, soft_vmap, while_loop __all__ = [ "find_valid_initial_params", @@ -426,6 +426,7 @@ def _get_model_transforms(model, model_args=(), model_kwargs=None): " enumerated sites need to be marked with" " `infer={'enumerate': 'parallel'}`.", FutureWarning, + stacklevel=find_stack_level(), ) else: support = v["fn"].support @@ -886,6 +887,7 @@ def __init__( batch_size, num_samples, batch_size ), UserWarning, + stacklevel=find_stack_level(), ) num_samples = batch_size @@ -1036,7 +1038,7 @@ def helpful_support_errors(site, raise_warnings=False): + "a reparameterizer, e.g. " + f"numpyro.handlers.reparam(config={{'{name}': CircularReparam()}})." ) - warnings.warn(msg, UserWarning) + warnings.warn(msg, UserWarning, stacklevel=find_stack_level()) # Exceptions try: diff --git a/numpyro/primitives.py b/numpyro/primitives.py index 670564a9d..3caad05d8 100644 --- a/numpyro/primitives.py +++ b/numpyro/primitives.py @@ -11,7 +11,7 @@ import jax.numpy as jnp import numpyro -from numpyro.util import identity +from numpyro.util import find_stack_level, identity _PYRO_STACK = [] @@ -463,7 +463,8 @@ def _subsample(name, size, subsample_size, dim): "subsample_size does not match len(subsample), {} vs {}.".format( subsample_size, len(subsample) ) - + " Did you accidentally use different subsample_size in the model and guide?" + + " Did you accidentally use different subsample_size in the model and guide?", + stacklevel=find_stack_level(), ) cond_indep_stack = msg["cond_indep_stack"] occupied_dims = {f.dim for f in cond_indep_stack} diff --git a/numpyro/util.py b/numpyro/util.py index 472fb3289..3cccef8c1 100644 --- a/numpyro/util.py +++ b/numpyro/util.py @@ -311,7 +311,8 @@ def fori_collect( # See: https://github.com/google/jax/issues/6447 if num_chains > 1 and jax.default_backend() == "gpu": warnings.warn( - "We will disable progress bar because it does not work yet on multi-GPUs platforms." + "We will disable progress bar because it does not work yet on multi-GPUs platforms.", + stacklevel=find_stack_level(), ) progbar = False @@ -547,20 +548,23 @@ def check_model_guide_match(model_trace, guide_trace): if aux_vars & model_vars: warnings.warn( - "Found auxiliary vars in the model: {}".format(aux_vars & model_vars) + "Found auxiliary vars in the model: {}".format(aux_vars & model_vars), + stacklevel=find_stack_level(), ) if not (guide_vars <= model_vars | aux_vars): warnings.warn( "Found non-auxiliary vars in guide but not model, " "consider marking these infer={{'is_auxiliary': True}}:\n{}".format( guide_vars - aux_vars - model_vars - ) + ), + stacklevel=find_stack_level(), ) if not (model_vars <= guide_vars | enum_vars): warnings.warn( "Found vars in model but not guide: {}".format( model_vars - guide_vars - enum_vars - ) + ), + stacklevel=find_stack_level(), ) # Check shapes agree. @@ -605,7 +609,8 @@ def check_model_guide_match(model_trace, guide_trace): warnings.warn( "Found plate statements in guide but not model: {}".format( guide_vars - model_vars - ) + ), + stacklevel=find_stack_level(), ) # Check if plate is missing in the model. @@ -627,7 +632,8 @@ def check_model_guide_match(model_trace, guide_trace): warnings.warn( f"Missing a plate statement for batch dimension {dim}" f" at site '{name}'. You can use `numpyro.util.format_shapes`" - " utility to check shapes at all sites of your model." + " utility to check shapes at all sites of your model.", + stacklevel=find_stack_level(), ) @@ -689,3 +695,29 @@ def _versiontuple(version): Source: https://stackoverflow.com/a/11887825/4451315 """ return tuple([int(number) for number in version.split(".")]) + + +def find_stack_level() -> int: + """ + Find the first place in the stack that is not inside numpyro + (tests notwithstanding). + + Source: + https://github.com/pandas-dev/pandas/blob/9a4fcea8de798938a434fcaf67a0aa5a46b76b5b/pandas/util/_exceptions.py#L27-L45 + """ + import inspect + + stack = inspect.stack() + + import numpyro + + pkg_dir = os.path.dirname(numpyro.__file__) + test_dir = os.path.join(pkg_dir, "tests") + + for n in range(len(stack)): + fname = stack[n].filename + if fname.startswith(pkg_dir) and not fname.startswith(test_dir): + continue + else: + break + return n