-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fa cmake #29
base: main
Are you sure you want to change the base?
Fa cmake #29
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR描述补充下吧,包括:
(1)本PR的工作
(2)非advance分支动态库的大小
(3)advance分支动态库的大小、打包的whl包名字、打包了哪些内容、打包的命令
PR title也完善下,准确描述该PR的工作内容
csrc/CMakeLists.txt
Outdated
@@ -96,7 +101,7 @@ target_compile_options(flashattn PRIVATE $<$<COMPILE_LANGUAGE:CUDA>: | |||
--expt-relaxed-constexpr | |||
--expt-extended-lambda | |||
--use_fast_math | |||
"SHELL:-gencode arch=compute_80,code=sm_80" | |||
"SHELL:-gencode arch=compute_80,code=sm_80 -gencode arch=compute_90,code=sm_90" | |||
>) | |||
|
|||
target_compile_options(flashattn_with_bias_mask PRIVATE $<$<COMPILE_LANGUAGE:CUDA>: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分功能也可以放到advance里面
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
来不及修改了 先合入基础的吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个只需修改cmake即可,建议这个PR直接定下来
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经修改
@@ -111,10 +116,23 @@ target_compile_options(flashattn_with_bias_mask PRIVATE $<$<COMPILE_LANGUAGE:CUD | |||
--expt-relaxed-constexpr | |||
--expt-extended-lambda | |||
--use_fast_math | |||
"SHELL:-gencode arch=compute_80,code=sm_80" | |||
"SHELL:-gencode arch=compute_80,code=sm_80 -gencode arch=compute_90,code=sm_90" | |||
>) | |||
|
|||
INSTALL(TARGETS flashattn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
最终生成的动态库名称,在关闭、开启advance功能时,最好有所区分,这样Paddle框架中在加载动态库时容易区分些。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
关闭是就是libflashattn.so 开启的时候是libflashattn_advanced.so
#else | ||
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { | ||
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { | ||
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个分支是不是应该保留is_equal_qk
模板?我理解非advance分支,需要是causal
最优的性能版本
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经修改
os.environ["PY_VERSION"] = python_version | ||
|
||
paddle_include_path = paddle.sysconfig.get_include() | ||
paddle_lib_path = paddle.sysconfig.get_lib() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为啥要依赖paddle的路径呢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为了将安装的libflash_attn_advanced.so拷贝到paddle路径下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这种方式我不太确定,请@sneaxiy 也看下。我理解:
- FA动态图即使是安装在自己的目录下,应该也是能找到的
- FA打包成.whl后,对FA的依赖应该需要写到Paddle的requirements.txt里面,那安装FA的时候很可能是还没有安装Paddle的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FA应该作为Paddle侧的一个外部算子?类似于xformers跟PyTorch的关系,而不是新的so直接去替换Paddle里的FA的so?
Check whether CUDA is available. | ||
""" | ||
try: | ||
assert len(paddle.static.cuda_places()) > 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是否有cuda设备,最好不要用paddle接口来判断
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
暂未找到其他方法
cxx11_abi = "" # str(paddle._C.-D_GLIBCXX_USE_CXX11_ABI).upper() | ||
|
||
# Determine wheel URL based on CUDA version, paddle version, python version and OS | ||
wheel_filename = f'{PACKAGE_NAME}-{flash_version}-cu{cuda_version}-paddle{paddle_version}-{python_version}-{python_version}-{platform_name}.whl' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
whl包里面不需要加paddle版本吧?本身flashattn对paddle版本并没有依赖,是paddle对flashattn版本存在依赖
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的这个后续会去掉,现在的是默认版本:paddle_flash_attn-2.0.8-cp37-none-any.whl
csrc/setup.py
Outdated
# Determine the version numbers that will be used to determine the correct wheel | ||
# We're using the CUDA version used to build paddle, not the one currently installed | ||
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) | ||
paddle_cuda_version = "234" # parse(paddle.version.cuda) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个干啥用的呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个之前获取paddle cuda版本的时候报错了, 临时加的,后续会删除
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经修改
4a63f7c
to
3aca223
Compare
csrc/CMakeLists.txt
Outdated
) | ||
endif() | ||
|
||
if (WITH_ADVANCED) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
L148、150可以去掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
os.environ["PY_VERSION"] = python_version | ||
|
||
paddle_include_path = paddle.sysconfig.get_include() | ||
paddle_lib_path = paddle.sysconfig.get_lib() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这种方式我不太确定,请@sneaxiy 也看下。我理解:
- FA动态图即使是安装在自己的目录下,应该也是能找到的
- FA打包成.whl后,对FA的依赖应该需要写到Paddle的requirements.txt里面,那安装FA的时候很可能是还没有安装Paddle的
不建议提供这么复杂的配置方式。建议基础版本和扩展版本提供的causal计算能力和性能保持一致;若Paddle中使用到了扩展版本中新增的功能,则Paddle应自动去调用扩展库,用户唯一的感知就是可能需要手动执行下 |
FA动态图即使是安装在自己的目录下,paddle是找不到的,paddle只会在自己的libs下找, 可以看Paddle的修改https://github.com/PaddlePaddle/Paddle/pull/59802/files |
建议基础版本和扩展版本提供的causal计算能力和性能保持一致-》 已经是一致的了,描述存在问题,已修改 |
}); | ||
}); | ||
}); | ||
#endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的代码是否可以简化下?以免出现2个分支?比如对BOOL_SWITCH
做个改进,WITH_ADVANCED
前后走不同的定义,以免以后维护更加困难?类似于:
#ifdef PADDLE_WITH_ADVANCED
#define BOOL_SWITCH(...) ...
#else
#define BOOL_SWITCH(...) ...
#endif
@@ -59,6 +60,29 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { | |||
}); | |||
}); | |||
}); | |||
#else | |||
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上。最好对BOOL_SWITCH
进行改进。在WITH_ADVANCED
开启前后走不同的定义,避免重复代码。
version = version_detail[0] + version_detail[1] / 10 | ||
env_version = os.getenv("PY_VERSION") | ||
|
||
if version < 3.7: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个判断比较粗糙。version
是浮点数而不是整数,建议改成使用version_detail
整数的判断。
os.environ["PY_VERSION"] = python_version | ||
|
||
paddle_include_path = paddle.sysconfig.get_include() | ||
paddle_lib_path = paddle.sysconfig.get_lib() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FA应该作为Paddle侧的一个外部算子?类似于xformers跟PyTorch的关系,而不是新的so直接去替换Paddle里的FA的so?
@@ -83,18 +97,14 @@ target_link_libraries(flashattn flashattn_with_bias_mask) | |||
|
|||
add_dependencies(flashattn flashattn_with_bias_mask) | |||
|
|||
set(NVCC_ARCH_BIN 80 CACHE STRING "CUDA architectures") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个写法在外部设置了-DNVCC_ARCH_BIN=...
的情况下,取值会是多少,是80还是外部设置的值?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
外部设置的值, 已经做过实验会拿到外部设置的
|
PR描述:
将csrc/build编译产生的结果编译成whl 或者 .so
使用方法:
1. attn_mask功能支持,
2. fa反向确定算法支持,用于逐位对齐精度调试)
1. flash-attention/csrc/build 下会产生 libflashattn_advanced.so
2. flash-attention/csrc/build/dist 下会生成 paddle_flash_attn-2.0.8-cp37-none-any.whl 【注意这个打包只打包了libflashattn_advanced.so】
3. 2.0.8表示当前paddle flash_attention的版本, cp37表示python版本为3.7, any 表示任意平台均可
1. flash-attention/csrc/build 下会产生 libflashattn.so 不会对.so进行打包【用于paddle内部源码编译】
paddle_flash_attn-2.0.8-cp37-none-any.whl 使用方法:
直接pip install 即可
说明:安装后后在/usr/local/lib/python3.7/dist-packages/paddle/libs/下新增 libflashattn_advanced.so。
触发条件: 当使用additional_mask功能 | 设置确定算法环境变量 | 设置 FLAGS_flash_attention_with_advanced 环境变量时 会调用 libflashattn_advanced.so 中的实现,如果未安装libflashattn_advanced.so 则会直接报错。
本PR修改后.so大小: