diff --git a/numpyro/contrib/einstein/mixture_guide_predictive.py b/numpyro/contrib/einstein/mixture_guide_predictive.py index 3d9849224..c3b2fdc5b 100644 --- a/numpyro/contrib/einstein/mixture_guide_predictive.py +++ b/numpyro/contrib/einstein/mixture_guide_predictive.py @@ -5,8 +5,8 @@ from functools import partial from typing import Optional -from jax import numpy as jnp, random, tree_map, vmap -from jax.tree_util import tree_flatten +from jax import numpy as jnp, random, vmap +from jax.tree_util import tree_flatten, tree_map from numpyro.handlers import substitute from numpyro.infer import Predictive diff --git a/numpyro/distributions/batch_util.py b/numpyro/distributions/batch_util.py index 0292e02e1..83698b00e 100644 --- a/numpyro/distributions/batch_util.py +++ b/numpyro/distributions/batch_util.py @@ -5,8 +5,8 @@ from functools import singledispatch from typing import Union -from jax import tree_map import jax.numpy as jnp +from jax.tree_util import tree_map from numpyro.distributions import constraints from numpyro.distributions.conjugate import ( diff --git a/numpyro/util.py b/numpyro/util.py index ccc8c09ae..c4c8c53a3 100644 --- a/numpyro/util.py +++ b/numpyro/util.py @@ -45,7 +45,7 @@ def enable_x64(use_x64=True): """ if not use_x64: use_x64 = os.getenv("JAX_ENABLE_X64", 0) - jax.config.update("jax_enable_x64", use_x64) + jax.config.update("jax_enable_x64", bool(use_x64)) def set_platform(platform=None): diff --git a/test/test_constraints.py b/test/test_constraints.py index fb34bf6b8..67832c476 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -5,8 +5,9 @@ import pytest -from jax import jit, tree_map, vmap +from jax import jit, vmap import jax.numpy as jnp +from jax.tree_util import tree_map from numpyro.distributions import constraints diff --git a/test/test_transforms.py b/test/test_transforms.py index 15a6fa394..4329d8b63 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -8,8 +8,9 @@ import numpy as np import pytest -from jax import jacfwd, jit, random, tree_map, vmap +from jax import jacfwd, jit, random, vmap import jax.numpy as jnp +from jax.tree_util import tree_map from numpyro.distributions import constraints from numpyro.distributions.flows import (