From 5af9ebda72bd7aeb08c61e4248ecd0d982473224 Mon Sep 17 00:00:00 2001 From: Du Phan <fehiepsi@gmail.com> Date: Wed, 26 Jun 2024 05:10:07 -0400 Subject: [PATCH] Update jax.tree_util.tree_map to jax.tree.map (#1821) * update jax.tree_util.tree_foo to jax.tree.foo * bump minimal jax version to 0.4.25, which supports jax.tree * fix lint issues * also fix deprecation warning of using a_min, a_max in jnp.clip --- examples/annotation.py | 2 +- examples/gp.py | 2 +- .../source/time_series_forecasting.ipynb | 4 +- numpyro/contrib/control_flow/scan.py | 40 ++++++++-------- .../einstein/mixture_guide_predictive.py | 6 +-- numpyro/contrib/einstein/stein_util.py | 10 ++-- numpyro/contrib/einstein/steinvi.py | 10 ++-- numpyro/contrib/module.py | 11 +++-- numpyro/contrib/tfp/distributions.py | 4 +- numpyro/contrib/tfp/mcmc.py | 4 +- numpyro/diagnostics.py | 12 ++--- numpyro/distributions/batch_util.py | 4 +- numpyro/distributions/continuous.py | 12 ++--- numpyro/distributions/directional.py | 2 +- numpyro/distributions/discrete.py | 4 +- numpyro/distributions/distribution.py | 3 +- numpyro/distributions/flows.py | 2 +- numpyro/distributions/transforms.py | 22 ++++----- numpyro/distributions/truncated.py | 10 ++-- numpyro/distributions/util.py | 6 +-- numpyro/infer/autoguide.py | 17 ++++--- numpyro/infer/barker.py | 2 +- numpyro/infer/ensemble.py | 2 +- numpyro/infer/ensemble_util.py | 3 +- numpyro/infer/hmc.py | 2 +- numpyro/infer/hmc_gibbs.py | 2 +- numpyro/infer/hmc_util.py | 26 +++++------ numpyro/infer/mcmc.py | 32 ++++++------- numpyro/infer/mixed_hmc.py | 2 +- numpyro/infer/svi.py | 3 +- numpyro/infer/util.py | 5 +- numpyro/ops/provenance.py | 14 +++--- numpyro/optim.py | 7 ++- numpyro/util.py | 13 +++--- setup.py | 4 +- test/contrib/einstein/test_steinvi_util.py | 8 ++-- test/contrib/test_enum_elbo.py | 46 +++++++++---------- test/contrib/test_module.py | 6 +-- test/infer/test_autoguide.py | 8 ++-- test/infer/test_ensemble_util.py | 2 +- test/infer/test_gradient.py | 16 +++---- test/infer/test_hmc_util.py | 4 +- test/infer/test_mcmc.py | 17 ++++--- test/infer/test_svi.py | 5 +- test/ops/test_provenance.py | 8 ++-- test/test_constraints.py | 6 +-- test/test_distributions.py | 6 +-- test/test_handlers.py | 3 +- test/test_pickle.py | 20 +++++--- test/test_transforms.py | 6 +-- test/test_util.py | 14 +++--- 51 files changed, 236 insertions(+), 243 deletions(-) diff --git a/examples/annotation.py b/examples/annotation.py index 881825316..3341dbe72 100644 --- a/examples/annotation.py +++ b/examples/annotation.py @@ -309,7 +309,7 @@ def main(args): # is stored in `discrete_samples`. To merge those discrete samples into the `mcmc` # instance, we can use the following pattern:: # -# chain_discrete_samples = jax.tree_util.tree_map( +# chain_discrete_samples = jax.tree.map( # lambda x: x.reshape((args.num_chains, args.num_samples) + x.shape[1:]), # discrete_samples) # mcmc.get_samples().update(discrete_samples) diff --git a/examples/gp.py b/examples/gp.py index aac9632f8..0f70400db 100644 --- a/examples/gp.py +++ b/examples/gp.py @@ -116,7 +116,7 @@ def predict(rng_key, X, Y, X_test, var, length, noise, use_cholesky=True): K = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX))) mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, Y)) - sigma_noise = jnp.sqrt(jnp.clip(jnp.diag(K), a_min=0.0)) * jax.random.normal( + sigma_noise = jnp.sqrt(jnp.clip(jnp.diag(K), 0.0)) * jax.random.normal( rng_key, X_test.shape[:1] ) diff --git a/notebooks/source/time_series_forecasting.ipynb b/notebooks/source/time_series_forecasting.ipynb index 3ff5667d6..40f467ee0 100644 --- a/notebooks/source/time_series_forecasting.ipynb +++ b/notebooks/source/time_series_forecasting.ipynb @@ -206,7 +206,7 @@ " level, s, moving_sum = carry\n", " season = s[0] * level**pow_season\n", " exp_val = level + coef_trend * level**pow_trend + season\n", - " exp_val = jnp.clip(exp_val, a_min=0)\n", + " exp_val = jnp.clip(exp_val, 0)\n", " # use expected vale when forecasting\n", " y_t = jnp.where(t >= N, exp_val, y[t])\n", "\n", @@ -215,7 +215,7 @@ " )\n", " level_p = jnp.where(t >= seasonality, moving_sum / seasonality, y_t - season)\n", " level = level_sm * level_p + (1 - level_sm) * level\n", - " level = jnp.clip(level, a_min=0)\n", + " level = jnp.clip(level, 0)\n", "\n", " new_s = (s_sm * (y_t - level) / season + (1 - s_sm)) * s[0]\n", " # repeat s when forecasting\n", diff --git a/numpyro/contrib/control_flow/scan.py b/numpyro/contrib/control_flow/scan.py index 6b657b494..7a2689cf1 100644 --- a/numpyro/contrib/control_flow/scan.py +++ b/numpyro/contrib/control_flow/scan.py @@ -4,9 +4,9 @@ from collections import OrderedDict from functools import partial +import jax from jax import device_put, lax, random import jax.numpy as jnp -from jax.tree_util import tree_flatten, tree_map, tree_unflatten from numpyro import handlers from numpyro.distributions.batch_util import promote_batch_shape @@ -98,7 +98,7 @@ def postprocess_message(self, msg): fn_batch_ndim = len(fn.batch_shape) if fn_batch_ndim < value_batch_ndims: prepend_shapes = (1,) * (value_batch_ndims - fn_batch_ndim) - msg["fn"] = tree_map( + msg["fn"] = jax.tree.map( lambda x: jnp.reshape(x, prepend_shapes + jnp.shape(x)), fn ) @@ -140,11 +140,11 @@ def scan_enum( history = min(history, length) unroll_steps = min(2 * history - 1, length) if reverse: - x0 = tree_map(lambda x: x[-unroll_steps:][::-1], xs) - xs_ = tree_map(lambda x: x[:-unroll_steps], xs) + x0 = jax.tree.map(lambda x: x[-unroll_steps:][::-1], xs) + xs_ = jax.tree.map(lambda x: x[:-unroll_steps], xs) else: - x0 = tree_map(lambda x: x[:unroll_steps], xs) - xs_ = tree_map(lambda x: x[unroll_steps:], xs) + x0 = jax.tree.map(lambda x: x[:unroll_steps], xs) + xs_ = jax.tree.map(lambda x: x[unroll_steps:], xs) carry_shapes = [] @@ -187,10 +187,12 @@ def body_fn(wrapped_carry, x, prefix=None): # store shape of new_carry at a global variable if len(carry_shapes) < (history + 1): - carry_shapes.append([jnp.shape(x) for x in tree_flatten(new_carry)[0]]) + carry_shapes.append( + [jnp.shape(x) for x in jax.tree.flatten(new_carry)[0]] + ) # make new_carry have the same shape as carry # FIXME: is this rigorous? - new_carry = tree_map( + new_carry = jax.tree.map( lambda a, b: jnp.reshape(a, jnp.shape(b)), new_carry, carry ) return (i + 1, rng_key, new_carry), (PytreeTrace(trace), y) @@ -204,11 +206,11 @@ def body_fn(wrapped_carry, x, prefix=None): for i in markov(range(unroll_steps + 1), history=history): if i < unroll_steps: wrapped_carry, (_, y0) = body_fn( - wrapped_carry, tree_map(lambda z: z[i], x0) + wrapped_carry, jax.tree.map(lambda z: z[i], x0) ) if i > 0: # reshape y1, y2,... to have the same shape as y0 - y0 = tree_map( + y0 = jax.tree.map( lambda z0, z: jnp.reshape(z, jnp.shape(z0)), y0s[0], y0 ) y0s.append(y0) @@ -216,15 +218,15 @@ def body_fn(wrapped_carry, x, prefix=None): # shape so we don't need to record them here if (i >= history - 1) and (len(carry_shapes) < history + 1): carry_shapes.append( - jnp.shape(x) for x in tree_flatten(wrapped_carry[-1])[0] + jnp.shape(x) for x in jax.tree.flatten(wrapped_carry[-1])[0] ) else: # this is the last rolling step - y0s = tree_map(lambda *z: jnp.stack(z, axis=0), *y0s) + y0s = jax.tree.map(lambda *z: jnp.stack(z, axis=0), *y0s) # return early if length = unroll_steps if length == unroll_steps: return wrapped_carry, (PytreeTrace({}), y0s) - wrapped_carry = tree_map(device_put, wrapped_carry) + wrapped_carry = jax.tree.map(device_put, wrapped_carry) wrapped_carry, (pytree_trace, ys) = lax.scan( body_fn, wrapped_carry, xs_, length - unroll_steps, reverse ) @@ -251,20 +253,20 @@ def body_fn(wrapped_carry, x, prefix=None): site["infer"]["dim_to_name"][time_dim] = "_time_{}".format(first_var) # similar to carry, we need to reshape due to shape alternating in markov - ys = tree_map( + ys = jax.tree.map( lambda z0, z: jnp.reshape(z, z.shape[:1] + jnp.shape(z0)[1:]), y0s, ys ) # then join with y0s - ys = tree_map(lambda z0, z: jnp.concatenate([z0, z], axis=0), y0s, ys) + ys = jax.tree.map(lambda z0, z: jnp.concatenate([z0, z], axis=0), y0s, ys) # we also need to reshape `carry` to match sequential behavior i = (length + 1) % (history + 1) t, rng_key, carry = wrapped_carry carry_shape = carry_shapes[i] - flatten_carry, treedef = tree_flatten(carry) + flatten_carry, treedef = jax.tree.flatten(carry) flatten_carry = [ jnp.reshape(x, t1_shape) for x, t1_shape in zip(flatten_carry, carry_shape) ] - carry = tree_unflatten(treedef, flatten_carry) + carry = jax.tree.unflatten(treedef, flatten_carry) wrapped_carry = (t, rng_key, carry) return wrapped_carry, (pytree_trace, ys) @@ -282,7 +284,7 @@ def scan_wrapper( first_available_dim=None, ): if length is None: - length = jnp.shape(tree_flatten(xs)[0][0])[0] + length = jnp.shape(jax.tree.flatten(xs)[0][0])[0] if enum and history > 0: return scan_enum( # TODO: replay for enum @@ -324,7 +326,7 @@ def body_fn(wrapped_carry, x): return (i + 1, rng_key, carry), (PytreeTrace(trace), y) - wrapped_carry = tree_map(device_put, (0, rng_key, init)) + wrapped_carry = jax.tree.map(device_put, (0, rng_key, init)) last_carry, (pytree_trace, ys) = lax.scan( body_fn, wrapped_carry, xs, length=length, reverse=reverse ) diff --git a/numpyro/contrib/einstein/mixture_guide_predictive.py b/numpyro/contrib/einstein/mixture_guide_predictive.py index c3b2fdc5b..2a2a8ed51 100644 --- a/numpyro/contrib/einstein/mixture_guide_predictive.py +++ b/numpyro/contrib/einstein/mixture_guide_predictive.py @@ -5,8 +5,8 @@ from functools import partial from typing import Optional +import jax from jax import numpy as jnp, random, vmap -from jax.tree_util import tree_flatten, tree_map from numpyro.handlers import substitute from numpyro.infer import Predictive @@ -63,7 +63,7 @@ def __init__( self.guide = guide self.return_sites = return_sites - self.num_mixture_components = jnp.shape(tree_flatten(params)[0][0])[0] + self.num_mixture_components = jnp.shape(jax.tree.flatten(params)[0][0])[0] self.mixture_assignment_sitename = mixture_assignment_sitename def _call_with_params(self, rng_key, params, args, kwargs): @@ -99,7 +99,7 @@ def __call__(self, rng_key, *args, **kwargs): minval=0, maxval=self.num_mixture_components, ) - predictive_assign = tree_map( + predictive_assign = jax.tree.map( lambda arr: vmap(lambda i, assign: arr[i, assign])( jnp.arange(self._batch_shape[0]), assigns ), diff --git a/numpyro/contrib/einstein/stein_util.py b/numpyro/contrib/einstein/stein_util.py index e8cb80372..741e7c816 100644 --- a/numpyro/contrib/einstein/stein_util.py +++ b/numpyro/contrib/einstein/stein_util.py @@ -1,9 +1,9 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import jax from jax import numpy as jnp, vmap from jax.flatten_util import ravel_pytree -from jax.tree_util import tree_map from numpyro.distributions import biject_to from numpyro.distributions.constraints import real @@ -64,14 +64,14 @@ def batch_ravel_pytree(pytree, nbatch_dims=0): flat, unravel_fn = ravel_pytree(pytree) return flat, unravel_fn, unravel_fn - shapes = tree_map(lambda x: x.shape, pytree) - flat_pytree = tree_map(lambda x: x.reshape(*x.shape[:-nbatch_dims], -1), pytree) + shapes = jax.tree.map(lambda x: x.shape, pytree) + flat_pytree = jax.tree.map(lambda x: x.reshape(*x.shape[:-nbatch_dims], -1), pytree) flat = vmap(lambda x: ravel_pytree(x)[0])(flat_pytree) - unravel_fn = ravel_pytree(tree_map(lambda x: x[0], flat_pytree))[1] + unravel_fn = ravel_pytree(jax.tree.map(lambda x: x[0], flat_pytree))[1] return ( flat, unravel_fn, - lambda _flat: tree_map( + lambda _flat: jax.tree.map( lambda x, shape: x.reshape(shape), vmap(unravel_fn)(_flat), shapes ), ) diff --git a/numpyro/contrib/einstein/steinvi.py b/numpyro/contrib/einstein/steinvi.py index 7f2b47f70..98c055db1 100644 --- a/numpyro/contrib/einstein/steinvi.py +++ b/numpyro/contrib/einstein/steinvi.py @@ -9,9 +9,9 @@ from itertools import chain import operator +import jax from jax import grad, jacfwd, numpy as jnp, random, vmap from jax.flatten_util import ravel_pytree -from jax.tree_util import tree_map from numpyro import handlers from numpyro.contrib.einstein.stein_kernels import SteinKernel @@ -340,10 +340,10 @@ def _update_force(attr_force, rep_force, jac): return force.reshape(attr_force.shape) reparam_jac = { - name: tree_map(lambda var: _nontrivial_jac(name, var), variables) + name: jax.tree.map(lambda var: _nontrivial_jac(name, var), variables) for name, variables in unravel_pytree(particle).items() } - jac_params = tree_map( + jac_params = jax.tree.map( _update_force, unravel_pytree(attr_forces), unravel_pytree(rep_forces), @@ -363,7 +363,7 @@ def _update_force(attr_force, rep_force, jac): stein_param_grads = unravel_pytree_batched(particle_grads) # 6. Return loss and gradients (based on parameter forces) - res_grads = tree_map( + res_grads = jax.tree.map( lambda x: -x, {**non_mixture_param_grads, **stein_param_grads} ) return jnp.linalg.norm(particle_grads), res_grads @@ -427,7 +427,7 @@ def init(self, rng_key, *args, **kwargs): if site["name"] in guide_init_params: pval = guide_init_params[site["name"]] if self.non_mixture_params_fn(site["name"]): - pval = tree_map(lambda x: x[0], pval) + pval = jax.tree.map(lambda x: x[0], pval) else: pval = site["value"] params[site["name"]] = transform.inv(pval) diff --git a/numpyro/contrib/module.py b/numpyro/contrib/module.py index f370b4330..3f40363e6 100644 --- a/numpyro/contrib/module.py +++ b/numpyro/contrib/module.py @@ -5,9 +5,10 @@ from copy import deepcopy from functools import partial +import jax from jax import random import jax.numpy as jnp -from jax.tree_util import register_pytree_node, tree_flatten, tree_unflatten +from jax.tree_util import register_pytree_node import numpyro import numpyro.distributions as dist @@ -106,8 +107,8 @@ def flax_module( assert set(mutable) == set(nn_state) numpyro_mutable(name + "$state", nn_state) # make sure that nn_params keep the same order after unflatten - params_flat, tree_def = tree_flatten(nn_params) - nn_params = tree_unflatten(tree_def, params_flat) + params_flat, tree_def = jax.tree.flatten(nn_params) + nn_params = jax.tree.unflatten(tree_def, params_flat) numpyro.param(module_key, nn_params) def apply_with_state(params, *args, **kwargs): @@ -195,8 +196,8 @@ def haiku_module(name, nn_module, *args, input_shape=None, apply_rng=False, **kw nn_params = hk.data_structures.to_mutable_dict(nn_params) # we cast it to a mutable one to be able to set priors for parameters # make sure that nn_params keep the same order after unflatten - params_flat, tree_def = tree_flatten(nn_params) - nn_params = tree_unflatten(tree_def, params_flat) + params_flat, tree_def = jax.tree.flatten(nn_params) + nn_params = jax.tree.unflatten(tree_def, params_flat) numpyro.param(module_key, nn_params) def apply_with_state(params, *args, **kwargs): diff --git a/numpyro/contrib/tfp/distributions.py b/numpyro/contrib/tfp/distributions.py index 4db875cfe..4c3a76009 100644 --- a/numpyro/contrib/tfp/distributions.py +++ b/numpyro/contrib/tfp/distributions.py @@ -282,11 +282,11 @@ def is_discrete(self): return self.support is None def tree_flatten(self): - return jax.tree_util.tree_flatten(self.tfp_dist) + return jax.tree.flatten(self.tfp_dist) @classmethod def tree_unflatten(cls, aux_data, params): - fn = jax.tree_util.tree_unflatten(aux_data, params) + fn = jax.tree.unflatten(aux_data, params) with warnings.catch_warnings(): warnings.simplefilter("ignore", category=FutureWarning) return TFPDistribution[fn.__class__](**fn.parameters) diff --git a/numpyro/contrib/tfp/mcmc.py b/numpyro/contrib/tfp/mcmc.py index b660af837..7a2312e70 100644 --- a/numpyro/contrib/tfp/mcmc.py +++ b/numpyro/contrib/tfp/mcmc.py @@ -5,10 +5,10 @@ from collections import namedtuple import inspect +import jax from jax import random, vmap from jax.flatten_util import ravel_pytree import jax.numpy as jnp -from jax.tree_util import tree_map import tensorflow_probability.substrates.jax as tfp from numpyro.infer import init_to_uniform @@ -44,7 +44,7 @@ def log_prob_fn(x): flatten_result = vmap(lambda a: -potential_fn(unravel_fn(a)))( jnp.reshape(x, (-1,) + jnp.shape(x)[-1:]) ) - return tree_map( + return jax.tree.map( lambda a: jnp.reshape(a, batch_shape + jnp.shape(a)[1:]), flatten_result ) else: diff --git a/numpyro/diagnostics.py b/numpyro/diagnostics.py index 99aa9761c..cf3d04a9e 100644 --- a/numpyro/diagnostics.py +++ b/numpyro/diagnostics.py @@ -10,8 +10,8 @@ import numpy as np +import jax from jax import device_get -from jax.tree_util import tree_flatten, tree_map __all__ = [ "autocorrelation", @@ -182,7 +182,7 @@ def effective_sample_size(x): Rho_k = np.concatenate( [ Rho_init, - np.minimum.accumulate(np.clip(Rho_k[1:, ...], a_min=0, a_max=None), axis=0), + np.minimum.accumulate(np.clip(Rho_k[1:, ...], 0, None), axis=0), ], axis=0, ) @@ -238,10 +238,10 @@ def summary(samples, prob=0.90, group_by_chain=True): chain dimension). """ if not group_by_chain: - samples = tree_map(lambda x: x[None, ...], samples) + samples = jax.tree.map(lambda x: x[None, ...], samples) if not isinstance(samples, dict): samples = { - "Param:{}".format(i): v for i, v in enumerate(tree_flatten(samples)[0]) + "Param:{}".format(i): v for i, v in enumerate(jax.tree.flatten(samples)[0]) } summary_dict = {} @@ -288,10 +288,10 @@ def print_summary(samples, prob=0.90, group_by_chain=True): chain dimension). """ if not group_by_chain: - samples = tree_map(lambda x: x[None, ...], samples) + samples = jax.tree.map(lambda x: x[None, ...], samples) if not isinstance(samples, dict): samples = { - "Param:{}".format(i): v for i, v in enumerate(tree_flatten(samples)[0]) + "Param:{}".format(i): v for i, v in enumerate(jax.tree.flatten(samples)[0]) } summary_dict = summary(samples, prob, group_by_chain=True) diff --git a/numpyro/distributions/batch_util.py b/numpyro/distributions/batch_util.py index 83698b00e..235127335 100644 --- a/numpyro/distributions/batch_util.py +++ b/numpyro/distributions/batch_util.py @@ -5,8 +5,8 @@ from functools import singledispatch from typing import Union +import jax import jax.numpy as jnp -from jax.tree_util import tree_map from numpyro.distributions import constraints from numpyro.distributions.conjugate import ( @@ -547,7 +547,7 @@ def _promote_batch_shape_expanded(d: ExpandedDistribution): len(new_shapes_elems), len(new_shapes_elems) + len(orig_delta_batch_shape), ) - new_base_dist = tree_map( + new_base_dist = jax.tree.map( lambda x: jnp.expand_dims(x, axis=new_axes_locs), new_self.base_dist ) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 237cb633a..652941746 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -281,9 +281,7 @@ def sample(self, key, sample_shape=()): assert is_prng_key(key) shape = sample_shape + self.batch_shape samples = random.dirichlet(key, self.concentration, shape=shape) - return jnp.clip( - samples, a_min=jnp.finfo(samples).tiny, a_max=1 - jnp.finfo(samples).eps - ) + return jnp.clip(samples, jnp.finfo(samples).tiny, 1 - jnp.finfo(samples).eps) @validate_sample def log_prob(self, value): @@ -840,15 +838,15 @@ def sample(self, key, sample_shape=()): u = random.uniform( key, shape=sample_shape + self.batch_shape, minval=finfo.tiny ) - u_con0 = jnp.clip(u ** (1 / self.concentration0), a_max=1 - finfo.eps) + u_con0 = jnp.clip(u ** (1 / self.concentration0), None, 1 - finfo.eps) log_sample = jnp.log1p(-u_con0) / self.concentration1 - return jnp.clip(jnp.exp(log_sample), a_min=finfo.tiny, a_max=1 - finfo.eps) + return jnp.clip(jnp.exp(log_sample), finfo.tiny, 1 - finfo.eps) @validate_sample def log_prob(self, value): finfo = jnp.finfo(jnp.result_type(float)) normalize_term = jnp.log(self.concentration0) + jnp.log(self.concentration1) - value_con1 = jnp.clip(value**self.concentration1, a_max=1 - finfo.eps) + value_con1 = jnp.clip(value**self.concentration1, None, 1 - finfo.eps) return ( xlogy(self.concentration1 - 1, value) + xlog1py(self.concentration0 - 1, -value_con1) @@ -2363,7 +2361,7 @@ def log_prob(self, value): def cdf(self, value): cdf = (value - self.low) / (self.high - self.low) - return jnp.clip(cdf, a_min=0.0, a_max=1.0) + return jnp.clip(cdf, 0.0, 1.0) def icdf(self, value): return self.low + value * (self.high - self.low) diff --git a/numpyro/distributions/directional.py b/numpyro/distributions/directional.py index fd5b1596c..8156855b1 100644 --- a/numpyro/distributions/directional.py +++ b/numpyro/distributions/directional.py @@ -401,7 +401,7 @@ def norm_const(self): lbinoms = num - 2 * den fs = lbinoms.reshape(-1, 1) + m * ( - jnp.log(jnp.clip(corr**2, a_min=jnp.finfo(jnp.result_type(float)).tiny)) + jnp.log(jnp.clip(corr**2, jnp.finfo(jnp.result_type(float)).tiny)) - jnp.log(4 * jnp.prod(conc, axis=-1)) ) fs += log_I1(49, conc, terms=51).sum(-1) diff --git a/numpyro/distributions/discrete.py b/numpyro/distributions/discrete.py index 0ee140620..7d7358a5d 100644 --- a/numpyro/distributions/discrete.py +++ b/numpyro/distributions/discrete.py @@ -65,7 +65,7 @@ def _to_probs_multinom(logits): def _to_logits_multinom(probs): minval = jnp.finfo(jnp.result_type(probs)).min - return jnp.clip(jnp.log(probs), a_min=minval) + return jnp.clip(jnp.log(probs), minval) class BernoulliProbs(Distribution): @@ -443,7 +443,7 @@ def log_prob(self, value): def cdf(self, value): cdf = (jnp.floor(value) + 1 - self.low) / (self.high - self.low + 1) - return jnp.clip(cdf, a_min=0.0, a_max=1.0) + return jnp.clip(cdf, 0.0, 1.0) def icdf(self, value): return self.low + value * (self.high - self.low + 1) - 1 diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index cba2994b8..04b213a3e 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -33,6 +33,7 @@ import numpy as np +import jax from jax import lax, tree_util import jax.numpy as jnp from jax.scipy.special import logsumexp @@ -636,7 +637,7 @@ def reshape_sample(x): event_shape = jnp.shape(x)[batch_ndims:] return x.reshape(sample_shape + self.batch_shape + event_shape) - intermediates = tree_util.tree_map(reshape_sample, intermediates) + intermediates = jax.tree.map(reshape_sample, intermediates) samples = reshape_sample(samples) return samples, intermediates diff --git a/numpyro/distributions/flows.py b/numpyro/distributions/flows.py index cd9b21c35..2980587b4 100644 --- a/numpyro/distributions/flows.py +++ b/numpyro/distributions/flows.py @@ -10,7 +10,7 @@ def _clamp_preserve_gradients(x, min, max): - return x + lax.stop_gradient(jnp.clip(x, a_min=min, a_max=max) - x) + return x + lax.stop_gradient(jnp.clip(x, min, max) - x) # adapted from https://github.com/pyro-ppl/pyro/blob/dev/pyro/distributions/transforms/iaf.py diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 564c628d8..927c2b096 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -6,15 +6,15 @@ import weakref import numpy as np -from numpy.core.numeric import normalize_axis_tuple +import jax from jax import lax, vmap from jax.flatten_util import ravel_pytree from jax.nn import log_sigmoid, softplus import jax.numpy as jnp from jax.scipy.linalg import solve_triangular from jax.scipy.special import expit, logit -from jax.tree_util import register_pytree_node, tree_flatten, tree_map +from jax.tree_util import register_pytree_node from numpyro.distributions import constraints from numpyro.distributions.util import ( @@ -57,7 +57,7 @@ def _clipped_expit(x): finfo = jnp.finfo(jnp.result_type(x)) - return jnp.clip(expit(x), a_min=finfo.tiny, a_max=1.0 - finfo.eps) + return jnp.clip(expit(x), finfo.tiny, 1.0 - finfo.eps) class Transform(object): @@ -650,11 +650,11 @@ def _inverse(self, y): pad_width = [(0, 0)] * (y.ndim - 1) + [(1, 0)] remainder = jnp.pad(remainder, pad_width, mode="constant", constant_values=1.0) finfo = jnp.finfo(y.dtype) - remainder = jnp.clip(remainder, a_min=finfo.tiny) + remainder = jnp.clip(remainder, finfo.tiny) t = y / remainder # inverse of tanh - t = jnp.clip(t, a_min=-1 + finfo.eps, a_max=1 - finfo.eps) + t = jnp.clip(t, -1 + finfo.eps, 1 - finfo.eps) return jnp.arctanh(t) def log_abs_det_jacobian(self, x, y, intermediates=None): @@ -666,7 +666,7 @@ def log_abs_det_jacobian(self, x, y, intermediates=None): # of the diagonal part of the jacobian one_minus_remainder = jnp.cumsum(jnp.abs(y[..., :-1]), axis=-1) eps = jnp.finfo(y.dtype).eps - one_minus_remainder = jnp.clip(one_minus_remainder, a_max=1 - eps) + one_minus_remainder = jnp.clip(one_minus_remainder, None, 1 - eps) # log(remainder) = log1p(remainder - 1) stick_breaking_logdet = jnp.sum(jnp.log1p(-one_minus_remainder), axis=-1) @@ -1074,9 +1074,7 @@ def __call__(self, x): def _inverse(self, y): y_crop = y[..., :-1] - z1m_cumprod = jnp.clip( - 1 - jnp.cumsum(y_crop, axis=-1), a_min=jnp.finfo(y.dtype).tiny - ) + z1m_cumprod = jnp.clip(1 - jnp.cumsum(y_crop, axis=-1), jnp.finfo(y.dtype).tiny) # hence x = logit(z) = log(z / (1 - z)) = y[::-1] / z1m_cumprod x = jnp.log(y_crop / z1m_cumprod) return x + jnp.log(x.shape[-1] - jnp.arange(x.shape[-1])) @@ -1116,7 +1114,7 @@ def __call__(self, x): batch_shape = x.shape[:-1] if batch_shape: unpacked = vmap(self.unpack_fn)(x.reshape((-1,) + x.shape[-1:])) - return tree_map( + return jax.tree.map( lambda z: jnp.reshape(z, batch_shape + z.shape[1:]), unpacked ) else: @@ -1124,7 +1122,7 @@ def __call__(self, x): def _inverse(self, y): leading_dims = [ - v.shape[0] if jnp.ndim(v) > 0 else 0 for v in tree_flatten(y)[0] + v.shape[0] if jnp.ndim(v) > 0 else 0 for v in jax.tree.flatten(y)[0] ] d0 = leading_dims[0] not_scalar = d0 > 0 or len(leading_dims) > 1 @@ -1417,7 +1415,7 @@ def _inverse(self, y: jnp.ndarray) -> jnp.ndarray: return y def extend_axis_rev(self, array: jnp.ndarray, axis: int) -> jnp.ndarray: - normalized_axis = normalize_axis_tuple(axis, array.ndim)[0] + normalized_axis = axis if axis >= 0 else jnp.ndim(array) + axis n = array.shape[normalized_axis] last = jnp.take(array, jnp.array([-1]), axis=normalized_axis) diff --git a/numpyro/distributions/truncated.py b/numpyro/distributions/truncated.py index 078ea83bc..236e2a000 100644 --- a/numpyro/distributions/truncated.py +++ b/numpyro/distributions/truncated.py @@ -1,11 +1,11 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import jax from jax import lax import jax.numpy as jnp import jax.random as random from jax.scipy.special import logsumexp -from jax.tree_util import tree_map from numpyro.distributions import constraints from numpyro.distributions.continuous import ( @@ -38,7 +38,7 @@ def __init__(self, base_dist, low=0.0, *, validate_args=None): base_dist.support is constraints.real ), "The base distribution should be univariate and have real support." batch_shape = lax.broadcast_shapes(base_dist.batch_shape, jnp.shape(low)) - self.base_dist = tree_map( + self.base_dist = jax.tree.map( lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist ) (self.low,) = promote_shapes(low, shape=batch_shape) @@ -117,7 +117,7 @@ def __init__(self, base_dist, high=0.0, *, validate_args=None): base_dist.support is constraints.real ), "The base distribution should be univariate and have real support." batch_shape = lax.broadcast_shapes(base_dist.batch_shape, jnp.shape(high)) - self.base_dist = tree_map( + self.base_dist = jax.tree.map( lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist ) (self.high,) = promote_shapes(high, shape=batch_shape) @@ -186,7 +186,7 @@ def __init__(self, base_dist, low=0.0, high=1.0, *, validate_args=None): batch_shape = lax.broadcast_shapes( base_dist.batch_shape, jnp.shape(low), jnp.shape(high) ) - self.base_dist = tree_map( + self.base_dist = jax.tree.map( lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist ) (self.low,) = promote_shapes(low, shape=batch_shape) @@ -348,7 +348,7 @@ def sample(self, key, sample_shape=()): key, jnp.ones(self.batch_shape + sample_shape + (self.num_gamma_variates,)) ) x = jnp.sum(x / denom, axis=-1) - return jnp.clip(x * (0.5 / jnp.pi**2), a_max=self.truncation_point) + return jnp.clip(x * (0.5 / jnp.pi**2), None, self.truncation_point) @validate_sample def log_prob(self, value): diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index aca32b1f7..c83efb701 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -386,7 +386,7 @@ def scan_fn(carry, val): def signed_stick_breaking_tril(t): # make sure that t in (-1, 1) eps = jnp.finfo(t.dtype).eps - t = jnp.clip(t, a_min=(-1 + eps), a_max=(1 - eps)) + t = jnp.clip(t, -1 + eps, 1 - eps) # transform t to tril matrix with identity diagonal r = vec_to_tril_matrix(t, diagonal=-1) @@ -417,7 +417,7 @@ def logmatmulexp(x, y): def clamp_probs(probs): finfo = jnp.finfo(jnp.result_type(probs, float)) - return jnp.clip(probs, a_min=finfo.tiny, a_max=1.0 - finfo.eps) + return jnp.clip(probs, finfo.tiny, 1.0 - finfo.eps) def betainc(a, b, x): @@ -607,7 +607,7 @@ def safe_normalize(x, *, p=2): assert isinstance(p, (float, int)) assert p >= 0 norm = jnp.linalg.norm(x, p, axis=-1, keepdims=True) - x = x / jnp.clip(norm, a_min=jnp.finfo(x).tiny) + x = x / jnp.clip(norm, jnp.finfo(x).tiny) # Avoid the singularity. mask = jnp.all(x == 0, axis=-1, keepdims=True) x = jnp.where(mask, x.shape[-1] ** (-1 / p), x) diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index 351cb6390..624865a75 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -14,7 +14,6 @@ from jax import grad, hessian, lax, random from jax.example_libraries import stax import jax.numpy as jnp -from jax.tree_util import tree_map import numpyro from numpyro import handlers @@ -454,12 +453,12 @@ def _constrain(self, latent_samples): : jnp.ndim(latent_samples[name]) - jnp.ndim(self._init_locs[name]) ] if sample_shape: - flatten_samples = tree_map( + flatten_samples = jax.tree.map( lambda x: jnp.reshape(x, (-1,) + jnp.shape(x)[len(sample_shape) :]), latent_samples, ) contrained_samples = lax.map(self._postprocess_fn, flatten_samples) - return tree_map( + return jax.tree.map( lambda x: jnp.reshape(x, sample_shape + jnp.shape(x)[1:]), contrained_samples, ) @@ -751,7 +750,7 @@ def unpack_single_latent(latent): latent_sample, (-1, jnp.shape(latent_sample)[-1]) ) unpacked_samples = lax.map(unpack_single_latent, latent_sample) - return tree_map( + return jax.tree.map( lambda x: jnp.reshape(x, sample_shape + jnp.shape(x)[1:]), unpacked_samples, ) @@ -968,7 +967,7 @@ def log_density(x): def scan_body(carry, eps_beta): eps, beta = eps_beta eta = eta0 + eta_coeff * beta - eta = jnp.clip(eta, a_min=0.0, a_max=self.eta_max) + eta = jnp.clip(eta, 0.0, self.eta_max) z_prev, v_prev, log_factor = carry z_half = z_prev + v_prev * eta * inv_mass_matrix q_grad = (1.0 - beta) * grad(base_z_dist.log_prob)(z_half) @@ -997,7 +996,7 @@ def _single_sample(_rng_key): if sample_shape: rng_key = random.split(rng_key, int(np.prod(sample_shape))) samples = lax.map(_single_sample, rng_key) - return tree_map( + return jax.tree.map( lambda x: jnp.reshape(x, sample_shape + jnp.shape(x)[1:]), samples, ) @@ -1187,7 +1186,7 @@ def blocked_surrogate_model(x): def scan_body(carry, eps_beta): eps, beta = eps_beta eta = eta0 + eta_coeff * beta - eta = jnp.clip(eta, a_min=0.0, a_max=self.eta_max) + eta = jnp.clip(eta, 0.0, self.eta_max) z_prev, v_prev, log_factor = carry z_half = z_prev + v_prev * eta * inv_mass_matrix q_grad = (1.0 - beta) * grad(base_z_dist_log_prob)(z_half) @@ -1642,7 +1641,7 @@ def base_z_dist_log_prob(x): def scan_body(carry, eps_beta): eps, beta = eps_beta eta = eta0 + eta_coeff * beta - eta = jnp.clip(eta, a_min=0.0, a_max=self.eta_max) + eta = jnp.clip(eta, 0.0, self.eta_max) assert eps.shape == (subsample_size, D) assert eta.shape == beta.shape == (subsample_size,) z_prev, v_prev, log_factor = carry @@ -1697,7 +1696,7 @@ def _single_sample(_rng_key): if sample_shape: rng_key = random.split(rng_key, int(np.prod(sample_shape))) samples = lax.map(_single_sample, rng_key) - return tree_map( + return jax.tree.map( lambda x: jnp.reshape(x, sample_shape + jnp.shape(x)[1:]), samples, ) diff --git a/numpyro/infer/barker.py b/numpyro/infer/barker.py index 5496e7baa..9d5fa0f2b 100644 --- a/numpyro/infer/barker.py +++ b/numpyro/infer/barker.py @@ -260,7 +260,7 @@ def sample(self, state, model_args, model_kwargs): - softplus(-dx_flat * y_grad_flat_scaled) ) ) - accept_prob = jnp.clip(jnp.exp(log_accept_ratio), a_max=1.0) + accept_prob = jnp.clip(jnp.exp(log_accept_ratio), None, 1.0) x, x_flat, pe, x_grad = jax.lax.cond( random.bernoulli(key_accept, accept_prob), diff --git a/numpyro/infer/ensemble.py b/numpyro/infer/ensemble.py index 33e3f50bf..033b2c7b0 100644 --- a/numpyro/infer/ensemble.py +++ b/numpyro/infer/ensemble.py @@ -160,7 +160,7 @@ def init( assert all( [ param.shape[0] == self._num_chains - for param in jax.tree_util.tree_leaves(init_params) + for param in jax.tree.leaves(init_params) ] ), "The batch dimension of each param must match n_chains" diff --git a/numpyro/infer/ensemble_util.py b/numpyro/infer/ensemble_util.py index 9f213ea5c..a5a3dd569 100644 --- a/numpyro/infer/ensemble_util.py +++ b/numpyro/infer/ensemble_util.py @@ -6,7 +6,6 @@ import jax from jax.flatten_util import ravel_pytree import jax.numpy as jnp -from jax.tree_util import tree_map def get_nondiagonal_indices(n): @@ -41,6 +40,6 @@ def batch_ravel_pytree(pytree): component of the output. """ flat = jax.vmap(lambda x: ravel_pytree(x)[0])(pytree) - unravel_fn = jax.vmap(ravel_pytree(tree_map(lambda z: z[0], pytree))[1]) + unravel_fn = jax.vmap(ravel_pytree(jax.tree.map(lambda z: z[0], pytree))[1]) return flat, unravel_fn diff --git a/numpyro/infer/hmc.py b/numpyro/infer/hmc.py index 709c16824..a299dc89c 100644 --- a/numpyro/infer/hmc.py +++ b/numpyro/infer/hmc.py @@ -400,7 +400,7 @@ def _hmc_next( ) delta_energy = energy_new - energy_old delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy) - accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0) + accept_prob = jnp.clip(jnp.exp(-delta_energy), None, 1.0) diverging = delta_energy > max_delta_energy transition = random.bernoulli(rng_key, accept_prob) vv_state, energy = cond( diff --git a/numpyro/infer/hmc_gibbs.py b/numpyro/infer/hmc_gibbs.py index 53622dcf2..f6b95389b 100644 --- a/numpyro/infer/hmc_gibbs.py +++ b/numpyro/infer/hmc_gibbs.py @@ -657,7 +657,7 @@ def potential_fn(z_gibbs, gibbs_state, z_hmc): # given a fixed hmc_sites, pe_new - pe_curr = loglik_new - loglik_curr pe = state.hmc_state.potential_energy pe_new = potential_fn(z_gibbs_new, gibbs_state_new, state.hmc_state.z) - accept_prob = jnp.clip(jnp.exp(pe - pe_new), a_max=1.0) + accept_prob = jnp.clip(jnp.exp(pe - pe_new), None, 1.0) transition = random.bernoulli(rng_key, accept_prob) grad_ = jacfwd if self.inner_kernel._forward_mode_differentiation else grad z_gibbs, gibbs_state, pe, z_grad = cond( diff --git a/numpyro/infer/hmc_util.py b/numpyro/infer/hmc_util.py index b331540d1..a21e7329e 100644 --- a/numpyro/infer/hmc_util.py +++ b/numpyro/infer/hmc_util.py @@ -3,12 +3,12 @@ from collections import OrderedDict, namedtuple +import jax from jax import grad, jacfwd, random, value_and_grad, vmap from jax.flatten_util import ravel_pytree import jax.numpy as jnp from jax.scipy.linalg import solve_triangular from jax.scipy.special import expit -from jax.tree_util import tree_flatten, tree_map import numpyro.distributions as dist from numpyro.util import cond, identity, while_loop @@ -295,15 +295,15 @@ def update_fn(step_size, inverse_mass_matrix, state): :return: new state for the integrator. """ z, r, _, z_grad = state - r = tree_map( + r = jax.tree.map( lambda r, z_grad: r - 0.5 * step_size * z_grad, r, z_grad ) # r(n+1/2) r_grad = _kinetic_grad(kinetic_fn, inverse_mass_matrix, r) - z = tree_map(lambda z, r_grad: z + step_size * r_grad, z, r_grad) # z(n+1) + z = jax.tree.map(lambda z, r_grad: z + step_size * r_grad, z, r_grad) # z(n+1) potential_energy, z_grad = _value_and_grad( potential_fn, z, forward_mode_differentiation ) - r = tree_map( + r = jax.tree.map( lambda r, z_grad: r - 0.5 * step_size * z_grad, r, z_grad ) # r(n+1) return IntegratorState(z, r, potential_energy, z_grad) @@ -669,7 +669,7 @@ def update_fn(t, accept_prob, z_info, state): ) # account the the case log_step_size is an extreme number finfo = jnp.finfo(jnp.result_type(step_size)) - step_size = jnp.clip(step_size, a_min=finfo.tiny, a_max=finfo.max) + step_size = jnp.clip(step_size, finfo.tiny, finfo.max) # update mass matrix state is_middle_window = (0 < window_idx) & (window_idx < (num_windows - 1)) @@ -759,7 +759,7 @@ def _biased_transition_kernel(current_tree, new_tree): # If new tree is turning or diverging, we won't move the proposal # to the new tree. transition_prob = jnp.where( - new_tree.turning | new_tree.diverging, 0.0, jnp.clip(transition_prob, a_max=1.0) + new_tree.turning | new_tree.diverging, 0.0, jnp.clip(transition_prob, None, 1.0) ) return transition_prob @@ -790,7 +790,7 @@ def _combine_tree( trees[1].z_right_grad, ), ) - r_sum = tree_map(jnp.add, current_tree.r_sum, new_tree.r_sum) + r_sum = jax.tree.map(jnp.add, current_tree.r_sum, new_tree.r_sum) if biased_transition: transition_prob = _biased_transition_kernel(current_tree, new_tree) @@ -872,7 +872,7 @@ def _build_basetree( tree_weight = -delta_energy diverging = delta_energy > max_delta_energy - accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0) + accept_prob = jnp.clip(jnp.exp(-delta_energy), None, 1.0) return TreeInfo( z_new, r_new, @@ -1242,7 +1242,7 @@ def consensus(subposteriors, num_draws=None, diagonal=False, rng_key=None): a collection of `num_draws` samples with the same data structure as each subposterior. """ # stack subposteriors - joined_subposteriors = tree_map(lambda *args: jnp.stack(args), *subposteriors) + joined_subposteriors = jax.tree.map(lambda *args: jnp.stack(args), *subposteriors) # shape of joined_subposteriors: n_subs x n_samples x sample_shape joined_subposteriors = vmap(vmap(lambda sample: ravel_pytree(sample)[0]))( joined_subposteriors @@ -1252,7 +1252,7 @@ def consensus(subposteriors, num_draws=None, diagonal=False, rng_key=None): rng_key = random.PRNGKey(0) if rng_key is None else rng_key # randomly gets num_draws from subposteriors n_subs = len(subposteriors) - n_samples = tree_flatten(subposteriors[0])[0][0].shape[0] + n_samples = jax.tree.flatten(subposteriors[0])[0][0].shape[0] # shape of draw_idxs: n_subs x num_draws x sample_shape draw_idxs = random.randint( rng_key, shape=(n_subs, num_draws), minval=0, maxval=n_samples @@ -1279,7 +1279,7 @@ def consensus(subposteriors, num_draws=None, diagonal=False, rng_key=None): ) # unravel_fn acts on 1 sample of a subposterior - _, unravel_fn = ravel_pytree(tree_map(lambda x: x[0], subposteriors[0])) + _, unravel_fn = ravel_pytree(jax.tree.map(lambda x: x[0], subposteriors[0])) return vmap(lambda x: unravel_fn(x))(samples_flat) @@ -1297,7 +1297,7 @@ def parametric(subposteriors, diagonal=False): `False` (using covariance). :return: the estimated mean and variance/covariance parameters of the joined posterior """ - joined_subposteriors = tree_map(lambda *args: jnp.stack(args), *subposteriors) + joined_subposteriors = jax.tree.map(lambda *args: jnp.stack(args), *subposteriors) joined_subposteriors = vmap(vmap(lambda sample: ravel_pytree(sample)[0]))( joined_subposteriors ) @@ -1345,5 +1345,5 @@ def parametric_draws(subposteriors, num_draws, diagonal=False, rng_key=None): mean, cov = parametric(subposteriors, diagonal=False) samples_flat = dist.MultivariateNormal(mean, cov).sample(rng_key, (num_draws,)) - _, unravel_fn = ravel_pytree(tree_map(lambda x: x[0], subposteriors[0])) + _, unravel_fn = ravel_pytree(jax.tree.map(lambda x: x[0], subposteriors[0])) return vmap(lambda x: unravel_fn(x))(samples_flat) diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index f1e0a7013..ad016825b 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -9,9 +9,9 @@ import numpy as np +import jax from jax import device_get, jit, lax, local_device_count, pmap, random, vmap import jax.numpy as jnp -from jax.tree_util import tree_flatten, tree_map from numpyro.diagnostics import print_summary from numpyro.util import ( @@ -164,18 +164,18 @@ def _get_progbar_desc_str(num_warmup, phase, i): def _get_value_from_index(xs, i): - return tree_map(lambda x: x[i], xs) + return jax.tree.map(lambda x: x[i], xs) def _laxmap(f, xs): - n = tree_flatten(xs)[0][0].shape[0] + n = jax.tree.flatten(xs)[0][0].shape[0] ys = [] for i in range(n): x = jit(_get_value_from_index)(xs, i) ys.append(f(x)) - return tree_map(lambda *args: jnp.stack(args), *ys) + return jax.tree.map(lambda *args: jnp.stack(args), *ys) def _sample_fn_jit_args(state, sampler): @@ -378,8 +378,8 @@ def _get_cached_fns(self): if self._jit_model_args: args, kwargs = (None,), (None,) else: - args = tree_map(lambda x: _hashable(x), self._args) - kwargs = tree_map( + args = jax.tree.map(lambda x: _hashable(x), self._args) + kwargs = jax.tree.map( lambda x: _hashable(x), tuple(sorted(self._kwargs.items())) ) key = args + kwargs @@ -422,8 +422,8 @@ def laxmap_postprocess_fn(states, args, kwargs): def _get_cached_init_state(self, rng_key, args, kwargs): rng_key = (_hashable(rng_key),) - args = tree_map(lambda x: _hashable(x), args) - kwargs = tree_map(lambda x: _hashable(x), tuple(sorted(kwargs.items()))) + args = jax.tree.map(lambda x: _hashable(x), args) + kwargs = jax.tree.map(lambda x: _hashable(x), tuple(sorted(kwargs.items()))) key = rng_key + args + kwargs try: return self._init_state_cache.get(key, None) @@ -480,7 +480,7 @@ def _single_chain_mcmc(self, init, args, kwargs, collect_fields, remove_sites): states = (states,) states = dict(zip(collect_fields, states)) # Apply constraints if number of samples is non-zero - site_values = tree_flatten(states[self._sample_field])[0] + site_values = jax.tree.flatten(states[self._sample_field])[0] # XXX: lax.map still works if some arrays have 0 size # so we only need to filter out the case site_value.shape[0] == 0 # (which happens when lower_idx==upper_idx) @@ -509,8 +509,8 @@ def _compile(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs): rng_key, *args, extra_fields=extra_fields, init_params=init_params, **kwargs ) rng_key = (_hashable(rng_key),) - args = tree_map(lambda x: _hashable(x), args) - kwargs = tree_map(lambda x: _hashable(x), tuple(sorted(kwargs.items()))) + args = jax.tree.map(lambda x: _hashable(x), args) + kwargs = jax.tree.map(lambda x: _hashable(x), tuple(sorted(kwargs.items()))) key = rng_key + args + kwargs try: self._init_state_cache[key] = self._last_state @@ -520,7 +520,7 @@ def _compile(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs): def _get_states_flat(self): if self._states_flat is None: - self._states_flat = tree_map( + self._states_flat = jax.tree.map( # need to calculate first dimension manually; see issue #1328 lambda x: jnp.reshape(x, (x.shape[0] * x.shape[1],) + x.shape[2:]), self._states, @@ -629,7 +629,7 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs): See https://jax.readthedocs.io/en/latest/async_dispatch.html and https://jax.readthedocs.io/en/latest/profiling.html for pointers on profiling jax programs. """ - init_params = tree_map( + init_params = jax.tree.map( lambda x: lax.convert_element_type(x, jnp.result_type(x)), init_params ) self._args = args @@ -643,7 +643,7 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs): init_state = self._warmup_state._replace(rng_key=rng_key) if init_params is not None and self.num_chains > 1: - prototype_init_val = tree_flatten(init_params)[0][0] + prototype_init_val = jax.tree.flatten(init_params)[0][0] if jnp.shape(prototype_init_val)[0] != self.num_chains: raise ValueError( "`init_params` must have the same leading dimension" @@ -673,7 +673,7 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs): map_args = (rng_key, init_state, init_params) if self.num_chains == 1: states_flat, last_state = partial_map_fn(map_args) - states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat) + states = jax.tree.map(lambda x: x[jnp.newaxis, ...], states_flat) else: if self.chain_method == "sequential": states, last_state = _laxmap(partial_map_fn, map_args) @@ -683,7 +683,7 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs): assert self.chain_method == "vectorized" states, last_state = partial_map_fn(map_args) # swap num_samples x num_chains to num_chains x num_samples - states = tree_map(lambda x: jnp.swapaxes(x, 0, 1), states) + states = jax.tree.map(lambda x: jnp.swapaxes(x, 0, 1), states) self._last_state = last_state self._states = states diff --git a/numpyro/infer/mixed_hmc.py b/numpyro/infer/mixed_hmc.py index deea69a79..3e3d2ae59 100644 --- a/numpyro/infer/mixed_hmc.py +++ b/numpyro/infer/mixed_hmc.py @@ -263,7 +263,7 @@ def body_fn(i, vals): # Algo 1, line 11: perform MH correction delta_energy = energy_new - energy_old - delta_pe_sum delta_energy = jnp.where(jnp.isnan(delta_energy), jnp.inf, delta_energy) - accept_prob = jnp.clip(jnp.exp(-delta_energy), a_max=1.0) + accept_prob = jnp.clip(jnp.exp(-delta_energy), None, 1.0) # record the correct new num_steps hmc_state = hmc_state._replace(num_steps=hmc_state_new.num_steps) diff --git a/numpyro/infer/svi.py b/numpyro/infer/svi.py index b21b531de..c70f0d91a 100644 --- a/numpyro/infer/svi.py +++ b/numpyro/infer/svi.py @@ -11,7 +11,6 @@ from jax import jit, lax, random from jax.example_libraries import optimizers import jax.numpy as jnp -from jax.tree_util import tree_map from numpyro.distributions import constraints from numpyro.distributions.transforms import biject_to @@ -240,7 +239,7 @@ def init(self, rng_key, *args, init_params=None, **kwargs): self.constrain_fn = partial(transform_fn, inv_transforms) # we convert weak types like float to float32/float64 # to avoid recompiling body_fn in svi.run - params, mutable_state = tree_map( + params, mutable_state = jax.tree.map( lambda x: lax.convert_element_type(x, jnp.result_type(x)), (params, mutable_state), ) diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py index 7d57cbbd0..3775893a7 100644 --- a/numpyro/infer/util.py +++ b/numpyro/infer/util.py @@ -14,7 +14,6 @@ from jax.flatten_util import ravel_pytree from jax.lax import broadcast_shapes import jax.numpy as jnp -from jax.tree_util import tree_flatten, tree_map import numpyro from numpyro.distributions import constraints @@ -770,7 +769,7 @@ def _predictive( # inspect the model to get some structure rng_key, subkey = random.split(rng_key) batch_ndim = len(batch_shape) - prototype_sample = tree_map( + prototype_sample = jax.tree.map( lambda x: jnp.reshape(x, (-1,) + jnp.shape(x)[batch_ndim:])[0], posterior_samples, ) @@ -1027,7 +1026,7 @@ def __call__(self, rng_key, *args, **kwargs): if self.batch_ndims == 0 or self.params == {} or self.guide is None: return self._call_with_params(rng_key, self.params, args, kwargs) elif self.batch_ndims == 1: # batch over parameters - batch_size = jnp.shape(tree_flatten(self.params)[0][0])[0] + batch_size = jnp.shape(jax.tree.flatten(self.params)[0][0])[0] rng_keys = random.split(rng_key, batch_size) return jax.vmap( partial(self._call_with_params, args=args, kwargs=kwargs), diff --git a/numpyro/ops/provenance.py b/numpyro/ops/provenance.py index dd88afdf8..68797396e 100644 --- a/numpyro/ops/provenance.py +++ b/numpyro/ops/provenance.py @@ -39,7 +39,7 @@ def eval_provenance(fn, **kwargs): :returns: A pytree of :class:`frozenset` indicating the dependency on the inputs. """ # Flatten the function and its arguments - args, in_tree = jax.tree_util.tree_flatten(((), kwargs)) + args, in_tree = jax.tree.flatten(((), kwargs)) wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fn), in_tree) # Abstract eval to get output pytree avals = util.safe_map(shaped_abstractify, args) @@ -54,19 +54,17 @@ def eval_provenance(fn, **kwargs): aval_kwargs = {} for n, v in kwargs.items(): aval = jax.ShapeDtypeStruct((), jnp.bool_, {"provenance": frozenset({n})}) - aval_kwargs[n] = jax.tree_util.tree_map(lambda _: aval, v) - aval_args, _ = jax.tree_util.tree_flatten(((), aval_kwargs)) - provenance_inputs = jax.tree_util.tree_map( - lambda x: x.named_shape["provenance"], aval_args - ) + aval_kwargs[n] = jax.tree.map(lambda _: aval, v) + aval_args, _ = jax.tree.flatten(((), aval_kwargs)) + provenance_inputs = jax.tree.map(lambda x: x.named_shape["provenance"], aval_args) provenance_outputs = track_deps_jaxpr(jaxpr, provenance_inputs) out_flat = [] for v, p in zip(avals_out, provenance_outputs): val = jax.ShapeDtypeStruct(jnp.shape(v), jnp.result_type(v), {"provenance": p}) out_flat.append(val) - out = jax.tree_util.tree_unflatten(out_tree(), out_flat) - return jax.tree_util.tree_map(lambda x: x.named_shape["provenance"], out) + out = jax.tree.unflatten(out_tree(), out_flat) + return jax.tree.map(lambda x: x.named_shape["provenance"], out) def track_deps_jaxpr(jaxpr, provenance_inputs): diff --git a/numpyro/optim.py b/numpyro/optim.py index 225a6f6bb..0abc90ee3 100644 --- a/numpyro/optim.py +++ b/numpyro/optim.py @@ -11,12 +11,13 @@ from collections.abc import Callable from typing import Any, TypeVar +import jax from jax import jacfwd, lax, value_and_grad from jax.example_libraries import optimizers from jax.flatten_util import ravel_pytree import jax.numpy as jnp from jax.scipy.optimize import minimize -from jax.tree_util import register_pytree_node, tree_map +from jax.tree_util import register_pytree_node __all__ = [ "Adam", @@ -176,9 +177,7 @@ def __init__(self, *args, clip_norm=10.0, **kwargs): def update(self, g, state): i, opt_state = state # clip norm - g = tree_map( - lambda g_: jnp.clip(g_, a_min=-self.clip_norm, a_max=self.clip_norm), g - ) + g = jax.tree.map(lambda g_: jnp.clip(g_, -self.clip_norm, self.clip_norm), g) opt_state = self.update_fn(i, g, opt_state) return i + 1, opt_state diff --git a/numpyro/util.py b/numpyro/util.py index 23a210d03..a30c09a53 100644 --- a/numpyro/util.py +++ b/numpyro/util.py @@ -20,7 +20,6 @@ from jax.core import Tracer from jax.experimental import host_callback import jax.numpy as jnp -from jax.tree_util import tree_flatten, tree_map _DISABLE_CONTROL_FLOW_PRIM = False _CHAIN_RE = re.compile(r"\d+$") # e.g. get '3' from 'TFRT_CPU_3' @@ -423,7 +422,7 @@ def soft_vmap(fn, xs, batch_ndims=1, chunk_size=None): Defaults to the size of batch dimensions. :returns: output of `fn(xs)`. """ - flatten_xs = tree_flatten(xs)[0] + flatten_xs = jax.tree.flatten(xs)[0] batch_shape = np.shape(flatten_xs[0])[:batch_ndims] for x in flatten_xs[1:]: assert np.shape(x)[:batch_ndims] == batch_shape @@ -431,7 +430,7 @@ def soft_vmap(fn, xs, batch_ndims=1, chunk_size=None): # we'll do map(vmap(fn), xs) and make xs.shape = (num_chunks, chunk_size, ...) num_chunks = batch_size = int(np.prod(batch_shape)) prepend_shape = (batch_size,) if batch_size > 1 else () - xs = tree_map( + xs = jax.tree.map( lambda x: jnp.reshape(x, prepend_shape + jnp.shape(x)[batch_ndims:]), xs ) # XXX: probably for the default behavior with chunk_size=None, @@ -439,12 +438,12 @@ def soft_vmap(fn, xs, batch_ndims=1, chunk_size=None): chunk_size = batch_size if chunk_size is None else min(batch_size, chunk_size) if chunk_size > 1: pad = chunk_size - (batch_size % chunk_size) - xs = tree_map( + xs = jax.tree.map( lambda x: jnp.pad(x, ((0, pad),) + ((0, 0),) * (np.ndim(x) - 1)), xs ) num_chunks = batch_size // chunk_size + int(pad > 0) prepend_shape = (-1,) if num_chunks > 1 else () - xs = tree_map( + xs = jax.tree.map( lambda x: jnp.reshape(x, prepend_shape + (chunk_size,) + jnp.shape(x)[1:]), xs, ) @@ -452,13 +451,13 @@ def soft_vmap(fn, xs, batch_ndims=1, chunk_size=None): ys = lax.map(fn, xs) if num_chunks > 1 else fn(xs) map_ndims = int(num_chunks > 1) + int(chunk_size > 1) - ys = tree_map( + ys = jax.tree.map( lambda y: jnp.reshape( y, (int(np.prod(jnp.shape(y)[:map_ndims])),) + jnp.shape(y)[map_ndims:] )[:batch_size], ys, ) - return tree_map(lambda y: jnp.reshape(y, batch_shape + jnp.shape(y)[1:]), ys) + return jax.tree.map(lambda y: jnp.reshape(y, batch_shape + jnp.shape(y)[1:]), ys) def format_shapes( diff --git a/setup.py b/setup.py index 0d9e4fb02..b47b0bc92 100644 --- a/setup.py +++ b/setup.py @@ -9,8 +9,8 @@ from setuptools import find_packages, setup PROJECT_PATH = os.path.dirname(os.path.abspath(__file__)) -_jax_version_constraints = ">=0.4.14" -_jaxlib_version_constraints = ">=0.4.14" +_jax_version_constraints = ">=0.4.25" +_jaxlib_version_constraints = ">=0.4.25" # Find version for line in open(os.path.join(PROJECT_PATH, "numpyro", "version.py")): diff --git a/test/contrib/einstein/test_steinvi_util.py b/test/contrib/einstein/test_steinvi_util.py index 617effb71..38fbd0603 100644 --- a/test/contrib/einstein/test_steinvi_util.py +++ b/test/contrib/einstein/test_steinvi_util.py @@ -8,8 +8,8 @@ import pytest import scipy +import jax from jax import numpy as jnp -from jax.tree_util import tree_flatten, tree_map from numpyro.contrib.einstein.stein_util import batch_ravel_pytree, posdef, sqrth @@ -82,10 +82,10 @@ def test_sqrth_shape(batch_shape): def test_ravel_pytree_batched(pytree, nbatch_dims): flat, _, unravel_fn = batch_ravel_pytree(pytree, nbatch_dims) unravel = unravel_fn(flat) - tree_flatten(tree_map(lambda x, y: assert_allclose(x, y), unravel, pytree)) + jax.tree.flatten(jax.tree.map(lambda x, y: assert_allclose(x, y), unravel, pytree)) assert all( - tree_flatten( - tree_map( + jax.tree.flatten( + jax.tree.map( lambda x, y: jnp.result_type(x) == jnp.result_type(y), unravel, pytree ) )[0] diff --git a/test/contrib/test_enum_elbo.py b/test/contrib/test_enum_elbo.py index d464fb33a..52c270cea 100644 --- a/test/contrib/test_enum_elbo.py +++ b/test/contrib/test_enum_elbo.py @@ -32,9 +32,7 @@ def assert_equal(a, b, prec=0): - return jax.tree_util.tree_map( - lambda a, b: np.testing.assert_allclose(a, b, atol=prec), a, b - ) + return jax.tree.map(lambda a, b: np.testing.assert_allclose(a, b, atol=prec), a, b) def xfail_param(*args, **kwargs): @@ -122,16 +120,16 @@ def guide(params): pyro.sample("x", dist.Categorical(probs_x), infer={"enumerate": "parallel"}) def auto_loss_fn(params_raw): - params = jax.tree_util.tree_map(transform, params_raw) + params = jax.tree.map(transform, params_raw) elbo = infer.TraceEnum_ELBO() return elbo.loss(random.PRNGKey(0), {}, auto_model, guide, params) def hand_loss_fn(params_raw): - params = jax.tree_util.tree_map(transform, params_raw) + params = jax.tree.map(transform, params_raw) elbo = infer.TraceEnum_ELBO() return elbo.loss(random.PRNGKey(0), {}, hand_model, guide, params) - params_raw = jax.tree_util.tree_map(transform.inv, params) + params_raw = jax.tree.map(transform.inv, params) auto_loss, auto_grad = jax.value_and_grad(auto_loss_fn)(params_raw) hand_loss, hand_grad = jax.value_and_grad(hand_loss_fn)(params_raw) @@ -174,16 +172,16 @@ def guide(params): pyro.sample("x", dist.Categorical(probs_x)) def auto_loss_fn(params_raw): - params = jax.tree_util.tree_map(transform, params_raw) + params = jax.tree.map(transform, params_raw) elbo = infer.TraceEnum_ELBO() return elbo.loss(random.PRNGKey(0), {}, auto_model, guide, params) def hand_loss_fn(params_raw): - params = jax.tree_util.tree_map(transform, params_raw) + params = jax.tree.map(transform, params_raw) elbo = infer.TraceEnum_ELBO() return elbo.loss(random.PRNGKey(0), {}, hand_model, guide, params) - params_raw = jax.tree_util.tree_map(transform.inv, params) + params_raw = jax.tree.map(transform.inv, params) auto_loss, auto_grad = jax.value_and_grad(auto_loss_fn)(params_raw) hand_loss, hand_grad = jax.value_and_grad(hand_loss_fn)(params_raw) @@ -225,16 +223,16 @@ def guide(params): pyro.sample("x", dist.Categorical(probs_x)) def auto_loss_fn(params_raw): - params = jax.tree_util.tree_map(transform, params_raw) + params = jax.tree.map(transform, params_raw) elbo = infer.TraceEnum_ELBO() return elbo.loss(random.PRNGKey(0), {}, auto_model, guide, params) def hand_loss_fn(params_raw): - params = jax.tree_util.tree_map(transform, params_raw) + params = jax.tree.map(transform, params_raw) elbo = infer.TraceEnum_ELBO() return elbo.loss(random.PRNGKey(0), {}, hand_model, guide, params) - params_raw = jax.tree_util.tree_map(transform.inv, params) + params_raw = jax.tree.map(transform.inv, params) auto_loss, auto_grad = jax.value_and_grad(auto_loss_fn)(params_raw) hand_loss, hand_grad = jax.value_and_grad(hand_loss_fn)(params_raw) @@ -296,16 +294,16 @@ def guide(data, params): ) def auto_loss_fn(params_raw): - params = jax.tree_util.tree_map(transform, params_raw) + params = jax.tree.map(transform, params_raw) elbo = infer.TraceEnum_ELBO() return elbo.loss(random.PRNGKey(0), {}, auto_model, guide, data, params) def hand_loss_fn(params_raw): - params = jax.tree_util.tree_map(transform, params_raw) + params = jax.tree.map(transform, params_raw) elbo = infer.TraceEnum_ELBO() return elbo.loss(random.PRNGKey(0), {}, hand_model, guide, data, params) - params_raw = jax.tree_util.tree_map(transform.inv, params) + params_raw = jax.tree.map(transform.inv, params) auto_loss, auto_grad = jax.value_and_grad(auto_loss_fn)(params_raw) hand_loss, hand_grad = jax.value_and_grad(hand_loss_fn)(params_raw) @@ -375,16 +373,16 @@ def guide(data, params): ) def auto_loss_fn(params_raw): - params = jax.tree_util.tree_map(transform, params_raw) + params = jax.tree.map(transform, params_raw) elbo = infer.TraceEnum_ELBO() return elbo.loss(random.PRNGKey(0), {}, auto_model, guide, data, params) def hand_loss_fn(params_raw): - params = jax.tree_util.tree_map(transform, params_raw) + params = jax.tree.map(transform, params_raw) elbo = infer.TraceEnum_ELBO() return elbo.loss(random.PRNGKey(0), {}, hand_model, guide, data, params) - params_raw = jax.tree_util.tree_map(transform.inv, params) + params_raw = jax.tree.map(transform.inv, params) auto_loss, auto_grad = jax.value_and_grad(auto_loss_fn)(params_raw) hand_loss, hand_grad = jax.value_and_grad(hand_loss_fn)(params_raw) @@ -467,16 +465,16 @@ def hand_guide(data, params): ) def auto_loss_fn(params_raw): - params = jax.tree_util.tree_map(transform, params_raw) + params = jax.tree.map(transform, params_raw) elbo = infer.TraceEnum_ELBO() return elbo.loss(random.PRNGKey(0), {}, auto_model, auto_guide, data, params) def hand_loss_fn(params_raw): - params = jax.tree_util.tree_map(transform, params_raw) + params = jax.tree.map(transform, params_raw) elbo = infer.TraceEnum_ELBO() return elbo.loss(random.PRNGKey(0), {}, hand_model, hand_guide, data, params) - params_raw = jax.tree_util.tree_map(transform.inv, params) + params_raw = jax.tree.map(transform.inv, params) auto_loss, auto_grad = jax.value_and_grad(auto_loss_fn)(params_raw) hand_loss, hand_grad = jax.value_and_grad(hand_loss_fn)(params_raw) @@ -2491,13 +2489,13 @@ def guide(params): "probs_a": jnp.array([3.0, 2.5]), } transform = dist.biject_to(dist.constraints.positive) - params_raw = jax.tree_util.tree_map(transform.inv, params) + params_raw = jax.tree.map(transform.inv, params) # TraceGraph_ELBO grads averaged over num_particles elbo = infer.TraceGraph_ELBO(num_particles=50_000) def graph_loss_fn(params_raw): - params = jax.tree_util.tree_map(transform, params_raw) + params = jax.tree.map(transform, params_raw) return elbo.loss(random.PRNGKey(0), {}, model, guide, params) graph_loss, graph_grads = jax.value_and_grad(graph_loss_fn)(params_raw) @@ -2506,7 +2504,7 @@ def graph_loss_fn(params_raw): elbo = infer.TraceEnum_ELBO(num_particles=50_000) def enum_loss_fn(params_raw): - params = jax.tree_util.tree_map(transform, params_raw) + params = jax.tree.map(transform, params_raw) return elbo.loss(random.PRNGKey(0), {}, model, guide, params) enum_loss, enum_grads = jax.value_and_grad(enum_loss_fn)(params_raw) diff --git a/test/contrib/test_module.py b/test/contrib/test_module.py index 5f43dcb3e..d3f27f17e 100644 --- a/test/contrib/test_module.py +++ b/test/contrib/test_module.py @@ -7,8 +7,8 @@ from numpy.testing import assert_allclose import pytest +import jax from jax import random -from jax.tree_util import tree_all, tree_map import numpyro from numpyro import handlers @@ -141,8 +141,8 @@ def test_update_params(): "a": {"b": {"c": {"d": ParamShape(())}, "e": 2}, "f": ParamShape((4,))} } - tree_all( - tree_map( + jax.tree.all( + jax.tree.map( assert_allclose, new_params, { diff --git a/test/infer/test_autoguide.py b/test/infer/test_autoguide.py index 9394e17a1..0e2e53aa4 100644 --- a/test/infer/test_autoguide.py +++ b/test/infer/test_autoguide.py @@ -7,10 +7,10 @@ from numpy.testing import assert_allclose import pytest +import jax from jax import jacobian, jit, lax, random, vmap from jax.example_libraries.stax import Dense import jax.numpy as jnp -from jax.tree_util import tree_all, tree_map import optax from optax import piecewise_constant_schedule @@ -270,7 +270,7 @@ def model(data, labels): transforms.biject_to(constraints.interval(-1, 1))(expected_sample["offset"]), ) - tree_all(tree_map(assert_allclose, actual_output, expected_output)) + jax.tree.all(jax.tree.map(assert_allclose, actual_output, expected_output)) def test_uniform_normal(): @@ -391,8 +391,8 @@ def expected_model(data): expected_loss = svi.evaluate(svi_state, data) # test auto_loc, auto_scale - tree_all(tree_map(assert_allclose, actual_opt_params, expected_opt_params)) - tree_all(tree_map(assert_allclose, actual_params, expected_params)) + jax.tree.all(jax.tree.map(assert_allclose, actual_opt_params, expected_opt_params)) + jax.tree.all(jax.tree.map(assert_allclose, actual_params, expected_params)) # test latent values assert_allclose(actual_values["alpha"], expected_values["alpha"]) assert_allclose(actual_values["loc_base"], expected_values["loc"]) diff --git a/test/infer/test_ensemble_util.py b/test/infer/test_ensemble_util.py index ad28a76ad..94fb9e560 100644 --- a/test/infer/test_ensemble_util.py +++ b/test/infer/test_ensemble_util.py @@ -26,6 +26,6 @@ def test_batch_ravel_pytree(): assert flattened.shape == (5, 2 + 3 + 4) for unflattened_leaf, original_leaf in zip( - jax.tree_util.tree_leaves(unflattened), jax.tree_util.tree_leaves(tree) + jax.tree.leaves(unflattened), jax.tree.leaves(tree) ): assert jnp.all(unflattened_leaf == original_leaf) diff --git a/test/infer/test_gradient.py b/test/infer/test_gradient.py index 2cb74dbe2..dec977909 100644 --- a/test/infer/test_gradient.py +++ b/test/infer/test_gradient.py @@ -22,9 +22,7 @@ def assert_equal(a, b, prec=0): - return jax.tree_util.tree_map( - lambda a, b: np.testing.assert_allclose(a, b, atol=prec), a, b - ) + return jax.tree.map(lambda a, b: np.testing.assert_allclose(a, b, atol=prec), a, b) def model_0(data, params): @@ -107,13 +105,13 @@ def guide_2(data, params): ) def test_gradient(model, guide, params, data): transform = dist.biject_to(dist.constraints.simplex) - params_raw = jax.tree_util.tree_map(transform.inv, params) + params_raw = jax.tree.map(transform.inv, params) # Expected grads based on exact integration elbo = infer.TraceEnum_ELBO() def expected_loss_fn(params_raw): - params = jax.tree_util.tree_map(transform, params_raw) + params = jax.tree.map(transform, params_raw) return elbo.loss( random.PRNGKey(0), {}, model, config_enumerate(guide), data, params ) @@ -124,7 +122,7 @@ def expected_loss_fn(params_raw): elbo = infer.TraceGraph_ELBO(num_particles=10_000) def actual_loss_fn(params_raw): - params = jax.tree_util.tree_map(transform, params_raw) + params = jax.tree.map(transform, params_raw) return elbo.loss(random.PRNGKey(0), {}, model, guide, data, params) actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw) @@ -336,20 +334,20 @@ def guide(params): "probs_z3": jnp.array([[[0.4, 0.6], [0.5, 0.5]], [[0.7, 0.3], [0.9, 0.1]]]), } transform = dist.biject_to(dist.constraints.simplex) - params_raw = jax.tree_util.tree_map(transform.inv, params) + params_raw = jax.tree.map(transform.inv, params) elbo = infer.TraceEnum_ELBO() # Exact integration based on enumeration def expected_loss_fn(params_raw): - params = jax.tree_util.tree_map(transform, params_raw) + params = jax.tree.map(transform, params_raw) return elbo.loss(random.PRNGKey(0), {}, model, guide, params) expected_loss, expected_grads = jax.value_and_grad(expected_loss_fn)(params_raw) # Exact integration based on the mix of enumeration and analytic kl def actual_loss_fn(params_raw): - params = jax.tree_util.tree_map(transform, params_raw) + params = jax.tree.map(transform, params_raw) return elbo.loss( random.PRNGKey(0), {}, model, config_kl(guide, kl_sites), params ) diff --git a/test/infer/test_hmc_util.py b/test/infer/test_hmc_util.py index 3b298c08d..f1a81e436 100644 --- a/test/infer/test_hmc_util.py +++ b/test/infer/test_hmc_util.py @@ -9,9 +9,9 @@ from numpy.testing import assert_allclose import pytest +import jax from jax import device_put, disable_jit, grad, jit, random import jax.numpy as jnp -from jax.tree_util import tree_map import numpyro.distributions as dist from numpyro.infer.hmc_util import ( @@ -222,7 +222,7 @@ def get_final_state(model, step_size, num_steps, q_i, p_i): assert_allclose(energy_initial, energy_final, atol=1e-5) logger.info("Test time reversibility:") - p_reverse = tree_map(lambda x: -x, p_f) + p_reverse = jax.tree.map(lambda x: -x, p_f) q_i, p_i = get_final_state(model, args.step_size, args.num_steps, q_f, p_reverse) for node in args.q_i: assert_allclose(q_i[node], args.q_i[node], atol=1e-4) diff --git a/test/infer/test_mcmc.py b/test/infer/test_mcmc.py index 480aba80d..6e5d31f4b 100644 --- a/test/infer/test_mcmc.py +++ b/test/infer/test_mcmc.py @@ -12,7 +12,6 @@ from jax import device_get, jit, lax, pmap, random, vmap import jax.numpy as jnp from jax.scipy.special import logit -from jax.tree_util import tree_all, tree_map import numpyro import numpyro.distributions as dist @@ -450,8 +449,8 @@ def model(data): mcmc1.run(random.PRNGKey(2), data) with pytest.raises(AssertionError): - tree_all( - tree_map( + jax.tree.all( + jax.tree.map( partial(assert_allclose, atol=1e-4, rtol=1e-4), mcmc1.get_samples(), mcmc.get_samples(), @@ -459,21 +458,21 @@ def model(data): ) mcmc1.warmup(random.PRNGKey(2), data) mcmc1.run(random.PRNGKey(3), data) - tree_all( - tree_map( + jax.tree.all( + jax.tree.map( partial(assert_allclose, atol=1e-4, rtol=1e-4), mcmc1.get_samples(), mcmc.get_samples(), ) ) - tree_all( - tree_map( + jax.tree.all( + jax.tree.map( partial(assert_allclose, atol=1e-4, rtol=1e-4), - tree_map( + jax.tree.map( lambda x: random.key_data(x) if is_prng_key(x) else x, mcmc1.post_warmup_state, ), - tree_map( + jax.tree.map( lambda x: random.key_data(x) if is_prng_key(x) else x, mcmc.post_warmup_state, ), diff --git a/test/infer/test_svi.py b/test/infer/test_svi.py index 166499063..f52c1cef2 100644 --- a/test/infer/test_svi.py +++ b/test/infer/test_svi.py @@ -11,7 +11,6 @@ from jax import jit, lax, random, value_and_grad from jax.example_libraries import optimizers import jax.numpy as jnp -from jax.tree_util import tree_all, tree_map import numpyro from numpyro import optim @@ -31,7 +30,7 @@ def assert_equal(a, b, prec=0): - return jax.tree_util.tree_map(lambda a, b: assert_allclose(a, b, atol=prec), a, b) + return jax.tree.map(lambda a, b: assert_allclose(a, b, atol=prec), a, b) @pytest.mark.parametrize("alpha", [0.0, 2.0]) @@ -272,7 +271,7 @@ def guide(data): expected = svi.get_params(svi.update(svi_state, data)[0]) actual = svi.get_params(jit(svi.update)(svi_state, data=data)[0]) - tree_all(tree_map(partial(assert_allclose, atol=1e-5), actual, expected)) + jax.tree.all(jax.tree.map(partial(assert_allclose, atol=1e-5), actual, expected)) def test_param(): diff --git a/test/ops/test_provenance.py b/test/ops/test_provenance.py index a64fcaadc..f40b846c1 100644 --- a/test/ops/test_provenance.py +++ b/test/ops/test_provenance.py @@ -72,19 +72,19 @@ def f(x, y): def test_provenance_call(): def identity(x): - args, in_tree = jax.tree_util.tree_flatten((x,)) + args, in_tree = jax.tree.flatten((x,)) fn, out_tree = flatten_fun_nokwargs(lu.wrap_init(lambda x: x), in_tree) out = core.closed_call_p.bind(fn, *args) - return jax.tree_util.tree_unflatten(out_tree(), out) + return jax.tree.unflatten(out_tree(), out) assert eval_provenance(identity, x={"v": 2}) == {"v": frozenset({"x"})} def test_provenance_closed_call(): def identity(x): - args, in_tree = jax.tree_util.tree_flatten((x,)) + args, in_tree = jax.tree.flatten((x,)) fn, out_tree = flatten_fun_nokwargs(lu.wrap_init(lambda x: x), in_tree) out = core.closed_call_p.bind(fn, *args) - return jax.tree_util.tree_unflatten(out_tree(), out) + return jax.tree.unflatten(out_tree(), out) assert eval_provenance(identity, x={"v": 2}) == {"v": frozenset({"x"})} diff --git a/test/test_constraints.py b/test/test_constraints.py index bfb459bdd..acd96732e 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -5,9 +5,9 @@ import pytest +import jax from jax import jit, vmap import jax.numpy as jnp -from jax.tree_util import tree_map from numpyro.distributions import constraints @@ -134,14 +134,14 @@ def out_cst(constraint, x): if len(cst_args) > 0: # test creating and manipulating vmapped constraints - vmapped_cst_args = tree_map(lambda x: x[None], cst_args) + vmapped_cst_args = jax.tree.map(lambda x: x[None], cst_args) vmapped_csts = jit(vmap(lambda args: cls(*args, **cst_kwargs), in_axes=(0,)))( vmapped_cst_args ) assert vmap(lambda x: x == constraint, in_axes=0)(vmapped_csts).all() - twice_vmapped_cst_args = tree_map(lambda x: x[None], vmapped_cst_args) + twice_vmapped_cst_args = jax.tree.map(lambda x: x[None], vmapped_cst_args) vmapped_csts = jit( vmap( diff --git a/test/test_distributions.py b/test/test_distributions.py index a4dc3528b..7453e3f51 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -3189,7 +3189,7 @@ def _allclose_or_equal(a1, a2): def _tree_equal(t1, t2): - t = jax.tree_util.tree_map(_allclose_or_equal, t1, t2) + t = jax.tree.map(_allclose_or_equal, t1, t2) return jnp.all(jax.flatten_util.ravel_pytree(t)[0]) @@ -3216,7 +3216,7 @@ def sample(d: dist.Distribution): # In this case, since csr arrays are not jittable, # _SparseCAR has a csr_matrix as part of its pytree # definition (not as a pytree leaf). This causes pytree - # operations like tree_map to fail, since these functions + # operations like jax.tree.map to fail, since these functions # compare the pytree def of each of the arguments using == # which is ambiguous for array-like objects. return @@ -3261,7 +3261,7 @@ def sample(d: dist.Distribution): for in_axes, out_axes in in_out_axes_cases: batched_params = [ ( - jax.tree_map(lambda x: jnp.expand_dims(x, ax), arg) + jax.jax.tree.map(lambda x: jnp.expand_dims(x, ax), arg) if isinstance(ax, int) else arg ) diff --git a/test/test_handlers.py b/test/test_handlers.py index 518f856dc..4ef449237 100644 --- a/test/test_handlers.py +++ b/test/test_handlers.py @@ -8,7 +8,6 @@ import jax from jax import jit, random, value_and_grad, vmap import jax.numpy as jnp -from jax.tree_util import tree_map try: import funsor @@ -441,7 +440,7 @@ def guide(subsample): svi_state.rng_key, svi.constrain_fn(x), svi.model, svi.guide, subsample ) )(params) - grads = tree_map(lambda *vals: vals[0] + vals[1], grads1, grads2) + grads = jax.tree.map(lambda *vals: vals[0] + vals[1], grads1, grads2) loss = loss1 + loss2 else: subsample = jnp.array([0, 1]) diff --git a/test/test_pickle.py b/test/test_pickle.py index a54479be2..7ed338578 100644 --- a/test/test_pickle.py +++ b/test/test_pickle.py @@ -7,9 +7,9 @@ from numpy.testing import assert_allclose import pytest +import jax from jax import random import jax.numpy as jnp -from jax.tree_util import tree_all, tree_map import numpyro from numpyro.contrib.funsor import config_kl @@ -90,7 +90,9 @@ def test_pickle_hmc(kernel): mcmc = MCMC(kernel(normal_model), num_warmup=10, num_samples=10) mcmc.run(random.PRNGKey(0)) pickled_mcmc = pickle.loads(pickle.dumps(mcmc)) - tree_all(tree_map(assert_allclose, mcmc.get_samples(), pickled_mcmc.get_samples())) + jax.tree.all( + jax.tree.map(assert_allclose, mcmc.get_samples(), pickled_mcmc.get_samples()) + ) @pytest.mark.parametrize("kernel", [BarkerMH, HMC, NUTS, SA]) @@ -108,7 +110,9 @@ def test_pickle_hmc_enumeration(kernel): mcmc = MCMC(kernel(gmm), num_warmup=10, num_samples=10) mcmc.run(random.PRNGKey(0), data, K) pickled_mcmc = pickle.loads(pickle.dumps(mcmc)) - tree_all(tree_map(assert_allclose, mcmc.get_samples(), pickled_mcmc.get_samples())) + jax.tree.all( + jax.tree.map(assert_allclose, mcmc.get_samples(), pickled_mcmc.get_samples()) + ) @pytest.mark.parametrize("kernel", [DiscreteHMCGibbs, MixedHMC]) @@ -116,14 +120,18 @@ def test_pickle_discrete_hmc(kernel): mcmc = MCMC(kernel(HMC(bernoulli_model)), num_warmup=10, num_samples=10) mcmc.run(random.PRNGKey(0)) pickled_mcmc = pickle.loads(pickle.dumps(mcmc)) - tree_all(tree_map(assert_allclose, mcmc.get_samples(), pickled_mcmc.get_samples())) + jax.tree.all( + jax.tree.map(assert_allclose, mcmc.get_samples(), pickled_mcmc.get_samples()) + ) def test_pickle_hmcecs(): mcmc = MCMC(HMCECS(NUTS(logistic_regression)), num_warmup=10, num_samples=10) mcmc.run(random.PRNGKey(0)) pickled_mcmc = pickle.loads(pickle.dumps(mcmc)) - tree_all(tree_map(assert_allclose, mcmc.get_samples(), pickled_mcmc.get_samples())) + jax.tree.all( + jax.tree.map(assert_allclose, mcmc.get_samples(), pickled_mcmc.get_samples()) + ) def poisson_regression(x, N): @@ -236,4 +244,4 @@ def guide(data): svi_result = svi.run(random.PRNGKey(0), 3, data) pickled_params = svi_result.params - tree_all(tree_map(assert_allclose, params, pickled_params)) + jax.tree.all(jax.tree.map(assert_allclose, params, pickled_params)) diff --git a/test/test_transforms.py b/test/test_transforms.py index 8f68eaf6e..997959244 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -8,9 +8,9 @@ import numpy as np import pytest +import jax from jax import jacfwd, jit, random, vmap import jax.numpy as jnp -from jax.tree_util import tree_map from numpyro.distributions import constraints from numpyro.distributions.flows import ( @@ -175,14 +175,14 @@ def out_t(transform, x): # this test assumes jittable args, and non-jittable kwargs, which is # not suited for all transforms, see InverseAutoregressiveTransform. # TODO: split among jittable and non-jittable args/kwargs instead. - vmapped_transform_args = tree_map(lambda x: x[None], transform_args) + vmapped_transform_args = jax.tree.map(lambda x: x[None], transform_args) vmapped_transform = jit( vmap(lambda args: cls(*args, **transform_kwargs), in_axes=(0,)) )(vmapped_transform_args) assert vmap(lambda x: x == transform, in_axes=0)(vmapped_transform).all() - twice_vmapped_transform_args = tree_map( + twice_vmapped_transform_args = jax.tree.map( lambda x: x[None], vmapped_transform_args ) diff --git a/test/test_util.py b/test/test_util.py index 208749e26..f351b45f3 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -5,10 +5,10 @@ from numpy.testing import assert_allclose import pytest +import jax from jax import random from jax.flatten_util import ravel_pytree import jax.numpy as jnp -from jax.tree_util import tree_all, tree_flatten, tree_map import numpyro import numpyro.distributions as dist @@ -44,7 +44,7 @@ def f(x): expected_tree = {"i": np.array([[0.0], [2.0]])} actual_tree = fori_collect(1, 3, f, a, transform=lambda a: {"i": a["i"]}) - tree_all(tree_map(assert_allclose, actual_tree, expected_tree)) + jax.tree.all(jax.tree.map(assert_allclose, actual_tree, expected_tree)) @pytest.mark.parametrize("progbar", [False, True]) @@ -64,8 +64,8 @@ def f(x): ) expected_tree = {"i": np.array([3, 4])} expected_last_state = {"i": np.array(4)} - tree_all(tree_map(assert_allclose, init_state, expected_last_state)) - tree_all(tree_map(assert_allclose, tree, expected_tree)) + jax.tree.all(jax.tree.map(assert_allclose, init_state, expected_last_state)) + jax.tree.all(jax.tree.map(assert_allclose, tree, expected_tree)) @pytest.mark.parametrize( @@ -82,10 +82,10 @@ def f(x): def test_ravel_pytree(pytree): flat, unravel_fn = ravel_pytree(pytree) unravel = unravel_fn(flat) - tree_flatten(tree_map(lambda x, y: assert_allclose(x, y), unravel, pytree)) + jax.tree.flatten(jax.tree.map(lambda x, y: assert_allclose(x, y), unravel, pytree)) assert all( - tree_flatten( - tree_map( + jax.tree.flatten( + jax.tree.map( lambda x, y: jnp.result_type(x) == jnp.result_type(y), unravel, pytree ) )[0]