From 3ef82e8d98c28a65a044a351357fe4cd1b07109f Mon Sep 17 00:00:00 2001 From: nstarman Date: Wed, 16 Oct 2024 16:37:30 -0400 Subject: [PATCH 1/3] build: switch to pyproject.toml Signed-off-by: nstarman --- pyproject.toml | 88 ++++++++++++++++++++++++++++++++++++++++++ setup.py | 103 ------------------------------------------------- 2 files changed, 88 insertions(+), 103 deletions(-) delete mode 100644 setup.py diff --git a/pyproject.toml b/pyproject.toml index 8e4a9df63..e4bb01e59 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,91 @@ +[project] +name = "numpyro" +version = "0.1.53" +description="Pyro PPL on NumPy" +readme = "README.md" +requires-python = ">=3.9" +authors = [ + { name = "Uber AI Labs", email = "npradhan@uber.com" }, +] +keywords = ["probabilistic", "machine learning", "bayesian", "statistics"] +license.file = "LICENSE.md" +classifiers=[ + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: POSIX :: Linux", + "Operating System :: MacOS :: MacOS X", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", +] +dependencies = [ + "jax>=0.4.25", + "jaxlib>=0.4.25", + "multipledispatch>=1.0.0", + "numpy>=1.22", + "tqdm>=4.60", +] + +[project.optional-dependencies] +doc = [ + "ipython", # sphinx needs this to render codes + "nbsphinx>=0.8.9", + "readthedocs-sphinx-search>=0.3.2", + "sphinx>=5", + "sphinx_rtd_theme", + "sphinx-gallery", +] +test = [ + "importlib-metadata<5.0", + "ruff>=0.1.8", + "pytest>=4.1", + "pyro-api>=0.1.1", + "scikit-learn", + "scipy>=1.9", +] +dev = [ + "dm-haiku", + "flax", + "funsor>=0.4.1", + "graphviz", + "jaxns==2.6.3", + "matplotlib", + "optax>=0.0.6", + "pylab-sdk", # jaxns dependency + "pyyaml", # flax dependency + "requests", # pylab dependency + "tensorflow_probability>=0.18.0", +] +examples = [ + "arviz", + "jupyter", + "matplotlib", + "pandas", + "seaborn", + "scikit-learn", + "wordcloud", +] +cpu = ["numpyro", "jax[cpu]"] +gpu = ["numpyro", "jax[tpu]"] +cuda = ["numpyro", "jax[cuda]"] + +[project.urls] +Documentation = "https://num.pyro.ai/en/stable/" +Repository = "https://github.com/pyro-ppl/numpyro" + + +[build-system] +requires = ["setuptools>=42", "wheel"] +build-backend = "setuptools.build_meta" + + +# NOTE: this can be simplified using src-layout +[tool.setuptools.packages.find] +include = ["numpyro*"] + + [tool.ruff] # Exclude a variety of commonly ignored directories. exclude = [ diff --git a/setup.py b/setup.py deleted file mode 100644 index 25a13af84..000000000 --- a/setup.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright Contributors to the Pyro project. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import absolute_import, division, print_function - -import os -import sys - -from setuptools import find_packages, setup - -PROJECT_PATH = os.path.dirname(os.path.abspath(__file__)) -_jax_version_constraints = ">=0.4.25" -_jaxlib_version_constraints = ">=0.4.25" - -# Find version -for line in open(os.path.join(PROJECT_PATH, "numpyro", "version.py")): - if line.startswith("__version__ = "): - version = line.strip().split()[2][1:-1] - -# READ README.md for long description on PyPi. -try: - long_description = open("README.md", encoding="utf-8").read() -except Exception as e: - sys.stderr.write("Failed to read README.md:\n {}\n".format(e)) - sys.stderr.flush() - long_description = "" - -setup( - name="numpyro", - version=version, - description="Pyro PPL on NumPy", - packages=find_packages(include=["numpyro", "numpyro.*"]), - url="https://github.com/pyro-ppl/numpyro", - author="Uber AI Labs", - install_requires=[ - f"jax{_jax_version_constraints}", - f"jaxlib{_jaxlib_version_constraints}", - "multipledispatch", - "numpy", - "tqdm", - ], - extras_require={ - "doc": [ - "ipython", # sphinx needs this to render codes - "nbsphinx>=0.8.9", - "readthedocs-sphinx-search>=0.3.2", - "sphinx>=5", - "sphinx_rtd_theme", - "sphinx-gallery", - ], - "test": [ - "importlib-metadata<5.0", - "ruff>=0.1.8", - "pytest>=4.1", - "pyro-api>=0.1.1", - "scikit-learn", - "scipy>=1.9", - ], - "dev": [ - "dm-haiku", - "flax", - "funsor>=0.4.1", - "graphviz", - "jaxns==2.6.3", - "matplotlib", - "optax>=0.0.6", - "pylab-sdk", # jaxns dependency - "pyyaml", # flax dependency - "requests", # pylab dependency - "tensorflow_probability>=0.18.0", - ], - "examples": [ - "arviz", - "jupyter", - "matplotlib", - "pandas", - "seaborn", - "scikit-learn", - "wordcloud", - ], - "cpu": f"jax[cpu]{_jax_version_constraints}", - # TPU and CUDA installations, currently require to add package repository URL, i.e., - # pip install numpyro[cuda] -f https://storage.googleapis.com/jax-releases/jax_releases.html - "tpu": f"jax[tpu]{_jax_version_constraints}", - "cuda": f"jax[cuda]{_jax_version_constraints}", - }, - python_requires=">=3.9", - long_description=long_description, - long_description_content_type="text/markdown", - keywords="probabilistic machine learning bayesian statistics", - license="Apache License 2.0", - classifiers=[ - "Intended Audience :: Developers", - "Intended Audience :: Education", - "Intended Audience :: Science/Research", - "License :: OSI Approved :: Apache Software License", - "Operating System :: POSIX :: Linux", - "Operating System :: MacOS :: MacOS X", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - ], -) From f3fe8a9388a09310b7624ed835a578b1fbd803ff Mon Sep 17 00:00:00 2001 From: nstarman Date: Wed, 16 Oct 2024 17:24:45 -0400 Subject: [PATCH 2/3] build: bump setuptools Signed-off-by: nstarman --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e4bb01e59..abd2b7fad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,7 +77,7 @@ Repository = "https://github.com/pyro-ppl/numpyro" [build-system] -requires = ["setuptools>=42", "wheel"] +requires = ["setuptools>=61", "wheel"] build-backend = "setuptools.build_meta" From e33c95d863f755975a75834324af63be075f3278 Mon Sep 17 00:00:00 2001 From: nstarman Date: Wed, 16 Oct 2024 17:02:39 -0400 Subject: [PATCH 3/3] ci: apply fixes Signed-off-by: nstarman --- numpyro/contrib/control_flow/scan.py | 7 ++++--- numpyro/infer/autoguide.py | 15 ++++++++++----- numpyro/infer/elbo.py | 7 +++++-- numpyro/infer/inspect.py | 11 +++++++---- test/infer/test_hmc_util.py | 5 +++-- test/test_distributions.py | 8 ++++++-- test/test_handlers.py | 6 ++++-- 7 files changed, 39 insertions(+), 20 deletions(-) diff --git a/numpyro/contrib/control_flow/scan.py b/numpyro/contrib/control_flow/scan.py index 7a2689cf1..957cebe6b 100644 --- a/numpyro/contrib/control_flow/scan.py +++ b/numpyro/contrib/control_flow/scan.py @@ -197,9 +197,10 @@ def body_fn(wrapped_carry, x, prefix=None): ) return (i + 1, rng_key, new_carry), (PytreeTrace(trace), y) - with handlers.block( - hide_fn=lambda site: not site["name"].startswith("_PREV_") - ), enum(first_available_dim=first_available_dim): + with ( + handlers.block(hide_fn=lambda site: not site["name"].startswith("_PREV_")), + enum(first_available_dim=first_available_dim), + ): wrapped_carry = (0, rng_key, init) y0s = [] # We run unroll_steps + 1 where the last step is used for rolling with `lax.scan` diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index 72b2df3bf..6df31055a 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -1463,8 +1463,10 @@ def _sample_latent(self, *args, **kwargs): if self.global_guide is not None: global_latents = self.global_guide(*args, **kwargs) rng_key = numpyro.prng_key() - with handlers.block(), handlers.seed(rng_seed=rng_key), handlers.substitute( - data=global_latents + with ( + handlers.block(), + handlers.seed(rng_seed=rng_key), + handlers.substitute(data=global_latents), ): global_outputs = self.global_guide.model(*args, **kwargs) local_args = (global_outputs,) @@ -1575,9 +1577,12 @@ def fn(x): if self.local_guide is not None: key = numpyro.prng_key() subsample_guide = partial(_subsample_model, self.local_guide) - with handlers.block(), handlers.trace() as tr, handlers.seed( - rng_seed=key - ), handlers.substitute(data=local_guide_params): + with ( + handlers.block(), + handlers.trace() as tr, + handlers.seed(rng_seed=key), + handlers.substitute(data=local_guide_params), + ): with warnings.catch_warnings(): warnings.simplefilter("ignore") subsample_guide(*local_args, **local_kwargs) diff --git a/numpyro/infer/elbo.py b/numpyro/infer/elbo.py index c0aaa3118..4fcd97b67 100644 --- a/numpyro/infer/elbo.py +++ b/numpyro/infer/elbo.py @@ -874,8 +874,11 @@ def get_importance_trace_enum( trace as _trace, ) - with plate_to_enum_plate(), enum( - first_available_dim=(-max_plate_nesting - 1) if max_plate_nesting else None + with ( + plate_to_enum_plate(), + enum( + first_available_dim=(-max_plate_nesting - 1) if max_plate_nesting else None + ), ): guide = substitute(guide, data=params) with _without_rsample_stop_gradient(): diff --git a/numpyro/infer/inspect.py b/numpyro/infer/inspect.py index 7b999ede5..1e4b52b75 100644 --- a/numpyro/infer/inspect.py +++ b/numpyro/infer/inspect.py @@ -58,8 +58,10 @@ def get_trace(): def _get_log_probs(model, model_args, model_kwargs, **sample): # Note: We use seed 0 for parameter initialization. - with handlers.trace() as tr, handlers.seed(rng_seed=0), handlers.substitute( - data=sample + with ( + handlers.trace() as tr, + handlers.seed(rng_seed=0), + handlers.substitute(data=sample), ): model(*model_args, **model_kwargs) return { @@ -370,8 +372,9 @@ def process_message(self, msg): # Note: We use seed 0 for parameter initialization. with handlers.trace() as tr, handlers.seed(rng_seed=0): - with handlers.substitute(data=sample), substitute_deterministic( - data=sample + with ( + handlers.substitute(data=sample), + substitute_deterministic(data=sample), ): model(*model_args, **model_kwargs) provenance_arrays = {} diff --git a/test/infer/test_hmc_util.py b/test/infer/test_hmc_util.py index f1a81e436..080c6a6f4 100644 --- a/test/infer/test_hmc_util.py +++ b/test/infer/test_hmc_util.py @@ -57,8 +57,9 @@ def optimize(f): @pytest.mark.parametrize("regularize", [True, False]) @pytest.mark.filterwarnings("ignore:numpy.linalg support is experimental:UserWarning") def test_welford_covariance(jitted, diagonal, regularize): - with optional(jitted, disable_jit()), optional( - jitted, control_flow_prims_disabled() + with ( + optional(jitted, disable_jit()), + optional(jitted, control_flow_prims_disabled()), ): np.random.seed(0) loc = np.random.randn(3) diff --git a/test/test_distributions.py b/test/test_distributions.py index 20da4165d..8ed9e78e5 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -2221,8 +2221,12 @@ def g(x): def test_beta_proportion_invalid_mean(): - with dist.distribution.validation_enabled(), pytest.raises( - ValueError, match=r"^BetaProportion distribution got invalid mean parameter\.$" + with ( + dist.distribution.validation_enabled(), + pytest.raises( + ValueError, + match=r"^BetaProportion distribution got invalid mean parameter\.$", + ), ): dist.BetaProportion(1.0, 1.0) diff --git a/test/test_handlers.py b/test/test_handlers.py index 4ef449237..15121eb46 100644 --- a/test/test_handlers.py +++ b/test/test_handlers.py @@ -372,8 +372,10 @@ def test_subsample_substitute(): data = jnp.arange(100.0) subsample_size = 7 subsample = jnp.array([13, 3, 30, 4, 1, 68, 5]) - with handlers.trace() as tr, handlers.seed(rng_seed=0), handlers.substitute( - data={"a": subsample} + with ( + handlers.trace() as tr, + handlers.seed(rng_seed=0), + handlers.substitute(data={"a": subsample}), ): with numpyro.plate("a", len(data), subsample_size=subsample_size) as idx: assert data[idx].shape == (subsample_size,)