diff --git a/docs/JAX_visualizing_models_metrics.ipynb b/docs/JAX_visualizing_models_metrics.ipynb
new file mode 100644
index 0000000..b28c603
--- /dev/null
+++ b/docs/JAX_visualizing_models_metrics.ipynb
@@ -0,0 +1,472 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# JAX and Tensorboard / NNX Display"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "To keep things straightforward and familiar, we reuse the model and data from '[Getting started with JAX for AI](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html)' - if you haven't read that yet and want the primer, start there before returning.\n",
+ "\n",
+ "All of the modeling and training code is the same here. What we have added are the tensorboard connections and the discussion around them."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import tensorflow as tf\n",
+ "import io\n",
+ "from datetime import datetime"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "id": "hKhPLnNxfOHU",
+ "outputId": "ac3508f0-ccc6-409b-c719-99a4b8f94bd6"
+ },
+ "outputs": [],
+ "source": [
+ "from sklearn.datasets import load_digits\n",
+ "digits = load_digits()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Here we set the location of the tensorflow writer - the organization is somewhat arbitrary, though keeping a folder for each training run can make later navigation more straightforward."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "file_path = \"runs/test/\" + datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
+ "test_summary_writer = tf.summary.create_file_writer(file_path)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Pulled from the official tensorboard examples, this convert function makes it simple to drop matplotlib figures directly into tensorboard"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def plot_to_image(figure):\n",
+ " \"\"\"Sourced from https://www.tensorflow.org/tensorboard/image_summaries\n",
+ " Converts the matplotlib plot specified by 'figure' to a PNG image and\n",
+ " returns it. The supplied figure is closed and inaccessible after this call.\"\"\"\n",
+ " # Save the plot to a PNG in memory.\n",
+ " buf = io.BytesIO()\n",
+ " plt.savefig(buf, format='png')\n",
+ " # Closing the figure prevents it from being displayed directly inside\n",
+ " # the notebook.\n",
+ " plt.close(figure)\n",
+ " buf.seek(0)\n",
+ " # Convert PNG buffer to TF image\n",
+ " image = tf.image.decode_png(buf.getvalue(), channels=4)\n",
+ " # Add the batch dimension\n",
+ " image = tf.expand_dims(image, 0)\n",
+ " return image"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Whereas previously the example displays the training data snapshot in the notebook, here we stash it in the tensorboard images. If a given training is to be repeated many, many times it can save space to stash the training data information as its own run and skip this step for each subsequent training, provided the input is static. Note that this pattern uses the writer in a `with` context manager. We are able to step into and out of this type of context through the run without losing the same file/folder experiment."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "id": "Y8cMntSdfyyT",
+ "outputId": "9343a558-cd8c-473c-c109-aa8015c7ae7e"
+ },
+ "outputs": [],
+ "source": [
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "fig, axes = plt.subplots(10, 10, figsize=(6, 6),\n",
+ " subplot_kw={'xticks':[], 'yticks':[]},\n",
+ " gridspec_kw=dict(hspace=0.1, wspace=0.1))\n",
+ "\n",
+ "for i, ax in enumerate(axes.flat):\n",
+ " ax.imshow(digits.images[i], cmap='binary', interpolation='gaussian')\n",
+ " ax.text(0.05, 0.05, str(digits.target[i]), transform=ax.transAxes, color='green')\n",
+ "with test_summary_writer.as_default():\n",
+ " tf.summary.image(\"Training Data\", plot_to_image(fig), step=0)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "After running all above and launching `tensorboard --logdir runs/test` from the same folder, you should see the following in the supplied URL:\n",
+ "\n",
+ "![image.png](./_static/training_data_example.png)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "id": "6jrYisoPh6TL"
+ },
+ "outputs": [],
+ "source": [
+ "from sklearn.model_selection import train_test_split\n",
+ "splits = train_test_split(digits.images, digits.target, random_state=0)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "id": "oMRcwKd4hqOo",
+ "outputId": "0ad36290-397b-431d-eba2-ef114daf5ea6"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "images_train.shape=(1347, 8, 8) label_train.shape=(1347,)\n",
+ "images_test.shape=(450, 8, 8) label_test.shape=(450,)\n"
+ ]
+ }
+ ],
+ "source": [
+ "import jax.numpy as jnp\n",
+ "images_train, images_test, label_train, label_test = map(jnp.asarray, splits)\n",
+ "print(f\"{images_train.shape=} {label_train.shape=}\")\n",
+ "print(f\"{images_test.shape=} {label_test.shape=}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {
+ "id": "U77VMQwRjTfH",
+ "outputId": "345fed7a-4455-4036-85ed-57e673a4de01"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "from flax import nnx\n",
+ "\n",
+ "class SimpleNN(nnx.Module):\n",
+ "\n",
+ " def __init__(self, n_features: int = 64, n_hidden: int = 100, n_targets: int = 10,\n",
+ " *, rngs: nnx.Rngs):\n",
+ " self.n_features = n_features\n",
+ " self.layer1 = nnx.Linear(n_features, n_hidden, rngs=rngs)\n",
+ " self.layer2 = nnx.Linear(n_hidden, n_hidden, rngs=rngs)\n",
+ " self.layer3 = nnx.Linear(n_hidden, n_targets, rngs=rngs)\n",
+ "\n",
+ " def __call__(self, x):\n",
+ " x = x.reshape(x.shape[0], self.n_features) # Flatten images.\n",
+ " x = nnx.selu(self.layer1(x))\n",
+ " x = nnx.selu(self.layer2(x))\n",
+ " x = self.layer3(x)\n",
+ " return x\n",
+ "\n",
+ "model = SimpleNN(rngs=nnx.Rngs(0))\n",
+ "\n",
+ "nnx.display(model) # Interactive display if penzai is installed."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We've now created the basic model - the above cell will render an interactive view of the model. Which, when fully expanded, should look something like this:\n",
+ "\n",
+ "![image.png](./_static/nnx_display_example.png)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In order to track loss across our training run, we've collected the loss function call inside the training step:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {
+ "id": "QwRvFPkYl5b2"
+ },
+ "outputs": [],
+ "source": [
+ "import jax\n",
+ "import optax\n",
+ "\n",
+ "optimizer = nnx.Optimizer(model, optax.sgd(learning_rate=0.05))\n",
+ "\n",
+ "def loss_fun(\n",
+ " model: nnx.Module,\n",
+ " data: jax.Array,\n",
+ " labels: jax.Array):\n",
+ " logits = model(data)\n",
+ " loss = optax.softmax_cross_entropy_with_integer_labels(\n",
+ " logits=logits, labels=labels\n",
+ " ).mean()\n",
+ " return loss, logits\n",
+ "\n",
+ "@nnx.jit # JIT-compile the function\n",
+ "def train_step(\n",
+ " model: nnx.Module,\n",
+ " optimizer: nnx.Optimizer,\n",
+ " data: jax.Array,\n",
+ " labels: jax.Array):\n",
+ " loss_gradient = nnx.grad(loss_fun, has_aux=True) # gradient transform!\n",
+ " grads, logits = loss_gradient(model, data, labels)\n",
+ " optimizer.update(grads) # inplace update\n",
+ "\n",
+ " # Calculate loss\n",
+ " loss, _ = loss_fun(model, images_test, label_test)\n",
+ " return loss"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now, we've collected the metrics that were previously computed once at the end of training and called them throughout the `for` loop, as you would in an eval stage.\n",
+ "\n",
+ "With the summary_writer context in place, we write out the `Loss` scalar every epoch, test the model accuracy every 10, and stash a accuracy test sheet every 500. Any custom metric can be added this way, through the tf.summary API."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {
+ "id": "l9mukT0eqmsr",
+ "outputId": "c6c7b2d6-8706-4bc3-d5a6-0396d7cfbf56"
+ },
+ "outputs": [],
+ "source": [
+ "max_epoch = 3000\n",
+ "with test_summary_writer.as_default():\n",
+ " for i in range(max_epoch):\n",
+ " loss = train_step(model, optimizer, images_train, label_train)\n",
+ " ## Store the training loss per epoch\n",
+ " tf.summary.scalar('Loss', loss.item(), step=i+1) #.item() because the loss coming out of train_step() is a tensor\n",
+ " if ((i+1)%10 == 0) or i == 0:\n",
+ " label_pred = model(images_test).argmax(axis=1)\n",
+ " num_matches = jnp.count_nonzero(label_pred == label_test)\n",
+ " num_total = len(label_test)\n",
+ " accuracy = num_matches / num_total\n",
+ " ## store the evaluated Accuracy every 10 epochs\n",
+ " tf.summary.scalar('Accuracy', accuracy.item(), step=i+1)\n",
+ " if ((i+1)%500 == 0) or i == 0:\n",
+ " fig, axes = None, None\n",
+ " fig, axes = plt.subplots(10, 10, figsize=(6, 6),\n",
+ " subplot_kw={'xticks':[], 'yticks':[]},\n",
+ " gridspec_kw=dict(hspace=0.1, wspace=0.1))\n",
+ "\n",
+ " label_pred = model(images_test).argmax(axis=1)\n",
+ "\n",
+ " for j, ax in enumerate(axes.flat):\n",
+ " ax.imshow(images_test[j], cmap='binary', interpolation='gaussian')\n",
+ " color = 'green' if label_pred[j] == label_test[j] else 'red'\n",
+ " ax.text(0.05, 0.05, str(label_pred[j]), transform=ax.transAxes, color=color)\n",
+ " ## store the Accuracy test sheet every 500 epochs - be sure to give each a different name, or they will overwrite the previous output.\n",
+ " tf.summary.image(f\"Step {i+1} Accuracy Testsheet\", plot_to_image(fig), step=i+1)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "During the training has run, and after, the added `Loss` and `Accuracy` scalars are available in the tensorboard UI under the run folder we've dynamically created by the datetime.\n",
+ "\n",
+ "The output there should look something like the following:\n",
+ "\n",
+ "![image.png](./_static/loss_acc_example.png)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Since we've stored the example test sheet every 500 epochs, it's easy to go back and step through the progress. With each training step using all of the training data the steps and epochs are essentially the same here.\n",
+ "\n",
+ "At step 1, we see poor accuracy, as you would expect\n",
+ "\n",
+ "![image.png](./_static/testsheet_start_example.png)\n",
+ "\n",
+ "By 500, the model is essentially done, but we see the bottom row `7` get lost and recovered at higher epochs as we go far into an overfitting regime. This kind of stored data can be very useful when the training routines become automated and a human is potentially only looking when something has gone wrong.\n",
+ "\n",
+ "![image.png](./_static/testsheets_500_3000.png)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Finally, it can be useful to use nnx.display's ability to visualize networks and model output. Here we feed the top 35 test images into the model and display the final output vector for each - in the top plot, each row is an individual image prediction result: each column corresponds to a class, in this case the digits (0-9). Since we're calling the highest value in a given row the class prediction (`.argmax(axis=1)`), the final image predictions (bottom plot) simply match the largest value in each row in the upper plot."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "(None, None)"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "nnx.display(model(images_test[:35])), nnx.display(model(images_test[:35]).argmax(axis=1))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The above cell output will give you an interactive plot that looks like this image below, where here we've 'clicked' in the bottom plot for entry `7` and hover over the corresponding value in the top plot.\n",
+ "\n",
+ "![image.png](./_static/model_display_example.png)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Extra Resources\n",
+ "\n",
+ "For further information about `TensorBoard` see [https://www.tensorflow.org/tensorboard/get_started](https://www.tensorflow.org/tensorboard/get_started)\n",
+ "\n",
+ "For more about `nnx.display()`, which calls Treescope under the hood, see [https://treescope.readthedocs.io/en/stable/](https://treescope.readthedocs.io/en/stable/)"
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "provenance": []
+ },
+ "jupytext": {
+ "formats": "ipynb,md:myst"
+ },
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.10"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/docs/JAX_visualizing_models_metrics.md b/docs/JAX_visualizing_models_metrics.md
new file mode 100644
index 0000000..2033d14
--- /dev/null
+++ b/docs/JAX_visualizing_models_metrics.md
@@ -0,0 +1,246 @@
+---
+jupytext:
+ formats: ipynb,md:myst
+ text_representation:
+ extension: .md
+ format_name: myst
+ format_version: 0.13
+ jupytext_version: 1.15.2
+kernelspec:
+ display_name: Python 3 (ipykernel)
+ language: python
+ name: python3
+---
+
+# JAX and Tensorboard / NNX Display
+
++++
+
+To keep things straightforward and familiar, we reuse the model and data from '[Getting started with JAX for AI](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html)' - if you haven't read that yet and want the primer, start there before returning.
+
+All of the modeling and training code is the same here. What we have added are the tensorboard connections and the discussion around them.
+
+```{code-cell} ipython3
+import tensorflow as tf
+import io
+from datetime import datetime
+```
+
+```{code-cell} ipython3
+:id: hKhPLnNxfOHU
+:outputId: ac3508f0-ccc6-409b-c719-99a4b8f94bd6
+
+from sklearn.datasets import load_digits
+digits = load_digits()
+```
+
+Here we set the location of the tensorflow writer - the organization is somewhat arbitrary, though keeping a folder for each training run can make later navigation more straightforward.
+
+```{code-cell} ipython3
+file_path = "runs/test/" + datetime.now().strftime("%Y%m%d-%H%M%S")
+test_summary_writer = tf.summary.create_file_writer(file_path)
+```
+
+Pulled from the official tensorboard examples, this convert function makes it simple to drop matplotlib figures directly into tensorboard
+
+```{code-cell} ipython3
+def plot_to_image(figure):
+ """Sourced from https://www.tensorflow.org/tensorboard/image_summaries
+ Converts the matplotlib plot specified by 'figure' to a PNG image and
+ returns it. The supplied figure is closed and inaccessible after this call."""
+ # Save the plot to a PNG in memory.
+ buf = io.BytesIO()
+ plt.savefig(buf, format='png')
+ # Closing the figure prevents it from being displayed directly inside
+ # the notebook.
+ plt.close(figure)
+ buf.seek(0)
+ # Convert PNG buffer to TF image
+ image = tf.image.decode_png(buf.getvalue(), channels=4)
+ # Add the batch dimension
+ image = tf.expand_dims(image, 0)
+ return image
+```
+
+Whereas previously the example displays the training data snapshot in the notebook, here we stash it in the tensorboard images. If a given training is to be repeated many, many times it can save space to stash the training data information as its own run and skip this step for each subsequent training, provided the input is static. Note that this pattern uses the writer in a `with` context manager. We are able to step into and out of this type of context through the run without losing the same file/folder experiment.
+
+```{code-cell} ipython3
+:id: Y8cMntSdfyyT
+:outputId: 9343a558-cd8c-473c-c109-aa8015c7ae7e
+
+import matplotlib.pyplot as plt
+
+fig, axes = plt.subplots(10, 10, figsize=(6, 6),
+ subplot_kw={'xticks':[], 'yticks':[]},
+ gridspec_kw=dict(hspace=0.1, wspace=0.1))
+
+for i, ax in enumerate(axes.flat):
+ ax.imshow(digits.images[i], cmap='binary', interpolation='gaussian')
+ ax.text(0.05, 0.05, str(digits.target[i]), transform=ax.transAxes, color='green')
+with test_summary_writer.as_default():
+ tf.summary.image("Training Data", plot_to_image(fig), step=0)
+```
+
+After running all above and launching `tensorboard --logdir runs/test` from the same folder, you should see the following in the supplied URL:
+
+![image.png](./_static/training_data_example.png)
+
+```{code-cell} ipython3
+:id: 6jrYisoPh6TL
+
+from sklearn.model_selection import train_test_split
+splits = train_test_split(digits.images, digits.target, random_state=0)
+```
+
+```{code-cell} ipython3
+:id: oMRcwKd4hqOo
+:outputId: 0ad36290-397b-431d-eba2-ef114daf5ea6
+
+import jax.numpy as jnp
+images_train, images_test, label_train, label_test = map(jnp.asarray, splits)
+print(f"{images_train.shape=} {label_train.shape=}")
+print(f"{images_test.shape=} {label_test.shape=}")
+```
+
+```{code-cell} ipython3
+:id: U77VMQwRjTfH
+:outputId: 345fed7a-4455-4036-85ed-57e673a4de01
+
+from flax import nnx
+
+class SimpleNN(nnx.Module):
+
+ def __init__(self, n_features: int = 64, n_hidden: int = 100, n_targets: int = 10,
+ *, rngs: nnx.Rngs):
+ self.n_features = n_features
+ self.layer1 = nnx.Linear(n_features, n_hidden, rngs=rngs)
+ self.layer2 = nnx.Linear(n_hidden, n_hidden, rngs=rngs)
+ self.layer3 = nnx.Linear(n_hidden, n_targets, rngs=rngs)
+
+ def __call__(self, x):
+ x = x.reshape(x.shape[0], self.n_features) # Flatten images.
+ x = nnx.selu(self.layer1(x))
+ x = nnx.selu(self.layer2(x))
+ x = self.layer3(x)
+ return x
+
+model = SimpleNN(rngs=nnx.Rngs(0))
+
+nnx.display(model) # Interactive display if penzai is installed.
+```
+
+We've now created the basic model - the above cell will render an interactive view of the model. Which, when fully expanded, should look something like this:
+
+![image.png](./_static/nnx_display_example.png)
+
++++
+
+In order to track loss across our training run, we've collected the loss function call inside the training step:
+
+```{code-cell} ipython3
+:id: QwRvFPkYl5b2
+
+import jax
+import optax
+
+optimizer = nnx.Optimizer(model, optax.sgd(learning_rate=0.05))
+
+def loss_fun(
+ model: nnx.Module,
+ data: jax.Array,
+ labels: jax.Array):
+ logits = model(data)
+ loss = optax.softmax_cross_entropy_with_integer_labels(
+ logits=logits, labels=labels
+ ).mean()
+ return loss, logits
+
+@nnx.jit # JIT-compile the function
+def train_step(
+ model: nnx.Module,
+ optimizer: nnx.Optimizer,
+ data: jax.Array,
+ labels: jax.Array):
+ loss_gradient = nnx.grad(loss_fun, has_aux=True) # gradient transform!
+ grads, logits = loss_gradient(model, data, labels)
+ optimizer.update(grads) # inplace update
+
+ # Calculate loss
+ loss, _ = loss_fun(model, images_test, label_test)
+ return loss
+```
+
+Now, we've collected the metrics that were previously computed once at the end of training and called them throughout the `for` loop, as you would in an eval stage.
+
+With the summary_writer context in place, we write out the `Loss` scalar every epoch, test the model accuracy every 10, and stash a accuracy test sheet every 500. Any custom metric can be added this way, through the tf.summary API.
+
+```{code-cell} ipython3
+:id: l9mukT0eqmsr
+:outputId: c6c7b2d6-8706-4bc3-d5a6-0396d7cfbf56
+
+max_epoch = 3000
+with test_summary_writer.as_default():
+ for i in range(max_epoch):
+ loss = train_step(model, optimizer, images_train, label_train)
+ ## Store the training loss per epoch
+ tf.summary.scalar('Loss', loss.item(), step=i+1) #.item() because the loss coming out of train_step() is a tensor
+ if ((i+1)%10 == 0) or i == 0:
+ label_pred = model(images_test).argmax(axis=1)
+ num_matches = jnp.count_nonzero(label_pred == label_test)
+ num_total = len(label_test)
+ accuracy = num_matches / num_total
+ ## store the evaluated Accuracy every 10 epochs
+ tf.summary.scalar('Accuracy', accuracy.item(), step=i+1)
+ if ((i+1)%500 == 0) or i == 0:
+ fig, axes = None, None
+ fig, axes = plt.subplots(10, 10, figsize=(6, 6),
+ subplot_kw={'xticks':[], 'yticks':[]},
+ gridspec_kw=dict(hspace=0.1, wspace=0.1))
+
+ label_pred = model(images_test).argmax(axis=1)
+
+ for j, ax in enumerate(axes.flat):
+ ax.imshow(images_test[j], cmap='binary', interpolation='gaussian')
+ color = 'green' if label_pred[j] == label_test[j] else 'red'
+ ax.text(0.05, 0.05, str(label_pred[j]), transform=ax.transAxes, color=color)
+ ## store the Accuracy test sheet every 500 epochs - be sure to give each a different name, or they will overwrite the previous output.
+ tf.summary.image(f"Step {i+1} Accuracy Testsheet", plot_to_image(fig), step=i+1)
+```
+
+During the training has run, and after, the added `Loss` and `Accuracy` scalars are available in the tensorboard UI under the run folder we've dynamically created by the datetime.
+
+The output there should look something like the following:
+
+![image.png](./_static/loss_acc_example.png)
+
++++
+
+Since we've stored the example test sheet every 500 epochs, it's easy to go back and step through the progress. With each training step using all of the training data the steps and epochs are essentially the same here.
+
+At step 1, we see poor accuracy, as you would expect
+
+![image.png](./_static/testsheet_start_example.png)
+
+By 500, the model is essentially done, but we see the bottom row `7` get lost and recovered at higher epochs as we go far into an overfitting regime. This kind of stored data can be very useful when the training routines become automated and a human is potentially only looking when something has gone wrong.
+
+![image.png](./_static/testsheets_500_3000.png)
+
++++
+
+Finally, it can be useful to use nnx.display's ability to visualize networks and model output. Here we feed the top 35 test images into the model and display the final output vector for each - in the top plot, each row is an individual image prediction result: each column corresponds to a class, in this case the digits (0-9). Since we're calling the highest value in a given row the class prediction (`.argmax(axis=1)`), the final image predictions (bottom plot) simply match the largest value in each row in the upper plot.
+
+```{code-cell} ipython3
+nnx.display(model(images_test[:35])), nnx.display(model(images_test[:35]).argmax(axis=1))
+```
+
+The above cell output will give you an interactive plot that looks like this image below, where here we've 'clicked' in the bottom plot for entry `7` and hover over the corresponding value in the top plot.
+
+![image.png](./_static/model_display_example.png)
+
++++
+
+## Extra Resources
+
+For further information about `TensorBoard` see [https://www.tensorflow.org/tensorboard/get_started](https://www.tensorflow.org/tensorboard/get_started)
+
+For more about `nnx.display()`, which calls Treescope under the hood, see [https://treescope.readthedocs.io/en/stable/](https://treescope.readthedocs.io/en/stable/)
diff --git a/docs/_static/loss_acc_example.png b/docs/_static/loss_acc_example.png
new file mode 100644
index 0000000..fd73060
Binary files /dev/null and b/docs/_static/loss_acc_example.png differ
diff --git a/docs/_static/model_display_example.png b/docs/_static/model_display_example.png
new file mode 100644
index 0000000..bf06ac3
Binary files /dev/null and b/docs/_static/model_display_example.png differ
diff --git a/docs/_static/nnx_display_example.png b/docs/_static/nnx_display_example.png
new file mode 100644
index 0000000..6841dfb
Binary files /dev/null and b/docs/_static/nnx_display_example.png differ
diff --git a/docs/_static/testsheet_start_example.png b/docs/_static/testsheet_start_example.png
new file mode 100644
index 0000000..f275370
Binary files /dev/null and b/docs/_static/testsheet_start_example.png differ
diff --git a/docs/_static/testsheets_500_3000.png b/docs/_static/testsheets_500_3000.png
new file mode 100644
index 0000000..dc0bb6e
Binary files /dev/null and b/docs/_static/testsheets_500_3000.png differ
diff --git a/docs/_static/training_data_example.png b/docs/_static/training_data_example.png
new file mode 100644
index 0000000..699f7cf
Binary files /dev/null and b/docs/_static/training_data_example.png differ
diff --git a/docs/conf.py b/docs/conf.py
index f6de076..ceec641 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -57,6 +57,7 @@
'JAX_examples_image_segmentation.md',
'JAX_Vision_transformer.md',
'JAX_machine_translation.md',
+ 'JAX_visualizing_models_metrics.md',
'JAX_image_captioning.md',
'JAX_time_series_classification.md',
'JAX_transformer_text_classification.md',
@@ -91,6 +92,7 @@
'JAX_examples_image_segmentation.ipynb',
'JAX_Vision_transformer.ipynb',
'JAX_machine_translation.ipynb',
+ 'JAX_visualizing_models_metrics.ipynb',
'JAX_image_captioning.ipynb',
'JAX_time_series_classification.ipynb',
'JAX_transformer_text_classification.ipynb',
diff --git a/docs/tutorials.md b/docs/tutorials.md
index 8335115..dab201f 100644
--- a/docs/tutorials.md
+++ b/docs/tutorials.md
@@ -17,6 +17,7 @@ JAX_basic_text_classification
JAX_examples_image_segmentation
JAX_Vision_transformer
JAX_machine_translation
+JAX_visualizing_models_metrics
JAX_image_captioning
JAX_time_series_classification
JAX_transformer_text_classification