From dc1bac91998646599df337a7cf868ec42eabe9cd Mon Sep 17 00:00:00 2001 From: Du Phan Date: Thu, 4 Apr 2024 17:40:02 -0400 Subject: [PATCH 1/2] remove deprecated import jax.tree_map --- numpyro/contrib/einstein/mixture_guide_predictive.py | 4 ++-- numpyro/distributions/batch_util.py | 2 +- test/test_constraints.py | 3 ++- test/test_transforms.py | 3 ++- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/numpyro/contrib/einstein/mixture_guide_predictive.py b/numpyro/contrib/einstein/mixture_guide_predictive.py index 3d9849224..c3b2fdc5b 100644 --- a/numpyro/contrib/einstein/mixture_guide_predictive.py +++ b/numpyro/contrib/einstein/mixture_guide_predictive.py @@ -5,8 +5,8 @@ from functools import partial from typing import Optional -from jax import numpy as jnp, random, tree_map, vmap -from jax.tree_util import tree_flatten +from jax import numpy as jnp, random, vmap +from jax.tree_util import tree_flatten, tree_map from numpyro.handlers import substitute from numpyro.infer import Predictive diff --git a/numpyro/distributions/batch_util.py b/numpyro/distributions/batch_util.py index 0292e02e1..83698b00e 100644 --- a/numpyro/distributions/batch_util.py +++ b/numpyro/distributions/batch_util.py @@ -5,8 +5,8 @@ from functools import singledispatch from typing import Union -from jax import tree_map import jax.numpy as jnp +from jax.tree_util import tree_map from numpyro.distributions import constraints from numpyro.distributions.conjugate import ( diff --git a/test/test_constraints.py b/test/test_constraints.py index 735969fa6..d73325d84 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -5,8 +5,9 @@ import pytest -from jax import jit, tree_map, vmap +from jax import jit, vmap import jax.numpy as jnp +from jax.tree_util import tree_map from numpyro.distributions import constraints diff --git a/test/test_transforms.py b/test/test_transforms.py index 1a706bbc6..3faf07815 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -7,8 +7,9 @@ import pytest -from jax import jacfwd, jit, random, tree_map, vmap +from jax import jacfwd, jit, random, vmap import jax.numpy as jnp +from jax.tree_util import tree_map from numpyro.distributions.flows import ( BlockNeuralAutoregressiveTransform, From 96aaf4b76502e75676426acff4d39bd086569de3 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Sat, 6 Apr 2024 08:52:14 -0400 Subject: [PATCH 2/2] fix enable_x64 logic --- numpyro/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpyro/util.py b/numpyro/util.py index ccc8c09ae..c4c8c53a3 100644 --- a/numpyro/util.py +++ b/numpyro/util.py @@ -45,7 +45,7 @@ def enable_x64(use_x64=True): """ if not use_x64: use_x64 = os.getenv("JAX_ENABLE_X64", 0) - jax.config.update("jax_enable_x64", use_x64) + jax.config.update("jax_enable_x64", bool(use_x64)) def set_platform(platform=None):