Wentian Zhang* Haozhe Liu1* Jinheng Xie2* Francesco Faccio3
Mike Zheng Shou2 Jürgen Schmidhuber1,3
1 AI Initiative, King Abdullah University of Science And Technology
2 Show Lab, National University of Singapore 3 The Swiss AI Lab, IDSIA
We find that cross-attention outputs converge to a fixed point during the initial denoising steps. Consequently, the entire inference process can be divided into two stages: an initial semantics-planning phase, during which the model relies on text to plan visual semantics, and a subsequent fidelity-improving phase, during which the model tries to generate images from previously planned semantics. Surprisingly, ignoring text conditions in the fidelity-improving stage not only reduces computation complexity, but also slightly decreases FID score. This yields a simple and training-free method called TGATE for efficient generation, which caches the cross-attention output once it converges and keeps it fixed during the remaining inference steps.
The images generated by the diffusion model with or without TGATE. Our method can accelerate the diffusion model without generation performance drops. It is training-free and can be widely complementary to the existing studies.
- Training-Free.
- Easily Integrate into Existing Frameworks.
- Only a few lines of code are required.
- Friendly support CNN-based U-Net, Transformer, and Consistency Model
- 10%-50% speed up for different diffusion models.
-
2024/04/04: Technical Report is available on arxiv.
-
2024/04/04: TGATE for DeepCache (SD-XL) is released.
-
2024/03/30: TGATE for SD-1.5/2.1/XL is released.
-
2024/03/29: TGATE for LCM (SD-XL), PixArt-Alpha is released.
-
2024/03/28: TGATE is open source.
The images generated by the diffusion model at different denoising steps. The first row feeds the text embedding to the cross-attention modules for all steps. The second row only uses the text embedding from the first step to the 10th step, and the third row inputs the text embedding from the 11th to the 25th step.
We summarize our observations as follows:
-
Cross-attention converges early during the inference process, which can be characterized by a semantics-planning and a fidelity-improving stages. The impact of cross-attention is not uniform in these two stages.
-
The semantics-planning embeds text through cross-attention to obtain visual semantics.
-
The fidelity-improving stage improves the generation quality without the requirement of cross-attention. In fact, a null text embedding in this stage can improve performance.
- Step 1: TGATE caches the attention outcomes from the semantics-planning stage.
if gate_step == cur_step:
hidden_uncond, hidden_pred_text = hidden_states.chunk(2)
cache = (hidden_uncond + hidden_pred_text ) / 2
- Step 2: TGATE reuses them throughout the fidelity-improving stage.
if cross_attn and (gate_step<cur_step):
hidden_states = cache
Model | MACs | Param | Latency | Zero-shot 10K-FID on MS-COCO |
---|---|---|---|---|
SD-1.5 | 16.938T | 859.520M | 7.032s | 23.927 |
SD-1.5 w/ TGATE | 9.875T | 815.557M | 4.313s | 20.789 |
SD-2.1 | 38.041T | 865.785M | 16.121s | 22.609 |
SD-2.1 w/ TGATE | 22.208T | 815.433 M | 9.878s | 19.940 |
SD-XL | 149.438T | 2.570B | 53.187s | 24.628 |
SD-XL w/ TGATE | 84.438T | 2.024B | 27.932s | 22.738 |
Pixart-Alpha | 107.031T | 611.350M | 61.502s | 38.669 |
Pixart-Alpha w/ TGATE | 65.318T | 462.585M | 37.867s | 35.825 |
DeepCache (SD-XL) | 57.888T | - | 19.931s | 23.755 |
DeepCache w/ TGATE | 43.868T | - | 14.666s | 23.999 |
LCM (SD-XL) | 11.955T | 2.570B | 3.805s | 25.044 |
LCM w/ TGATE | 11.171T | 2.024B | 3.533s | 25.028 |
LCM (Pixart-Alpha) | 8.563T | 611.350M | 4.733s | 36.086 |
LCM w/ TGATE | 7.623T | 462.585M | 4.543s | 37.048 |
The latency is tested on a 1080ti commercial card.
The MACs and Params is calculated by calflops.
The FID is calculated by PytorchFID.
- diffusers==0.27.0.dev0
- pytorch==2.2.0
- transformers
To use TGATE for accelerating the denoising process, you can simply use main.py
. For example,
python main.py \
--prompt 'your prompt (str)' \
--model ['pixart','sd_xl','sd_2.1','sd_1.5','lcm_sdxl','lcm_pixart'] \
--gate_step 10 \
--saved_path 'file/to/save' \
--inference_step 25
For LCMs, gate_step
is set as 1 or 2, and inference step
is set as 4.
After running, a generated image will be collected at 'file/to/save/test.png'.
For easy utilization, each sub folder can run independently.
SD-1.5 w/ TGATE: generate an image with the caption: "A coral reef bustling with diverse marine life."
cd SD_1_5
python generate.py
SD-2.1 w/ TGATE: generate an image with the caption: "High quality photo of an astronaut riding a horse in space"
cd SD_2_1
python generate.py
SD-XL w/ TGATE: generate an image with the caption: "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
cd SDXL
python generate.py
Pixart-Alpha w/ TGATE: generate an image with the caption: "An alpaca made of colorful building blocks, cyberpunk."
cd PixArt_alpha
python generate.py
LCM w/ TGATE: generate an image with the caption: "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"
cd LCM
python generate.py
DeepCache w/ TGATE: generate an image with the caption: "A haunted Victorian mansion under a full moon."
cd DeepCache
python generate.py
We encourage the users to read DeepCache and Adaptive Guidance
Methods | U-Net | Transformer | Consistency Model |
---|---|---|---|
DeepCache | ✓ | ✗ | - |
Adaptive Guidance | ✓ | ✓ | ✗ |
TGATE (Ours) | ✓ | ✓ | ✓ |
Compared with DeepCache:
- TGATE can cache one time and re-use the cached feature until ending sampling.
- TGATE is more friendly for Transformer-based Architecture and mobile devices since it drops the high-resolution cross-attention.
- TGATE is complementary to DeepCache.
Compared with Adaptive Guidance:
- TGATE can reduce the parameters in the second stage.
- TGATE can further improve the inference efficiency.
- TGATE is complementary to non-cfg framework, e.g. latent consistency model.
- TGATE is open source.
- We thank prompt to prompt and diffusers for the great code.
If you find our work inspiring or use our codebase in your research, please consider giving a star ⭐ and a citation.
@article{tgate,
title={Cross-Attention Makes Inference Cumbersome in Text-to-Image Diffusion Models},
author={Wentian Zhang and Haozhe Liu and Jinheng Xie and Francesco Faccio and Mike Zheng Shou and Jürgen Schmidhuber}
journal={arXiv preprint arXiv:2404.02747},
year={2024}
}