Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small tweaks #152

Merged
merged 4 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
254 changes: 157 additions & 97 deletions docs/source/JAX_porting_PyTorch_model.ipynb

Large diffs are not rendered by default.

30 changes: 22 additions & 8 deletions docs/source/JAX_porting_PyTorch_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,14 @@ kernelspec:

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_porting_PyTorch_model.ipynb)

**Note: On Colab we recommend running this on a T4 GPU instance. On Kaggle we recommend a T4x2 or P100 instance.**

In this tutorial we will learn how to port a PyTorch model to JAX and [Flax](https://flax.readthedocs.io/en/latest/nnx_basics.html). Flax provides an API very similar to the PyTorch `torch.nn` module and porting PyTorch models is rather straightforward. To install Flax, we can simply execute the following command: `pip install -U flax treescope`.

```{code-cell} ipython3
!pip install -Uq flax treescope
```

Say we have a trained PyTorch computer-vision model to classify images that we would like to port to JAX. We will use [`TorchVision`](https://pytorch.org/vision/stable/index.html) to provide a [MaxVit](https://pytorch.org/vision/stable/models/maxvit.html) model trained on ImageNet (MaxViT: Multi-Axis Vision Transformer, https://arxiv.org/abs/2204.01697).

First, we set up the model using TorchVision and explore briefly the model's architecture and the blocks we need to port. Next, we define equivalent blocks and the whole model using Flax. After that, we port the weights. Finally, we run some tests to ensure the correctness of the ported model.
Expand All @@ -44,11 +50,10 @@ torch_model = maxvit_t(weights=MaxVit_T_Weights.IMAGENET1K_V1)
```

We can use `flax.nnx.display` to display the model's architecture:
```python
nnx.display(torch_model)
```

+++
```{code-cell} ipython3
# nnx.display(torch_model)
```

We can see that there are four MaxViT blocks in the model and each block contains:
- MaxViT layers: two layers for blocks 0, 1, 3 and five layers for the block 4
Expand Down Expand Up @@ -81,9 +86,18 @@ print(output.shape) # (2, 1000)

We can download an image of a Pembroke Corgy dog from [TorchVision's gallery](https://github.com/pytorch/vision/blob/main/gallery/assets/dog1.jpg?raw=true) together with [ImageNet classes dictionary](https://raw.githubusercontent.com/pytorch/vision/refs/heads/main/gallery/assets/imagenet_class_index.json):

```bash
wget "https://github.com/pytorch/vision/blob/main/gallery/assets/dog1.jpg?raw=true" -O dog1.jpg
wget "https://raw.githubusercontent.com/pytorch/vision/refs/heads/main/gallery/assets/imagenet_class_index.json" -O imagenet_class_index.json
```{code-cell} ipython3
%%bash
if [ -f "dog1.jpg" ]; then
echo "dog1.jpg already exists."
else
wget -nv "https://github.com/pytorch/vision/blob/main/gallery/assets/dog1.jpg?raw=true" -O dog1.jpg
fi
if [ -f "imagenet_class_index.json" ]; then
echo "imagenet_class_index.json already exists."
else
wget -nv "https://raw.githubusercontent.com/pytorch/vision/refs/heads/main/gallery/assets/imagenet_class_index.json" -O imagenet_class_index.json
fi
```

```{code-cell} ipython3
Expand Down Expand Up @@ -1615,5 +1629,5 @@ cosine_dist

## Further reading

- [Flax documentation: Core Exampels](https://flax.readthedocs.io/en/latest/examples/core_examples.html)
- [Flax documentation: Core Examples](https://flax.readthedocs.io/en/latest/examples/core_examples.html)
- [JAX AI Stack tutorials](https://jax-ai-stack.readthedocs.io/en/latest/tutorials.html)
Loading