From 5af9ebda72bd7aeb08c61e4248ecd0d982473224 Mon Sep 17 00:00:00 2001
From: Du Phan <fehiepsi@gmail.com>
Date: Wed, 26 Jun 2024 05:10:07 -0400
Subject: [PATCH] Update jax.tree_util.tree_map to jax.tree.map (#1821)

* update jax.tree_util.tree_foo to jax.tree.foo

* bump minimal jax version to 0.4.25, which supports jax.tree

* fix lint issues

* also fix deprecation warning of using a_min, a_max in jnp.clip
---
 examples/annotation.py                        |  2 +-
 examples/gp.py                                |  2 +-
 .../source/time_series_forecasting.ipynb      |  4 +-
 numpyro/contrib/control_flow/scan.py          | 40 ++++++++--------
 .../einstein/mixture_guide_predictive.py      |  6 +--
 numpyro/contrib/einstein/stein_util.py        | 10 ++--
 numpyro/contrib/einstein/steinvi.py           | 10 ++--
 numpyro/contrib/module.py                     | 11 +++--
 numpyro/contrib/tfp/distributions.py          |  4 +-
 numpyro/contrib/tfp/mcmc.py                   |  4 +-
 numpyro/diagnostics.py                        | 12 ++---
 numpyro/distributions/batch_util.py           |  4 +-
 numpyro/distributions/continuous.py           | 12 ++---
 numpyro/distributions/directional.py          |  2 +-
 numpyro/distributions/discrete.py             |  4 +-
 numpyro/distributions/distribution.py         |  3 +-
 numpyro/distributions/flows.py                |  2 +-
 numpyro/distributions/transforms.py           | 22 ++++-----
 numpyro/distributions/truncated.py            | 10 ++--
 numpyro/distributions/util.py                 |  6 +--
 numpyro/infer/autoguide.py                    | 17 ++++---
 numpyro/infer/barker.py                       |  2 +-
 numpyro/infer/ensemble.py                     |  2 +-
 numpyro/infer/ensemble_util.py                |  3 +-
 numpyro/infer/hmc.py                          |  2 +-
 numpyro/infer/hmc_gibbs.py                    |  2 +-
 numpyro/infer/hmc_util.py                     | 26 +++++------
 numpyro/infer/mcmc.py                         | 32 ++++++-------
 numpyro/infer/mixed_hmc.py                    |  2 +-
 numpyro/infer/svi.py                          |  3 +-
 numpyro/infer/util.py                         |  5 +-
 numpyro/ops/provenance.py                     | 14 +++---
 numpyro/optim.py                              |  7 ++-
 numpyro/util.py                               | 13 +++---
 setup.py                                      |  4 +-
 test/contrib/einstein/test_steinvi_util.py    |  8 ++--
 test/contrib/test_enum_elbo.py                | 46 +++++++++----------
 test/contrib/test_module.py                   |  6 +--
 test/infer/test_autoguide.py                  |  8 ++--
 test/infer/test_ensemble_util.py              |  2 +-
 test/infer/test_gradient.py                   | 16 +++----
 test/infer/test_hmc_util.py                   |  4 +-
 test/infer/test_mcmc.py                       | 17 ++++---
 test/infer/test_svi.py                        |  5 +-
 test/ops/test_provenance.py                   |  8 ++--
 test/test_constraints.py                      |  6 +--
 test/test_distributions.py                    |  6 +--
 test/test_handlers.py                         |  3 +-
 test/test_pickle.py                           | 20 +++++---
 test/test_transforms.py                       |  6 +--
 test/test_util.py                             | 14 +++---
 51 files changed, 236 insertions(+), 243 deletions(-)

diff --git a/examples/annotation.py b/examples/annotation.py
index 881825316..3341dbe72 100644
--- a/examples/annotation.py
+++ b/examples/annotation.py
@@ -309,7 +309,7 @@ def main(args):
 #     is stored in `discrete_samples`. To merge those discrete samples into the `mcmc`
 #     instance, we can use the following pattern::
 #
-#         chain_discrete_samples = jax.tree_util.tree_map(
+#         chain_discrete_samples = jax.tree.map(
 #             lambda x: x.reshape((args.num_chains, args.num_samples) + x.shape[1:]),
 #             discrete_samples)
 #         mcmc.get_samples().update(discrete_samples)
diff --git a/examples/gp.py b/examples/gp.py
index aac9632f8..0f70400db 100644
--- a/examples/gp.py
+++ b/examples/gp.py
@@ -116,7 +116,7 @@ def predict(rng_key, X, Y, X_test, var, length, noise, use_cholesky=True):
         K = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))
         mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, Y))
 
-    sigma_noise = jnp.sqrt(jnp.clip(jnp.diag(K), a_min=0.0)) * jax.random.normal(
+    sigma_noise = jnp.sqrt(jnp.clip(jnp.diag(K), 0.0)) * jax.random.normal(
         rng_key, X_test.shape[:1]
     )
 
diff --git a/notebooks/source/time_series_forecasting.ipynb b/notebooks/source/time_series_forecasting.ipynb
index 3ff5667d6..40f467ee0 100644
--- a/notebooks/source/time_series_forecasting.ipynb
+++ b/notebooks/source/time_series_forecasting.ipynb
@@ -206,7 +206,7 @@
     "        level, s, moving_sum = carry\n",
     "        season = s[0] * level**pow_season\n",
     "        exp_val = level + coef_trend * level**pow_trend + season\n",
-    "        exp_val = jnp.clip(exp_val, a_min=0)\n",
+    "        exp_val = jnp.clip(exp_val, 0)\n",
     "        # use expected vale when forecasting\n",
     "        y_t = jnp.where(t >= N, exp_val, y[t])\n",
     "\n",
@@ -215,7 +215,7 @@
     "        )\n",
     "        level_p = jnp.where(t >= seasonality, moving_sum / seasonality, y_t - season)\n",
     "        level = level_sm * level_p + (1 - level_sm) * level\n",
-    "        level = jnp.clip(level, a_min=0)\n",
+    "        level = jnp.clip(level, 0)\n",
     "\n",
     "        new_s = (s_sm * (y_t - level) / season + (1 - s_sm)) * s[0]\n",
     "        # repeat s when forecasting\n",
diff --git a/numpyro/contrib/control_flow/scan.py b/numpyro/contrib/control_flow/scan.py
index 6b657b494..7a2689cf1 100644
--- a/numpyro/contrib/control_flow/scan.py
+++ b/numpyro/contrib/control_flow/scan.py
@@ -4,9 +4,9 @@
 from collections import OrderedDict
 from functools import partial
 
+import jax
 from jax import device_put, lax, random
 import jax.numpy as jnp
-from jax.tree_util import tree_flatten, tree_map, tree_unflatten
 
 from numpyro import handlers
 from numpyro.distributions.batch_util import promote_batch_shape
@@ -98,7 +98,7 @@ def postprocess_message(self, msg):
             fn_batch_ndim = len(fn.batch_shape)
             if fn_batch_ndim < value_batch_ndims:
                 prepend_shapes = (1,) * (value_batch_ndims - fn_batch_ndim)
-                msg["fn"] = tree_map(
+                msg["fn"] = jax.tree.map(
                     lambda x: jnp.reshape(x, prepend_shapes + jnp.shape(x)), fn
                 )
 
@@ -140,11 +140,11 @@ def scan_enum(
     history = min(history, length)
     unroll_steps = min(2 * history - 1, length)
     if reverse:
-        x0 = tree_map(lambda x: x[-unroll_steps:][::-1], xs)
-        xs_ = tree_map(lambda x: x[:-unroll_steps], xs)
+        x0 = jax.tree.map(lambda x: x[-unroll_steps:][::-1], xs)
+        xs_ = jax.tree.map(lambda x: x[:-unroll_steps], xs)
     else:
-        x0 = tree_map(lambda x: x[:unroll_steps], xs)
-        xs_ = tree_map(lambda x: x[unroll_steps:], xs)
+        x0 = jax.tree.map(lambda x: x[:unroll_steps], xs)
+        xs_ = jax.tree.map(lambda x: x[unroll_steps:], xs)
 
     carry_shapes = []
 
@@ -187,10 +187,12 @@ def body_fn(wrapped_carry, x, prefix=None):
 
             # store shape of new_carry at a global variable
             if len(carry_shapes) < (history + 1):
-                carry_shapes.append([jnp.shape(x) for x in tree_flatten(new_carry)[0]])
+                carry_shapes.append(
+                    [jnp.shape(x) for x in jax.tree.flatten(new_carry)[0]]
+                )
             # make new_carry have the same shape as carry
             # FIXME: is this rigorous?
-            new_carry = tree_map(
+            new_carry = jax.tree.map(
                 lambda a, b: jnp.reshape(a, jnp.shape(b)), new_carry, carry
             )
         return (i + 1, rng_key, new_carry), (PytreeTrace(trace), y)
@@ -204,11 +206,11 @@ def body_fn(wrapped_carry, x, prefix=None):
         for i in markov(range(unroll_steps + 1), history=history):
             if i < unroll_steps:
                 wrapped_carry, (_, y0) = body_fn(
-                    wrapped_carry, tree_map(lambda z: z[i], x0)
+                    wrapped_carry, jax.tree.map(lambda z: z[i], x0)
                 )
                 if i > 0:
                     # reshape y1, y2,... to have the same shape as y0
-                    y0 = tree_map(
+                    y0 = jax.tree.map(
                         lambda z0, z: jnp.reshape(z, jnp.shape(z0)), y0s[0], y0
                     )
                 y0s.append(y0)
@@ -216,15 +218,15 @@ def body_fn(wrapped_carry, x, prefix=None):
                 # shape so we don't need to record them here
                 if (i >= history - 1) and (len(carry_shapes) < history + 1):
                     carry_shapes.append(
-                        jnp.shape(x) for x in tree_flatten(wrapped_carry[-1])[0]
+                        jnp.shape(x) for x in jax.tree.flatten(wrapped_carry[-1])[0]
                     )
             else:
                 # this is the last rolling step
-                y0s = tree_map(lambda *z: jnp.stack(z, axis=0), *y0s)
+                y0s = jax.tree.map(lambda *z: jnp.stack(z, axis=0), *y0s)
                 # return early if length = unroll_steps
                 if length == unroll_steps:
                     return wrapped_carry, (PytreeTrace({}), y0s)
-                wrapped_carry = tree_map(device_put, wrapped_carry)
+                wrapped_carry = jax.tree.map(device_put, wrapped_carry)
                 wrapped_carry, (pytree_trace, ys) = lax.scan(
                     body_fn, wrapped_carry, xs_, length - unroll_steps, reverse
                 )
@@ -251,20 +253,20 @@ def body_fn(wrapped_carry, x, prefix=None):
         site["infer"]["dim_to_name"][time_dim] = "_time_{}".format(first_var)
 
     # similar to carry, we need to reshape due to shape alternating in markov
-    ys = tree_map(
+    ys = jax.tree.map(
         lambda z0, z: jnp.reshape(z, z.shape[:1] + jnp.shape(z0)[1:]), y0s, ys
     )
     # then join with y0s
-    ys = tree_map(lambda z0, z: jnp.concatenate([z0, z], axis=0), y0s, ys)
+    ys = jax.tree.map(lambda z0, z: jnp.concatenate([z0, z], axis=0), y0s, ys)
     # we also need to reshape `carry` to match sequential behavior
     i = (length + 1) % (history + 1)
     t, rng_key, carry = wrapped_carry
     carry_shape = carry_shapes[i]
-    flatten_carry, treedef = tree_flatten(carry)
+    flatten_carry, treedef = jax.tree.flatten(carry)
     flatten_carry = [
         jnp.reshape(x, t1_shape) for x, t1_shape in zip(flatten_carry, carry_shape)
     ]
-    carry = tree_unflatten(treedef, flatten_carry)
+    carry = jax.tree.unflatten(treedef, flatten_carry)
     wrapped_carry = (t, rng_key, carry)
     return wrapped_carry, (pytree_trace, ys)
 
@@ -282,7 +284,7 @@ def scan_wrapper(
     first_available_dim=None,
 ):
     if length is None:
-        length = jnp.shape(tree_flatten(xs)[0][0])[0]
+        length = jnp.shape(jax.tree.flatten(xs)[0][0])[0]
 
     if enum and history > 0:
         return scan_enum(  # TODO: replay for enum
@@ -324,7 +326,7 @@ def body_fn(wrapped_carry, x):
 
         return (i + 1, rng_key, carry), (PytreeTrace(trace), y)
 
-    wrapped_carry = tree_map(device_put, (0, rng_key, init))
+    wrapped_carry = jax.tree.map(device_put, (0, rng_key, init))
     last_carry, (pytree_trace, ys) = lax.scan(
         body_fn, wrapped_carry, xs, length=length, reverse=reverse
     )
diff --git a/numpyro/contrib/einstein/mixture_guide_predictive.py b/numpyro/contrib/einstein/mixture_guide_predictive.py
index c3b2fdc5b..2a2a8ed51 100644
--- a/numpyro/contrib/einstein/mixture_guide_predictive.py
+++ b/numpyro/contrib/einstein/mixture_guide_predictive.py
@@ -5,8 +5,8 @@
 from functools import partial
 from typing import Optional
 
+import jax
 from jax import numpy as jnp, random, vmap
-from jax.tree_util import tree_flatten, tree_map
 
 from numpyro.handlers import substitute
 from numpyro.infer import Predictive
@@ -63,7 +63,7 @@ def __init__(
 
         self.guide = guide
         self.return_sites = return_sites
-        self.num_mixture_components = jnp.shape(tree_flatten(params)[0][0])[0]
+        self.num_mixture_components = jnp.shape(jax.tree.flatten(params)[0][0])[0]
         self.mixture_assignment_sitename = mixture_assignment_sitename
 
     def _call_with_params(self, rng_key, params, args, kwargs):
@@ -99,7 +99,7 @@ def __call__(self, rng_key, *args, **kwargs):
             minval=0,
             maxval=self.num_mixture_components,
         )
-        predictive_assign = tree_map(
+        predictive_assign = jax.tree.map(
             lambda arr: vmap(lambda i, assign: arr[i, assign])(
                 jnp.arange(self._batch_shape[0]), assigns
             ),
diff --git a/numpyro/contrib/einstein/stein_util.py b/numpyro/contrib/einstein/stein_util.py
index e8cb80372..741e7c816 100644
--- a/numpyro/contrib/einstein/stein_util.py
+++ b/numpyro/contrib/einstein/stein_util.py
@@ -1,9 +1,9 @@
 # Copyright Contributors to the Pyro project.
 # SPDX-License-Identifier: Apache-2.0
 
+import jax
 from jax import numpy as jnp, vmap
 from jax.flatten_util import ravel_pytree
-from jax.tree_util import tree_map
 
 from numpyro.distributions import biject_to
 from numpyro.distributions.constraints import real
@@ -64,14 +64,14 @@ def batch_ravel_pytree(pytree, nbatch_dims=0):
         flat, unravel_fn = ravel_pytree(pytree)
         return flat, unravel_fn, unravel_fn
 
-    shapes = tree_map(lambda x: x.shape, pytree)
-    flat_pytree = tree_map(lambda x: x.reshape(*x.shape[:-nbatch_dims], -1), pytree)
+    shapes = jax.tree.map(lambda x: x.shape, pytree)
+    flat_pytree = jax.tree.map(lambda x: x.reshape(*x.shape[:-nbatch_dims], -1), pytree)
     flat = vmap(lambda x: ravel_pytree(x)[0])(flat_pytree)
-    unravel_fn = ravel_pytree(tree_map(lambda x: x[0], flat_pytree))[1]
+    unravel_fn = ravel_pytree(jax.tree.map(lambda x: x[0], flat_pytree))[1]
     return (
         flat,
         unravel_fn,
-        lambda _flat: tree_map(
+        lambda _flat: jax.tree.map(
             lambda x, shape: x.reshape(shape), vmap(unravel_fn)(_flat), shapes
         ),
     )
diff --git a/numpyro/contrib/einstein/steinvi.py b/numpyro/contrib/einstein/steinvi.py
index 7f2b47f70..98c055db1 100644
--- a/numpyro/contrib/einstein/steinvi.py
+++ b/numpyro/contrib/einstein/steinvi.py
@@ -9,9 +9,9 @@
 from itertools import chain
 import operator
 
+import jax
 from jax import grad, jacfwd, numpy as jnp, random, vmap
 from jax.flatten_util import ravel_pytree
-from jax.tree_util import tree_map
 
 from numpyro import handlers
 from numpyro.contrib.einstein.stein_kernels import SteinKernel
@@ -340,10 +340,10 @@ def _update_force(attr_force, rep_force, jac):
                 return force.reshape(attr_force.shape)
 
             reparam_jac = {
-                name: tree_map(lambda var: _nontrivial_jac(name, var), variables)
+                name: jax.tree.map(lambda var: _nontrivial_jac(name, var), variables)
                 for name, variables in unravel_pytree(particle).items()
             }
-            jac_params = tree_map(
+            jac_params = jax.tree.map(
                 _update_force,
                 unravel_pytree(attr_forces),
                 unravel_pytree(rep_forces),
@@ -363,7 +363,7 @@ def _update_force(attr_force, rep_force, jac):
         stein_param_grads = unravel_pytree_batched(particle_grads)
 
         # 6. Return loss and gradients (based on parameter forces)
-        res_grads = tree_map(
+        res_grads = jax.tree.map(
             lambda x: -x, {**non_mixture_param_grads, **stein_param_grads}
         )
         return jnp.linalg.norm(particle_grads), res_grads
@@ -427,7 +427,7 @@ def init(self, rng_key, *args, **kwargs):
                 if site["name"] in guide_init_params:
                     pval = guide_init_params[site["name"]]
                     if self.non_mixture_params_fn(site["name"]):
-                        pval = tree_map(lambda x: x[0], pval)
+                        pval = jax.tree.map(lambda x: x[0], pval)
                 else:
                     pval = site["value"]
                 params[site["name"]] = transform.inv(pval)
diff --git a/numpyro/contrib/module.py b/numpyro/contrib/module.py
index f370b4330..3f40363e6 100644
--- a/numpyro/contrib/module.py
+++ b/numpyro/contrib/module.py
@@ -5,9 +5,10 @@
 from copy import deepcopy
 from functools import partial
 
+import jax
 from jax import random
 import jax.numpy as jnp
-from jax.tree_util import register_pytree_node, tree_flatten, tree_unflatten
+from jax.tree_util import register_pytree_node
 
 import numpyro
 import numpyro.distributions as dist
@@ -106,8 +107,8 @@ def flax_module(
             assert set(mutable) == set(nn_state)
             numpyro_mutable(name + "$state", nn_state)
         # make sure that nn_params keep the same order after unflatten
-        params_flat, tree_def = tree_flatten(nn_params)
-        nn_params = tree_unflatten(tree_def, params_flat)
+        params_flat, tree_def = jax.tree.flatten(nn_params)
+        nn_params = jax.tree.unflatten(tree_def, params_flat)
         numpyro.param(module_key, nn_params)
 
     def apply_with_state(params, *args, **kwargs):
@@ -195,8 +196,8 @@ def haiku_module(name, nn_module, *args, input_shape=None, apply_rng=False, **kw
         nn_params = hk.data_structures.to_mutable_dict(nn_params)
         # we cast it to a mutable one to be able to set priors for parameters
         # make sure that nn_params keep the same order after unflatten
-        params_flat, tree_def = tree_flatten(nn_params)
-        nn_params = tree_unflatten(tree_def, params_flat)
+        params_flat, tree_def = jax.tree.flatten(nn_params)
+        nn_params = jax.tree.unflatten(tree_def, params_flat)
         numpyro.param(module_key, nn_params)
 
     def apply_with_state(params, *args, **kwargs):
diff --git a/numpyro/contrib/tfp/distributions.py b/numpyro/contrib/tfp/distributions.py
index 4db875cfe..4c3a76009 100644
--- a/numpyro/contrib/tfp/distributions.py
+++ b/numpyro/contrib/tfp/distributions.py
@@ -282,11 +282,11 @@ def is_discrete(self):
         return self.support is None
 
     def tree_flatten(self):
-        return jax.tree_util.tree_flatten(self.tfp_dist)
+        return jax.tree.flatten(self.tfp_dist)
 
     @classmethod
     def tree_unflatten(cls, aux_data, params):
-        fn = jax.tree_util.tree_unflatten(aux_data, params)
+        fn = jax.tree.unflatten(aux_data, params)
         with warnings.catch_warnings():
             warnings.simplefilter("ignore", category=FutureWarning)
             return TFPDistribution[fn.__class__](**fn.parameters)
diff --git a/numpyro/contrib/tfp/mcmc.py b/numpyro/contrib/tfp/mcmc.py
index b660af837..7a2312e70 100644
--- a/numpyro/contrib/tfp/mcmc.py
+++ b/numpyro/contrib/tfp/mcmc.py
@@ -5,10 +5,10 @@
 from collections import namedtuple
 import inspect
 
+import jax
 from jax import random, vmap
 from jax.flatten_util import ravel_pytree
 import jax.numpy as jnp
-from jax.tree_util import tree_map
 import tensorflow_probability.substrates.jax as tfp
 
 from numpyro.infer import init_to_uniform
@@ -44,7 +44,7 @@ def log_prob_fn(x):
             flatten_result = vmap(lambda a: -potential_fn(unravel_fn(a)))(
                 jnp.reshape(x, (-1,) + jnp.shape(x)[-1:])
             )
-            return tree_map(
+            return jax.tree.map(
                 lambda a: jnp.reshape(a, batch_shape + jnp.shape(a)[1:]), flatten_result
             )
         else:
diff --git a/numpyro/diagnostics.py b/numpyro/diagnostics.py
index 99aa9761c..cf3d04a9e 100644
--- a/numpyro/diagnostics.py
+++ b/numpyro/diagnostics.py
@@ -10,8 +10,8 @@
 
 import numpy as np
 
+import jax
 from jax import device_get
-from jax.tree_util import tree_flatten, tree_map
 
 __all__ = [
     "autocorrelation",
@@ -182,7 +182,7 @@ def effective_sample_size(x):
     Rho_k = np.concatenate(
         [
             Rho_init,
-            np.minimum.accumulate(np.clip(Rho_k[1:, ...], a_min=0, a_max=None), axis=0),
+            np.minimum.accumulate(np.clip(Rho_k[1:, ...], 0, None), axis=0),
         ],
         axis=0,
     )
@@ -238,10 +238,10 @@ def summary(samples, prob=0.90, group_by_chain=True):
         chain dimension).
     """
     if not group_by_chain:
-        samples = tree_map(lambda x: x[None, ...], samples)
+        samples = jax.tree.map(lambda x: x[None, ...], samples)
     if not isinstance(samples, dict):
         samples = {
-            "Param:{}".format(i): v for i, v in enumerate(tree_flatten(samples)[0])
+            "Param:{}".format(i): v for i, v in enumerate(jax.tree.flatten(samples)[0])
         }
 
     summary_dict = {}
@@ -288,10 +288,10 @@ def print_summary(samples, prob=0.90, group_by_chain=True):
         chain dimension).
     """
     if not group_by_chain:
-        samples = tree_map(lambda x: x[None, ...], samples)
+        samples = jax.tree.map(lambda x: x[None, ...], samples)
     if not isinstance(samples, dict):
         samples = {
-            "Param:{}".format(i): v for i, v in enumerate(tree_flatten(samples)[0])
+            "Param:{}".format(i): v for i, v in enumerate(jax.tree.flatten(samples)[0])
         }
     summary_dict = summary(samples, prob, group_by_chain=True)
 
diff --git a/numpyro/distributions/batch_util.py b/numpyro/distributions/batch_util.py
index 83698b00e..235127335 100644
--- a/numpyro/distributions/batch_util.py
+++ b/numpyro/distributions/batch_util.py
@@ -5,8 +5,8 @@
 from functools import singledispatch
 from typing import Union
 
+import jax
 import jax.numpy as jnp
-from jax.tree_util import tree_map
 
 from numpyro.distributions import constraints
 from numpyro.distributions.conjugate import (
@@ -547,7 +547,7 @@ def _promote_batch_shape_expanded(d: ExpandedDistribution):
         len(new_shapes_elems),
         len(new_shapes_elems) + len(orig_delta_batch_shape),
     )
-    new_base_dist = tree_map(
+    new_base_dist = jax.tree.map(
         lambda x: jnp.expand_dims(x, axis=new_axes_locs), new_self.base_dist
     )
 
diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py
index 237cb633a..652941746 100644
--- a/numpyro/distributions/continuous.py
+++ b/numpyro/distributions/continuous.py
@@ -281,9 +281,7 @@ def sample(self, key, sample_shape=()):
         assert is_prng_key(key)
         shape = sample_shape + self.batch_shape
         samples = random.dirichlet(key, self.concentration, shape=shape)
-        return jnp.clip(
-            samples, a_min=jnp.finfo(samples).tiny, a_max=1 - jnp.finfo(samples).eps
-        )
+        return jnp.clip(samples, jnp.finfo(samples).tiny, 1 - jnp.finfo(samples).eps)
 
     @validate_sample
     def log_prob(self, value):
@@ -840,15 +838,15 @@ def sample(self, key, sample_shape=()):
         u = random.uniform(
             key, shape=sample_shape + self.batch_shape, minval=finfo.tiny
         )
-        u_con0 = jnp.clip(u ** (1 / self.concentration0), a_max=1 - finfo.eps)
+        u_con0 = jnp.clip(u ** (1 / self.concentration0), None, 1 - finfo.eps)
         log_sample = jnp.log1p(-u_con0) / self.concentration1
-        return jnp.clip(jnp.exp(log_sample), a_min=finfo.tiny, a_max=1 - finfo.eps)
+        return jnp.clip(jnp.exp(log_sample), finfo.tiny, 1 - finfo.eps)
 
     @validate_sample
     def log_prob(self, value):
         finfo = jnp.finfo(jnp.result_type(float))
         normalize_term = jnp.log(self.concentration0) + jnp.log(self.concentration1)
-        value_con1 = jnp.clip(value**self.concentration1, a_max=1 - finfo.eps)
+        value_con1 = jnp.clip(value**self.concentration1, None, 1 - finfo.eps)
         return (
             xlogy(self.concentration1 - 1, value)
             + xlog1py(self.concentration0 - 1, -value_con1)
@@ -2363,7 +2361,7 @@ def log_prob(self, value):
 
     def cdf(self, value):
         cdf = (value - self.low) / (self.high - self.low)
-        return jnp.clip(cdf, a_min=0.0, a_max=1.0)
+        return jnp.clip(cdf, 0.0, 1.0)
 
     def icdf(self, value):
         return self.low + value * (self.high - self.low)
diff --git a/numpyro/distributions/directional.py b/numpyro/distributions/directional.py
index fd5b1596c..8156855b1 100644
--- a/numpyro/distributions/directional.py
+++ b/numpyro/distributions/directional.py
@@ -401,7 +401,7 @@ def norm_const(self):
         lbinoms = num - 2 * den
 
         fs = lbinoms.reshape(-1, 1) + m * (
-            jnp.log(jnp.clip(corr**2, a_min=jnp.finfo(jnp.result_type(float)).tiny))
+            jnp.log(jnp.clip(corr**2, jnp.finfo(jnp.result_type(float)).tiny))
             - jnp.log(4 * jnp.prod(conc, axis=-1))
         )
         fs += log_I1(49, conc, terms=51).sum(-1)
diff --git a/numpyro/distributions/discrete.py b/numpyro/distributions/discrete.py
index 0ee140620..7d7358a5d 100644
--- a/numpyro/distributions/discrete.py
+++ b/numpyro/distributions/discrete.py
@@ -65,7 +65,7 @@ def _to_probs_multinom(logits):
 
 def _to_logits_multinom(probs):
     minval = jnp.finfo(jnp.result_type(probs)).min
-    return jnp.clip(jnp.log(probs), a_min=minval)
+    return jnp.clip(jnp.log(probs), minval)
 
 
 class BernoulliProbs(Distribution):
@@ -443,7 +443,7 @@ def log_prob(self, value):
 
     def cdf(self, value):
         cdf = (jnp.floor(value) + 1 - self.low) / (self.high - self.low + 1)
-        return jnp.clip(cdf, a_min=0.0, a_max=1.0)
+        return jnp.clip(cdf, 0.0, 1.0)
 
     def icdf(self, value):
         return self.low + value * (self.high - self.low + 1) - 1
diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py
index cba2994b8..04b213a3e 100644
--- a/numpyro/distributions/distribution.py
+++ b/numpyro/distributions/distribution.py
@@ -33,6 +33,7 @@
 
 import numpy as np
 
+import jax
 from jax import lax, tree_util
 import jax.numpy as jnp
 from jax.scipy.special import logsumexp
@@ -636,7 +637,7 @@ def reshape_sample(x):
             event_shape = jnp.shape(x)[batch_ndims:]
             return x.reshape(sample_shape + self.batch_shape + event_shape)
 
-        intermediates = tree_util.tree_map(reshape_sample, intermediates)
+        intermediates = jax.tree.map(reshape_sample, intermediates)
         samples = reshape_sample(samples)
         return samples, intermediates
 
diff --git a/numpyro/distributions/flows.py b/numpyro/distributions/flows.py
index cd9b21c35..2980587b4 100644
--- a/numpyro/distributions/flows.py
+++ b/numpyro/distributions/flows.py
@@ -10,7 +10,7 @@
 
 
 def _clamp_preserve_gradients(x, min, max):
-    return x + lax.stop_gradient(jnp.clip(x, a_min=min, a_max=max) - x)
+    return x + lax.stop_gradient(jnp.clip(x, min, max) - x)
 
 
 # adapted from https://github.com/pyro-ppl/pyro/blob/dev/pyro/distributions/transforms/iaf.py
diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py
index 564c628d8..927c2b096 100644
--- a/numpyro/distributions/transforms.py
+++ b/numpyro/distributions/transforms.py
@@ -6,15 +6,15 @@
 import weakref
 
 import numpy as np
-from numpy.core.numeric import normalize_axis_tuple
 
+import jax
 from jax import lax, vmap
 from jax.flatten_util import ravel_pytree
 from jax.nn import log_sigmoid, softplus
 import jax.numpy as jnp
 from jax.scipy.linalg import solve_triangular
 from jax.scipy.special import expit, logit
-from jax.tree_util import register_pytree_node, tree_flatten, tree_map
+from jax.tree_util import register_pytree_node
 
 from numpyro.distributions import constraints
 from numpyro.distributions.util import (
@@ -57,7 +57,7 @@
 
 def _clipped_expit(x):
     finfo = jnp.finfo(jnp.result_type(x))
-    return jnp.clip(expit(x), a_min=finfo.tiny, a_max=1.0 - finfo.eps)
+    return jnp.clip(expit(x), finfo.tiny, 1.0 - finfo.eps)
 
 
 class Transform(object):
@@ -650,11 +650,11 @@ def _inverse(self, y):
         pad_width = [(0, 0)] * (y.ndim - 1) + [(1, 0)]
         remainder = jnp.pad(remainder, pad_width, mode="constant", constant_values=1.0)
         finfo = jnp.finfo(y.dtype)
-        remainder = jnp.clip(remainder, a_min=finfo.tiny)
+        remainder = jnp.clip(remainder, finfo.tiny)
         t = y / remainder
 
         # inverse of tanh
-        t = jnp.clip(t, a_min=-1 + finfo.eps, a_max=1 - finfo.eps)
+        t = jnp.clip(t, -1 + finfo.eps, 1 - finfo.eps)
         return jnp.arctanh(t)
 
     def log_abs_det_jacobian(self, x, y, intermediates=None):
@@ -666,7 +666,7 @@ def log_abs_det_jacobian(self, x, y, intermediates=None):
         # of the diagonal part of the jacobian
         one_minus_remainder = jnp.cumsum(jnp.abs(y[..., :-1]), axis=-1)
         eps = jnp.finfo(y.dtype).eps
-        one_minus_remainder = jnp.clip(one_minus_remainder, a_max=1 - eps)
+        one_minus_remainder = jnp.clip(one_minus_remainder, None, 1 - eps)
         # log(remainder) = log1p(remainder - 1)
         stick_breaking_logdet = jnp.sum(jnp.log1p(-one_minus_remainder), axis=-1)
 
@@ -1074,9 +1074,7 @@ def __call__(self, x):
 
     def _inverse(self, y):
         y_crop = y[..., :-1]
-        z1m_cumprod = jnp.clip(
-            1 - jnp.cumsum(y_crop, axis=-1), a_min=jnp.finfo(y.dtype).tiny
-        )
+        z1m_cumprod = jnp.clip(1 - jnp.cumsum(y_crop, axis=-1), jnp.finfo(y.dtype).tiny)
         # hence x = logit(z) = log(z / (1 - z)) = y[::-1] / z1m_cumprod
         x = jnp.log(y_crop / z1m_cumprod)
         return x + jnp.log(x.shape[-1] - jnp.arange(x.shape[-1]))
@@ -1116,7 +1114,7 @@ def __call__(self, x):
         batch_shape = x.shape[:-1]
         if batch_shape:
             unpacked = vmap(self.unpack_fn)(x.reshape((-1,) + x.shape[-1:]))
-            return tree_map(
+            return jax.tree.map(
                 lambda z: jnp.reshape(z, batch_shape + z.shape[1:]), unpacked
             )
         else:
@@ -1124,7 +1122,7 @@ def __call__(self, x):
 
     def _inverse(self, y):
         leading_dims = [
-            v.shape[0] if jnp.ndim(v) > 0 else 0 for v in tree_flatten(y)[0]
+            v.shape[0] if jnp.ndim(v) > 0 else 0 for v in jax.tree.flatten(y)[0]
         ]
         d0 = leading_dims[0]
         not_scalar = d0 > 0 or len(leading_dims) > 1
@@ -1417,7 +1415,7 @@ def _inverse(self, y: jnp.ndarray) -> jnp.ndarray:
         return y
 
     def extend_axis_rev(self, array: jnp.ndarray, axis: int) -> jnp.ndarray:
-        normalized_axis = normalize_axis_tuple(axis, array.ndim)[0]
+        normalized_axis = axis if axis >= 0 else jnp.ndim(array) + axis
 
         n = array.shape[normalized_axis]
         last = jnp.take(array, jnp.array([-1]), axis=normalized_axis)
diff --git a/numpyro/distributions/truncated.py b/numpyro/distributions/truncated.py
index 078ea83bc..236e2a000 100644
--- a/numpyro/distributions/truncated.py
+++ b/numpyro/distributions/truncated.py
@@ -1,11 +1,11 @@
 # Copyright Contributors to the Pyro project.
 # SPDX-License-Identifier: Apache-2.0
 
+import jax
 from jax import lax
 import jax.numpy as jnp
 import jax.random as random
 from jax.scipy.special import logsumexp
-from jax.tree_util import tree_map
 
 from numpyro.distributions import constraints
 from numpyro.distributions.continuous import (
@@ -38,7 +38,7 @@ def __init__(self, base_dist, low=0.0, *, validate_args=None):
             base_dist.support is constraints.real
         ), "The base distribution should be univariate and have real support."
         batch_shape = lax.broadcast_shapes(base_dist.batch_shape, jnp.shape(low))
-        self.base_dist = tree_map(
+        self.base_dist = jax.tree.map(
             lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist
         )
         (self.low,) = promote_shapes(low, shape=batch_shape)
@@ -117,7 +117,7 @@ def __init__(self, base_dist, high=0.0, *, validate_args=None):
             base_dist.support is constraints.real
         ), "The base distribution should be univariate and have real support."
         batch_shape = lax.broadcast_shapes(base_dist.batch_shape, jnp.shape(high))
-        self.base_dist = tree_map(
+        self.base_dist = jax.tree.map(
             lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist
         )
         (self.high,) = promote_shapes(high, shape=batch_shape)
@@ -186,7 +186,7 @@ def __init__(self, base_dist, low=0.0, high=1.0, *, validate_args=None):
         batch_shape = lax.broadcast_shapes(
             base_dist.batch_shape, jnp.shape(low), jnp.shape(high)
         )
-        self.base_dist = tree_map(
+        self.base_dist = jax.tree.map(
             lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist
         )
         (self.low,) = promote_shapes(low, shape=batch_shape)
@@ -348,7 +348,7 @@ def sample(self, key, sample_shape=()):
             key, jnp.ones(self.batch_shape + sample_shape + (self.num_gamma_variates,))
         )
         x = jnp.sum(x / denom, axis=-1)
-        return jnp.clip(x * (0.5 / jnp.pi**2), a_max=self.truncation_point)
+        return jnp.clip(x * (0.5 / jnp.pi**2), None, self.truncation_point)
 
     @validate_sample
     def log_prob(self, value):
diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py
index aca32b1f7..c83efb701 100644
--- a/numpyro/distributions/util.py
+++ b/numpyro/distributions/util.py
@@ -386,7 +386,7 @@ def scan_fn(carry, val):
 def signed_stick_breaking_tril(t):
     # make sure that t in (-1, 1)
     eps = jnp.finfo(t.dtype).eps
-    t = jnp.clip(t, a_min=(-1 + eps), a_max=(1 - eps))
+    t = jnp.clip(t, -1 + eps, 1 - eps)
     # transform t to tril matrix with identity diagonal
     r = vec_to_tril_matrix(t, diagonal=-1)
 
@@ -417,7 +417,7 @@ def logmatmulexp(x, y):
 
 def clamp_probs(probs):
     finfo = jnp.finfo(jnp.result_type(probs, float))
-    return jnp.clip(probs, a_min=finfo.tiny, a_max=1.0 - finfo.eps)
+    return jnp.clip(probs, finfo.tiny, 1.0 - finfo.eps)
 
 
 def betainc(a, b, x):
@@ -607,7 +607,7 @@ def safe_normalize(x, *, p=2):
     assert isinstance(p, (float, int))
     assert p >= 0
     norm = jnp.linalg.norm(x, p, axis=-1, keepdims=True)
-    x = x / jnp.clip(norm, a_min=jnp.finfo(x).tiny)
+    x = x / jnp.clip(norm, jnp.finfo(x).tiny)
     # Avoid the singularity.
     mask = jnp.all(x == 0, axis=-1, keepdims=True)
     x = jnp.where(mask, x.shape[-1] ** (-1 / p), x)
diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py
index 351cb6390..624865a75 100644
--- a/numpyro/infer/autoguide.py
+++ b/numpyro/infer/autoguide.py
@@ -14,7 +14,6 @@
 from jax import grad, hessian, lax, random
 from jax.example_libraries import stax
 import jax.numpy as jnp
-from jax.tree_util import tree_map
 
 import numpyro
 from numpyro import handlers
@@ -454,12 +453,12 @@ def _constrain(self, latent_samples):
             : jnp.ndim(latent_samples[name]) - jnp.ndim(self._init_locs[name])
         ]
         if sample_shape:
-            flatten_samples = tree_map(
+            flatten_samples = jax.tree.map(
                 lambda x: jnp.reshape(x, (-1,) + jnp.shape(x)[len(sample_shape) :]),
                 latent_samples,
             )
             contrained_samples = lax.map(self._postprocess_fn, flatten_samples)
-            return tree_map(
+            return jax.tree.map(
                 lambda x: jnp.reshape(x, sample_shape + jnp.shape(x)[1:]),
                 contrained_samples,
             )
@@ -751,7 +750,7 @@ def unpack_single_latent(latent):
                 latent_sample, (-1, jnp.shape(latent_sample)[-1])
             )
             unpacked_samples = lax.map(unpack_single_latent, latent_sample)
-            return tree_map(
+            return jax.tree.map(
                 lambda x: jnp.reshape(x, sample_shape + jnp.shape(x)[1:]),
                 unpacked_samples,
             )
@@ -968,7 +967,7 @@ def log_density(x):
         def scan_body(carry, eps_beta):
             eps, beta = eps_beta
             eta = eta0 + eta_coeff * beta
-            eta = jnp.clip(eta, a_min=0.0, a_max=self.eta_max)
+            eta = jnp.clip(eta, 0.0, self.eta_max)
             z_prev, v_prev, log_factor = carry
             z_half = z_prev + v_prev * eta * inv_mass_matrix
             q_grad = (1.0 - beta) * grad(base_z_dist.log_prob)(z_half)
@@ -997,7 +996,7 @@ def _single_sample(_rng_key):
         if sample_shape:
             rng_key = random.split(rng_key, int(np.prod(sample_shape)))
             samples = lax.map(_single_sample, rng_key)
-            return tree_map(
+            return jax.tree.map(
                 lambda x: jnp.reshape(x, sample_shape + jnp.shape(x)[1:]),
                 samples,
             )
@@ -1187,7 +1186,7 @@ def blocked_surrogate_model(x):
         def scan_body(carry, eps_beta):
             eps, beta = eps_beta
             eta = eta0 + eta_coeff * beta
-            eta = jnp.clip(eta, a_min=0.0, a_max=self.eta_max)
+            eta = jnp.clip(eta, 0.0, self.eta_max)
             z_prev, v_prev, log_factor = carry
             z_half = z_prev + v_prev * eta * inv_mass_matrix
             q_grad = (1.0 - beta) * grad(base_z_dist_log_prob)(z_half)
@@ -1642,7 +1641,7 @@ def base_z_dist_log_prob(x):
             def scan_body(carry, eps_beta):
                 eps, beta = eps_beta
                 eta = eta0 + eta_coeff * beta
-                eta = jnp.clip(eta, a_min=0.0, a_max=self.eta_max)
+                eta = jnp.clip(eta, 0.0, self.eta_max)
                 assert eps.shape == (subsample_size, D)
                 assert eta.shape == beta.shape == (subsample_size,)
                 z_prev, v_prev, log_factor = carry
@@ -1697,7 +1696,7 @@ def _single_sample(_rng_key):
         if sample_shape:
             rng_key = random.split(rng_key, int(np.prod(sample_shape)))
             samples = lax.map(_single_sample, rng_key)
-            return tree_map(
+            return jax.tree.map(
                 lambda x: jnp.reshape(x, sample_shape + jnp.shape(x)[1:]),
                 samples,
             )
diff --git a/numpyro/infer/barker.py b/numpyro/infer/barker.py
index 5496e7baa..9d5fa0f2b 100644
--- a/numpyro/infer/barker.py
+++ b/numpyro/infer/barker.py
@@ -260,7 +260,7 @@ def sample(self, state, model_args, model_kwargs):
                 - softplus(-dx_flat * y_grad_flat_scaled)
             )
         )
-        accept_prob = jnp.clip(jnp.exp(log_accept_ratio), a_max=1.0)
+        accept_prob = jnp.clip(jnp.exp(log_accept_ratio), None, 1.0)
 
         x, x_flat, pe, x_grad = jax.lax.cond(
             random.bernoulli(key_accept, accept_prob),
diff --git a/numpyro/infer/ensemble.py b/numpyro/infer/ensemble.py
index 33e3f50bf..033b2c7b0 100644
--- a/numpyro/infer/ensemble.py
+++ b/numpyro/infer/ensemble.py
@@ -160,7 +160,7 @@ def init(
             assert all(
                 [
                     param.shape[0] == self._num_chains
-                    for param in jax.tree_util.tree_leaves(init_params)
+                    for param in jax.tree.leaves(init_params)
                 ]
             ), "The batch dimension of each param must match n_chains"
 
diff --git a/numpyro/infer/ensemble_util.py b/numpyro/infer/ensemble_util.py
index 9f213ea5c..a5a3dd569 100644
--- a/numpyro/infer/ensemble_util.py
+++ b/numpyro/infer/ensemble_util.py
@@ -6,7 +6,6 @@
 import jax
 from jax.flatten_util import ravel_pytree
 import jax.numpy as jnp
-from jax.tree_util import tree_map
 
 
 def get_nondiagonal_indices(n):
@@ -41,6 +40,6 @@ def batch_ravel_pytree(pytree):
       component of the output.
     """
     flat = jax.vmap(lambda x: ravel_pytree(x)[0])(pytree)
-    unravel_fn = jax.vmap(ravel_pytree(tree_map(lambda z: z[0], pytree))[1])
+    unravel_fn = jax.vmap(ravel_pytree(jax.tree.map(lambda z: z[0], pytree))[1])
 
     return flat, unravel_fn
diff --git a/numpyro/infer/hmc.py b/numpyro/infer/hmc.py
index 709c16824..a299dc89c 100644
--- a/numpyro/infer/hmc.py
+++ b/numpyro/infer/hmc.py
@@ -400,7 +400,7 @@ def _hmc_next(
         )
         delta_energy = energy_new - energy_old
         delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy)
-        accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0)
+        accept_prob = jnp.clip(jnp.exp(-delta_energy), None, 1.0)
         diverging = delta_energy > max_delta_energy
         transition = random.bernoulli(rng_key, accept_prob)
         vv_state, energy = cond(
diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py
index 53622dcf2..f6b95389b 100644
--- a/numpyro/infer/hmc_gibbs.py
+++ b/numpyro/infer/hmc_gibbs.py
@@ -657,7 +657,7 @@ def potential_fn(z_gibbs, gibbs_state, z_hmc):
         # given a fixed hmc_sites, pe_new - pe_curr = loglik_new - loglik_curr
         pe = state.hmc_state.potential_energy
         pe_new = potential_fn(z_gibbs_new, gibbs_state_new, state.hmc_state.z)
-        accept_prob = jnp.clip(jnp.exp(pe - pe_new), a_max=1.0)
+        accept_prob = jnp.clip(jnp.exp(pe - pe_new), None, 1.0)
         transition = random.bernoulli(rng_key, accept_prob)
         grad_ = jacfwd if self.inner_kernel._forward_mode_differentiation else grad
         z_gibbs, gibbs_state, pe, z_grad = cond(
diff --git a/numpyro/infer/hmc_util.py b/numpyro/infer/hmc_util.py
index b331540d1..a21e7329e 100644
--- a/numpyro/infer/hmc_util.py
+++ b/numpyro/infer/hmc_util.py
@@ -3,12 +3,12 @@
 
 from collections import OrderedDict, namedtuple
 
+import jax
 from jax import grad, jacfwd, random, value_and_grad, vmap
 from jax.flatten_util import ravel_pytree
 import jax.numpy as jnp
 from jax.scipy.linalg import solve_triangular
 from jax.scipy.special import expit
-from jax.tree_util import tree_flatten, tree_map
 
 import numpyro.distributions as dist
 from numpyro.util import cond, identity, while_loop
@@ -295,15 +295,15 @@ def update_fn(step_size, inverse_mass_matrix, state):
         :return: new state for the integrator.
         """
         z, r, _, z_grad = state
-        r = tree_map(
+        r = jax.tree.map(
             lambda r, z_grad: r - 0.5 * step_size * z_grad, r, z_grad
         )  # r(n+1/2)
         r_grad = _kinetic_grad(kinetic_fn, inverse_mass_matrix, r)
-        z = tree_map(lambda z, r_grad: z + step_size * r_grad, z, r_grad)  # z(n+1)
+        z = jax.tree.map(lambda z, r_grad: z + step_size * r_grad, z, r_grad)  # z(n+1)
         potential_energy, z_grad = _value_and_grad(
             potential_fn, z, forward_mode_differentiation
         )
-        r = tree_map(
+        r = jax.tree.map(
             lambda r, z_grad: r - 0.5 * step_size * z_grad, r, z_grad
         )  # r(n+1)
         return IntegratorState(z, r, potential_energy, z_grad)
@@ -669,7 +669,7 @@ def update_fn(t, accept_prob, z_info, state):
             )
             # account the the case log_step_size is an extreme number
             finfo = jnp.finfo(jnp.result_type(step_size))
-            step_size = jnp.clip(step_size, a_min=finfo.tiny, a_max=finfo.max)
+            step_size = jnp.clip(step_size, finfo.tiny, finfo.max)
 
         # update mass matrix state
         is_middle_window = (0 < window_idx) & (window_idx < (num_windows - 1))
@@ -759,7 +759,7 @@ def _biased_transition_kernel(current_tree, new_tree):
     # If new tree is turning or diverging, we won't move the proposal
     # to the new tree.
     transition_prob = jnp.where(
-        new_tree.turning | new_tree.diverging, 0.0, jnp.clip(transition_prob, a_max=1.0)
+        new_tree.turning | new_tree.diverging, 0.0, jnp.clip(transition_prob, None, 1.0)
     )
     return transition_prob
 
@@ -790,7 +790,7 @@ def _combine_tree(
             trees[1].z_right_grad,
         ),
     )
-    r_sum = tree_map(jnp.add, current_tree.r_sum, new_tree.r_sum)
+    r_sum = jax.tree.map(jnp.add, current_tree.r_sum, new_tree.r_sum)
 
     if biased_transition:
         transition_prob = _biased_transition_kernel(current_tree, new_tree)
@@ -872,7 +872,7 @@ def _build_basetree(
     tree_weight = -delta_energy
 
     diverging = delta_energy > max_delta_energy
-    accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0)
+    accept_prob = jnp.clip(jnp.exp(-delta_energy), None, 1.0)
     return TreeInfo(
         z_new,
         r_new,
@@ -1242,7 +1242,7 @@ def consensus(subposteriors, num_draws=None, diagonal=False, rng_key=None):
         a collection of `num_draws` samples with the same data structure as each subposterior.
     """
     # stack subposteriors
-    joined_subposteriors = tree_map(lambda *args: jnp.stack(args), *subposteriors)
+    joined_subposteriors = jax.tree.map(lambda *args: jnp.stack(args), *subposteriors)
     # shape of joined_subposteriors: n_subs x n_samples x sample_shape
     joined_subposteriors = vmap(vmap(lambda sample: ravel_pytree(sample)[0]))(
         joined_subposteriors
@@ -1252,7 +1252,7 @@ def consensus(subposteriors, num_draws=None, diagonal=False, rng_key=None):
         rng_key = random.PRNGKey(0) if rng_key is None else rng_key
         # randomly gets num_draws from subposteriors
         n_subs = len(subposteriors)
-        n_samples = tree_flatten(subposteriors[0])[0][0].shape[0]
+        n_samples = jax.tree.flatten(subposteriors[0])[0][0].shape[0]
         # shape of draw_idxs: n_subs x num_draws x sample_shape
         draw_idxs = random.randint(
             rng_key, shape=(n_subs, num_draws), minval=0, maxval=n_samples
@@ -1279,7 +1279,7 @@ def consensus(subposteriors, num_draws=None, diagonal=False, rng_key=None):
         )
 
     # unravel_fn acts on 1 sample of a subposterior
-    _, unravel_fn = ravel_pytree(tree_map(lambda x: x[0], subposteriors[0]))
+    _, unravel_fn = ravel_pytree(jax.tree.map(lambda x: x[0], subposteriors[0]))
     return vmap(lambda x: unravel_fn(x))(samples_flat)
 
 
@@ -1297,7 +1297,7 @@ def parametric(subposteriors, diagonal=False):
         `False` (using covariance).
     :return: the estimated mean and variance/covariance parameters of the joined posterior
     """
-    joined_subposteriors = tree_map(lambda *args: jnp.stack(args), *subposteriors)
+    joined_subposteriors = jax.tree.map(lambda *args: jnp.stack(args), *subposteriors)
     joined_subposteriors = vmap(vmap(lambda sample: ravel_pytree(sample)[0]))(
         joined_subposteriors
     )
@@ -1345,5 +1345,5 @@ def parametric_draws(subposteriors, num_draws, diagonal=False, rng_key=None):
         mean, cov = parametric(subposteriors, diagonal=False)
         samples_flat = dist.MultivariateNormal(mean, cov).sample(rng_key, (num_draws,))
 
-    _, unravel_fn = ravel_pytree(tree_map(lambda x: x[0], subposteriors[0]))
+    _, unravel_fn = ravel_pytree(jax.tree.map(lambda x: x[0], subposteriors[0]))
     return vmap(lambda x: unravel_fn(x))(samples_flat)
diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py
index f1e0a7013..ad016825b 100644
--- a/numpyro/infer/mcmc.py
+++ b/numpyro/infer/mcmc.py
@@ -9,9 +9,9 @@
 
 import numpy as np
 
+import jax
 from jax import device_get, jit, lax, local_device_count, pmap, random, vmap
 import jax.numpy as jnp
-from jax.tree_util import tree_flatten, tree_map
 
 from numpyro.diagnostics import print_summary
 from numpyro.util import (
@@ -164,18 +164,18 @@ def _get_progbar_desc_str(num_warmup, phase, i):
 
 
 def _get_value_from_index(xs, i):
-    return tree_map(lambda x: x[i], xs)
+    return jax.tree.map(lambda x: x[i], xs)
 
 
 def _laxmap(f, xs):
-    n = tree_flatten(xs)[0][0].shape[0]
+    n = jax.tree.flatten(xs)[0][0].shape[0]
 
     ys = []
     for i in range(n):
         x = jit(_get_value_from_index)(xs, i)
         ys.append(f(x))
 
-    return tree_map(lambda *args: jnp.stack(args), *ys)
+    return jax.tree.map(lambda *args: jnp.stack(args), *ys)
 
 
 def _sample_fn_jit_args(state, sampler):
@@ -378,8 +378,8 @@ def _get_cached_fns(self):
         if self._jit_model_args:
             args, kwargs = (None,), (None,)
         else:
-            args = tree_map(lambda x: _hashable(x), self._args)
-            kwargs = tree_map(
+            args = jax.tree.map(lambda x: _hashable(x), self._args)
+            kwargs = jax.tree.map(
                 lambda x: _hashable(x), tuple(sorted(self._kwargs.items()))
             )
         key = args + kwargs
@@ -422,8 +422,8 @@ def laxmap_postprocess_fn(states, args, kwargs):
 
     def _get_cached_init_state(self, rng_key, args, kwargs):
         rng_key = (_hashable(rng_key),)
-        args = tree_map(lambda x: _hashable(x), args)
-        kwargs = tree_map(lambda x: _hashable(x), tuple(sorted(kwargs.items())))
+        args = jax.tree.map(lambda x: _hashable(x), args)
+        kwargs = jax.tree.map(lambda x: _hashable(x), tuple(sorted(kwargs.items())))
         key = rng_key + args + kwargs
         try:
             return self._init_state_cache.get(key, None)
@@ -480,7 +480,7 @@ def _single_chain_mcmc(self, init, args, kwargs, collect_fields, remove_sites):
             states = (states,)
         states = dict(zip(collect_fields, states))
         # Apply constraints if number of samples is non-zero
-        site_values = tree_flatten(states[self._sample_field])[0]
+        site_values = jax.tree.flatten(states[self._sample_field])[0]
         # XXX: lax.map still works if some arrays have 0 size
         # so we only need to filter out the case site_value.shape[0] == 0
         # (which happens when lower_idx==upper_idx)
@@ -509,8 +509,8 @@ def _compile(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs):
             rng_key, *args, extra_fields=extra_fields, init_params=init_params, **kwargs
         )
         rng_key = (_hashable(rng_key),)
-        args = tree_map(lambda x: _hashable(x), args)
-        kwargs = tree_map(lambda x: _hashable(x), tuple(sorted(kwargs.items())))
+        args = jax.tree.map(lambda x: _hashable(x), args)
+        kwargs = jax.tree.map(lambda x: _hashable(x), tuple(sorted(kwargs.items())))
         key = rng_key + args + kwargs
         try:
             self._init_state_cache[key] = self._last_state
@@ -520,7 +520,7 @@ def _compile(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs):
 
     def _get_states_flat(self):
         if self._states_flat is None:
-            self._states_flat = tree_map(
+            self._states_flat = jax.tree.map(
                 # need to calculate first dimension manually; see issue #1328
                 lambda x: jnp.reshape(x, (x.shape[0] * x.shape[1],) + x.shape[2:]),
                 self._states,
@@ -629,7 +629,7 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs):
             See https://jax.readthedocs.io/en/latest/async_dispatch.html and
             https://jax.readthedocs.io/en/latest/profiling.html for pointers on profiling jax programs.
         """
-        init_params = tree_map(
+        init_params = jax.tree.map(
             lambda x: lax.convert_element_type(x, jnp.result_type(x)), init_params
         )
         self._args = args
@@ -643,7 +643,7 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs):
             init_state = self._warmup_state._replace(rng_key=rng_key)
 
         if init_params is not None and self.num_chains > 1:
-            prototype_init_val = tree_flatten(init_params)[0][0]
+            prototype_init_val = jax.tree.flatten(init_params)[0][0]
             if jnp.shape(prototype_init_val)[0] != self.num_chains:
                 raise ValueError(
                     "`init_params` must have the same leading dimension"
@@ -673,7 +673,7 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs):
         map_args = (rng_key, init_state, init_params)
         if self.num_chains == 1:
             states_flat, last_state = partial_map_fn(map_args)
-            states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
+            states = jax.tree.map(lambda x: x[jnp.newaxis, ...], states_flat)
         else:
             if self.chain_method == "sequential":
                 states, last_state = _laxmap(partial_map_fn, map_args)
@@ -683,7 +683,7 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs):
                 assert self.chain_method == "vectorized"
                 states, last_state = partial_map_fn(map_args)
                 # swap num_samples x num_chains to num_chains x num_samples
-                states = tree_map(lambda x: jnp.swapaxes(x, 0, 1), states)
+                states = jax.tree.map(lambda x: jnp.swapaxes(x, 0, 1), states)
 
         self._last_state = last_state
         self._states = states
diff --git a/numpyro/infer/mixed_hmc.py b/numpyro/infer/mixed_hmc.py
index deea69a79..3e3d2ae59 100644
--- a/numpyro/infer/mixed_hmc.py
+++ b/numpyro/infer/mixed_hmc.py
@@ -263,7 +263,7 @@ def body_fn(i, vals):
         # Algo 1, line 11: perform MH correction
         delta_energy = energy_new - energy_old - delta_pe_sum
         delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy)
-        accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0)
+        accept_prob = jnp.clip(jnp.exp(-delta_energy), None, 1.0)
 
         # record the correct new num_steps
         hmc_state = hmc_state._replace(num_steps=hmc_state_new.num_steps)
diff --git a/numpyro/infer/svi.py b/numpyro/infer/svi.py
index b21b531de..c70f0d91a 100644
--- a/numpyro/infer/svi.py
+++ b/numpyro/infer/svi.py
@@ -11,7 +11,6 @@
 from jax import jit, lax, random
 from jax.example_libraries import optimizers
 import jax.numpy as jnp
-from jax.tree_util import tree_map
 
 from numpyro.distributions import constraints
 from numpyro.distributions.transforms import biject_to
@@ -240,7 +239,7 @@ def init(self, rng_key, *args, init_params=None, **kwargs):
         self.constrain_fn = partial(transform_fn, inv_transforms)
         # we convert weak types like float to float32/float64
         # to avoid recompiling body_fn in svi.run
-        params, mutable_state = tree_map(
+        params, mutable_state = jax.tree.map(
             lambda x: lax.convert_element_type(x, jnp.result_type(x)),
             (params, mutable_state),
         )
diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py
index 7d57cbbd0..3775893a7 100644
--- a/numpyro/infer/util.py
+++ b/numpyro/infer/util.py
@@ -14,7 +14,6 @@
 from jax.flatten_util import ravel_pytree
 from jax.lax import broadcast_shapes
 import jax.numpy as jnp
-from jax.tree_util import tree_flatten, tree_map
 
 import numpyro
 from numpyro.distributions import constraints
@@ -770,7 +769,7 @@ def _predictive(
         # inspect the model to get some structure
         rng_key, subkey = random.split(rng_key)
         batch_ndim = len(batch_shape)
-        prototype_sample = tree_map(
+        prototype_sample = jax.tree.map(
             lambda x: jnp.reshape(x, (-1,) + jnp.shape(x)[batch_ndim:])[0],
             posterior_samples,
         )
@@ -1027,7 +1026,7 @@ def __call__(self, rng_key, *args, **kwargs):
         if self.batch_ndims == 0 or self.params == {} or self.guide is None:
             return self._call_with_params(rng_key, self.params, args, kwargs)
         elif self.batch_ndims == 1:  # batch over parameters
-            batch_size = jnp.shape(tree_flatten(self.params)[0][0])[0]
+            batch_size = jnp.shape(jax.tree.flatten(self.params)[0][0])[0]
             rng_keys = random.split(rng_key, batch_size)
             return jax.vmap(
                 partial(self._call_with_params, args=args, kwargs=kwargs),
diff --git a/numpyro/ops/provenance.py b/numpyro/ops/provenance.py
index dd88afdf8..68797396e 100644
--- a/numpyro/ops/provenance.py
+++ b/numpyro/ops/provenance.py
@@ -39,7 +39,7 @@ def eval_provenance(fn, **kwargs):
     :returns: A pytree of :class:`frozenset` indicating the dependency on the inputs.
     """
     # Flatten the function and its arguments
-    args, in_tree = jax.tree_util.tree_flatten(((), kwargs))
+    args, in_tree = jax.tree.flatten(((), kwargs))
     wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fn), in_tree)
     # Abstract eval to get output pytree
     avals = util.safe_map(shaped_abstractify, args)
@@ -54,19 +54,17 @@ def eval_provenance(fn, **kwargs):
     aval_kwargs = {}
     for n, v in kwargs.items():
         aval = jax.ShapeDtypeStruct((), jnp.bool_, {"provenance": frozenset({n})})
-        aval_kwargs[n] = jax.tree_util.tree_map(lambda _: aval, v)
-    aval_args, _ = jax.tree_util.tree_flatten(((), aval_kwargs))
-    provenance_inputs = jax.tree_util.tree_map(
-        lambda x: x.named_shape["provenance"], aval_args
-    )
+        aval_kwargs[n] = jax.tree.map(lambda _: aval, v)
+    aval_args, _ = jax.tree.flatten(((), aval_kwargs))
+    provenance_inputs = jax.tree.map(lambda x: x.named_shape["provenance"], aval_args)
 
     provenance_outputs = track_deps_jaxpr(jaxpr, provenance_inputs)
     out_flat = []
     for v, p in zip(avals_out, provenance_outputs):
         val = jax.ShapeDtypeStruct(jnp.shape(v), jnp.result_type(v), {"provenance": p})
         out_flat.append(val)
-    out = jax.tree_util.tree_unflatten(out_tree(), out_flat)
-    return jax.tree_util.tree_map(lambda x: x.named_shape["provenance"], out)
+    out = jax.tree.unflatten(out_tree(), out_flat)
+    return jax.tree.map(lambda x: x.named_shape["provenance"], out)
 
 
 def track_deps_jaxpr(jaxpr, provenance_inputs):
diff --git a/numpyro/optim.py b/numpyro/optim.py
index 225a6f6bb..0abc90ee3 100644
--- a/numpyro/optim.py
+++ b/numpyro/optim.py
@@ -11,12 +11,13 @@
 from collections.abc import Callable
 from typing import Any, TypeVar
 
+import jax
 from jax import jacfwd, lax, value_and_grad
 from jax.example_libraries import optimizers
 from jax.flatten_util import ravel_pytree
 import jax.numpy as jnp
 from jax.scipy.optimize import minimize
-from jax.tree_util import register_pytree_node, tree_map
+from jax.tree_util import register_pytree_node
 
 __all__ = [
     "Adam",
@@ -176,9 +177,7 @@ def __init__(self, *args, clip_norm=10.0, **kwargs):
     def update(self, g, state):
         i, opt_state = state
         # clip norm
-        g = tree_map(
-            lambda g_: jnp.clip(g_, a_min=-self.clip_norm, a_max=self.clip_norm), g
-        )
+        g = jax.tree.map(lambda g_: jnp.clip(g_, -self.clip_norm, self.clip_norm), g)
         opt_state = self.update_fn(i, g, opt_state)
         return i + 1, opt_state
 
diff --git a/numpyro/util.py b/numpyro/util.py
index 23a210d03..a30c09a53 100644
--- a/numpyro/util.py
+++ b/numpyro/util.py
@@ -20,7 +20,6 @@
 from jax.core import Tracer
 from jax.experimental import host_callback
 import jax.numpy as jnp
-from jax.tree_util import tree_flatten, tree_map
 
 _DISABLE_CONTROL_FLOW_PRIM = False
 _CHAIN_RE = re.compile(r"\d+$")  # e.g. get '3' from 'TFRT_CPU_3'
@@ -423,7 +422,7 @@ def soft_vmap(fn, xs, batch_ndims=1, chunk_size=None):
         Defaults to the size of batch dimensions.
     :returns: output of `fn(xs)`.
     """
-    flatten_xs = tree_flatten(xs)[0]
+    flatten_xs = jax.tree.flatten(xs)[0]
     batch_shape = np.shape(flatten_xs[0])[:batch_ndims]
     for x in flatten_xs[1:]:
         assert np.shape(x)[:batch_ndims] == batch_shape
@@ -431,7 +430,7 @@ def soft_vmap(fn, xs, batch_ndims=1, chunk_size=None):
     # we'll do map(vmap(fn), xs) and make xs.shape = (num_chunks, chunk_size, ...)
     num_chunks = batch_size = int(np.prod(batch_shape))
     prepend_shape = (batch_size,) if batch_size > 1 else ()
-    xs = tree_map(
+    xs = jax.tree.map(
         lambda x: jnp.reshape(x, prepend_shape + jnp.shape(x)[batch_ndims:]), xs
     )
     # XXX: probably for the default behavior with chunk_size=None,
@@ -439,12 +438,12 @@ def soft_vmap(fn, xs, batch_ndims=1, chunk_size=None):
     chunk_size = batch_size if chunk_size is None else min(batch_size, chunk_size)
     if chunk_size > 1:
         pad = chunk_size - (batch_size % chunk_size)
-        xs = tree_map(
+        xs = jax.tree.map(
             lambda x: jnp.pad(x, ((0, pad),) + ((0, 0),) * (np.ndim(x) - 1)), xs
         )
         num_chunks = batch_size // chunk_size + int(pad > 0)
         prepend_shape = (-1,) if num_chunks > 1 else ()
-        xs = tree_map(
+        xs = jax.tree.map(
             lambda x: jnp.reshape(x, prepend_shape + (chunk_size,) + jnp.shape(x)[1:]),
             xs,
         )
@@ -452,13 +451,13 @@ def soft_vmap(fn, xs, batch_ndims=1, chunk_size=None):
 
     ys = lax.map(fn, xs) if num_chunks > 1 else fn(xs)
     map_ndims = int(num_chunks > 1) + int(chunk_size > 1)
-    ys = tree_map(
+    ys = jax.tree.map(
         lambda y: jnp.reshape(
             y, (int(np.prod(jnp.shape(y)[:map_ndims])),) + jnp.shape(y)[map_ndims:]
         )[:batch_size],
         ys,
     )
-    return tree_map(lambda y: jnp.reshape(y, batch_shape + jnp.shape(y)[1:]), ys)
+    return jax.tree.map(lambda y: jnp.reshape(y, batch_shape + jnp.shape(y)[1:]), ys)
 
 
 def format_shapes(
diff --git a/setup.py b/setup.py
index 0d9e4fb02..b47b0bc92 100644
--- a/setup.py
+++ b/setup.py
@@ -9,8 +9,8 @@
 from setuptools import find_packages, setup
 
 PROJECT_PATH = os.path.dirname(os.path.abspath(__file__))
-_jax_version_constraints = ">=0.4.14"
-_jaxlib_version_constraints = ">=0.4.14"
+_jax_version_constraints = ">=0.4.25"
+_jaxlib_version_constraints = ">=0.4.25"
 
 # Find version
 for line in open(os.path.join(PROJECT_PATH, "numpyro", "version.py")):
diff --git a/test/contrib/einstein/test_steinvi_util.py b/test/contrib/einstein/test_steinvi_util.py
index 617effb71..38fbd0603 100644
--- a/test/contrib/einstein/test_steinvi_util.py
+++ b/test/contrib/einstein/test_steinvi_util.py
@@ -8,8 +8,8 @@
 import pytest
 import scipy
 
+import jax
 from jax import numpy as jnp
-from jax.tree_util import tree_flatten, tree_map
 
 from numpyro.contrib.einstein.stein_util import batch_ravel_pytree, posdef, sqrth
 
@@ -82,10 +82,10 @@ def test_sqrth_shape(batch_shape):
 def test_ravel_pytree_batched(pytree, nbatch_dims):
     flat, _, unravel_fn = batch_ravel_pytree(pytree, nbatch_dims)
     unravel = unravel_fn(flat)
-    tree_flatten(tree_map(lambda x, y: assert_allclose(x, y), unravel, pytree))
+    jax.tree.flatten(jax.tree.map(lambda x, y: assert_allclose(x, y), unravel, pytree))
     assert all(
-        tree_flatten(
-            tree_map(
+        jax.tree.flatten(
+            jax.tree.map(
                 lambda x, y: jnp.result_type(x) == jnp.result_type(y), unravel, pytree
             )
         )[0]
diff --git a/test/contrib/test_enum_elbo.py b/test/contrib/test_enum_elbo.py
index d464fb33a..52c270cea 100644
--- a/test/contrib/test_enum_elbo.py
+++ b/test/contrib/test_enum_elbo.py
@@ -32,9 +32,7 @@
 
 
 def assert_equal(a, b, prec=0):
-    return jax.tree_util.tree_map(
-        lambda a, b: np.testing.assert_allclose(a, b, atol=prec), a, b
-    )
+    return jax.tree.map(lambda a, b: np.testing.assert_allclose(a, b, atol=prec), a, b)
 
 
 def xfail_param(*args, **kwargs):
@@ -122,16 +120,16 @@ def guide(params):
         pyro.sample("x", dist.Categorical(probs_x), infer={"enumerate": "parallel"})
 
     def auto_loss_fn(params_raw):
-        params = jax.tree_util.tree_map(transform, params_raw)
+        params = jax.tree.map(transform, params_raw)
         elbo = infer.TraceEnum_ELBO()
         return elbo.loss(random.PRNGKey(0), {}, auto_model, guide, params)
 
     def hand_loss_fn(params_raw):
-        params = jax.tree_util.tree_map(transform, params_raw)
+        params = jax.tree.map(transform, params_raw)
         elbo = infer.TraceEnum_ELBO()
         return elbo.loss(random.PRNGKey(0), {}, hand_model, guide, params)
 
-    params_raw = jax.tree_util.tree_map(transform.inv, params)
+    params_raw = jax.tree.map(transform.inv, params)
     auto_loss, auto_grad = jax.value_and_grad(auto_loss_fn)(params_raw)
     hand_loss, hand_grad = jax.value_and_grad(hand_loss_fn)(params_raw)
 
@@ -174,16 +172,16 @@ def guide(params):
         pyro.sample("x", dist.Categorical(probs_x))
 
     def auto_loss_fn(params_raw):
-        params = jax.tree_util.tree_map(transform, params_raw)
+        params = jax.tree.map(transform, params_raw)
         elbo = infer.TraceEnum_ELBO()
         return elbo.loss(random.PRNGKey(0), {}, auto_model, guide, params)
 
     def hand_loss_fn(params_raw):
-        params = jax.tree_util.tree_map(transform, params_raw)
+        params = jax.tree.map(transform, params_raw)
         elbo = infer.TraceEnum_ELBO()
         return elbo.loss(random.PRNGKey(0), {}, hand_model, guide, params)
 
-    params_raw = jax.tree_util.tree_map(transform.inv, params)
+    params_raw = jax.tree.map(transform.inv, params)
     auto_loss, auto_grad = jax.value_and_grad(auto_loss_fn)(params_raw)
     hand_loss, hand_grad = jax.value_and_grad(hand_loss_fn)(params_raw)
 
@@ -225,16 +223,16 @@ def guide(params):
         pyro.sample("x", dist.Categorical(probs_x))
 
     def auto_loss_fn(params_raw):
-        params = jax.tree_util.tree_map(transform, params_raw)
+        params = jax.tree.map(transform, params_raw)
         elbo = infer.TraceEnum_ELBO()
         return elbo.loss(random.PRNGKey(0), {}, auto_model, guide, params)
 
     def hand_loss_fn(params_raw):
-        params = jax.tree_util.tree_map(transform, params_raw)
+        params = jax.tree.map(transform, params_raw)
         elbo = infer.TraceEnum_ELBO()
         return elbo.loss(random.PRNGKey(0), {}, hand_model, guide, params)
 
-    params_raw = jax.tree_util.tree_map(transform.inv, params)
+    params_raw = jax.tree.map(transform.inv, params)
     auto_loss, auto_grad = jax.value_and_grad(auto_loss_fn)(params_raw)
     hand_loss, hand_grad = jax.value_and_grad(hand_loss_fn)(params_raw)
 
@@ -296,16 +294,16 @@ def guide(data, params):
     )
 
     def auto_loss_fn(params_raw):
-        params = jax.tree_util.tree_map(transform, params_raw)
+        params = jax.tree.map(transform, params_raw)
         elbo = infer.TraceEnum_ELBO()
         return elbo.loss(random.PRNGKey(0), {}, auto_model, guide, data, params)
 
     def hand_loss_fn(params_raw):
-        params = jax.tree_util.tree_map(transform, params_raw)
+        params = jax.tree.map(transform, params_raw)
         elbo = infer.TraceEnum_ELBO()
         return elbo.loss(random.PRNGKey(0), {}, hand_model, guide, data, params)
 
-    params_raw = jax.tree_util.tree_map(transform.inv, params)
+    params_raw = jax.tree.map(transform.inv, params)
     auto_loss, auto_grad = jax.value_and_grad(auto_loss_fn)(params_raw)
     hand_loss, hand_grad = jax.value_and_grad(hand_loss_fn)(params_raw)
 
@@ -375,16 +373,16 @@ def guide(data, params):
     )
 
     def auto_loss_fn(params_raw):
-        params = jax.tree_util.tree_map(transform, params_raw)
+        params = jax.tree.map(transform, params_raw)
         elbo = infer.TraceEnum_ELBO()
         return elbo.loss(random.PRNGKey(0), {}, auto_model, guide, data, params)
 
     def hand_loss_fn(params_raw):
-        params = jax.tree_util.tree_map(transform, params_raw)
+        params = jax.tree.map(transform, params_raw)
         elbo = infer.TraceEnum_ELBO()
         return elbo.loss(random.PRNGKey(0), {}, hand_model, guide, data, params)
 
-    params_raw = jax.tree_util.tree_map(transform.inv, params)
+    params_raw = jax.tree.map(transform.inv, params)
     auto_loss, auto_grad = jax.value_and_grad(auto_loss_fn)(params_raw)
     hand_loss, hand_grad = jax.value_and_grad(hand_loss_fn)(params_raw)
 
@@ -467,16 +465,16 @@ def hand_guide(data, params):
     )
 
     def auto_loss_fn(params_raw):
-        params = jax.tree_util.tree_map(transform, params_raw)
+        params = jax.tree.map(transform, params_raw)
         elbo = infer.TraceEnum_ELBO()
         return elbo.loss(random.PRNGKey(0), {}, auto_model, auto_guide, data, params)
 
     def hand_loss_fn(params_raw):
-        params = jax.tree_util.tree_map(transform, params_raw)
+        params = jax.tree.map(transform, params_raw)
         elbo = infer.TraceEnum_ELBO()
         return elbo.loss(random.PRNGKey(0), {}, hand_model, hand_guide, data, params)
 
-    params_raw = jax.tree_util.tree_map(transform.inv, params)
+    params_raw = jax.tree.map(transform.inv, params)
     auto_loss, auto_grad = jax.value_and_grad(auto_loss_fn)(params_raw)
     hand_loss, hand_grad = jax.value_and_grad(hand_loss_fn)(params_raw)
 
@@ -2491,13 +2489,13 @@ def guide(params):
         "probs_a": jnp.array([3.0, 2.5]),
     }
     transform = dist.biject_to(dist.constraints.positive)
-    params_raw = jax.tree_util.tree_map(transform.inv, params)
+    params_raw = jax.tree.map(transform.inv, params)
 
     # TraceGraph_ELBO grads averaged over num_particles
     elbo = infer.TraceGraph_ELBO(num_particles=50_000)
 
     def graph_loss_fn(params_raw):
-        params = jax.tree_util.tree_map(transform, params_raw)
+        params = jax.tree.map(transform, params_raw)
         return elbo.loss(random.PRNGKey(0), {}, model, guide, params)
 
     graph_loss, graph_grads = jax.value_and_grad(graph_loss_fn)(params_raw)
@@ -2506,7 +2504,7 @@ def graph_loss_fn(params_raw):
     elbo = infer.TraceEnum_ELBO(num_particles=50_000)
 
     def enum_loss_fn(params_raw):
-        params = jax.tree_util.tree_map(transform, params_raw)
+        params = jax.tree.map(transform, params_raw)
         return elbo.loss(random.PRNGKey(0), {}, model, guide, params)
 
     enum_loss, enum_grads = jax.value_and_grad(enum_loss_fn)(params_raw)
diff --git a/test/contrib/test_module.py b/test/contrib/test_module.py
index 5f43dcb3e..d3f27f17e 100644
--- a/test/contrib/test_module.py
+++ b/test/contrib/test_module.py
@@ -7,8 +7,8 @@
 from numpy.testing import assert_allclose
 import pytest
 
+import jax
 from jax import random
-from jax.tree_util import tree_all, tree_map
 
 import numpyro
 from numpyro import handlers
@@ -141,8 +141,8 @@ def test_update_params():
         "a": {"b": {"c": {"d": ParamShape(())}, "e": 2}, "f": ParamShape((4,))}
     }
 
-    tree_all(
-        tree_map(
+    jax.tree.all(
+        jax.tree.map(
             assert_allclose,
             new_params,
             {
diff --git a/test/infer/test_autoguide.py b/test/infer/test_autoguide.py
index 9394e17a1..0e2e53aa4 100644
--- a/test/infer/test_autoguide.py
+++ b/test/infer/test_autoguide.py
@@ -7,10 +7,10 @@
 from numpy.testing import assert_allclose
 import pytest
 
+import jax
 from jax import jacobian, jit, lax, random, vmap
 from jax.example_libraries.stax import Dense
 import jax.numpy as jnp
-from jax.tree_util import tree_all, tree_map
 import optax
 from optax import piecewise_constant_schedule
 
@@ -270,7 +270,7 @@ def model(data, labels):
         transforms.biject_to(constraints.interval(-1, 1))(expected_sample["offset"]),
     )
 
-    tree_all(tree_map(assert_allclose, actual_output, expected_output))
+    jax.tree.all(jax.tree.map(assert_allclose, actual_output, expected_output))
 
 
 def test_uniform_normal():
@@ -391,8 +391,8 @@ def expected_model(data):
     expected_loss = svi.evaluate(svi_state, data)
 
     # test auto_loc, auto_scale
-    tree_all(tree_map(assert_allclose, actual_opt_params, expected_opt_params))
-    tree_all(tree_map(assert_allclose, actual_params, expected_params))
+    jax.tree.all(jax.tree.map(assert_allclose, actual_opt_params, expected_opt_params))
+    jax.tree.all(jax.tree.map(assert_allclose, actual_params, expected_params))
     # test latent values
     assert_allclose(actual_values["alpha"], expected_values["alpha"])
     assert_allclose(actual_values["loc_base"], expected_values["loc"])
diff --git a/test/infer/test_ensemble_util.py b/test/infer/test_ensemble_util.py
index ad28a76ad..94fb9e560 100644
--- a/test/infer/test_ensemble_util.py
+++ b/test/infer/test_ensemble_util.py
@@ -26,6 +26,6 @@ def test_batch_ravel_pytree():
     assert flattened.shape == (5, 2 + 3 + 4)
 
     for unflattened_leaf, original_leaf in zip(
-        jax.tree_util.tree_leaves(unflattened), jax.tree_util.tree_leaves(tree)
+        jax.tree.leaves(unflattened), jax.tree.leaves(tree)
     ):
         assert jnp.all(unflattened_leaf == original_leaf)
diff --git a/test/infer/test_gradient.py b/test/infer/test_gradient.py
index 2cb74dbe2..dec977909 100644
--- a/test/infer/test_gradient.py
+++ b/test/infer/test_gradient.py
@@ -22,9 +22,7 @@
 
 
 def assert_equal(a, b, prec=0):
-    return jax.tree_util.tree_map(
-        lambda a, b: np.testing.assert_allclose(a, b, atol=prec), a, b
-    )
+    return jax.tree.map(lambda a, b: np.testing.assert_allclose(a, b, atol=prec), a, b)
 
 
 def model_0(data, params):
@@ -107,13 +105,13 @@ def guide_2(data, params):
 )
 def test_gradient(model, guide, params, data):
     transform = dist.biject_to(dist.constraints.simplex)
-    params_raw = jax.tree_util.tree_map(transform.inv, params)
+    params_raw = jax.tree.map(transform.inv, params)
 
     # Expected grads based on exact integration
     elbo = infer.TraceEnum_ELBO()
 
     def expected_loss_fn(params_raw):
-        params = jax.tree_util.tree_map(transform, params_raw)
+        params = jax.tree.map(transform, params_raw)
         return elbo.loss(
             random.PRNGKey(0), {}, model, config_enumerate(guide), data, params
         )
@@ -124,7 +122,7 @@ def expected_loss_fn(params_raw):
     elbo = infer.TraceGraph_ELBO(num_particles=10_000)
 
     def actual_loss_fn(params_raw):
-        params = jax.tree_util.tree_map(transform, params_raw)
+        params = jax.tree.map(transform, params_raw)
         return elbo.loss(random.PRNGKey(0), {}, model, guide, data, params)
 
     actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw)
@@ -336,20 +334,20 @@ def guide(params):
         "probs_z3": jnp.array([[[0.4, 0.6], [0.5, 0.5]], [[0.7, 0.3], [0.9, 0.1]]]),
     }
     transform = dist.biject_to(dist.constraints.simplex)
-    params_raw = jax.tree_util.tree_map(transform.inv, params)
+    params_raw = jax.tree.map(transform.inv, params)
 
     elbo = infer.TraceEnum_ELBO()
 
     # Exact integration based on enumeration
     def expected_loss_fn(params_raw):
-        params = jax.tree_util.tree_map(transform, params_raw)
+        params = jax.tree.map(transform, params_raw)
         return elbo.loss(random.PRNGKey(0), {}, model, guide, params)
 
     expected_loss, expected_grads = jax.value_and_grad(expected_loss_fn)(params_raw)
 
     # Exact integration based on the mix of enumeration and analytic kl
     def actual_loss_fn(params_raw):
-        params = jax.tree_util.tree_map(transform, params_raw)
+        params = jax.tree.map(transform, params_raw)
         return elbo.loss(
             random.PRNGKey(0), {}, model, config_kl(guide, kl_sites), params
         )
diff --git a/test/infer/test_hmc_util.py b/test/infer/test_hmc_util.py
index 3b298c08d..f1a81e436 100644
--- a/test/infer/test_hmc_util.py
+++ b/test/infer/test_hmc_util.py
@@ -9,9 +9,9 @@
 from numpy.testing import assert_allclose
 import pytest
 
+import jax
 from jax import device_put, disable_jit, grad, jit, random
 import jax.numpy as jnp
-from jax.tree_util import tree_map
 
 import numpyro.distributions as dist
 from numpyro.infer.hmc_util import (
@@ -222,7 +222,7 @@ def get_final_state(model, step_size, num_steps, q_i, p_i):
     assert_allclose(energy_initial, energy_final, atol=1e-5)
 
     logger.info("Test time reversibility:")
-    p_reverse = tree_map(lambda x: -x, p_f)
+    p_reverse = jax.tree.map(lambda x: -x, p_f)
     q_i, p_i = get_final_state(model, args.step_size, args.num_steps, q_f, p_reverse)
     for node in args.q_i:
         assert_allclose(q_i[node], args.q_i[node], atol=1e-4)
diff --git a/test/infer/test_mcmc.py b/test/infer/test_mcmc.py
index 480aba80d..6e5d31f4b 100644
--- a/test/infer/test_mcmc.py
+++ b/test/infer/test_mcmc.py
@@ -12,7 +12,6 @@
 from jax import device_get, jit, lax, pmap, random, vmap
 import jax.numpy as jnp
 from jax.scipy.special import logit
-from jax.tree_util import tree_all, tree_map
 
 import numpyro
 import numpyro.distributions as dist
@@ -450,8 +449,8 @@ def model(data):
     mcmc1.run(random.PRNGKey(2), data)
 
     with pytest.raises(AssertionError):
-        tree_all(
-            tree_map(
+        jax.tree.all(
+            jax.tree.map(
                 partial(assert_allclose, atol=1e-4, rtol=1e-4),
                 mcmc1.get_samples(),
                 mcmc.get_samples(),
@@ -459,21 +458,21 @@ def model(data):
         )
     mcmc1.warmup(random.PRNGKey(2), data)
     mcmc1.run(random.PRNGKey(3), data)
-    tree_all(
-        tree_map(
+    jax.tree.all(
+        jax.tree.map(
             partial(assert_allclose, atol=1e-4, rtol=1e-4),
             mcmc1.get_samples(),
             mcmc.get_samples(),
         )
     )
-    tree_all(
-        tree_map(
+    jax.tree.all(
+        jax.tree.map(
             partial(assert_allclose, atol=1e-4, rtol=1e-4),
-            tree_map(
+            jax.tree.map(
                 lambda x: random.key_data(x) if is_prng_key(x) else x,
                 mcmc1.post_warmup_state,
             ),
-            tree_map(
+            jax.tree.map(
                 lambda x: random.key_data(x) if is_prng_key(x) else x,
                 mcmc.post_warmup_state,
             ),
diff --git a/test/infer/test_svi.py b/test/infer/test_svi.py
index 166499063..f52c1cef2 100644
--- a/test/infer/test_svi.py
+++ b/test/infer/test_svi.py
@@ -11,7 +11,6 @@
 from jax import jit, lax, random, value_and_grad
 from jax.example_libraries import optimizers
 import jax.numpy as jnp
-from jax.tree_util import tree_all, tree_map
 
 import numpyro
 from numpyro import optim
@@ -31,7 +30,7 @@
 
 
 def assert_equal(a, b, prec=0):
-    return jax.tree_util.tree_map(lambda a, b: assert_allclose(a, b, atol=prec), a, b)
+    return jax.tree.map(lambda a, b: assert_allclose(a, b, atol=prec), a, b)
 
 
 @pytest.mark.parametrize("alpha", [0.0, 2.0])
@@ -272,7 +271,7 @@ def guide(data):
     expected = svi.get_params(svi.update(svi_state, data)[0])
     actual = svi.get_params(jit(svi.update)(svi_state, data=data)[0])
 
-    tree_all(tree_map(partial(assert_allclose, atol=1e-5), actual, expected))
+    jax.tree.all(jax.tree.map(partial(assert_allclose, atol=1e-5), actual, expected))
 
 
 def test_param():
diff --git a/test/ops/test_provenance.py b/test/ops/test_provenance.py
index a64fcaadc..f40b846c1 100644
--- a/test/ops/test_provenance.py
+++ b/test/ops/test_provenance.py
@@ -72,19 +72,19 @@ def f(x, y):
 
 def test_provenance_call():
     def identity(x):
-        args, in_tree = jax.tree_util.tree_flatten((x,))
+        args, in_tree = jax.tree.flatten((x,))
         fn, out_tree = flatten_fun_nokwargs(lu.wrap_init(lambda x: x), in_tree)
         out = core.closed_call_p.bind(fn, *args)
-        return jax.tree_util.tree_unflatten(out_tree(), out)
+        return jax.tree.unflatten(out_tree(), out)
 
     assert eval_provenance(identity, x={"v": 2}) == {"v": frozenset({"x"})}
 
 
 def test_provenance_closed_call():
     def identity(x):
-        args, in_tree = jax.tree_util.tree_flatten((x,))
+        args, in_tree = jax.tree.flatten((x,))
         fn, out_tree = flatten_fun_nokwargs(lu.wrap_init(lambda x: x), in_tree)
         out = core.closed_call_p.bind(fn, *args)
-        return jax.tree_util.tree_unflatten(out_tree(), out)
+        return jax.tree.unflatten(out_tree(), out)
 
     assert eval_provenance(identity, x={"v": 2}) == {"v": frozenset({"x"})}
diff --git a/test/test_constraints.py b/test/test_constraints.py
index bfb459bdd..acd96732e 100644
--- a/test/test_constraints.py
+++ b/test/test_constraints.py
@@ -5,9 +5,9 @@
 
 import pytest
 
+import jax
 from jax import jit, vmap
 import jax.numpy as jnp
-from jax.tree_util import tree_map
 
 from numpyro.distributions import constraints
 
@@ -134,14 +134,14 @@ def out_cst(constraint, x):
 
     if len(cst_args) > 0:
         # test creating and manipulating vmapped constraints
-        vmapped_cst_args = tree_map(lambda x: x[None], cst_args)
+        vmapped_cst_args = jax.tree.map(lambda x: x[None], cst_args)
 
         vmapped_csts = jit(vmap(lambda args: cls(*args, **cst_kwargs), in_axes=(0,)))(
             vmapped_cst_args
         )
         assert vmap(lambda x: x == constraint, in_axes=0)(vmapped_csts).all()
 
-        twice_vmapped_cst_args = tree_map(lambda x: x[None], vmapped_cst_args)
+        twice_vmapped_cst_args = jax.tree.map(lambda x: x[None], vmapped_cst_args)
 
         vmapped_csts = jit(
             vmap(
diff --git a/test/test_distributions.py b/test/test_distributions.py
index a4dc3528b..7453e3f51 100644
--- a/test/test_distributions.py
+++ b/test/test_distributions.py
@@ -3189,7 +3189,7 @@ def _allclose_or_equal(a1, a2):
 
 
 def _tree_equal(t1, t2):
-    t = jax.tree_util.tree_map(_allclose_or_equal, t1, t2)
+    t = jax.tree.map(_allclose_or_equal, t1, t2)
     return jnp.all(jax.flatten_util.ravel_pytree(t)[0])
 
 
@@ -3216,7 +3216,7 @@ def sample(d: dist.Distribution):
         # In this case, since csr arrays are not jittable,
         # _SparseCAR has a csr_matrix as part of its pytree
         # definition (not as a pytree leaf). This causes pytree
-        # operations like tree_map to fail, since these functions
+        # operations like jax.tree.map to fail, since these functions
         # compare the pytree def of each of the arguments using ==
         # which is ambiguous for array-like objects.
         return
@@ -3261,7 +3261,7 @@ def sample(d: dist.Distribution):
     for in_axes, out_axes in in_out_axes_cases:
         batched_params = [
             (
-                jax.tree_map(lambda x: jnp.expand_dims(x, ax), arg)
+                jax.jax.tree.map(lambda x: jnp.expand_dims(x, ax), arg)
                 if isinstance(ax, int)
                 else arg
             )
diff --git a/test/test_handlers.py b/test/test_handlers.py
index 518f856dc..4ef449237 100644
--- a/test/test_handlers.py
+++ b/test/test_handlers.py
@@ -8,7 +8,6 @@
 import jax
 from jax import jit, random, value_and_grad, vmap
 import jax.numpy as jnp
-from jax.tree_util import tree_map
 
 try:
     import funsor
@@ -441,7 +440,7 @@ def guide(subsample):
                 svi_state.rng_key, svi.constrain_fn(x), svi.model, svi.guide, subsample
             )
         )(params)
-        grads = tree_map(lambda *vals: vals[0] + vals[1], grads1, grads2)
+        grads = jax.tree.map(lambda *vals: vals[0] + vals[1], grads1, grads2)
         loss = loss1 + loss2
     else:
         subsample = jnp.array([0, 1])
diff --git a/test/test_pickle.py b/test/test_pickle.py
index a54479be2..7ed338578 100644
--- a/test/test_pickle.py
+++ b/test/test_pickle.py
@@ -7,9 +7,9 @@
 from numpy.testing import assert_allclose
 import pytest
 
+import jax
 from jax import random
 import jax.numpy as jnp
-from jax.tree_util import tree_all, tree_map
 
 import numpyro
 from numpyro.contrib.funsor import config_kl
@@ -90,7 +90,9 @@ def test_pickle_hmc(kernel):
     mcmc = MCMC(kernel(normal_model), num_warmup=10, num_samples=10)
     mcmc.run(random.PRNGKey(0))
     pickled_mcmc = pickle.loads(pickle.dumps(mcmc))
-    tree_all(tree_map(assert_allclose, mcmc.get_samples(), pickled_mcmc.get_samples()))
+    jax.tree.all(
+        jax.tree.map(assert_allclose, mcmc.get_samples(), pickled_mcmc.get_samples())
+    )
 
 
 @pytest.mark.parametrize("kernel", [BarkerMH, HMC, NUTS, SA])
@@ -108,7 +110,9 @@ def test_pickle_hmc_enumeration(kernel):
     mcmc = MCMC(kernel(gmm), num_warmup=10, num_samples=10)
     mcmc.run(random.PRNGKey(0), data, K)
     pickled_mcmc = pickle.loads(pickle.dumps(mcmc))
-    tree_all(tree_map(assert_allclose, mcmc.get_samples(), pickled_mcmc.get_samples()))
+    jax.tree.all(
+        jax.tree.map(assert_allclose, mcmc.get_samples(), pickled_mcmc.get_samples())
+    )
 
 
 @pytest.mark.parametrize("kernel", [DiscreteHMCGibbs, MixedHMC])
@@ -116,14 +120,18 @@ def test_pickle_discrete_hmc(kernel):
     mcmc = MCMC(kernel(HMC(bernoulli_model)), num_warmup=10, num_samples=10)
     mcmc.run(random.PRNGKey(0))
     pickled_mcmc = pickle.loads(pickle.dumps(mcmc))
-    tree_all(tree_map(assert_allclose, mcmc.get_samples(), pickled_mcmc.get_samples()))
+    jax.tree.all(
+        jax.tree.map(assert_allclose, mcmc.get_samples(), pickled_mcmc.get_samples())
+    )
 
 
 def test_pickle_hmcecs():
     mcmc = MCMC(HMCECS(NUTS(logistic_regression)), num_warmup=10, num_samples=10)
     mcmc.run(random.PRNGKey(0))
     pickled_mcmc = pickle.loads(pickle.dumps(mcmc))
-    tree_all(tree_map(assert_allclose, mcmc.get_samples(), pickled_mcmc.get_samples()))
+    jax.tree.all(
+        jax.tree.map(assert_allclose, mcmc.get_samples(), pickled_mcmc.get_samples())
+    )
 
 
 def poisson_regression(x, N):
@@ -236,4 +244,4 @@ def guide(data):
     svi_result = svi.run(random.PRNGKey(0), 3, data)
     pickled_params = svi_result.params
 
-    tree_all(tree_map(assert_allclose, params, pickled_params))
+    jax.tree.all(jax.tree.map(assert_allclose, params, pickled_params))
diff --git a/test/test_transforms.py b/test/test_transforms.py
index 8f68eaf6e..997959244 100644
--- a/test/test_transforms.py
+++ b/test/test_transforms.py
@@ -8,9 +8,9 @@
 import numpy as np
 import pytest
 
+import jax
 from jax import jacfwd, jit, random, vmap
 import jax.numpy as jnp
-from jax.tree_util import tree_map
 
 from numpyro.distributions import constraints
 from numpyro.distributions.flows import (
@@ -175,14 +175,14 @@ def out_t(transform, x):
         # this test assumes jittable args, and non-jittable kwargs, which is
         # not suited for all transforms, see InverseAutoregressiveTransform.
         # TODO: split among jittable and non-jittable args/kwargs instead.
-        vmapped_transform_args = tree_map(lambda x: x[None], transform_args)
+        vmapped_transform_args = jax.tree.map(lambda x: x[None], transform_args)
 
         vmapped_transform = jit(
             vmap(lambda args: cls(*args, **transform_kwargs), in_axes=(0,))
         )(vmapped_transform_args)
         assert vmap(lambda x: x == transform, in_axes=0)(vmapped_transform).all()
 
-        twice_vmapped_transform_args = tree_map(
+        twice_vmapped_transform_args = jax.tree.map(
             lambda x: x[None], vmapped_transform_args
         )
 
diff --git a/test/test_util.py b/test/test_util.py
index 208749e26..f351b45f3 100644
--- a/test/test_util.py
+++ b/test/test_util.py
@@ -5,10 +5,10 @@
 from numpy.testing import assert_allclose
 import pytest
 
+import jax
 from jax import random
 from jax.flatten_util import ravel_pytree
 import jax.numpy as jnp
-from jax.tree_util import tree_all, tree_flatten, tree_map
 
 import numpyro
 import numpyro.distributions as dist
@@ -44,7 +44,7 @@ def f(x):
     expected_tree = {"i": np.array([[0.0], [2.0]])}
     actual_tree = fori_collect(1, 3, f, a, transform=lambda a: {"i": a["i"]})
 
-    tree_all(tree_map(assert_allclose, actual_tree, expected_tree))
+    jax.tree.all(jax.tree.map(assert_allclose, actual_tree, expected_tree))
 
 
 @pytest.mark.parametrize("progbar", [False, True])
@@ -64,8 +64,8 @@ def f(x):
     )
     expected_tree = {"i": np.array([3, 4])}
     expected_last_state = {"i": np.array(4)}
-    tree_all(tree_map(assert_allclose, init_state, expected_last_state))
-    tree_all(tree_map(assert_allclose, tree, expected_tree))
+    jax.tree.all(jax.tree.map(assert_allclose, init_state, expected_last_state))
+    jax.tree.all(jax.tree.map(assert_allclose, tree, expected_tree))
 
 
 @pytest.mark.parametrize(
@@ -82,10 +82,10 @@ def f(x):
 def test_ravel_pytree(pytree):
     flat, unravel_fn = ravel_pytree(pytree)
     unravel = unravel_fn(flat)
-    tree_flatten(tree_map(lambda x, y: assert_allclose(x, y), unravel, pytree))
+    jax.tree.flatten(jax.tree.map(lambda x, y: assert_allclose(x, y), unravel, pytree))
     assert all(
-        tree_flatten(
-            tree_map(
+        jax.tree.flatten(
+            jax.tree.map(
                 lambda x, y: jnp.result_type(x) == jnp.result_type(y), unravel, pytree
             )
         )[0]