Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fa cmake #29

Open
wants to merge 36 commits into
base: main
Choose a base branch
from
Open

Fa cmake #29

wants to merge 36 commits into from

Conversation

AnnaTrainingG
Copy link

@AnnaTrainingG AnnaTrainingG commented Dec 6, 2023

PR描述:

  1. 将csrc/build编译产生的结果编译成whl 或者 .so

  2. 使用方法:

    1. 进入flash-attention/csrc 创建build文件夹
    2. 执行 cmake .. # 此处有2个选项可以设置
      1. -DNVCC_ARCH_BIN=80-90 表示编译80 和 90 架构,默认值编译 80
      2. -DWITH_ADVANCED=ON 表示是否使用paddle flash_attention的高级功能
        • 高级功能包含:
          1. attn_mask功能支持,
          2. fa反向确定算法支持,用于逐位对齐精度调试)
    3. make -j128
    4. 示例: cmake .. -DNVCC_ARCH_BIN=80-90,-DWITH_ADVANCED=ON
    5. 编译结束后在:
      1. 如果设置-DWITH_ADVANCED=ON :
        1. flash-attention/csrc/build 下会产生 libflashattn_advanced.so
        2. flash-attention/csrc/build/dist 下会生成 paddle_flash_attn-2.0.8-cp37-none-any.whl 【注意这个打包只打包了libflashattn_advanced.so】
        3. 2.0.8表示当前paddle flash_attention的版本, cp37表示python版本为3.7, any 表示任意平台均可
      2. 如果不设置则
        1. flash-attention/csrc/build 下会产生 libflashattn.so 不会对.so进行打包【用于paddle内部源码编译】

paddle_flash_attn-2.0.8-cp37-none-any.whl 使用方法:
直接pip install 即可

说明:安装后后在/usr/local/lib/python3.7/dist-packages/paddle/libs/下新增 libflashattn_advanced.so。
触发条件: 当使用additional_mask功能 | 设置确定算法环境变量 | 设置 FLAGS_flash_attention_with_advanced 环境变量时 会调用 libflashattn_advanced.so 中的实现,如果未安装libflashattn_advanced.so 则会直接报错。
image
image

本PR修改后.so大小:
image

Copy link
Collaborator

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR描述补充下吧,包括:
(1)本PR的工作
(2)非advance分支动态库的大小
(3)advance分支动态库的大小、打包的whl包名字、打包了哪些内容、打包的命令

PR title也完善下,准确描述该PR的工作内容

@@ -96,7 +101,7 @@ target_compile_options(flashattn PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:
--expt-relaxed-constexpr
--expt-extended-lambda
--use_fast_math
"SHELL:-gencode arch=compute_80,code=sm_80"
"SHELL:-gencode arch=compute_80,code=sm_80 -gencode arch=compute_90,code=sm_90"
>)

target_compile_options(flashattn_with_bias_mask PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分功能也可以放到advance里面

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

来不及修改了 先合入基础的吧

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个只需修改cmake即可,建议这个PR直接定下来

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经修改

@@ -111,10 +116,23 @@ target_compile_options(flashattn_with_bias_mask PRIVATE $<$<COMPILE_LANGUAGE:CUD
--expt-relaxed-constexpr
--expt-extended-lambda
--use_fast_math
"SHELL:-gencode arch=compute_80,code=sm_80"
"SHELL:-gencode arch=compute_80,code=sm_80 -gencode arch=compute_90,code=sm_90"
>)

INSTALL(TARGETS flashattn
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最终生成的动态库名称,在关闭、开启advance功能时,最好有所区分,这样Paddle框架中在加载动态库时容易区分些。

Copy link
Author

@AnnaTrainingG AnnaTrainingG Dec 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

关闭是就是libflashattn.so 开启的时候是libflashattn_advanced.so

#else
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个分支是不是应该保留is_equal_qk模板?我理解非advance分支,需要是causal最优的性能版本

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经修改

os.environ["PY_VERSION"] = python_version

paddle_include_path = paddle.sysconfig.get_include()
paddle_lib_path = paddle.sysconfig.get_lib()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为啥要依赖paddle的路径呢

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为了将安装的libflash_attn_advanced.so拷贝到paddle路径下

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这种方式我不太确定,请@sneaxiy 也看下。我理解:

  1. FA动态图即使是安装在自己的目录下,应该也是能找到的
  2. FA打包成.whl后,对FA的依赖应该需要写到Paddle的requirements.txt里面,那安装FA的时候很可能是还没有安装Paddle的

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FA应该作为Paddle侧的一个外部算子?类似于xformers跟PyTorch的关系,而不是新的so直接去替换Paddle里的FA的so?

Check whether CUDA is available.
"""
try:
assert len(paddle.static.cuda_places()) > 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否有cuda设备,最好不要用paddle接口来判断

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

暂未找到其他方法

cxx11_abi = "" # str(paddle._C.-D_GLIBCXX_USE_CXX11_ABI).upper()

# Determine wheel URL based on CUDA version, paddle version, python version and OS
wheel_filename = f'{PACKAGE_NAME}-{flash_version}-cu{cuda_version}-paddle{paddle_version}-{python_version}-{python_version}-{platform_name}.whl'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whl包里面不需要加paddle版本吧?本身flashattn对paddle版本并没有依赖,是paddle对flashattn版本存在依赖

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的这个后续会去掉,现在的是默认版本:paddle_flash_attn-2.0.8-cp37-none-any.whl

csrc/setup.py Outdated
# Determine the version numbers that will be used to determine the correct wheel
# We're using the CUDA version used to build paddle, not the one currently installed
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
paddle_cuda_version = "234" # parse(paddle.version.cuda)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个干啥用的呢?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个之前获取paddle cuda版本的时候报错了, 临时加的,后续会删除

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经修改

@Xreki Xreki requested a review from sneaxiy December 11, 2023 02:59
)
endif()

if (WITH_ADVANCED)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

L148、150可以去掉

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

os.environ["PY_VERSION"] = python_version

paddle_include_path = paddle.sysconfig.get_include()
paddle_lib_path = paddle.sysconfig.get_lib()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这种方式我不太确定,请@sneaxiy 也看下。我理解:

  1. FA动态图即使是安装在自己的目录下,应该也是能找到的
  2. FA打包成.whl后,对FA的依赖应该需要写到Paddle的requirements.txt里面,那安装FA的时候很可能是还没有安装Paddle的

@Xreki
Copy link
Collaborator

Xreki commented Dec 11, 2023

触发条件: 当用户使用了additional_mask功能 | 设置确定算法环境变量 | 设置 FLAGS_flash_attention_with_advanced 环境变量时 会调用 libflashattn_advanced.so 中的实现,如果未安装libflashattn_advanced.so 则会直接报错。

不建议提供这么复杂的配置方式。建议基础版本和扩展版本提供的causal计算能力和性能保持一致;若Paddle中使用到了扩展版本中新增的功能,则Paddle应自动去调用扩展库,用户唯一的感知就是可能需要手动执行下pip install paddle-flash-attn

@AnnaTrainingG
Copy link
Author

FA动态图即使是安装在自己的目录下,paddle是找不到的,paddle只会在自己的libs下找, 可以看Paddle的修改https://github.com/PaddlePaddle/Paddle/pull/59802/files

@AnnaTrainingG
Copy link
Author

建议基础版本和扩展版本提供的causal计算能力和性能保持一致-》 已经是一致的了,描述存在问题,已修改

});
});
});
#endif
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的代码是否可以简化下?以免出现2个分支?比如对BOOL_SWITCH做个改进,WITH_ADVANCED前后走不同的定义,以免以后维护更加困难?类似于:

#ifdef PADDLE_WITH_ADVANCED
#define BOOL_SWITCH(...) ...
#else   
#define BOOL_SWITCH(...) ...
#endif

@@ -59,6 +60,29 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
});
});
});
#else
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上。最好对BOOL_SWITCH进行改进。在WITH_ADVANCED开启前后走不同的定义,避免重复代码。

version = version_detail[0] + version_detail[1] / 10
env_version = os.getenv("PY_VERSION")

if version < 3.7:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个判断比较粗糙。version是浮点数而不是整数,建议改成使用version_detail整数的判断。

os.environ["PY_VERSION"] = python_version

paddle_include_path = paddle.sysconfig.get_include()
paddle_lib_path = paddle.sysconfig.get_lib()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FA应该作为Paddle侧的一个外部算子?类似于xformers跟PyTorch的关系,而不是新的so直接去替换Paddle里的FA的so?

@@ -83,18 +97,14 @@ target_link_libraries(flashattn flashattn_with_bias_mask)

add_dependencies(flashattn flashattn_with_bias_mask)

set(NVCC_ARCH_BIN 80 CACHE STRING "CUDA architectures")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个写法在外部设置了-DNVCC_ARCH_BIN=...的情况下,取值会是多少,是80还是外部设置的值?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

外部设置的值, 已经做过实验会拿到外部设置的

@CLAassistant
Copy link

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
You have signed the CLA already but the status is still pending? Let us recheck it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants