Skip to content

Commit

Permalink
Remove the deprecated import jax.tree_map (pyro-ppl#1775)
Browse files Browse the repository at this point in the history
* remove deprecated import jax.tree_map

* fix enable_x64 logic
  • Loading branch information
fehiepsi authored and OlaRonning committed May 6, 2024
1 parent 82b130c commit a822bce
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 6 deletions.
4 changes: 2 additions & 2 deletions numpyro/contrib/einstein/mixture_guide_predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from functools import partial
from typing import Optional

from jax import numpy as jnp, random, tree_map, vmap
from jax.tree_util import tree_flatten
from jax import numpy as jnp, random, vmap
from jax.tree_util import tree_flatten, tree_map

from numpyro.handlers import substitute
from numpyro.infer import Predictive
Expand Down
2 changes: 1 addition & 1 deletion numpyro/distributions/batch_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from functools import singledispatch
from typing import Union

from jax import tree_map
import jax.numpy as jnp
from jax.tree_util import tree_map

from numpyro.distributions import constraints
from numpyro.distributions.conjugate import (
Expand Down
2 changes: 1 addition & 1 deletion numpyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def enable_x64(use_x64=True):
"""
if not use_x64:
use_x64 = os.getenv("JAX_ENABLE_X64", 0)
jax.config.update("jax_enable_x64", use_x64)
jax.config.update("jax_enable_x64", bool(use_x64))


def set_platform(platform=None):
Expand Down
3 changes: 2 additions & 1 deletion test/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

import pytest

from jax import jit, tree_map, vmap
from jax import jit, vmap
import jax.numpy as jnp
from jax.tree_util import tree_map

from numpyro.distributions import constraints

Expand Down
3 changes: 2 additions & 1 deletion test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import numpy as np
import pytest

from jax import jacfwd, jit, random, tree_map, vmap
from jax import jacfwd, jit, random, vmap
import jax.numpy as jnp
from jax.tree_util import tree_map

from numpyro.distributions import constraints
from numpyro.distributions.flows import (
Expand Down

0 comments on commit a822bce

Please sign in to comment.