Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update jax.tree_util.tree_map to jax.tree.map #1821

Merged
merged 5 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def main(args):
# is stored in `discrete_samples`. To merge those discrete samples into the `mcmc`
# instance, we can use the following pattern::
#
# chain_discrete_samples = jax.tree_util.tree_map(
# chain_discrete_samples = jax.tree.map(
# lambda x: x.reshape((args.num_chains, args.num_samples) + x.shape[1:]),
# discrete_samples)
# mcmc.get_samples().update(discrete_samples)
Expand Down
2 changes: 1 addition & 1 deletion examples/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def predict(rng_key, X, Y, X_test, var, length, noise, use_cholesky=True):
K = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))
mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, Y))

sigma_noise = jnp.sqrt(jnp.clip(jnp.diag(K), a_min=0.0)) * jax.random.normal(
sigma_noise = jnp.sqrt(jnp.clip(jnp.diag(K), 0.0)) * jax.random.normal(
rng_key, X_test.shape[:1]
)

Expand Down
4 changes: 2 additions & 2 deletions notebooks/source/time_series_forecasting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@
" level, s, moving_sum = carry\n",
" season = s[0] * level**pow_season\n",
" exp_val = level + coef_trend * level**pow_trend + season\n",
" exp_val = jnp.clip(exp_val, a_min=0)\n",
" exp_val = jnp.clip(exp_val, 0)\n",
" # use expected vale when forecasting\n",
" y_t = jnp.where(t >= N, exp_val, y[t])\n",
"\n",
Expand All @@ -215,7 +215,7 @@
" )\n",
" level_p = jnp.where(t >= seasonality, moving_sum / seasonality, y_t - season)\n",
" level = level_sm * level_p + (1 - level_sm) * level\n",
" level = jnp.clip(level, a_min=0)\n",
" level = jnp.clip(level, 0)\n",
"\n",
" new_s = (s_sm * (y_t - level) / season + (1 - s_sm)) * s[0]\n",
" # repeat s when forecasting\n",
Expand Down
40 changes: 21 additions & 19 deletions numpyro/contrib/control_flow/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from collections import OrderedDict
from functools import partial

import jax
from jax import device_put, lax, random
import jax.numpy as jnp
from jax.tree_util import tree_flatten, tree_map, tree_unflatten

from numpyro import handlers
from numpyro.distributions.batch_util import promote_batch_shape
Expand Down Expand Up @@ -98,7 +98,7 @@ def postprocess_message(self, msg):
fn_batch_ndim = len(fn.batch_shape)
if fn_batch_ndim < value_batch_ndims:
prepend_shapes = (1,) * (value_batch_ndims - fn_batch_ndim)
msg["fn"] = tree_map(
msg["fn"] = jax.tree.map(
lambda x: jnp.reshape(x, prepend_shapes + jnp.shape(x)), fn
)

Expand Down Expand Up @@ -140,11 +140,11 @@ def scan_enum(
history = min(history, length)
unroll_steps = min(2 * history - 1, length)
if reverse:
x0 = tree_map(lambda x: x[-unroll_steps:][::-1], xs)
xs_ = tree_map(lambda x: x[:-unroll_steps], xs)
x0 = jax.tree.map(lambda x: x[-unroll_steps:][::-1], xs)
xs_ = jax.tree.map(lambda x: x[:-unroll_steps], xs)
else:
x0 = tree_map(lambda x: x[:unroll_steps], xs)
xs_ = tree_map(lambda x: x[unroll_steps:], xs)
x0 = jax.tree.map(lambda x: x[:unroll_steps], xs)
xs_ = jax.tree.map(lambda x: x[unroll_steps:], xs)

carry_shapes = []

Expand Down Expand Up @@ -187,10 +187,12 @@ def body_fn(wrapped_carry, x, prefix=None):

# store shape of new_carry at a global variable
if len(carry_shapes) < (history + 1):
carry_shapes.append([jnp.shape(x) for x in tree_flatten(new_carry)[0]])
carry_shapes.append(
[jnp.shape(x) for x in jax.tree.flatten(new_carry)[0]]
)
# make new_carry have the same shape as carry
# FIXME: is this rigorous?
new_carry = tree_map(
new_carry = jax.tree.map(
lambda a, b: jnp.reshape(a, jnp.shape(b)), new_carry, carry
)
return (i + 1, rng_key, new_carry), (PytreeTrace(trace), y)
Expand All @@ -204,27 +206,27 @@ def body_fn(wrapped_carry, x, prefix=None):
for i in markov(range(unroll_steps + 1), history=history):
if i < unroll_steps:
wrapped_carry, (_, y0) = body_fn(
wrapped_carry, tree_map(lambda z: z[i], x0)
wrapped_carry, jax.tree.map(lambda z: z[i], x0)
)
if i > 0:
# reshape y1, y2,... to have the same shape as y0
y0 = tree_map(
y0 = jax.tree.map(
lambda z0, z: jnp.reshape(z, jnp.shape(z0)), y0s[0], y0
)
y0s.append(y0)
# shapes of the first `history - 1` steps are not useful to interpret the last carry
# shape so we don't need to record them here
if (i >= history - 1) and (len(carry_shapes) < history + 1):
carry_shapes.append(
jnp.shape(x) for x in tree_flatten(wrapped_carry[-1])[0]
jnp.shape(x) for x in jax.tree.flatten(wrapped_carry[-1])[0]
)
else:
# this is the last rolling step
y0s = tree_map(lambda *z: jnp.stack(z, axis=0), *y0s)
y0s = jax.tree.map(lambda *z: jnp.stack(z, axis=0), *y0s)
# return early if length = unroll_steps
if length == unroll_steps:
return wrapped_carry, (PytreeTrace({}), y0s)
wrapped_carry = tree_map(device_put, wrapped_carry)
wrapped_carry = jax.tree.map(device_put, wrapped_carry)
wrapped_carry, (pytree_trace, ys) = lax.scan(
body_fn, wrapped_carry, xs_, length - unroll_steps, reverse
)
Expand All @@ -251,20 +253,20 @@ def body_fn(wrapped_carry, x, prefix=None):
site["infer"]["dim_to_name"][time_dim] = "_time_{}".format(first_var)

# similar to carry, we need to reshape due to shape alternating in markov
ys = tree_map(
ys = jax.tree.map(
lambda z0, z: jnp.reshape(z, z.shape[:1] + jnp.shape(z0)[1:]), y0s, ys
)
# then join with y0s
ys = tree_map(lambda z0, z: jnp.concatenate([z0, z], axis=0), y0s, ys)
ys = jax.tree.map(lambda z0, z: jnp.concatenate([z0, z], axis=0), y0s, ys)
# we also need to reshape `carry` to match sequential behavior
i = (length + 1) % (history + 1)
t, rng_key, carry = wrapped_carry
carry_shape = carry_shapes[i]
flatten_carry, treedef = tree_flatten(carry)
flatten_carry, treedef = jax.tree.flatten(carry)
flatten_carry = [
jnp.reshape(x, t1_shape) for x, t1_shape in zip(flatten_carry, carry_shape)
]
carry = tree_unflatten(treedef, flatten_carry)
carry = jax.tree.unflatten(treedef, flatten_carry)
wrapped_carry = (t, rng_key, carry)
return wrapped_carry, (pytree_trace, ys)

Expand All @@ -282,7 +284,7 @@ def scan_wrapper(
first_available_dim=None,
):
if length is None:
length = jnp.shape(tree_flatten(xs)[0][0])[0]
length = jnp.shape(jax.tree.flatten(xs)[0][0])[0]

if enum and history > 0:
return scan_enum( # TODO: replay for enum
Expand Down Expand Up @@ -324,7 +326,7 @@ def body_fn(wrapped_carry, x):

return (i + 1, rng_key, carry), (PytreeTrace(trace), y)

wrapped_carry = tree_map(device_put, (0, rng_key, init))
wrapped_carry = jax.tree.map(device_put, (0, rng_key, init))
last_carry, (pytree_trace, ys) = lax.scan(
body_fn, wrapped_carry, xs, length=length, reverse=reverse
)
Expand Down
6 changes: 3 additions & 3 deletions numpyro/contrib/einstein/mixture_guide_predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from functools import partial
from typing import Optional

import jax
from jax import numpy as jnp, random, vmap
from jax.tree_util import tree_flatten, tree_map

from numpyro.handlers import substitute
from numpyro.infer import Predictive
Expand Down Expand Up @@ -63,7 +63,7 @@ def __init__(

self.guide = guide
self.return_sites = return_sites
self.num_mixture_components = jnp.shape(tree_flatten(params)[0][0])[0]
self.num_mixture_components = jnp.shape(jax.tree.flatten(params)[0][0])[0]
self.mixture_assignment_sitename = mixture_assignment_sitename

def _call_with_params(self, rng_key, params, args, kwargs):
Expand Down Expand Up @@ -99,7 +99,7 @@ def __call__(self, rng_key, *args, **kwargs):
minval=0,
maxval=self.num_mixture_components,
)
predictive_assign = tree_map(
predictive_assign = jax.tree.map(
lambda arr: vmap(lambda i, assign: arr[i, assign])(
jnp.arange(self._batch_shape[0]), assigns
),
Expand Down
10 changes: 5 additions & 5 deletions numpyro/contrib/einstein/stein_util.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import jax
from jax import numpy as jnp, vmap
from jax.flatten_util import ravel_pytree
from jax.tree_util import tree_map

from numpyro.distributions import biject_to
from numpyro.distributions.constraints import real
Expand Down Expand Up @@ -64,14 +64,14 @@ def batch_ravel_pytree(pytree, nbatch_dims=0):
flat, unravel_fn = ravel_pytree(pytree)
return flat, unravel_fn, unravel_fn

shapes = tree_map(lambda x: x.shape, pytree)
flat_pytree = tree_map(lambda x: x.reshape(*x.shape[:-nbatch_dims], -1), pytree)
shapes = jax.tree.map(lambda x: x.shape, pytree)
flat_pytree = jax.tree.map(lambda x: x.reshape(*x.shape[:-nbatch_dims], -1), pytree)
flat = vmap(lambda x: ravel_pytree(x)[0])(flat_pytree)
unravel_fn = ravel_pytree(tree_map(lambda x: x[0], flat_pytree))[1]
unravel_fn = ravel_pytree(jax.tree.map(lambda x: x[0], flat_pytree))[1]
return (
flat,
unravel_fn,
lambda _flat: tree_map(
lambda _flat: jax.tree.map(
lambda x, shape: x.reshape(shape), vmap(unravel_fn)(_flat), shapes
),
)
10 changes: 5 additions & 5 deletions numpyro/contrib/einstein/steinvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
from itertools import chain
import operator

import jax
from jax import grad, jacfwd, numpy as jnp, random, vmap
from jax.flatten_util import ravel_pytree
from jax.tree_util import tree_map

from numpyro import handlers
from numpyro.contrib.einstein.stein_kernels import SteinKernel
Expand Down Expand Up @@ -340,10 +340,10 @@ def _update_force(attr_force, rep_force, jac):
return force.reshape(attr_force.shape)

reparam_jac = {
name: tree_map(lambda var: _nontrivial_jac(name, var), variables)
name: jax.tree.map(lambda var: _nontrivial_jac(name, var), variables)
for name, variables in unravel_pytree(particle).items()
}
jac_params = tree_map(
jac_params = jax.tree.map(
_update_force,
unravel_pytree(attr_forces),
unravel_pytree(rep_forces),
Expand All @@ -363,7 +363,7 @@ def _update_force(attr_force, rep_force, jac):
stein_param_grads = unravel_pytree_batched(particle_grads)

# 6. Return loss and gradients (based on parameter forces)
res_grads = tree_map(
res_grads = jax.tree.map(
lambda x: -x, {**non_mixture_param_grads, **stein_param_grads}
)
return jnp.linalg.norm(particle_grads), res_grads
Expand Down Expand Up @@ -427,7 +427,7 @@ def init(self, rng_key, *args, **kwargs):
if site["name"] in guide_init_params:
pval = guide_init_params[site["name"]]
if self.non_mixture_params_fn(site["name"]):
pval = tree_map(lambda x: x[0], pval)
pval = jax.tree.map(lambda x: x[0], pval)
else:
pval = site["value"]
params[site["name"]] = transform.inv(pval)
Expand Down
11 changes: 6 additions & 5 deletions numpyro/contrib/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from copy import deepcopy
from functools import partial

import jax
from jax import random
import jax.numpy as jnp
from jax.tree_util import register_pytree_node, tree_flatten, tree_unflatten
from jax.tree_util import register_pytree_node

import numpyro
import numpyro.distributions as dist
Expand Down Expand Up @@ -106,8 +107,8 @@ def flax_module(
assert set(mutable) == set(nn_state)
numpyro_mutable(name + "$state", nn_state)
# make sure that nn_params keep the same order after unflatten
params_flat, tree_def = tree_flatten(nn_params)
nn_params = tree_unflatten(tree_def, params_flat)
params_flat, tree_def = jax.tree.flatten(nn_params)
nn_params = jax.tree.unflatten(tree_def, params_flat)
numpyro.param(module_key, nn_params)

def apply_with_state(params, *args, **kwargs):
Expand Down Expand Up @@ -195,8 +196,8 @@ def haiku_module(name, nn_module, *args, input_shape=None, apply_rng=False, **kw
nn_params = hk.data_structures.to_mutable_dict(nn_params)
# we cast it to a mutable one to be able to set priors for parameters
# make sure that nn_params keep the same order after unflatten
params_flat, tree_def = tree_flatten(nn_params)
nn_params = tree_unflatten(tree_def, params_flat)
params_flat, tree_def = jax.tree.flatten(nn_params)
nn_params = jax.tree.unflatten(tree_def, params_flat)
numpyro.param(module_key, nn_params)

def apply_with_state(params, *args, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions numpyro/contrib/tfp/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,11 @@ def is_discrete(self):
return self.support is None

def tree_flatten(self):
return jax.tree_util.tree_flatten(self.tfp_dist)
return jax.tree.flatten(self.tfp_dist)

@classmethod
def tree_unflatten(cls, aux_data, params):
fn = jax.tree_util.tree_unflatten(aux_data, params)
fn = jax.tree.unflatten(aux_data, params)
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=FutureWarning)
return TFPDistribution[fn.__class__](**fn.parameters)
Expand Down
4 changes: 2 additions & 2 deletions numpyro/contrib/tfp/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from collections import namedtuple
import inspect

import jax
from jax import random, vmap
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
from jax.tree_util import tree_map
import tensorflow_probability.substrates.jax as tfp

from numpyro.infer import init_to_uniform
Expand Down Expand Up @@ -44,7 +44,7 @@ def log_prob_fn(x):
flatten_result = vmap(lambda a: -potential_fn(unravel_fn(a)))(
jnp.reshape(x, (-1,) + jnp.shape(x)[-1:])
)
return tree_map(
return jax.tree.map(
lambda a: jnp.reshape(a, batch_shape + jnp.shape(a)[1:]), flatten_result
)
else:
Expand Down
12 changes: 6 additions & 6 deletions numpyro/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

import numpy as np

import jax
from jax import device_get
from jax.tree_util import tree_flatten, tree_map

__all__ = [
"autocorrelation",
Expand Down Expand Up @@ -182,7 +182,7 @@ def effective_sample_size(x):
Rho_k = np.concatenate(
[
Rho_init,
np.minimum.accumulate(np.clip(Rho_k[1:, ...], a_min=0, a_max=None), axis=0),
np.minimum.accumulate(np.clip(Rho_k[1:, ...], 0, None), axis=0),
],
axis=0,
)
Expand Down Expand Up @@ -238,10 +238,10 @@ def summary(samples, prob=0.90, group_by_chain=True):
chain dimension).
"""
if not group_by_chain:
samples = tree_map(lambda x: x[None, ...], samples)
samples = jax.tree.map(lambda x: x[None, ...], samples)
if not isinstance(samples, dict):
samples = {
"Param:{}".format(i): v for i, v in enumerate(tree_flatten(samples)[0])
"Param:{}".format(i): v for i, v in enumerate(jax.tree.flatten(samples)[0])
}

summary_dict = {}
Expand Down Expand Up @@ -288,10 +288,10 @@ def print_summary(samples, prob=0.90, group_by_chain=True):
chain dimension).
"""
if not group_by_chain:
samples = tree_map(lambda x: x[None, ...], samples)
samples = jax.tree.map(lambda x: x[None, ...], samples)
if not isinstance(samples, dict):
samples = {
"Param:{}".format(i): v for i, v in enumerate(tree_flatten(samples)[0])
"Param:{}".format(i): v for i, v in enumerate(jax.tree.flatten(samples)[0])
}
summary_dict = summary(samples, prob, group_by_chain=True)

Expand Down
4 changes: 2 additions & 2 deletions numpyro/distributions/batch_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from functools import singledispatch
from typing import Union

import jax
import jax.numpy as jnp
from jax.tree_util import tree_map

from numpyro.distributions import constraints
from numpyro.distributions.conjugate import (
Expand Down Expand Up @@ -547,7 +547,7 @@ def _promote_batch_shape_expanded(d: ExpandedDistribution):
len(new_shapes_elems),
len(new_shapes_elems) + len(orig_delta_batch_shape),
)
new_base_dist = tree_map(
new_base_dist = jax.tree.map(
lambda x: jnp.expand_dims(x, axis=new_axes_locs), new_self.base_dist
)

Expand Down
Loading
Loading