Skip to content

Commit

Permalink
Small tweaks (#152)
Browse files Browse the repository at this point in the history
* Updates in Colab (small tweaks)

* More tweaks

* Oops, removing second Colab button.

* Following the instructions for once!
  • Loading branch information
rcrowe-google authored Jan 17, 2025
1 parent 6e8aeaf commit b33a446
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 105 deletions.
254 changes: 157 additions & 97 deletions docs/source/JAX_porting_PyTorch_model.ipynb

Large diffs are not rendered by default.

30 changes: 22 additions & 8 deletions docs/source/JAX_porting_PyTorch_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,14 @@ kernelspec:

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_porting_PyTorch_model.ipynb)

**Note: On Colab we recommend running this on a T4 GPU instance. On Kaggle we recommend a T4x2 or P100 instance.**

In this tutorial we will learn how to port a PyTorch model to JAX and [Flax](https://flax.readthedocs.io/en/latest/nnx_basics.html). Flax provides an API very similar to the PyTorch `torch.nn` module and porting PyTorch models is rather straightforward. To install Flax, we can simply execute the following command: `pip install -U flax treescope`.

```{code-cell} ipython3
!pip install -Uq flax treescope
```

Say we have a trained PyTorch computer-vision model to classify images that we would like to port to JAX. We will use [`TorchVision`](https://pytorch.org/vision/stable/index.html) to provide a [MaxVit](https://pytorch.org/vision/stable/models/maxvit.html) model trained on ImageNet (MaxViT: Multi-Axis Vision Transformer, https://arxiv.org/abs/2204.01697).

First, we set up the model using TorchVision and explore briefly the model's architecture and the blocks we need to port. Next, we define equivalent blocks and the whole model using Flax. After that, we port the weights. Finally, we run some tests to ensure the correctness of the ported model.
Expand All @@ -44,11 +50,10 @@ torch_model = maxvit_t(weights=MaxVit_T_Weights.IMAGENET1K_V1)
```

We can use `flax.nnx.display` to display the model's architecture:
```python
nnx.display(torch_model)
```

+++
```{code-cell} ipython3
# nnx.display(torch_model)
```

We can see that there are four MaxViT blocks in the model and each block contains:
- MaxViT layers: two layers for blocks 0, 1, 3 and five layers for the block 4
Expand Down Expand Up @@ -81,9 +86,18 @@ print(output.shape) # (2, 1000)

We can download an image of a Pembroke Corgy dog from [TorchVision's gallery](https://github.com/pytorch/vision/blob/main/gallery/assets/dog1.jpg?raw=true) together with [ImageNet classes dictionary](https://raw.githubusercontent.com/pytorch/vision/refs/heads/main/gallery/assets/imagenet_class_index.json):

```bash
wget "https://github.com/pytorch/vision/blob/main/gallery/assets/dog1.jpg?raw=true" -O dog1.jpg
wget "https://raw.githubusercontent.com/pytorch/vision/refs/heads/main/gallery/assets/imagenet_class_index.json" -O imagenet_class_index.json
```{code-cell} ipython3
%%bash
if [ -f "dog1.jpg" ]; then
echo "dog1.jpg already exists."
else
wget -nv "https://github.com/pytorch/vision/blob/main/gallery/assets/dog1.jpg?raw=true" -O dog1.jpg
fi
if [ -f "imagenet_class_index.json" ]; then
echo "imagenet_class_index.json already exists."
else
wget -nv "https://raw.githubusercontent.com/pytorch/vision/refs/heads/main/gallery/assets/imagenet_class_index.json" -O imagenet_class_index.json
fi
```

```{code-cell} ipython3
Expand Down Expand Up @@ -1615,5 +1629,5 @@ cosine_dist

## Further reading

- [Flax documentation: Core Exampels](https://flax.readthedocs.io/en/latest/examples/core_examples.html)
- [Flax documentation: Core Examples](https://flax.readthedocs.io/en/latest/examples/core_examples.html)
- [JAX AI Stack tutorials](https://jax-ai-stack.readthedocs.io/en/latest/tutorials.html)

0 comments on commit b33a446

Please sign in to comment.