Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flash attention support for inference #367

Merged
merged 3 commits into from
Oct 26, 2023
Merged

Conversation

GoGoJoestar
Copy link
Collaborator

@GoGoJoestar GoGoJoestar commented Oct 25, 2023

Description

Flash-attention released an updates that optimize for inference. This PR adds the adaptation for speeding up inference with flash-attention.
Specifically, this PR includes:

  1. Add flash attention patch for inference. Users can use flash attention in inference with inference_hf.py and gradio_demo.py by adding the --flash_attn parameter.
  2. Add padding_mask in flash attention and xformers patches. This parameter was added in LlamaAttention.forward function in transformers v4.34.

Related Issue

Add padding_mask: #326

@airaria airaria self-requested a review October 25, 2023 07:35
@airaria
Copy link
Contributor

airaria commented Oct 26, 2023

flash-attention and speculative sampling work correctly, but the inference speed is slow when enabling both CFG sampling and speculative sampling.
We should advise users not to use both CFG sampling and speculative sampling simultaneously.

@ymcui ymcui merged commit c20d308 into ymcui:main Oct 26, 2023
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants