-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fa cmake #29
base: main
Are you sure you want to change the base?
Fa cmake #29
Changes from 34 commits
b5b20b8
199b9d6
a582b3a
78080dd
0d6766e
17b89c9
1355060
d536119
66fc8a7
41ebd07
ad614e0
c8d003a
559a479
bd670ae
e4b5006
4f7f1f0
8c12f72
7b257e8
4fd33ea
f03a1df
d810108
48eb647
7bb6f31
58563ba
e856a05
256a3c6
af386bf
3aca223
06edc27
45fcc53
940a8ae
6b6c7a8
18ae756
d926c09
a2714eb
a61e35b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,19 @@ | ||
cmake_minimum_required(VERSION 3.9 FATAL_ERROR) | ||
project(flash-attention LANGUAGES CXX CUDA) | ||
set(CMAKE_CXX_STANDARD 17) | ||
set(CMAKE_CXX_STANDARD_REQUIRED ON) | ||
|
||
find_package(Git QUIET REQUIRED) | ||
|
||
execute_process(COMMAND ${GIT_EXECUTABLE} submodule update --init --recursive | ||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} | ||
RESULT_VARIABLE GIT_SUBMOD_RESULT) | ||
|
||
#cmake -DWITH_ADVANCED=ON | ||
if (WITH_ADVANCED) | ||
add_compile_definitions(PADDLE_WITH_ADVANCED) | ||
endif() | ||
|
||
add_definitions("-DFLASH_ATTN_WITH_TORCH=0") | ||
|
||
set(CUTLASS_3_DIR ${CMAKE_CURRENT_SOURCE_DIR}/cutlass) | ||
|
@@ -83,18 +90,14 @@ target_link_libraries(flashattn flashattn_with_bias_mask) | |
|
||
add_dependencies(flashattn flashattn_with_bias_mask) | ||
|
||
set(NVCC_ARCH_BIN 80 CACHE STRING "CUDA architectures") | ||
|
||
if (NOT DEFINED NVCC_ARCH_BIN) | ||
message(FATAL_ERROR "NVCC_ARCH_BIN is not defined.") | ||
endif() | ||
|
||
if (NVCC_ARCH_BIN STREQUAL "") | ||
message(FATAL_ERROR "NVCC_ARCH_BIN is not set.") | ||
endif() | ||
message("NVCC_ARCH_BIN is set to: ${NVCC_ARCH_BIN}") | ||
|
||
STRING(REPLACE "-" ";" FA_NVCC_ARCH_BIN ${NVCC_ARCH_BIN}) | ||
|
||
set(FA_GENCODE_OPTION "SHELL:") | ||
|
||
foreach(arch ${FA_NVCC_ARCH_BIN}) | ||
if(${arch} GREATER_EQUAL 80) | ||
set(FA_GENCODE_OPTION "${FA_GENCODE_OPTION} -gencode arch=compute_${arch},code=sm_${arch}") | ||
|
@@ -131,7 +134,28 @@ target_compile_options(flashattn_with_bias_mask PRIVATE $<$<COMPILE_LANGUAGE:CUD | |
"${FA_GENCODE_OPTION}" | ||
>) | ||
|
||
|
||
INSTALL(TARGETS flashattn | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 最终生成的动态库名称,在关闭、开启advance功能时,最好有所区分,这样Paddle框架中在加载动态库时容易区分些。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 关闭是就是libflashattn.so 开启的时候是libflashattn_advanced.so |
||
LIBRARY DESTINATION "lib") | ||
|
||
INSTALL(FILES capi/flash_attn.h DESTINATION "include") | ||
|
||
if (WITH_ADVANCED) | ||
set_target_properties(flashattn PROPERTIES | ||
OUTPUT_NAME libflashattn_advanced | ||
PREFIX "" | ||
) | ||
endif() | ||
|
||
if (WITH_ADVANCED) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. L148、150可以去掉 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
add_custom_target(build_whl | ||
COMMAND ${CMAKE_COMMAND} -E env python ${CMAKE_SOURCE_DIR}/setup.py bdist_wheel | ||
WORKING_DIRECTORY ${CMAKE_BINARY_DIR} | ||
DEPENDS flashattn | ||
COMMENT "Running build wheel" | ||
) | ||
|
||
add_custom_target(default_target DEPENDS build_whl) | ||
|
||
set_property(DIRECTORY PROPERTY DEFAULT_TARGET default_target) | ||
endif() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -64,6 +64,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, | |
const bool is_attn_mask = params.attn_mask_ptr != nullptr; | ||
const bool is_deterministic = params.num_splits == 1; | ||
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv); | ||
#ifdef PADDLE_WITH_ADVANCED | ||
BOOL_SWITCH(params.is_causal, IsCausalConst, [&] { | ||
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { | ||
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { | ||
|
@@ -82,6 +83,21 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, | |
}); | ||
}); | ||
}); | ||
#else | ||
BOOL_SWITCH(params.is_causal, IsCausalConst, [&] { | ||
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { | ||
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { | ||
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, IsEvenKConst, false,false>; | ||
if (smem_size_dq_dk_dv >= 48 * 1024) { | ||
C10_CUDA_CHECK(cudaFuncSetAttribute( | ||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); | ||
} | ||
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params); | ||
C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||
}); | ||
}); | ||
}); | ||
#endif | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里的代码是否可以简化下?以免出现2个分支?比如对
|
||
|
||
auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>; | ||
if (Kernel_traits::kSmemdQSize >= 48 * 1024) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,6 +36,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { | |
const bool return_softmax = params.p_ptr != nullptr; | ||
const bool is_attn_mask = params.attn_mask_ptr != nullptr; | ||
const bool is_equal_qk = (params.seqlen_q == params.seqlen_k) && (Is_causal) && (!is_attn_mask); | ||
#ifdef PADDLE_WITH_ADVANCED | ||
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { | ||
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { | ||
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { | ||
|
@@ -59,6 +60,29 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { | |
}); | ||
}); | ||
}); | ||
#else | ||
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上。最好对 |
||
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { | ||
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个分支是不是应该保留 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已经修改 |
||
BOOL_SWITCH(is_equal_qk, Is_equal_seq_qk, [&] { | ||
// Will only return softmax if dropout, to reduce compilation time. | ||
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, IsEvenKConst, ReturnSoftmaxConst && Is_dropout, false, Is_equal_seq_qk>; | ||
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, true, ReturnSoftmaxConst && Is_dropout>; | ||
if (smem_size >= 48 * 1024) { | ||
C10_CUDA_CHECK(cudaFuncSetAttribute( | ||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); | ||
} | ||
int ctas_per_sm; | ||
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( | ||
&ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); | ||
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); | ||
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params); | ||
C10_CUDA_KERNEL_LAUNCH_CHECK(); | ||
}); | ||
}); | ||
}); | ||
}); | ||
#endif | ||
} | ||
|
||
template<typename T> | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,221 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py | ||
import ast | ||
import logging | ||
import os | ||
import platform | ||
import re | ||
import shutil | ||
import subprocess | ||
import sys | ||
import warnings | ||
from pathlib import Path | ||
|
||
from packaging.version import parse | ||
from setuptools import find_packages, setup | ||
from setuptools.command.install import install as _install | ||
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel | ||
|
||
import paddle | ||
from paddle.utils.cpp_extension.extension_utils import find_cuda_home | ||
|
||
version_detail = sys.version_info | ||
python_version = platform.python_version() | ||
version = version_detail[0] + version_detail[1] / 10 | ||
env_version = os.getenv("PY_VERSION") | ||
|
||
if version < 3.7: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个判断比较粗糙。 |
||
raise RuntimeError( | ||
f"Paddle only supports Python version >= 3.7 now," | ||
f"you are using Python {python_version}" | ||
) | ||
elif env_version is None: | ||
print(f"export PY_VERSION = { python_version }") | ||
os.environ["PY_VERSION"] = python_version | ||
|
||
elif env_version != version: | ||
warnings.warn( | ||
f"You set PY_VERSION={env_version}, but" | ||
f"your current python environment is {version}" | ||
f"we will use your current python version to execute" | ||
) | ||
os.environ["PY_VERSION"] = python_version | ||
|
||
paddle_include_path = paddle.sysconfig.get_include() | ||
paddle_lib_path = paddle.sysconfig.get_lib() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为啥要依赖paddle的路径呢 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为了将安装的libflash_attn_advanced.so拷贝到paddle路径下 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这种方式我不太确定,请@sneaxiy 也看下。我理解:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FA应该作为Paddle侧的一个外部算子?类似于xformers跟PyTorch的关系,而不是新的so直接去替换Paddle里的FA的so? |
||
|
||
print("Paddle Include Path:", paddle_include_path) | ||
print("Paddle Lib Path:", paddle_lib_path) | ||
|
||
# preparing parameters for setup() | ||
paddle_version = paddle.version.full_version | ||
cuda_version = paddle.version.cuda_version | ||
|
||
|
||
with open("../../README.md", "r", encoding="utf-8") as fh: | ||
long_description = fh.read() | ||
|
||
|
||
# ninja build does not work unless include_dirs are abs path | ||
this_dir = os.path.dirname(os.path.abspath(__file__)) | ||
CUDA_HOME = find_cuda_home() | ||
PACKAGE_NAME = "paddle_flash_attn" | ||
|
||
|
||
def get_platform(): | ||
""" | ||
Returns the platform name as used in wheel filenames. | ||
""" | ||
if sys.platform.startswith('linux'): | ||
return 'linux_x86_64' | ||
elif sys.platform == 'darwin': | ||
mac_version = '.'.join(platform.mac_ver()[0].split('.')[:2]) | ||
return f'macosx_{mac_version}_x86_64' | ||
elif sys.platform == 'win32': | ||
return 'win_amd64' | ||
else: | ||
raise ValueError(f'Unsupported platform: {sys.platform}') | ||
|
||
|
||
def get_cuda_bare_metal_version(cuda_dir): | ||
raw_output = subprocess.check_output( | ||
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True | ||
) | ||
output = raw_output.split() | ||
release_idx = output.index("release") + 1 | ||
bare_metal_version = parse(output[release_idx].split(",")[0]) | ||
|
||
return raw_output, bare_metal_version | ||
|
||
|
||
def _is_cuda_available(): | ||
""" | ||
Check whether CUDA is available. | ||
""" | ||
try: | ||
assert len(paddle.static.cuda_places()) > 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是否有cuda设备,最好不要用paddle接口来判断 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 暂未找到其他方法 |
||
return True | ||
except Exception as e: | ||
logging.warning( | ||
"You are using GPU version PaddlePaddle, but there is no GPU " | ||
"detected on your machine. Maybe CUDA devices is not set properly." | ||
f"\n Original Error is {e}" | ||
) | ||
return False | ||
|
||
|
||
check = _is_cuda_available() | ||
cmdclass = {} | ||
|
||
|
||
def get_package_version(): | ||
with open(Path(this_dir) / "../flash_attn" / "__init__.py", "r") as f: | ||
version_match = re.search( | ||
r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE | ||
) | ||
public_version = ast.literal_eval(version_match.group(1)) | ||
local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION") | ||
if local_version: | ||
return f"{public_version}+{local_version}" | ||
else: | ||
return str(public_version) | ||
|
||
|
||
def get_data_files(): | ||
data_files = [] | ||
#source_lib_path = 'libflashattn.so' | ||
#data_files.append((".", [source_lib_path])) | ||
data_files.append((".", ['libflashattn_advanced.so'])) | ||
return data_files | ||
|
||
|
||
class CustomWheelsCommand(_bdist_wheel): | ||
""" | ||
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot | ||
find an existing wheel (which is currently the case for all flash attention installs). We use | ||
the environment parameters to detect whether there is already a pre-built version of a compatible | ||
wheel available and short-circuits the standard full build pipeline. | ||
""" | ||
|
||
def run(self): | ||
self.run_command('build_ext') | ||
super().run() | ||
# Determine the version numbers that will be used to determine the correct wheel | ||
# We're using the CUDA version used to build paddle, not the one currently installed | ||
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) | ||
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" | ||
platform_name = get_platform() | ||
flash_version = get_package_version() | ||
cxx11_abi = "" # str(paddle._C.-D_GLIBCXX_USE_CXX11_ABI).upper() | ||
|
||
# Determine wheel URL based on CUDA version, paddle version, python version and OS | ||
wheel_filename = f'{PACKAGE_NAME}-{flash_version}-cu{cuda_version}-paddle{paddle_version}-{python_version}-{python_version}-{platform_name}.whl' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. whl包里面不需要加paddle版本吧?本身flashattn对paddle版本并没有依赖,是paddle对flashattn版本存在依赖 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的这个后续会去掉,现在的是默认版本:paddle_flash_attn-2.0.8-cp37-none-any.whl |
||
impl_tag, abi_tag, plat_tag = self.get_tag() | ||
original_wheel_name = ( | ||
f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" | ||
) | ||
|
||
# new_wheel_name = wheel_filename | ||
new_wheel_name = ( | ||
f"{self.wheel_dist_name}-{python_version}-{abi_tag}-{plat_tag}" | ||
) | ||
shutil.move( | ||
f"{self.dist_dir}/{original_wheel_name}.whl", | ||
f"{self.dist_dir}/{new_wheel_name}.whl", | ||
) | ||
|
||
|
||
class CustomInstallCommand(_install): | ||
def run(self): | ||
_install.run(self) | ||
install_path = self.install_lib | ||
source_lib_path = os.path.abspath('libflashattn_advanced.so') | ||
destination_lib_path = os.path.join(paddle_lib_path, 'libflashattn_advanced.so') | ||
shutil.copy(f"{source_lib_path}", f"{destination_lib_path}") | ||
|
||
|
||
setup( | ||
name=PACKAGE_NAME, | ||
version=get_package_version(), | ||
packages=find_packages(), | ||
data_files=get_data_files(), | ||
package_data={PACKAGE_NAME: ['build/libflashattn.so']}, | ||
author_email="[email protected]", | ||
description="Flash Attention: Fast and Memory-Efficient Exact Attention", | ||
long_description=long_description, | ||
long_description_content_type="text/markdown", | ||
url="https://github.com/PaddlePaddle/flash-attention", | ||
classifiers=[ | ||
"Programming Language :: Python :: 37", | ||
"License :: OSI Approved :: BSD License", | ||
"Operating System :: Unix", | ||
], | ||
cmdclass={ | ||
'bdist_wheel': CustomWheelsCommand, | ||
'install': CustomInstallCommand, | ||
}, | ||
python_requires=">=3.7", | ||
install_requires=[ | ||
"common", | ||
"dual", | ||
"tight>=0.1.0", | ||
"data", | ||
"prox", | ||
"ninja", # Put ninja before paddle if paddle depends on it | ||
"einops", | ||
"packaging", | ||
], | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个写法在外部设置了
-DNVCC_ARCH_BIN=...
的情况下,取值会是多少,是80还是外部设置的值?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
外部设置的值, 已经做过实验会拿到外部设置的