Skip to content

Commit

Permalink
perf: Dense and sparse customizable flashattention-3 template (#667)
Browse files Browse the repository at this point in the history
This PR adds flashattention-3 template for improving prefill performance
on hopper. Block/Vector-sparse support in FlashInfer early version are
ported to FA-3 template with CustomStride abstraction in CuTE so that we
can support PageAttention with any page size. The programming interface
for FA3 template is exactly the same as our previous FA2 template while
we add an argument `backend` to allow user to select their own backend.

Functionalities that are missing in current template include custom mask
and we plan to support it using JIT instead of AOT.

H100 Reference performance on variable-length dense and sparse attention
kernels (exposed through
[BatchPrefillWithRaggedKVCacheWrapper](https://docs.flashinfer.ai/api/prefill.html#flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper)
and
[BatchDecodeWithPagedKVCacheWrapper](https://docs.flashinfer.ai/api/decode.html#flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper)
API correspondingly, for sparse attention workload, we use PageAttention
with `page_size=1`.

![image](https://github.com/user-attachments/assets/7e989f8c-8b0f-4c99-ad11-6102c2dc5090)

FlashInfer's vector sparse (page_size=1) attention implementation can
get 90% percent of the dense equivalent, reference benchmark:
https://github.com/flashinfer-ai/flashinfer/blob/04ee9bceb5ab0a66c612c1abaee8fa28de2b2349/benchmarks/bench_hopper_attention
.

JIT support is left to the next PR because this PR is already heavy. For
fp8 support, we will incorporate SageAttention-2 algorithm for numerical
stability, and it's left to v0.2.1.

Currently there is some discrepancy in attention variant interface for
our FA2 and FA3 template and we will gradually fix the gap.

cc @merrymercy @zhyncs @youkaichao @WoosukKwon @jason-huang03
  • Loading branch information
yzh119 authored Dec 16, 2024
1 parent d9d8eb1 commit 51236c9
Show file tree
Hide file tree
Showing 52 changed files with 5,730 additions and 132 deletions.
22 changes: 22 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,25 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

-------------------------------------------------------------------------------------------------
Some of the code in this project are adapted from other open-source projects with different
licenses. This product also bundles some third-party components under other open source licenses.
This section summarizes those components and their licenses.
See licenses/ for text of these licenses.

BSD 3-Clause License
--------------------

include/flashinfer/attention/hopper/epilogue.cuh
include/flashinfer/attention/hopper/mainloop.cuh
include/flashinfer/attention/hopper/kernel_traits.cuh
include/flashinfer/attention/hopper/named_barrier.cuh
include/flashinfer/attention/hopper/tile_scheduler.cuh
include/flashinfer/attention/hopper/utils.cuh

BSD 3-Clause "New" License
--------------------------

3rdparty/cutlass
include/flashinfer/attention/hopper/block_sparse_gather.cuh
96 changes: 96 additions & 0 deletions aot_build_utils/generate_batch_paged_prefill_sm90_inst.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""
Copyright (c) 2024 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import re
import sys
from pathlib import Path

from .literal_map import (
dtype_literal,
idtype_literal,
mask_mode_literal,
pos_encoding_mode_literal,
)


def get_cu_file_str(
head_dim,
pos_encoding_mode,
allow_fp16_qk_reduction,
mask_mode,
dtype_q,
dtype_kv,
dtype_out,
idtype,
):
def get_insts(attention_variant):
return "\n".join(
[
"""template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, {attention_variant}>(
Params& params,
cudaStream_t stream);
template cudaError_t BatchPrefillWithPagedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, {attention_variant}>(
Params& params,
cudaStream_t stream);
""".format(
head_dim=head_dim,
pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)],
allow_fp16_qk_reduction=allow_fp16_qk_reduction,
mask_mode=mask_mode_literal[int(mask_mode)],
attention_variant=attention_variant,
)
]
)

dtype_q = dtype_literal[dtype_q]
dtype_kv = dtype_literal[dtype_kv]
dtype_out = dtype_literal[dtype_out]
idtype = idtype_literal[idtype]

content = f"""#include <flashinfer/attention/hopper/prefill_sm90.cuh>
#include <flashinfer/attention/hopper/variants.cuh>
#include <flashinfer/cutlass_utils.cuh>
namespace flashinfer {{
using DTypeQ = cutlass_dtype_t<{dtype_q}>;
using DTypeKV = cutlass_dtype_t<{dtype_kv}>;
using DTypeO = cutlass_dtype_t<{dtype_out}>;
using Params = BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, {idtype}>;
{get_insts("LogitsSoftCap")}
{get_insts("StandardAttention")}
}}"""
return content


if __name__ == "__main__":
pattern = (
r"batch_paged_prefill_head_([0-9]+)_posenc_([0-9]+)_"
r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)_sm90\.cu"
)
compiled_pattern = re.compile(pattern)
path = Path(sys.argv[1])
fname = path.name
match = compiled_pattern.match(fname)

with open(path, "w") as f:
f.write(get_cu_file_str(*match.groups()))
97 changes: 97 additions & 0 deletions aot_build_utils/generate_batch_ragged_prefill_sm90_inst.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""
Copyright (c) 2024 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import re
import sys
from pathlib import Path

from .literal_map import (
dtype_literal,
idtype_literal,
mask_mode_literal,
pos_encoding_mode_literal,
)


def get_cu_file_str(
head_dim,
pos_encoding_mode,
allow_fp16_qk_reduction,
mask_mode,
dtype_q,
dtype_kv,
dtype_out,
idtype,
):

def get_insts(attention_variant):
return "\n".join(
[
"""template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, {attention_variant}>(
Params& params,
cudaStream_t stream);
template cudaError_t BatchPrefillWithRaggedKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, {attention_variant}>(
Params& params,
cudaStream_t stream);
""".format(
head_dim=head_dim,
pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)],
allow_fp16_qk_reduction=allow_fp16_qk_reduction,
mask_mode=mask_mode_literal[int(mask_mode)],
attention_variant=attention_variant,
)
]
)

dtype_q = dtype_literal[dtype_q]
dtype_kv = dtype_literal[dtype_kv]
dtype_out = dtype_literal[dtype_out]
idtype = idtype_literal[idtype]

content = f"""#include <flashinfer/attention/hopper/prefill_sm90.cuh>
#include <flashinfer/attention/hopper/variants.cuh>
#include <flashinfer/cutlass_utils.cuh>
namespace flashinfer {{
using DTypeQ = cutlass_dtype_t<{dtype_q}>;
using DTypeKV = cutlass_dtype_t<{dtype_kv}>;
using DTypeO = cutlass_dtype_t<{dtype_out}>;
using Params = BatchPrefillRaggedParams<DTypeQ, DTypeKV, DTypeO, {idtype}>;
{get_insts("LogitsSoftCap")}
{get_insts("StandardAttention")}
}}
"""
return content


if __name__ == "__main__":
pattern = (
r"batch_ragged_prefill_head_([0-9]+)_posenc_([0-9]+)_"
r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)_sm90\.cu"
)
compiled_pattern = re.compile(pattern)
path = Path(sys.argv[1])
fname = path.name
match = compiled_pattern.match(fname)
with open(path, "w") as f:
f.write(get_cu_file_str(*match.groups()))
85 changes: 85 additions & 0 deletions aot_build_utils/generate_single_prefill_sm90_inst.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
Copyright (c) 2024 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import re
import sys
from pathlib import Path

from .literal_map import dtype_literal, mask_mode_literal, pos_encoding_mode_literal


def get_cu_file_str(
head_dim,
pos_encoding_mode,
allow_fp16_qk_reduction,
mask_mode,
dtype_q,
dtype_kv,
dtype_out,
):
content = """#include <flashinfer/attention/hopper/prefill_sm90.cuh>
#include <flashinfer/attention/hopper/variants.cuh>
#include <flashinfer/cutlass_utils.cuh>
namespace flashinfer {{
using DTypeQ = cutlass_dtype_t<{dtype_q}>;
using DTypeKV = cutlass_dtype_t<{dtype_kv}>;
using DTypeO = cutlass_dtype_t<{dtype_out}>;
using Params = SinglePrefillParams<DTypeQ, DTypeKV, DTypeO>;
template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, LogitsSoftCap>(
Params& params,
cudaStream_t stream);
template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, LogitsSoftCap>(
Params& params,
cudaStream_t stream);
template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/true, StandardAttention>(
Params& params,
cudaStream_t stream);
template cudaError_t SinglePrefillWithKVCacheDispatched<{head_dim}, {mask_mode}, /*USE_SWA=*/false, StandardAttention>(
Params& params,
cudaStream_t stream);
}}
""".format(
head_dim=head_dim,
pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)],
allow_fp16_qk_reduction=allow_fp16_qk_reduction,
mask_mode=mask_mode_literal[int(mask_mode)],
dtype_q=dtype_literal[dtype_q],
dtype_kv=dtype_literal[dtype_kv],
dtype_out=dtype_literal[dtype_out],
use_custom_mask="true" if int(mask_mode) == 2 else "false",
)
return content


if __name__ == "__main__":
pattern = (
r"single_prefill_head_([0-9]+)_posenc_([0-9]+)_"
r"fp16qkred_([a-z]+)_mask_([0-9]+)_dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_sm90\.cu"
)

compiled_pattern = re.compile(pattern)
path = Path(sys.argv[1])
fname = path.name
match = compiled_pattern.match(fname)
with open(path, "w") as f:
f.write(get_cu_file_str(*match.groups()))
Loading

0 comments on commit 51236c9

Please sign in to comment.