Skip to content

Commit

Permalink
Replace deprecated jax.tree_* functions with jax.tree.*
Browse files Browse the repository at this point in the history
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25.

PiperOrigin-RevId: 634343669
  • Loading branch information
Jake VanderPlas authored and copybara-github committed May 16, 2024
1 parent d887ed2 commit f593e4c
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion python/jax/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ grad_jax, new_state = gfunc(weights_jax, run_seg_input, hypers_jax, state_jax)
## Now one can update `Weights` based on `grad_jax`.
# One can use a library like `Optax`. Here, for illustration, we can just do the
# gradient descent (stepsize=0.1) by,
new_weights = jax.tree_map(lambda x,y: x-0.1*y, weights_jax, grad_jax)
new_weights = jax.tree.map(lambda x,y: x-0.1*y, weights_jax, grad_jax)
# Please note that: we currently put as many coefficients as possible into
# `Weights` but normally we don't need to train all of them (and some of the
# weights are much more sensitive than the others). One can selectively update
Expand Down
12 changes: 6 additions & 6 deletions python/jax/carfac_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,12 @@ def test_same_outputs_parallel_for_pmap(self):
self.assertTrue((combined_output[1][4] == agc_out_b).all())
self.assertTrue(
jax.tree_util.tree_all(
jax.tree_map(jnp.allclose, state_out_a, combined_output[0][1])
jax.tree.map(jnp.allclose, state_out_a, combined_output[0][1])
)
)
self.assertTrue(
jax.tree_util.tree_all(
jax.tree_map(jnp.allclose, state_out_b, combined_output[1][1])
jax.tree.map(jnp.allclose, state_out_b, combined_output[1][1])
)
)

Expand Down Expand Up @@ -154,25 +154,25 @@ def test_same_outputs_parallel_for_shmap(self):
self.assertTrue((combined_output[1][4] == agc_out_b).all())
self.assertTrue(
jax.tree_util.tree_all(
jax.tree_map(jnp.allclose, state_out_a, combined_output[0][1])
jax.tree.map(jnp.allclose, state_out_a, combined_output[0][1])
)
)
self.assertTrue(
jax.tree_util.tree_all(
jax.tree_map(jnp.allclose, state_out_b, combined_output[1][1])
jax.tree.map(jnp.allclose, state_out_b, combined_output[1][1])
)
)

# The state is the most unusual one in all the outputs: ensure that
# equality is complete and double sided.
self.assertTrue(
jax.tree_util.tree_all(
jax.tree_map(jnp.allclose, combined_output[0][1], state_out_a)
jax.tree.map(jnp.allclose, combined_output[0][1], state_out_a)
)
)
self.assertTrue(
jax.tree_util.tree_all(
jax.tree_map(jnp.allclose, combined_output[1][1], state_out_b)
jax.tree.map(jnp.allclose, combined_output[1][1], state_out_b)
)
)

Expand Down

0 comments on commit f593e4c

Please sign in to comment.