Skip to content

Commit

Permalink
Updated ViT tutorial to fine-tune the model on Food101 dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 committed Dec 2, 2024
1 parent 0f09306 commit 68167c8
Show file tree
Hide file tree
Showing 2 changed files with 585 additions and 954 deletions.
1,339 changes: 411 additions & 928 deletions docs/JAX_Vision_transformer.ipynb

Large diffs are not rendered by default.

200 changes: 174 additions & 26 deletions docs/JAX_Vision_transformer.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ kernelspec:

# Vision Transformer with JAX & FLAX


In this tutorial we implement from scratch Vision Transformer (ViT) model based on the paper by Dosovitskiy et al: [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929). We will train this model on [Food 101](https://huggingface.co/datasets/ethz/food101) dataset.
In this tutorial we implement from scratch the Vision Transformer (ViT) model based on the paper by Dosovitskiy et al: [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929). We load the ImageNet pretrained weights and finetune this model on [Food 101](https://huggingface.co/datasets/ethz/food101) dataset.
This tutorial is originally inspired by [HuggingFace Image classification tutorial](https://huggingface.co/docs/transformers/tasks/image_classification).

+++
Expand All @@ -27,9 +26,10 @@ We will need to install the following Python packages:
- [TorchVision](https://pytorch.org/vision) will be used for image augmentations
- [grain](https://github.com/google/grain/) will be be used for efficient data loading
- [tqdm](https://tqdm.github.io/) for a progress bar to monitor the training progress.
- [Matplotlib](https://matplotlib.org/stable/) will be used for visualization purposes

```{code-cell} ipython3
# !pip install -U datasets grain torchvision tqdm
# !pip install -U datasets grain torchvision tqdm matplotlib
# !pip install -U flax optax
```

Expand Down Expand Up @@ -98,7 +98,7 @@ class VisionTransformer(nnx.Module):
TransformerEncoder(hidden_size, mlp_dim, num_heads, dropout_rate, rngs=rngs)
for i in range(num_layers)
])
self.lnorm = nnx.LayerNorm(hidden_size, rngs=rngs)
self.final_norm = nnx.LayerNorm(hidden_size, rngs=rngs)
# Classification head
self.classifier = nnx.Linear(hidden_size, num_classes, rngs=rngs)
Expand All @@ -116,7 +116,7 @@ class VisionTransformer(nnx.Module):
# Encoder blocks
x = self.encoder(embeddings)
x = self.lnorm(x)
x = self.final_norm(x)
# fetch the first token
x = x[:, 0]
Expand Down Expand Up @@ -162,9 +162,155 @@ class TransformerEncoder(nnx.Module):
return x
# We use a configuration to make smaller model to reduce the training time
x = jnp.ones((4, 120, 120, 3))
model = VisionTransformer(num_classes=10, num_layers=4, num_heads=4, img_size=120, patch_size=8)
x = jnp.ones((4, 224, 224, 3))
model = VisionTransformer(num_classes=1000)
y = model(x)
print("Predictions shape: ", y.shape)
```

Let's now load the weights pretrained on the ImageNet dataset using HuggingFace Transformers. We load all weights and check whether we have consistent results with the reference model.

```{code-cell} ipython3
from transformers import FlaxViTForImageClassification
tf_model = FlaxViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
```

```{code-cell} ipython3
def vit_inplace_copy_weights(*, src_model, dst_model):
assert isinstance(src_model, FlaxViTForImageClassification)
assert isinstance(dst_model, VisionTransformer)
tf_model_params = src_model.params
tf_model_params_fstate = nnx.traversals.flatten_mapping(tf_model_params)
flax_model_params = nnx.state(dst_model, nnx.Param)
flax_model_params_fstate = flax_model_params.flat_state()
params_name_mapping = {
("cls_token",): ("vit", "embeddings", "cls_token"),
("position_embeddings",): ("vit", "embeddings", "position_embeddings"),
**{
("patch_embeddings", x): ("vit", "embeddings", "patch_embeddings", "projection", x)
for x in ["kernel", "bias"]
},
**{
("encoder", "layers", i, "attn", y, x): (
"vit", "encoder", "layer", str(i), "attention", "attention", y, x
)
for x in ["kernel", "bias"]
for y in ["key", "value", "query"]
for i in range(12)
},
**{
("encoder", "layers", i, "attn", "out", x): (
"vit", "encoder", "layer", str(i), "attention", "output", "dense", x
)
for x in ["kernel", "bias"]
for i in range(12)
},
**{
("encoder", "layers", i, "mlp", "layers", y1, x): (
"vit", "encoder", "layer", str(i), y2, "dense", x
)
for x in ["kernel", "bias"]
for y1, y2 in [(0, "intermediate"), (3, "output")]
for i in range(12)
},
**{
("encoder", "layers", i, y1, x): (
"vit", "encoder", "layer", str(i), y2, x
)
for x in ["scale", "bias"]
for y1, y2 in [("norm1", "layernorm_before"), ("norm2", "layernorm_after")]
for i in range(12)
},
**{
("final_norm", x): ("vit", "layernorm", x)
for x in ["scale", "bias"]
},
**{
("classifier", x): ("classifier", x)
for x in ["kernel", "bias"]
}
}
nonvisited = set(flax_model_params_fstate.keys())
for key1, key2 in params_name_mapping.items():
assert key1 in flax_model_params_fstate, key1
assert key2 in tf_model_params_fstate, (key1, key2)
nonvisited.remove(key1)
src_value = tf_model_params_fstate[key2]
if key2[-1] == "kernel" and key2[-2] in ("key", "value", "query"):
shape = src_value.shape
src_value = src_value.reshape((shape[0], 12, 64))
if key2[-1] == "bias" and key2[-2] in ("key", "value", "query"):
src_value = src_value.reshape((12, 64))
if key2[-4:] == ("attention", "output", "dense", "kernel"):
shape = src_value.shape
src_value = src_value.reshape((12, 64, shape[-1]))
dst_value = flax_model_params_fstate[key1]
assert src_value.shape == dst_value.value.shape, (key2, src_value.shape, key1, dst_value.value.shape)
dst_value.value = src_value.copy()
assert dst_value.value.mean() == src_value.mean(), (dst_value.value, src_value.mean())
assert len(nonvisited) == 0, nonvisited
nnx.update(dst_model, nnx.State.from_flat_path(flax_model_params_fstate))
vit_inplace_copy_weights(src_model=tf_model, dst_model=model)
```

Let's check the pretrained weights of our model and compare with the reference model results:

```{code-cell} ipython3
import matplotlib.pyplot as plt
from transformers import ViTImageProcessor
from PIL import Image
import requests
url = "https://farm2.staticflickr.com/1152/1151216944_1525126615_z.jpg"
image = Image.open(requests.get(url, stream=True).raw)
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
inputs = processor(images=image, return_tensors="np")
outputs = tf_model(**inputs)
logits = outputs.logits
model.eval()
x = jnp.transpose(inputs["pixel_values"], axes=(0, 2, 3, 1))
output = model(x)
# model predicts one of the 1000 ImageNet classes
ref_class_idx = logits.argmax(-1).item()
pred_class_idx = output.argmax(-1).item()
assert jnp.abs(logits[0, :] - output[0, :]).max() < 0.1
fig, axs = plt.subplots(1, 2, figsize=(12, 8))
axs[0].set_title(
f"Reference model:\n{tf_model.config.id2label[ref_class_idx]}\nP={nnx.softmax(logits, axis=-1)[0, ref_class_idx]:.3f}"
)
axs[0].imshow(image)
axs[1].set_title(
f"Our model:\n{tf_model.config.id2label[pred_class_idx]}\nP={nnx.softmax(output, axis=-1)[0, pred_class_idx]:.3f}"
)
axs[1].imshow(image)
```

Now let's replace the classifier with a smaller fully-connected layer returning 20 classes instead of 1000:

```{code-cell} ipython3
model.classifier = nnx.Linear(model.classifier.in_features, 20, rngs=nnx.Rngs(0))
x = jnp.ones((4, 224, 224, 3))
y = model(x)
print("Predictions shape: ", y.shape)
```
Expand All @@ -177,19 +323,19 @@ In the following sections we set up a image classification dataset and train thi

In the this tutorial we use [Food 101](https://huggingface.co/datasets/ethz/food101) dataset which consists of 101 food categories, with 101,000 images. For each class, 250 manually reviewed test images are provided as well as 750 training images. On purpose, the training images were not cleaned, and thus still contain some amount of noise. This comes mostly in the form of intense colors and sometimes wrong labels. All images were rescaled to have a maximum side length of 512 pixels.

We will download the data using [HuggingFace Datasets](https://huggingface.co/docs/datasets/) and select 10 classes to reduce the dataset size and the model training time. We will be using [TorchVision](https://pytorch.org/vision) to transform input images and [`grain`](https://github.com/google/grain/) for efficient data loading.
We will download the data using [HuggingFace Datasets](https://huggingface.co/docs/datasets/) and select 20 classes to reduce the dataset size and the model training time. We will be using [TorchVision](https://pytorch.org/vision) to transform input images and [`grain`](https://github.com/google/grain/) for efficient data loading.

```{code-cell} ipython3
from datasets import load_dataset
# Select first 10 classes to reduce the dataset size and the training time.
train_size = 10 * 750
val_size = 10 * 250
# Select first 20 classes to reduce the dataset size and the training time.
train_size = 20 * 750
val_size = 20 * 250
train_dataset = load_dataset("food101", split=f"train[:{train_size}]")
val_dataset = load_dataset("food101", split=f"validation[:{val_size}]")
# Let's create labels mapping where we map current labels between 0 and 9
# Let's create labels mapping where we map current labels between 0 and 19
labels_mapping = {}
index = 0
for i in range(0, len(val_dataset), 250):
Expand All @@ -198,6 +344,7 @@ for i in range(0, len(val_dataset), 250):
labels_mapping[label] = index
index += 1
inv_labels_mapping = {v: k for k, v in labels_mapping.items()}
print("Training dataset size:", len(train_dataset))
print("Validation dataset size:", len(val_dataset))
Expand Down Expand Up @@ -248,18 +395,19 @@ import numpy as np
from torchvision.transforms import v2 as T
img_size = 120
img_size = 224
def to_np_array(pil_image):
return np.asarray(pil_image.convert("RGB"))
def normalize(image):
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
image = image.astype(np.float32) / 255.0
return (image - mean) / std
# Image preprocessing matches the one of pretrained ViT
mean = np.array([0.5, 0.5, 0.5], dtype=np.float32)
std = np.array([0.5, 0.5, 0.5], dtype=np.float32)
image = image.astype(np.float32) / 255.0
return (image - mean) / std
tv_train_transforms = T.Compose([
Expand All @@ -283,7 +431,7 @@ def get_transform(fn):
batch["image"] = [
fn(pil_image) for pil_image in batch["image"]
]
# map label index between 0 - 9
# map label index between 0 - 19
batch["label"] = [
labels_mapping[label] for label in batch["label"]
]
Expand All @@ -303,7 +451,7 @@ import grain.python as grain
seed = 12
train_batch_size = 64
train_batch_size = 32
val_batch_size = 2 * train_batch_size
Expand Down Expand Up @@ -363,15 +511,15 @@ print("Validation batch info:", val_batch["image"].shape, val_batch["image"].dty
display_datapoints(
*[(train_batch["image"][i], train_batch["label"][i]) for i in range(5)],
tag="(Training) ",
names_map=train_dataset.features["label"].names
names_map={k: train_dataset.features["label"].names[v] for k, v in inv_labels_mapping.items()}
)
```

```{code-cell} ipython3
display_datapoints(
*[(val_batch["image"][i], val_batch["label"][i]) for i in range(5)],
tag="(Validation) ",
names_map=val_dataset.features["label"].names
names_map={k: val_dataset.features["label"].names[v] for k, v in inv_labels_mapping.items()}
)
```

Expand All @@ -382,8 +530,8 @@ We defined training and validation datasets and the model. In this section we wi
```{code-cell} ipython3
import optax
num_epochs = 50
learning_rate = 0.005
num_epochs = 3
learning_rate = 0.001
momentum = 0.8
total_steps = len(train_dataset) // train_batch_size
Expand Down Expand Up @@ -544,7 +692,6 @@ preds = model(test_images)
```{code-cell} ipython3
num_samples = len(test_indices)
names_map = train_dataset.features["label"].names
inv_labels_mapping = {v: k for k, v in labels_mapping.items()}
probas = nnx.softmax(preds, axis=1)
pred_labels = probas.argmax(axis=1)
Expand All @@ -567,10 +714,11 @@ for i in range(num_samples):

## Further reading

In this tutorial we implemented from scratch Vision Transformer model and trained it on a subset of Food 101 dataset. Trained model shows 67% classification accuracy. Next steps could be to finetune hyperparameters like the learning rate and train for more epochs.
In this tutorial we implemented from scratch the Vision Transformer model and finetuned it on a subset of Food 101 dataset. The trained model shows almost perfect classification accuracy: 95%.

- Model checkpointing and exporting using [Orbax](https://orbax.readthedocs.io/en/latest/).
- Optimizers and the learning rate scheduling using [Optax](https://optax.readthedocs.io/en/latest/).
- Freezing model's parameters using trainable parameters filtering: [example 1](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/training/optimizer.html#flax.nnx.optimizer.Optimizer.update) and [example 2](https://github.com/google/flax/issues/4167#issuecomment-2324245208).
- Other Computer Vision tutorials in [jax-ai-stack](https://jax-ai-stack.readthedocs.io/en/latest/tutorials.html).

```{code-cell} ipython3
Expand Down

0 comments on commit 68167c8

Please sign in to comment.