-
Notifications
You must be signed in to change notification settings - Fork 185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: append attention kernels for fp8 kv-cache #420
Conversation
CMakeLists.txt
Outdated
@@ -91,6 +91,7 @@ set (IDTYPES "i32") | |||
if(FLASHINFER_ENABLE_FP8) | |||
list(APPEND DECODE_DTYPES "e4m3" "e5m2") | |||
list(APPEND DECODE_FP8_DTYPES "e4m3" "e5m2") | |||
list(APPEND PREFILL_FP8_DTYPES "e4m3") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we also support e5m2?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
e5m2 support would be great
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
e5m2
added, however, this argument only affects C++ tests and has nothing to do with python wheels.
@Yard1 @comaniac @cassiewilliam The functionality tests have passed. Feel free to try it and report any possible issues. |
The swizzling mode name in #420 is wrong, this PR aligns it with ptx documentation: 32B -> 64B 64B -> 128B
🤖 I have created a release *beep* *boop* --- ## [0.1.4](v0.1.3...v0.1.4) (2024-08-09) ### Features * append attention kernels for fp8 kv-cache ([#420](#420)) ([906c2f5](906c2f5)) * support min_p sampling ([#422](#422)) ([d52f2da](d52f2da)) * deterministic sampling ([#417](#417)) ([0dd801d](0dd801d)) * more sampling operator options ([#431](#431)) ([68df9c4](68df9c4)) * support fused add rmsnorm ([#419](#419)) ([b781513](b781513)) * support fused silu mul ([#427](#427)) ([ea0ba9a](ea0ba9a)) ### Bug Fixes * fix dispatch fp16 type when enable fp8 ([#430](#430)) ([daa5566](daa5566)) * improve numerical stability of sampling kernels ([#429](#429)) ([898d8ea](898d8ea)) ### Other improvements * break up `_kernels` into multiple modules ([#428](#428)) ([8e482d9](8e482d9)) ### Acknowledgement We thank contributions and feedbacks from the community: [@comaniac](https://github.com/comaniac), [@esmeetu](https://github.com/esmeetu), [@LiuXiaoxuanPKU](https://github.com/LiuXiaoxuanPKU), [@peng1999](https://github.com/peng1999), [@xslingcn](https://github.com/xslingcn), [@Yard1](https://github.com/Yard1), [@zhyncs](https://github.com/zhyncs). --- This PR was generated with [Release Please](https://github.com/googleapis/release-please). See [documentation](https://github.com/googleapis/release-please#release-please). --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Zihao Ye <[email protected]>
hardware fp8->fp16 fast conversion instruction is not available for sm_80 & sm_89, which makes #420 slow for these architectures. this pr uses marlin's fast fp8->fp16x4 conversion algorithm (copied from vllm project) to accelerate such cases. Co-authored-by: Antoni Baum <[email protected]> Co-authored-by: Cody Yu <[email protected]>
hardware fp8->fp16 fast conversion instruction is not available for sm_80 & sm_89, which makes #420 slow for these architectures. this pr uses marlin's fast fp8->fp16x4 conversion algorithm (copied from vllm project) to accelerate such cases. Co-authored-by: Antoni Baum <[email protected]> Co-authored-by: Cody Yu <[email protected]>
This implementation do not rely on fp8 tensor cores, but uses fp16 tensor cores instead (so sm_80 architectures can also use it), the fp8 kv-cache will be dequantized on-the-fly.
sm_89 and sm_90 append attention kernels that uses native fp8 tensor cores will be available in later PRs.