This is a PyTorch/GPU implementation of the paper Autoregressive Image Generation without Vector Quantization (Neurips 2024 Spotlight Presentation):
@article{li2024autoregressive,
title={Autoregressive Image Generation without Vector Quantization},
author={Li, Tianhong and Tian, Yonglong and Li, He and Deng, Mingyang and He, Kaiming},
journal={arXiv preprint arXiv:2406.11838},
year={2024}
}
This repo contains:
- 🪐 A simple PyTorch implementation of MAR and DiffLoss
- ⚡️ Pre-trained class-conditional MAR models trained on ImageNet 256x256
- 💥 A self-contained Colab notebook for running various pre-trained MAR models
- 🛸 An MAR+DiffLoss training and evaluation script using PyTorch DDP
- 🎉 Also checkout our Hugging Face model cards and Gradio demo (thanks @jadechoghari).
Download ImageNet dataset, and place it in your IMAGENET_PATH
.
Download the code:
git clone https://github.com/LTH14/mar.git
cd mar
A suitable conda environment named mar
can be created and activated with:
conda env create -f environment.yaml
conda activate mar
Download pre-trained VAE and MAR models:
python util/download.py
For convenience, our pre-trained MAR models can be downloaded directly here as well:
MAR Model | FID-50K | Inception Score | #params |
---|---|---|---|
MAR-B | 2.31 | 281.7 | 208M |
MAR-L | 1.78 | 296.0 | 479M |
MAR-H | 1.55 | 303.7 | 943M |
Given that our data augmentation consists of simple center cropping and random flipping,
the VAE latents can be pre-computed and saved to CACHED_PATH
to save computations during MAR training:
torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 \
main_cache.py \
--img_size 256 --vae_path pretrained_models/vae/kl16.ckpt --vae_embed_dim 16 \
--batch_size 128 \
--data_path ${IMAGENET_PATH} --cached_path ${CACHED_PATH}
Run our interactive visualization demo using Colab notebook!
python demo/gradio_app.py
Script for the default setting (MAR-L, DiffLoss MLP with 3 blocks and a width of 1024 channels, 400 epochs):
torchrun --nproc_per_node=8 --nnodes=4 --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \
main_mar.py \
--img_size 256 --vae_path pretrained_models/vae/kl16.ckpt --vae_embed_dim 16 --vae_stride 16 --patch_size 1 \
--model mar_large --diffloss_d 3 --diffloss_w 1024 \
--epochs 400 --warmup_epochs 100 --batch_size 64 --blr 1.0e-4 --diffusion_batch_mul 4 \
--output_dir ${OUTPUT_DIR} --resume ${OUTPUT_DIR} \
--data_path ${IMAGENET_PATH}
- Training time is ~1d7h on 32 H100 GPUs with
--batch_size 64
. - Add
--online_eval
to evaluate FID during training (every 40 epochs). - (Optional) To train with cached VAE latents, add
--use_cached --cached_path ${CACHED_PATH}
to the arguments. Training time with cached latents is ~1d11h on 16 H100 GPUs with--batch_size 128
(nearly 2x faster than without caching). - (Optional) To save GPU memory during training by using gradient checkpointing (thanks to @Jiawei-Yang), add
--grad_checkpointing
to the arguments. Note that this may slightly reduce training speed.
Evaluate MAR-B (DiffLoss MLP with 6 blocks and a width of 1024 channels, 800 epochs) with classifier-free guidance:
torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 \
main_mar.py \
--model mar_base --diffloss_d 6 --diffloss_w 1024 \
--eval_bsz 256 --num_images 50000 \
--num_iter 256 --num_sampling_steps 100 --cfg 2.9 --cfg_schedule linear --temperature 1.0 \
--output_dir pretrained_models/mar/mar_base \
--resume pretrained_models/mar/mar_base \
--data_path ${IMAGENET_PATH} --evaluate
Evaluate MAR-L (DiffLoss MLP with 8 blocks and a width of 1280 channels, 800 epochs) with classifier-free guidance:
torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 \
main_mar.py \
--model mar_large --diffloss_d 8 --diffloss_w 1280 \
--eval_bsz 256 --num_images 50000 \
--num_iter 256 --num_sampling_steps 100 --cfg 3.0 --cfg_schedule linear --temperature 1.0 \
--output_dir pretrained_models/mar/mar_large \
--resume pretrained_models/mar/mar_large \
--data_path ${IMAGENET_PATH} --evaluate
Evaluate MAR-H (DiffLoss MLP with 12 blocks and a width of 1536 channels, 800 epochs) with classifier-free guidance:
torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 \
main_mar.py \
--model mar_huge --diffloss_d 12 --diffloss_w 1536 \
--eval_bsz 128 --num_images 50000 \
--num_iter 256 --num_sampling_steps 100 --cfg 3.2 --cfg_schedule linear --temperature 1.0 \
--output_dir pretrained_models/mar/mar_huge \
--resume pretrained_models/mar/mar_huge \
--data_path ${IMAGENET_PATH} --evaluate
- Set
--cfg 1.0 --temperature 0.95
to evaluate without classifier-free guidance. - Generation speed can be significantly increased by reducing the number of autoregressive iterations (e.g.,
--num_iter 64
).
We thank Congyue Deng and Xinlei Chen for helpful discussion. We thank Google TPU Research Cloud (TRC) for granting us access to TPUs, and Google Cloud Platform for supporting GPU resources.
A large portion of codes in this repo is based on MAE, MAGE and DiT.
If you have any questions, feel free to contact me through email ([email protected]). Enjoy!