Skip to content

Commit

Permalink
use jax.tree_util.* functions to deprecation warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
NTT123 authored Jul 30, 2022
1 parent ec6018f commit 821438b
Show file tree
Hide file tree
Showing 30 changed files with 169 additions and 122 deletions.
31 changes: 31 additions & 0 deletions NOTE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
Things that can be improved:

- [x] Create a core without kind
- [x] Improve how to register parameters/states.
- [x] Compute gradient with respect to trainable parameters.
- [x] How to manage random key???
- [x] Improve mixed precision and flatten module.
- [x] Improve optimizer API.
- [x] Support mixed attributes.
- [ ] Performance penalty due to tree flatten, unflatten.



- [ ] Improve impure -> pure API.

Current solution:

pax.module_and_value(net)(x, y, z)

--> new solution
pax.purecall(net, x, y, z)
Other... approach is to transformation a ... module into a pure ... init and apply.

We provide a better/general solution.

t = pax.module_value(net)(x, y)

net, t = net % (x, y)

out = pax.unsafe(net)(x, y)
net, out = pax.pure(lambda net: net, net(x))
2 changes: 1 addition & 1 deletion docs/notebooks/basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@
"outputs": [],
"source": [
"def sgd(params: Linear, gradients: Linear, lr: float = 1e-1):\n",
" updated_params = jax.tree_map(lambda p, g: p - lr * g, params, gradients)\n",
" updated_params = jax.tree_util.tree_map(lambda p, g: p - lr * g, params, gradients)\n",
" return updated_params"
]
},
Expand Down
30 changes: 15 additions & 15 deletions docs/notebooks/performance.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@
"\n",
"start = time.perf_counter()\n",
"for i in range(10_000):\n",
" a, b = jax.tree_flatten((net, optimizer))\n",
" net, optimizer = jax.tree_unflatten(b, a)\n",
" a, b = jax.tree_util.tree_flatten((net, optimizer))\n",
" net, optimizer = jax.tree_util.tree_unflatten(b, a)\n",
"end = time.perf_counter()\n",
"print(\"Duration:\", end - start)"
]
Expand Down Expand Up @@ -156,7 +156,7 @@
"In this mode, the optimizer will automatically flatten the parameters and gradients to a list of leaves instead of dealing with the full tree structure. This reduces the `flatten` and `unflatten` time of the optimizer to almost zero.\n",
"\n",
"However, we are no longer able to access the optimizer's pytree objects. \n",
"Fortunately, we rarely need to access the optimizer's pytree objects, and one can easily convert the flatten list back to the pytree object using `jax.tree_unflatten` function.\n"
"Fortunately, we rarely need to access the optimizer's pytree objects, and one can easily convert the flatten list back to the pytree object using `jax.tree_util.tree_unflatten` function.\n"
]
},
{
Expand All @@ -183,8 +183,8 @@
"\n",
"start = time.perf_counter()\n",
"for i in range(10_000):\n",
" a, b = jax.tree_flatten((net, optimizer))\n",
" net, optimizer = jax.tree_unflatten(b, a)\n",
" a, b = jax.tree_util.tree_flatten((net, optimizer))\n",
" net, optimizer = jax.tree_util.tree_unflatten(b, a)\n",
"end = time.perf_counter()\n",
"print(\"Duration:\", end - start)"
]
Expand All @@ -211,8 +211,8 @@
"source": [
"start = time.perf_counter()\n",
"for i in range(10_000):\n",
" a, b = jax.tree_flatten(optimizer)\n",
" optimizer = jax.tree_unflatten(b, a)\n",
" a, b = jax.tree_util.tree_flatten(optimizer)\n",
" optimizer = jax.tree_util.tree_unflatten(b, a)\n",
"end = time.perf_counter()\n",
"print(\"Duration:\", end - start)"
]
Expand Down Expand Up @@ -316,11 +316,11 @@
"\n",
"@partial(jax.jit, static_argnums=0)\n",
"def flatten_update(model_def, model_leaves, optimizer, inputs):\n",
" model = jax.tree_unflatten(model_def, model_leaves)\n",
" model = jax.tree_util.tree_unflatten(model_def, model_leaves)\n",
" params = model.parameters()\n",
" grads, (loss, model) = jax.grad(loss_fn, has_aux=True)(params, model, inputs)\n",
" model, optimizer = opax.apply_gradients(model, optimizer, grads=grads)\n",
" return jax.tree_leaves(model), optimizer, loss"
" return jax.tree_util.tree_leaves(model), optimizer, loss"
]
},
{
Expand All @@ -331,7 +331,7 @@
},
"outputs": [],
"source": [
"net_leaves, net_def = jax.tree_flatten(net)\n",
"net_leaves, net_def = jax.tree_util.tree_flatten(net)\n",
"# net_leaves, optimizer, loss = flatten_update(net_def, net_leaves, optimizer, (img, label))"
]
},
Expand All @@ -357,8 +357,8 @@
"source": [
"start = time.perf_counter()\n",
"for i in range(10_000):\n",
" a, b = jax.tree_flatten((net_leaves, optimizer))\n",
" (net_leaves, optimizer) = jax.tree_unflatten(b, a)\n",
" a, b = jax.tree_util.tree_flatten((net_leaves, optimizer))\n",
" (net_leaves, optimizer) = jax.tree_util.tree_unflatten(b, a)\n",
"end = time.perf_counter()\n",
"print(\"Duration:\", end - start)"
]
Expand All @@ -378,7 +378,7 @@
"metadata": {},
"outputs": [],
"source": [
"net = jax.tree_unflatten(net_def, net_leaves)"
"net = jax.tree_util.tree_unflatten(net_def, net_leaves)"
]
},
{
Expand Down Expand Up @@ -421,8 +421,8 @@
"source": [
"start = time.perf_counter()\n",
"for i in range(10_000):\n",
" a, b = jax.tree_flatten(flat_mods)\n",
" flat_mods = jax.tree_unflatten(b, a)\n",
" a, b = jax.tree_util.tree_flatten(flat_mods)\n",
" flat_mods = jax.tree_util.tree_unflatten(b, a)\n",
"end = time.perf_counter()\n",
"print(\"Duration:\", end - start)"
]
Expand Down
18 changes: 9 additions & 9 deletions docs/notebooks/understanding.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"JAX provides the `jax.tree_flatten` function that transforms an object into its tree representation that includes:\n",
"JAX provides the `jax.tree_util.tree_flatten` function that transforms an object into its tree representation that includes:\n",
"\n",
"- `leaves`: a list of tree leaves.\n",
"- `treedef`: information about the structure of the tree."
Expand All @@ -85,7 +85,7 @@
}
],
"source": [
"leaves, treedef = jax.tree_flatten(e)\n",
"leaves, treedef = jax.tree_util.tree_flatten(e)\n",
"print(\"Leaves:\", leaves)\n",
"print(\"TreeDef:\", treedef)"
]
Expand All @@ -96,7 +96,7 @@
"source": [
".. note:: Even though a pytree can have any object at its leaves, many jax functions such as ``jax.jit``, ``jax.lax.scan``, ``jax.grad``, etc. only support pytrees with `ndarray` leaves.\n",
"\n",
"We can reverse ``jax.tree_flatten`` transformation with ``jax.tree_unflatten``:"
"We can reverse ``jax.tree_util.tree_flatten`` transformation with ``jax.tree_util.tree_unflatten``:"
]
},
{
Expand All @@ -116,7 +116,7 @@
}
],
"source": [
"jax.tree_unflatten(treedef=treedef, leaves=leaves)"
"jax.tree_util.tree_unflatten(treedef=treedef, leaves=leaves)"
]
},
{
Expand Down Expand Up @@ -201,9 +201,9 @@
"source": [
"mod = ModuleV0([1, 2, 3])\n",
"print(mod)\n",
"leaves, tree_def = jax.tree_flatten(mod)\n",
"leaves, tree_def = jax.tree_util.tree_flatten(mod)\n",
"print(leaves, tree_def)\n",
"new_mod = jax.tree_unflatten(tree_def, leaves)\n",
"new_mod = jax.tree_util.tree_unflatten(tree_def, leaves)\n",
"new_mod"
]
},
Expand Down Expand Up @@ -331,9 +331,9 @@
"print(counter)\n",
"counter.step()\n",
"print(counter)\n",
"leaves, treedef = jax.tree_flatten(counter)\n",
"leaves, treedef = jax.tree_util.tree_flatten(counter)\n",
"print((leaves, treedef))\n",
"new_counter = jax.tree_unflatten(treedef, leaves)\n",
"new_counter = jax.tree_util.tree_unflatten(treedef, leaves)\n",
"print(new_counter)"
]
},
Expand Down Expand Up @@ -486,7 +486,7 @@
" def find_and_register_subtree(self):\n",
" for name, value in self.__dict__.items():\n",
" is_pytree = lambda x: isinstance(x, (np.ndarray, jnp.ndarray, ModuleV3))\n",
" leaves, _ = jax.tree_flatten(value, is_leaf=is_pytree)\n",
" leaves, _ = jax.tree_util.tree_flatten(value, is_leaf=is_pytree)\n",
" if any(map(is_pytree, leaves)):\n",
" self.register_subtree(name, value)"
]
Expand Down
2 changes: 1 addition & 1 deletion examples/char_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def detokenize(tokens):
loss = loss_accum[0] / loss_accum[1]
loss_accum = 0.0, 0
# eval on a single device
eval_net = jax.tree_map(lambda x: x[0], net.eval())
eval_net = jax.tree_util.tree_map(lambda x: x[0], net.eval())
out = eval_net.inference(
prompt=tokenize(test_prompt),
length=(100 if step < num_steps else 1000),
Expand Down
2 changes: 1 addition & 1 deletion examples/denoising_diffusion/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def loss_fn(model, inputs):
total_loss = 0.0
tr = tqdm(dataloader)
for step, batch in enumerate(tr, 1):
batch = jax.tree_map(lambda x: x.numpy(), batch)
batch = jax.tree_util.tree_map(lambda x: x.numpy(), batch)
diffusion, optimizer, loss = fast_update_fn(diffusion, optimizer, batch)
total_loss = total_loss + loss

Expand Down
4 changes: 2 additions & 2 deletions examples/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,15 @@ def train(

# training
for batch in tqdm(train_data, desc="train", leave=False):
batch = jax.tree_map(lambda x: x.numpy(), batch)
batch = jax.tree_util.tree_map(lambda x: x.numpy(), batch)
net, optimizer, loss = update_fn(net, optimizer, batch)
losses = losses + loss
loss = losses / len(train_data)

# testing
test_losses = 0.0
for batch in tqdm(test_data, desc="test", leave=False):
batch = jax.tree_map(lambda x: x.numpy(), batch)
batch = jax.tree_util.tree_map(lambda x: x.numpy(), batch)
test_losses = test_losses + test_loss_fn(net, batch)
test_loss = test_losses / len(test_data)

Expand Down
4 changes: 2 additions & 2 deletions examples/mnist_mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def train(batch_size=32, num_epochs=5, learning_rate=1e-4, weight_decay=1e-4):
for epoch in range(0, num_epochs):
losses = 0.0
for batch in tqdm(train_data, desc="train", leave=False):
batch = jax.tree_map(lambda x: x.numpy(), batch)
batch = jax.tree_util.tree_map(lambda x: x.numpy(), batch)
net, optimizer, loss_scale, loss = update_fn(
net, optimizer, loss_scale, batch
)
Expand All @@ -139,7 +139,7 @@ def train(batch_size=32, num_epochs=5, learning_rate=1e-4, weight_decay=1e-4):

test_losses = 0.0
for batch in tqdm(test_data, desc="eval", leave=False):
batch = jax.tree_map(lambda x: x.numpy(), batch)
batch = jax.tree_util.tree_map(lambda x: x.numpy(), batch)
test_losses = test_losses + test_loss_fn(net, batch)
test_loss = test_losses / len(test_data)

Expand Down
6 changes: 4 additions & 2 deletions examples/notebooks/DCGAN.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -648,15 +648,17 @@
" (netG, netD, optG, optD, rng_key), data\n",
" )\n",
"\n",
" accum_train_record = jax.tree_map(\n",
" accum_train_record = jax.tree_util.tree_map(\n",
" lambda x, y: x + y, accum_train_record, train_record\n",
" )\n",
"\n",
" D_losses.append(train_record.errD)\n",
" G_losses.append(train_record.errG)\n",
"\n",
" if i % log_freq == 0:\n",
" avg: TrainRecord = jax.tree_map(lambda x: x / log_freq, accum_train_record)\n",
" avg: TrainRecord = jax.tree_util.tree_map(\n",
" lambda x: x / log_freq, accum_train_record\n",
" )\n",
" tr.write(\n",
" \"[Step {:>4}] errD {:.3f} errG {:.3f} D_real {:.3f} D_fake {:.3f} D_G {:.3f}\".format(\n",
" i, avg.errD, avg.errG, avg.D_real, avg.D_fake, avg.D_G\n",
Expand Down
8 changes: 4 additions & 4 deletions examples/notebooks/VAE.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -242,12 +242,12 @@
" for epoch in range(50):\n",
" losses = LossInfo(0.0, 0.0, 0.0)\n",
" for batch in train_data:\n",
" batch = jax.tree_map(lambda x: x.numpy(), batch)\n",
" batch = jax.tree_util.tree_map(lambda x: x.numpy(), batch)\n",
" vae, optimizer, loss_info = fast_update_fn(vae, optimizer, batch)\n",
" training_losses.append(loss_info.loss)\n",
" losses = jax.tree_map(lambda x, y: x + y, losses, loss_info)\n",
" losses = jax.tree_util.tree_map(lambda x, y: x + y, losses, loss_info)\n",
"\n",
" losses = jax.tree_map(lambda x: x / len(train_data), losses)\n",
" losses = jax.tree_util.tree_map(lambda x: x / len(train_data), losses)\n",
" print(\n",
" f\"[Epoch {epoch:>2}] train loss {losses.loss:.3f} reconstruction loss {losses.reconstruction_loss:.3f} kl_loss {losses.kl_loss:.3f}\"\n",
" )\n",
Expand Down Expand Up @@ -351,7 +351,7 @@
" fast_encoder = jax.jit(lambda model, data: model.encoder(data))\n",
"\n",
" for batch in data:\n",
" batch = jax.tree_map(lambda x: x.numpy(), batch)\n",
" batch = jax.tree_util.tree_map(lambda x: x.numpy(), batch)\n",
" data, label = batch\n",
" z = fast_encoder(vae, data)\n",
" z_mean, _ = jnp.split(z, 2, axis=-1)\n",
Expand Down
8 changes: 4 additions & 4 deletions examples/notebooks/fine_tuning_resnet18.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@
"def test_accuracy(model, test_data):\n",
" num_correct_predictions, total = 0, 0\n",
" for batch in test_data:\n",
" batch = jax.tree_map(lambda x: x.numpy(), batch)\n",
" batch = jax.tree_util.tree_map(lambda x: x.numpy(), batch)\n",
" predicted_label = predict(model, batch[\"image\"])\n",
" num_correct_predictions = num_correct_predictions + jnp.sum(\n",
" predicted_label == batch[\"label\"]\n",
Expand Down Expand Up @@ -259,7 +259,7 @@
" losses = 0.0\n",
"\n",
" for step, batch in enumerate(tqdm(train_data), 1):\n",
" batch = jax.tree_map(lambda x: x.numpy(), batch)\n",
" batch = jax.tree_util.tree_map(lambda x: x.numpy(), batch)\n",
" resnet18, opt, loss = fast_update_fn(resnet18, opt, batch)\n",
" losses = losses + loss\n",
"\n",
Expand All @@ -269,7 +269,7 @@
"\n",
" total_test_loss = 0.0\n",
" for batch in test_data:\n",
" batch = jax.tree_map(lambda x: x.numpy(), batch)\n",
" batch = jax.tree_util.tree_map(lambda x: x.numpy(), batch)\n",
" loss = fast_test_loss_fn(resnet18, batch)\n",
" total_test_loss = total_test_loss + loss\n",
" test_loss = total_test_loss / len(test_data)\n",
Expand Down Expand Up @@ -299,7 +299,7 @@
"source": [
"def plot_model_prediction(model):\n",
" test_batch = next(iter(test_data))\n",
" test_batch = jax.tree_map(lambda x: x.numpy(), test_batch)\n",
" test_batch = jax.tree_util.tree_map(lambda x: x.numpy(), test_batch)\n",
" test_image, test_label = test_batch[\"image\"], test_batch[\"label\"]\n",
" test_image = jnp.transpose(test_image, axes=(0, 3, 1, 2))\n",
"\n",
Expand Down
6 changes: 3 additions & 3 deletions examples/notebooks/mixed_precision.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@
"total_loss = jnp.array(0.0, dtype=jnp.float32)\n",
"\n",
"for step, batch in tr:\n",
" batch = jax.tree_map(lambda x: x.numpy(), batch)\n",
" batch = jax.tree_util.tree_map(lambda x: x.numpy(), batch)\n",
" total_loss, unet, optimizer, loss_scale = update_fn(\n",
" unet, optimizer, loss_scale, batch, total_loss\n",
" )\n",
Expand All @@ -491,7 +491,7 @@
"\n",
" test_losses = 0.0\n",
" for test_batch in tqdm(test_batches, desc=\"evaluating\", leave=False):\n",
" test_batch = jax.tree_map(lambda x: x.numpy(), test_batch)\n",
" test_batch = jax.tree_util.tree_map(lambda x: x.numpy(), test_batch)\n",
" test_losses = test_losses + test_loss_fn(unet, loss_scale, test_batch)\n",
" test_loss = test_losses / len(test_batches)\n",
"\n",
Expand Down Expand Up @@ -554,7 +554,7 @@
},
"outputs": [],
"source": [
"test_batch = jax.tree_map(lambda x: x.numpy(), next(test_iter))\n",
"test_batch = jax.tree_util.tree_map(lambda x: x.numpy(), next(test_iter))\n",
"logits = unet.eval()(jnp.array(test_batch[0]))\n",
"label = jnp.argmax(logits, axis=-1)\n",
"for i in range(3):\n",
Expand Down
8 changes: 6 additions & 2 deletions examples/transformer/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@ def detokenize(tokens):


def _device_put_sharded(sharded_tree, devices):
leaves, treedef = jax.tree_flatten(sharded_tree)
leaves, treedef = jax.tree_util.tree_flatten(sharded_tree)
n = leaves[0].shape[0]
return jax.device_put_sharded(
[jax.tree_unflatten(treedef, [l[i] for l in leaves]) for i in range(n)], devices
[
jax.tree_util.tree_unflatten(treedef, [l[i] for l in leaves])
for i in range(n)
],
devices,
)


Expand Down
2 changes: 1 addition & 1 deletion examples/transformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def train():
loss = jnp.mean(total_losses) / (1000 if step > 0 else steps_per_update)
total_losses = jnp.zeros_like(total_losses)
# eval on a single device
eval_net = jax.tree_map(lambda x: x[0], net.eval())
eval_net = jax.tree_util.tree_map(lambda x: x[0], net.eval())
out = eval_net.inference(
prompt=tokenize(test_prompt),
length=(128 if step < num_steps else 1024),
Expand Down
Loading

0 comments on commit 821438b

Please sign in to comment.