Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace FlashAttention with xformers #70

Merged
merged 12 commits into from
May 5, 2023
Merged

Replace FlashAttention with xformers #70

merged 12 commits into from
May 5, 2023

Conversation

WoosukKwon
Copy link
Collaborator

This PR replaces FlashAttention with xformers.

Pros:

  • Richer features & higher compatibility. xformers supports attention bias, FP32, head size 256, and old GPUs (such as V100) while FlashAttention does not.
  • xformers provides pre-compiled python wheels, while FlashAttention compiles the entire CUDA code during installation.
  • Future-proof, as the repository is maintained by many developers from Meta.

Cons:

  • xformers can be slower than FlashAttention for small inputs, because it incurs higher CPU overheads.
  • xformers internally creates a new tensor for the attention output. In our case, this leads to an extra copy overhead, because we concatenate the outputs of the two attention ops.

@WoosukKwon WoosukKwon requested a review from zhuohan123 May 4, 2023 10:31
pip install sentencepiece # Required for LlamaTokenizer.
pip install ninja # To parallelize the compilation of flash-attn.
pip install flash-attn # This may take up to 10 mins.
pip install ninja psutil numpy sentencepiece ray torch transformers xformers
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO (in the next PR): specify the exact dependencies in setup.py.

@zhisbug
Copy link
Collaborator

zhisbug commented May 4, 2023

is the memory footprint same with flashattention?

@zhisbug
Copy link
Collaborator

zhisbug commented May 5, 2023

I did a test myself and found the memory saving is almost the same.

@WoosukKwon
Copy link
Collaborator Author

It seems the memory usage is comparable to FlashAttention's. @zhuohan123 Please review the PR.

Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks!

@@ -213,7 +213,7 @@ def add_server_arguments(parser: argparse.ArgumentParser):
parser.add_argument('--use-np-cache', action='store_true',
help='save a numpy copy of model weights for faster loading')
parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights')
# NOTE(woosuk): FlashAttention does not support float32.
# TODO(woosuk): Support FP32 for debugging.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does xformers support FP32?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it does. It is our attention kernel that does not support FP32. More precisely, our attention kernel currently does not support some block sizes when FP32 is used. I will fix this in the future.

@WoosukKwon WoosukKwon mentioned this pull request May 5, 2023
@WoosukKwon WoosukKwon merged commit c9d5b6d into main May 5, 2023
@WoosukKwon WoosukKwon deleted the xformers branch May 5, 2023 09:01
hongxiayang pushed a commit to hongxiayang/vllm that referenced this pull request Feb 13, 2024
yukavio pushed a commit to yukavio/vllm that referenced this pull request Jul 3, 2024
SUMMARY:
for Apache 4(b) -- "You must cause any modified files to carry prominent
notices stating that You changed the files"
https://www.apache.org/licenses/LICENSE-2.0 

TEST PLAN:
GHA
dllehr-amd pushed a commit to dllehr-amd/vllm that referenced this pull request Jul 22, 2024
* Enabling some basic tests for ROCm 6.2

Use strict xfail for ROCm 6.2 test repairs

* Use lenient xfail instead

---------

Co-authored-by: Alexei V. Ivanov <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants