-
Notifications
You must be signed in to change notification settings - Fork 198
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
perf: Dense and sparse customizable flashattention-3 template (#667)
This PR adds flashattention-3 template for improving prefill performance on hopper. Block/Vector-sparse support in FlashInfer early version are ported to FA-3 template with CustomStride abstraction in CuTE so that we can support PageAttention with any page size. The programming interface for FA3 template is exactly the same as our previous FA2 template while we add an argument `backend` to allow user to select their own backend. Functionalities that are missing in current template include custom mask and we plan to support it using JIT instead of AOT. H100 Reference performance on variable-length dense and sparse attention kernels (exposed through [BatchPrefillWithRaggedKVCacheWrapper](https://docs.flashinfer.ai/api/prefill.html#flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper) and [BatchDecodeWithPagedKVCacheWrapper](https://docs.flashinfer.ai/api/decode.html#flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper) API correspondingly, for sparse attention workload, we use PageAttention with `page_size=1`. ![image](https://github.com/user-attachments/assets/7e989f8c-8b0f-4c99-ad11-6102c2dc5090) FlashInfer's vector sparse (page_size=1) attention implementation can get 90% percent of the dense equivalent, reference benchmark: https://github.com/flashinfer-ai/flashinfer/blob/04ee9bceb5ab0a66c612c1abaee8fa28de2b2349/benchmarks/bench_hopper_attention . JIT support is left to the next PR because this PR is already heavy. For fp8 support, we will incorporate SageAttention-2 algorithm for numerical stability, and it's left to v0.2.1. Currently there is some discrepancy in attention variant interface for our FA2 and FA3 template and we will gradually fix the gap. cc @merrymercy @zhyncs @youkaichao @WoosukKwon @jason-huang03
- Loading branch information
Showing
52 changed files
with
5,730 additions
and
132 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
""" | ||
Copyright (c) 2024 by FlashInfer team. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
import re | ||
import sys | ||
from pathlib import Path | ||
|
||
from .literal_map import ( | ||
dtype_literal, | ||
idtype_literal, | ||
mask_mode_literal, | ||
pos_encoding_mode_literal, | ||
) | ||
|
||
|
||
def get_cu_file_str( | ||
head_dim, | ||
pos_encoding_mode, | ||
allow_fp16_qk_reduction, | ||
mask_mode, | ||
dtype_q, | ||
dtype_kv, | ||
dtype_out, | ||
idtype, | ||
): | ||
def get_insts(attention_variant): | ||
return "\n".join( | ||
[ | ||
"""template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, {attention_variant}>( | ||
Params& params, | ||
cudaStream_t stream); | ||
template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, {attention_variant}>( | ||
Params& params, | ||
cudaStream_t stream); | ||
""".format( | ||
head_dim=head_dim, | ||
pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], | ||
allow_fp16_qk_reduction=allow_fp16_qk_reduction, | ||
mask_mode=mask_mode_literal[int(mask_mode)], | ||
attention_variant=attention_variant, | ||
) | ||
] | ||
) | ||
|
||
dtype_q = dtype_literal[dtype_q] | ||
dtype_kv = dtype_literal[dtype_kv] | ||
dtype_out = dtype_literal[dtype_out] | ||
idtype = idtype_literal[idtype] | ||
|
||
content = f"""#include <flashinfer/attention/hopper/prefill_sm90.cuh> | ||
#include <flashinfer/attention/hopper/variants.cuh> | ||
#include <flashinfer/cutlass_utils.cuh> | ||
namespace flashinfer {{ | ||
using DTypeQ = cutlass_dtype_t<{dtype_q}>; | ||
using DTypeKV = cutlass_dtype_t<{dtype_kv}>; | ||
using DTypeO = cutlass_dtype_t<{dtype_out}>; | ||
using Params = BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, {idtype}>; | ||
{get_insts("LogitsSoftCap")} | ||
{get_insts("StandardAttention")} | ||
}}""" | ||
return content | ||
|
||
|
||
if __name__ == "__main__": | ||
pattern = ( | ||
r"batch_paged_prefill_head_([0-9]+)_posenc_([0-9]+)_" | ||
r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)_sm90\.cu" | ||
) | ||
compiled_pattern = re.compile(pattern) | ||
path = Path(sys.argv[1]) | ||
fname = path.name | ||
match = compiled_pattern.match(fname) | ||
|
||
with open(path, "w") as f: | ||
f.write(get_cu_file_str(*match.groups())) |
97 changes: 97 additions & 0 deletions
97
aot_build_utils/generate_batch_ragged_prefill_sm90_inst.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
""" | ||
Copyright (c) 2024 by FlashInfer team. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
import re | ||
import sys | ||
from pathlib import Path | ||
|
||
from .literal_map import ( | ||
dtype_literal, | ||
idtype_literal, | ||
mask_mode_literal, | ||
pos_encoding_mode_literal, | ||
) | ||
|
||
|
||
def get_cu_file_str( | ||
head_dim, | ||
pos_encoding_mode, | ||
allow_fp16_qk_reduction, | ||
mask_mode, | ||
dtype_q, | ||
dtype_kv, | ||
dtype_out, | ||
idtype, | ||
): | ||
|
||
def get_insts(attention_variant): | ||
return "\n".join( | ||
[ | ||
"""template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, {attention_variant}>( | ||
Params& params, | ||
cudaStream_t stream); | ||
template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, {attention_variant}>( | ||
Params& params, | ||
cudaStream_t stream); | ||
""".format( | ||
head_dim=head_dim, | ||
pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], | ||
allow_fp16_qk_reduction=allow_fp16_qk_reduction, | ||
mask_mode=mask_mode_literal[int(mask_mode)], | ||
attention_variant=attention_variant, | ||
) | ||
] | ||
) | ||
|
||
dtype_q = dtype_literal[dtype_q] | ||
dtype_kv = dtype_literal[dtype_kv] | ||
dtype_out = dtype_literal[dtype_out] | ||
idtype = idtype_literal[idtype] | ||
|
||
content = f"""#include <flashinfer/attention/hopper/prefill_sm90.cuh> | ||
#include <flashinfer/attention/hopper/variants.cuh> | ||
#include <flashinfer/cutlass_utils.cuh> | ||
namespace flashinfer {{ | ||
using DTypeQ = cutlass_dtype_t<{dtype_q}>; | ||
using DTypeKV = cutlass_dtype_t<{dtype_kv}>; | ||
using DTypeO = cutlass_dtype_t<{dtype_out}>; | ||
using Params = BatchPrefillRaggedParams<DTypeQ, DTypeKV, DTypeO, {idtype}>; | ||
{get_insts("LogitsSoftCap")} | ||
{get_insts("StandardAttention")} | ||
}} | ||
""" | ||
return content | ||
|
||
|
||
if __name__ == "__main__": | ||
pattern = ( | ||
r"batch_ragged_prefill_head_([0-9]+)_posenc_([0-9]+)_" | ||
r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)_sm90\.cu" | ||
) | ||
compiled_pattern = re.compile(pattern) | ||
path = Path(sys.argv[1]) | ||
fname = path.name | ||
match = compiled_pattern.match(fname) | ||
with open(path, "w") as f: | ||
f.write(get_cu_file_str(*match.groups())) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
""" | ||
Copyright (c) 2024 by FlashInfer team. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
|
||
import re | ||
import sys | ||
from pathlib import Path | ||
|
||
from .literal_map import dtype_literal, mask_mode_literal, pos_encoding_mode_literal | ||
|
||
|
||
def get_cu_file_str( | ||
head_dim, | ||
pos_encoding_mode, | ||
allow_fp16_qk_reduction, | ||
mask_mode, | ||
dtype_q, | ||
dtype_kv, | ||
dtype_out, | ||
): | ||
content = """#include <flashinfer/attention/hopper/prefill_sm90.cuh> | ||
#include <flashinfer/attention/hopper/variants.cuh> | ||
#include <flashinfer/cutlass_utils.cuh> | ||
namespace flashinfer {{ | ||
using DTypeQ = cutlass_dtype_t<{dtype_q}>; | ||
using DTypeKV = cutlass_dtype_t<{dtype_kv}>; | ||
using DTypeO = cutlass_dtype_t<{dtype_out}>; | ||
using Params = SinglePrefillParams<DTypeQ, DTypeKV, DTypeO>; | ||
template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, LogitsSoftCap>( | ||
Params& params, | ||
cudaStream_t stream); | ||
template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, LogitsSoftCap>( | ||
Params& params, | ||
cudaStream_t stream); | ||
template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, StandardAttention>( | ||
Params& params, | ||
cudaStream_t stream); | ||
template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, StandardAttention>( | ||
Params& params, | ||
cudaStream_t stream); | ||
}} | ||
""".format( | ||
head_dim=head_dim, | ||
pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], | ||
allow_fp16_qk_reduction=allow_fp16_qk_reduction, | ||
mask_mode=mask_mode_literal[int(mask_mode)], | ||
dtype_q=dtype_literal[dtype_q], | ||
dtype_kv=dtype_literal[dtype_kv], | ||
dtype_out=dtype_literal[dtype_out], | ||
use_custom_mask="true" if int(mask_mode) == 2 else "false", | ||
) | ||
return content | ||
|
||
|
||
if __name__ == "__main__": | ||
pattern = ( | ||
r"single_prefill_head_([0-9]+)_posenc_([0-9]+)_" | ||
r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_sm90\.cu" | ||
) | ||
|
||
compiled_pattern = re.compile(pattern) | ||
path = Path(sys.argv[1]) | ||
fname = path.name | ||
match = compiled_pattern.match(fname) | ||
with open(path, "w") as f: | ||
f.write(get_cu_file_str(*match.groups())) |
Oops, something went wrong.