Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: append attention kernels for fp8 kv-cache #420

Merged
merged 19 commits into from
Aug 6, 2024
Merged

Conversation

yzh119
Copy link
Collaborator

@yzh119 yzh119 commented Aug 5, 2024

This implementation do not rely on fp8 tensor cores, but uses fp16 tensor cores instead (so sm_80 architectures can also use it), the fp8 kv-cache will be dequantized on-the-fly.

sm_89 and sm_90 append attention kernels that uses native fp8 tensor cores will be available in later PRs.

CMakeLists.txt Outdated
@@ -91,6 +91,7 @@ set (IDTYPES "i32")
if(FLASHINFER_ENABLE_FP8)
list(APPEND DECODE_DTYPES "e4m3" "e5m2")
list(APPEND DECODE_FP8_DTYPES "e4m3" "e5m2")
list(APPEND PREFILL_FP8_DTYPES "e4m3")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also support e5m2?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e5m2 support would be great

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e5m2 added, however, this argument only affects C++ tests and has nothing to do with python wheels.

@yzh119
Copy link
Collaborator Author

yzh119 commented Aug 6, 2024

@Yard1 @comaniac @cassiewilliam
To keep binary size reasonable (<2GB), I only kept fp8 support for BatchPrefillWithPagedKVCacheWrapper, which should be general enough for append attention.

The functionality tests have passed. Feel free to try it and report any possible issues.

@yzh119 yzh119 merged commit 906c2f5 into main Aug 6, 2024
yzh119 added a commit that referenced this pull request Aug 6, 2024
The swizzling mode name in #420 is wrong, this PR aligns it with ptx
documentation:
32B -> 64B
64B -> 128B
yzh119 added a commit that referenced this pull request Aug 9, 2024
🤖 I have created a release *beep* *boop*
---
##
[0.1.4](v0.1.3...v0.1.4)
(2024-08-09)


### Features

* append attention kernels for fp8 kv-cache
([#420](#420))
([906c2f5](906c2f5))
* support min_p sampling
([#422](#422))
([d52f2da](d52f2da))
* deterministic sampling
([#417](#417))
([0dd801d](0dd801d))
* more sampling operator options
([#431](#431))
([68df9c4](68df9c4))
* support fused add rmsnorm
([#419](#419))
([b781513](b781513))
* support fused silu mul
([#427](#427))
([ea0ba9a](ea0ba9a))

### Bug Fixes

* fix dispatch fp16 type when enable fp8
([#430](#430))
([daa5566](daa5566))
* improve numerical stability of sampling kernels
([#429](#429))
([898d8ea](898d8ea))

### Other improvements
* break up `_kernels` into multiple modules
([#428](#428))
([8e482d9](8e482d9))

### Acknowledgement

We thank contributions and feedbacks from the community:
[@comaniac](https://github.com/comaniac),
[@esmeetu](https://github.com/esmeetu),
[@LiuXiaoxuanPKU](https://github.com/LiuXiaoxuanPKU),
[@peng1999](https://github.com/peng1999),
[@xslingcn](https://github.com/xslingcn),
[@Yard1](https://github.com/Yard1),
[@zhyncs](https://github.com/zhyncs).

---
This PR was generated with [Release
Please](https://github.com/googleapis/release-please). See
[documentation](https://github.com/googleapis/release-please#release-please).

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Zihao Ye <[email protected]>
@yzh119 yzh119 deleted the fp8-with-fp16-tc branch August 10, 2024 18:38
yzh119 added a commit that referenced this pull request Aug 11, 2024
hardware fp8->fp16 fast conversion instruction is not available for
sm_80 & sm_89, which makes #420 slow for these architectures.

this pr uses marlin's fast fp8->fp16x4 conversion algorithm (copied from
vllm project) to accelerate such cases.

Co-authored-by: Antoni Baum <[email protected]>
Co-authored-by: Cody Yu <[email protected]>
zhyncs pushed a commit that referenced this pull request Aug 14, 2024
hardware fp8->fp16 fast conversion instruction is not available for
sm_80 & sm_89, which makes #420 slow for these architectures.

this pr uses marlin's fast fp8->fp16x4 conversion algorithm (copied from
vllm project) to accelerate such cases.

Co-authored-by: Antoni Baum <[email protected]>
Co-authored-by: Cody Yu <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants