From b5b20b814f60f6435b73b75653c1a8a7244acefa Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 02:50:56 +0000 Subject: [PATCH 01/36] update --- csrc/CMakeLists.txt | 93 ++++++++++++++++++++++++++------------------- csrc/setup.py | 15 ++++++++ csrc/tp.py | 75 ++++++++++++++++++++++++++++++++++++ csrc/yest | 3 ++ yes.py | 12 ++++++ 5 files changed, 159 insertions(+), 39 deletions(-) create mode 100644 csrc/setup.py create mode 100644 csrc/tp.py create mode 100644 csrc/yest create mode 100644 yes.py diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 5b305fa98..04597cf3e 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -13,38 +13,38 @@ set(CUTLASS_3_DIR ${CMAKE_CURRENT_SOURCE_DIR}/cutlass) set(FA2_SOURCES_CU flash_attn/src/cuda_utils.cu - flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu ) add_library(flashattn SHARED @@ -57,13 +57,13 @@ target_include_directories(flashattn PRIVATE set(FA1_SOURCES_CU flash_attn_with_bias_and_mask/flash_attn_with_bias_mask.cu - flash_attn_with_bias_and_mask/src/cuda_utils.cu - flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim32.cu - flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim64.cu - flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim128.cu - flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim32.cu - flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim64.cu - flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim128.cu + #flash_attn_with_bias_and_mask/src/cuda_utils.cu + #flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim32.cu + #flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim64.cu + #flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim128.cu + #flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim32.cu + #flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim64.cu + #flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim128.cu flash_attn_with_bias_and_mask/src/utils.cu) add_library(flashattn_with_bias_mask STATIC @@ -131,7 +131,22 @@ target_compile_options(flashattn_with_bias_mask PRIVATE $<$) +# INSTALL(TARGETS flashattn LIBRARY DESTINATION "lib") INSTALL(FILES capi/flash_attn.h DESTINATION "include") + +add_custom_target(run_my_executable + COMMAND ${CMAKE_COMMAND} -E env python ${CMAKE_SOURCE_DIR}/setup.py bdist + WORKING_DIRECTORY ${CMAKE_BINARY_DIR} + DEPENDS flashattn + COMMENT "Running my_executable" +) + +# 创建一个伪目标作为默认构建目标 +add_custom_target(default_target DEPENDS run_my_executable) + +# 设置 'default_target' 为默认构建目标 +set_property(DIRECTORY PROPERTY DEFAULT_TARGET default_target) + diff --git a/csrc/setup.py b/csrc/setup.py new file mode 100644 index 000000000..92a9a3f16 --- /dev/null +++ b/csrc/setup.py @@ -0,0 +1,15 @@ +from setuptools import setup, find_packages +from setuptools import setup, find_namespace_packages + +setup( + packages=find_packages(where="src"), + package_dir={"": "src"}, + package_data={"": ["*.so"]}, + exclude_package_data={"flash_attn_with_bias_and_mask": ["*"]}, + include_package_data=True, + #packages=find_namespace_packages(where="src"), + #package_dir={"": "src"}, + #package_data={ + # "": ["*.so"], + #} +) diff --git a/csrc/tp.py b/csrc/tp.py new file mode 100644 index 000000000..5b5d9d95d --- /dev/null +++ b/csrc/tp.py @@ -0,0 +1,75 @@ +import paddle +from setuptools import setup, find_packages +import sys +import os + +python_version = sys.version +print("Installing your_package...") + +# Get the CUDA version from PaddlePaddle +cuda_version = paddle.version.cuda() +fa_version = f"1.0.0.post{cuda_version}" +package_name = 'flash_attention_paddle_gpu' + +def get_data_files(): + data_files = [] + + # Assuming 'libflashattn.so' is located in the same directory as setup.py + source_lib_path = 'libflashattn.so' + + # Specify the destination directory within the package + destination_lib_path = os.path.join(package_name, 'libflashattn.so') + + data_files.append((os.path.join(package_name, 'libflashattn.so'), [source_lib_path])) + print(destination_lib_path, "asdf ****************") + print(data_files) + return data_files + +setup( + name=package_name, + version=fa_version, + data_files=get_data_files(), + description='Flash attention in paddlepaddle', + packages=find_packages(), + package_data={package_name: ['src/libflashattn.so']}, +) +# +#import paddle +#import os +#from setuptools import setup +#import sys +# +#python_version = sys.version +#print("Installing your_package...") +# +## Get the CUDA version from PaddlePaddle +#cuda_version = paddle.version.cuda() +#fa_version = f"1.0.0.post{cuda_version}" +#package_name = 'flash_attention_paddle_gpu' # Adjusted package name +# +#def get_data_files(): +# data_files = [] +# +# # Assuming 'libflashattn.so' is located in the same directory as setup.py +# source_lib_path = os.path.abspath('libflashattn.so') +# +# # Specify the destination directory within the package +# destination_lib_path = os.path.join(package_name, 'libflashattn.so') +# +# data_files.append((os.path.join(package_name, 'libflashattn.so'), [source_lib_path])) +# print(destination_lib_path, "asdf ****************") +# print(data_files) +# return data_files +# +## Create an empty __init__.py file in the package directory +#init_file_path = os.path.join(package_name, '__init__.py') +#with open(init_file_path, 'w') as f: +# pass +# +#setup( +# name=package_name, +# version=fa_version, +# description='Flash attention in paddlepaddle', +# packages=[package_name], +# package_data={package_name: ['libflashattn.so']}, +#) diff --git a/csrc/yest b/csrc/yest new file mode 100644 index 000000000..b3d4d3cd0 --- /dev/null +++ b/csrc/yest @@ -0,0 +1,3 @@ +include build/libflashattn.so +include src/libflashattn.so +include ./libflashattn.so diff --git a/yes.py b/yes.py new file mode 100644 index 000000000..29917c43e --- /dev/null +++ b/yes.py @@ -0,0 +1,12 @@ +from setuptools import setup + +package_name = '' #flash-attention-paddle-gpu' +setup( + name=package_name, + version='1.0.0', + description='Flash attention in PaddlePaddle', + packages=[package_name], + include_package_data=True, + package_data={package_name: ['csrc/build/libflashattn.so']}, +) + From 199b9d68ad1d89e724987bbd9de16a92a7c6cf38 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 03:22:28 +0000 Subject: [PATCH 02/36] has data --- csrc/CMakeLists.txt | 2 +- csrc/tp.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 04597cf3e..639d8b1eb 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -138,7 +138,7 @@ INSTALL(TARGETS flashattn INSTALL(FILES capi/flash_attn.h DESTINATION "include") add_custom_target(run_my_executable - COMMAND ${CMAKE_COMMAND} -E env python ${CMAKE_SOURCE_DIR}/setup.py bdist + COMMAND ${CMAKE_COMMAND} -E env python ${CMAKE_SOURCE_DIR}/tp.py sdist bdist_wheel WORKING_DIRECTORY ${CMAKE_BINARY_DIR} DEPENDS flashattn COMMENT "Running my_executable" diff --git a/csrc/tp.py b/csrc/tp.py index 5b5d9d95d..56aebee07 100644 --- a/csrc/tp.py +++ b/csrc/tp.py @@ -2,7 +2,9 @@ from setuptools import setup, find_packages import sys import os - +import paddle +paddle_path = paddle.sysconfig.get_lib +print(paddle_path) python_version = sys.version print("Installing your_package...") @@ -31,7 +33,7 @@ def get_data_files(): data_files=get_data_files(), description='Flash attention in paddlepaddle', packages=find_packages(), - package_data={package_name: ['src/libflashattn.so']}, + package_data={package_name: ['build/libflashattn.so']}, ) # #import paddle From a582b3a9e571bea2f65284c0626756ad7ff589a3 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 04:47:09 +0000 Subject: [PATCH 03/36] update --- csrc/README.md | 234 ++++++++++++++++++++++++++++++++++ csrc/mp.py | 336 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 570 insertions(+) create mode 100644 csrc/README.md create mode 100644 csrc/mp.py diff --git a/csrc/README.md b/csrc/README.md new file mode 100644 index 000000000..79d334530 --- /dev/null +++ b/csrc/README.md @@ -0,0 +1,234 @@ +# FlashAttention +This repository provides the official implementation of FlashAttention and +FlashAttention-2 from the +following papers. + +**FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness** +Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré +Paper: https://arxiv.org/abs/2205.14135 +IEEE Spectrum [article](https://spectrum.ieee.org/mlperf-rankings-2022) about our submission to the MLPerf 2.0 benchmark using FlashAttention. +![FlashAttention](assets/flashattn_banner.jpg) + +**FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning** +Tri Dao + +Paper: https://tridao.me/publications/flash2/flash2.pdf + +![FlashAttention-2](assets/flashattention_logo.png) + + +## Usage + +We've been very happy to see FlashAttention being widely adopted in such a short +time after its release. This [page](https://github.com/Dao-AILab/flash-attention/blob/main/usage.md) +contains a partial list of places where FlashAttention is being used. + +FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE). +Please cite and credit FlashAttention if you use it. + +## Installation and features + +Requirements: +- CUDA 11.4 and above. +- PyTorch 1.12 and above. + +We recommend the +[Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) +container from Nvidia, which has all the required tools to install FlashAttention. + +To install: +1. Make sure that PyTorch is installed. +2. Make sure that `packaging` is installed (`pip install packaging`) +3. Make sure that `ninja` is installed and that it works correctly (e.g. `ninja +--version` then `echo $?` should return exit code 0). If not (sometimes `ninja +--version` then `echo $?` returns a nonzero exit code), uninstall then reinstall +`ninja` (`pip uninstall -y ninja && pip install ninja`). Without `ninja`, +compiling can take a very long time (2h) since it does not use multiple CPU +cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine. +4. Then: +```sh +pip install flash-attn --no-build-isolation +``` +Alternatively you can compile from source: +```sh +python setup.py install +``` + +If your machine has less than 96GB of RAM and lots of CPU cores, `ninja` might +run too many parallel compilation jobs that could exhaust the amount of RAM. To +limit the number of parallel compilation jobs, you can set the environment +variable `MAX_JOBS`: +```sh +MAX_JOBS=4 pip install flash-attn --no-build-isolation +``` + +Interface: `src/flash_attention_interface.py` + +FlashAttention-2 currently supports: +1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing + GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing + GPUs for now. +2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs). +3. All head dimensions up to 256. Head dim > 192 backward requires A100/A800 or H100/H800. + + +## How to use FlashAttention + +The main functions implement scaled dot product attention (softmax(Q @ K^T * +softmax_scale) @ V): +```python +from flash_attn import flash_attn_qkvpacked_func, flash_attn_func +``` + +```python +flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False): +"""dropout_p should be set to 0.0 during evaluation +If Q, K, V are already stacked into 1 tensor, this function will be faster than +calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation +of the gradients of Q, K, V. +Arguments: + qkv: (batch_size, seqlen, 3, nheads, headdim) + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). +Return: + out: (batch_size, seqlen, nheads, headdim). +""" +``` + +```python +flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False): +"""dropout_p should be set to 0.0 during evaluation +Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads +than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. +For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head +0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + +Arguments: + q: (batch_size, seqlen, nheads, headdim) + k: (batch_size, seqlen, nheads_k, headdim) + v: (batch_size, seqlen, nheads_k, headdim) + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). +Return: + out: (batch_size, seqlen, nheads, headdim). +""" +``` + +To see how these functions are used in a multi-head attention layer (which +includes QKV projection, output projection), see the MHA [implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py). + +## Upgrading from FlashAttention (1.x) to FlashAttention-2 + +These functions have been renamed: +- `flash_attn_unpadded_func` -> `flash_attn_varlen_func` +- `flash_attn_unpadded_qkvpacked_func` -> `flash_attn_varlen_qkvpacked_func` +- `flash_attn_unpadded_kvpacked_func` -> `flash_attn_varlen_kvpacked_func` + +If the inputs have the same sequence lengths in the same batch, it is simpler +and faster to use these functions: +```python +flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False) +``` +```python +flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False) +``` + +## Performance + +We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory). + +We currently have benchmarks for these GPUs: +* [A100](#a100) +* [H100](#h100) + + + +### A100 + +We display FlashAttention speedup using these parameters: +* Head dimension 64 or 128, hidden dimension 2048 (i.e. either 32 or 16 heads). +* Sequence length 512, 1k, 2k, 4k, 8k, 16k. +* Batch size set to 16k / seqlen. + +#### Speedup + +![FlashAttention speedup on A100 80GB SXM5 with FP16/BF16](assets/flash2_a100_fwd_bwd_benchmark.png) + +#### Memory + +![FlashAttention memory](assets/flashattn_memory.jpg) + +We show memory savings in this graph (note that memory footprint is the same no matter if you use dropout or masking). +Memory savings are proportional to sequence length -- since standard attention has memory quadratic in sequence length, whereas FlashAttention has memory linear in sequence length. +We see 10X memory savings at sequence length 2K, and 20X at 4K. +As a result, FlashAttention can scale to much longer sequence lengths. + +### H100 + +![FlashAttention speedup on H100 SXM5 with FP16/BF16](assets/flash2_h100_fwd_bwd_benchmark.png) + +## Full model code and training script + +We have released the full GPT model +[implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/gpt.py). +We also provide optimized implementations of other layers (e.g., MLP, LayerNorm, +cross-entropy loss, rotary embedding). Overall this speeds up training by 3-5x +compared to the baseline implementation from Huggingface, reaching up to 225 +TFLOPs/sec per A100, equivalent to 72% model FLOPs utilization (we don't need +any activation checkpointing). + +We also include a training +[script](https://github.com/Dao-AILab/flash-attention/tree/main/training) to +train GPT2 on Openwebtext and GPT3 on The Pile. + +## Triton implementation of FlashAttention + +Phil Tillet (OpenAI) has an experimental implementation of FlashAttention in Triton: +https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py + +As Triton is a higher-level language than CUDA, it might be easier to understand +and experiment with. The notations in the Triton implementation are also closer +to what's used in our paper. + +We also have an experimental implementation in Triton that support attention +bias (e.g. ALiBi): +https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py + + +## Tests +We test that FlashAttention produces the same output and gradient as a reference +implementation, up to some numerical tolerance. In particular, we check that the +maximum numerical error of FlashAttention is at most twice the numerical error +of a baseline implementation in Pytorch (for different head dimensions, input +dtype, sequence length, causal / non-causal). + +To run the tests: +```sh +pytest -q -s tests/test_flash_attn.py +``` +## When you encounter issues + +This new release of FlashAttention-2 has been tested on several GPT-style +models, mostly on A100 GPUs. + +If you encounter bugs, please open a GitHub Issue! + +## Citation +If you use this codebase, or otherwise found our work valuable, please cite: +``` +@inproceedings{dao2022flashattention, + title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness}, + author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher}, + booktitle={Advances in Neural Information Processing Systems}, + year={2022} +} +@article{dao2023flashattention2, + title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning, + author={Dao, Tri}, + year={2023} +} +``` diff --git a/csrc/mp.py b/csrc/mp.py new file mode 100644 index 000000000..c379c9c4e --- /dev/null +++ b/csrc/mp.py @@ -0,0 +1,336 @@ +# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py +import sys +import warnings +import os +import re +import ast +from pathlib import Path +from packaging.version import parse, Version +import platform + +from setuptools import setup, find_packages +import subprocess + +import urllib.request +import urllib.error +from wheel.bdist_wheel import bdist_wheel as _bdist_wheel + +import paddle +from paddle.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension +from paddle.utils.cpp_extension.extension_utils import find_cuda_home + +with open("README.md", "r", encoding="utf-8") as fh: + long_description = fh.read() + + +# ninja build does not work unless include_dirs are abs path +this_dir = os.path.dirname(os.path.abspath(__file__)) +CUDA_HOME = find_cuda_home() +PACKAGE_NAME = "flash_attn" + +BASE_WHEEL_URL = "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}" + +# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels +# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation +FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE" +SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE" +# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI +FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE" +# For CI, we want the option to not add "--threads 4" to nvcc, since the runner can OOM +FORCE_SINGLE_THREAD = os.getenv("FLASH_ATTENTION_FORCE_SINGLE_THREAD", "FALSE") == "TRUE" + + +def get_platform(): + """ + Returns the platform name as used in wheel filenames. + """ + if sys.platform.startswith('linux'): + return 'linux_x86_64' + elif sys.platform == 'darwin': + mac_version = '.'.join(platform.mac_ver()[0].split('.')[:2]) + return f'macosx_{mac_version}_x86_64' + elif sys.platform == 'win32': + return 'win_amd64' + else: + raise ValueError('Unsupported platform: {}'.format(sys.platform)) + + +def get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + bare_metal_version = parse(output[release_idx].split(",")[0]) + + return raw_output, bare_metal_version + + +def check_cuda_paddle_binary_vs_bare_metal(cuda_dir): + raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir) + paddle_binary_version = parse(paddle.version.cuda()) + + print("\nCompiling cuda extensions with") + print(raw_output + "from " + cuda_dir + "/bin\n") + + if (bare_metal_version != paddle_binary_version): + raise RuntimeError( + "Cuda extensions are being compiled with a version of Cuda that does " + "not match the version used to compile Pypaddle binaries. " + "Pypaddle binaries were compiled with Cuda {}.\n".format(paddle.version.cuda) + + "In some cases, a minor-version mismatch will not cause later errors: " + "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " + "You can try commenting out this check (at your own risk)." + ) + + +def raise_if_cuda_home_none(global_option: str) -> None: + if CUDA_HOME is not None: + return + raise RuntimeError( + f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " + "If you're installing within a container from https://hub.docker.com/r/pypaddle/pypaddle, " + "only images whose names contain 'devel' will provide nvcc." + ) + + +def append_nvcc_threads(nvcc_extra_args): + if not FORCE_SINGLE_THREAD: + return nvcc_extra_args + ["--threads", "4"] + return nvcc_extra_args + +def _is_cuda_available(): + """ + Check whether CUDA is available. + """ + try: + assert len(paddle.static.cuda_places()) > 0 + return True + except Exception as e: + logging.warning( + "You are using GPU version PaddlePaddle, but there is no GPU " + "detected on your machine. Maybe CUDA devices is not set properly." + f"\n Original Error is {e}" + ) + return False + +if paddle.is_compiled_with_cuda() and _is_cuda_available(): + # https://github.com/NVIDIA/apex/issues/486 + # Extension builds after https://github.com/pypaddle/pypaddle/pull/23408 attempt to query paddle.cuda.get_device_capability(), + # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). + print( + "\nWarning: Torch did not find available GPUs on this system.\n", + "If your intention is to cross-compile, this is not an error.\n" + "By default, FlashAttention will cross-compile for Ampere (compute capability 8.0, 8.6, " + "8.9), and, if the CUDA version is >= 11.8, Hopper (compute capability 9.0).\n" + "If you wish to cross-compile for a single specific architecture,\n" + 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', + ) + if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version >= Version("11.8"): + os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6;9.0" + elif bare_metal_version >= Version("11.4"): + os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6" + else: + os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6" + +cmdclass = {} +ext_modules = [] + +# We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp +# files included in the source distribution, in case the user compiles from source. +subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) + +if not SKIP_CUDA_BUILD: + print("\n\npaddle.__version__ = {}\n\n".format(paddle.__version__)) + TORCH_MAJOR = int(paddle.__version__.split(".")[0]) + TORCH_MINOR = int(paddle.__version__.split(".")[1]) + + # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h + # See https://github.com/pypaddle/pypaddle/pull/70650 + generator_flag = [] + paddle_dir = paddle.__path__[0] + if os.path.exists(os.path.join(paddle_dir, "include", "ATen", "CUDAGeneratorImpl.h")): + generator_flag = ["-DOLD_GENERATOR_PATH"] + + raise_if_cuda_home_none("flash_attn") + # Check, if CUDA11 is installed for compute capability 8.0 + cc_flag = [] + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version < Version("11.4"): + raise RuntimeError("FlashAttention is only supported on CUDA 11.4 and above") + # cc_flag.append("-gencode") + # cc_flag.append("arch=compute_75,code=sm_75") + cc_flag.append("-gencode") + cc_flag.append("arch=compute_80,ggcode=sm_80") + if bare_metal_version >= Version("11.8"): + cc_flag.append("-gencode") + cc_flag.append("arch=compute_90,code=sm_90") + + # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as + # paddle._C._GLIBCXX_USE_CXX11_ABI + # https://github.com/pypaddle/pypaddle/blob/8472c24e3b5b60150096486616d98b7bea01500b/paddle/utils/cpp_extension.py#L920 + if FORCE_CXX11_ABI: + paddle._C._GLIBCXX_USE_CXX11_ABI = True + ext_modules.append( + CUDAExtension( + sources=[ + "csrc/flash_attn/flash_api.cpp", + "csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu", + ], + extra_compile_args={ + "cxx": ["-O3", "-std=c++17"] + generator_flag, + "nvcc": append_nvcc_threads( + [ + "-O3", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + "--ptxas-options=-v", + # "--ptxas-options=-O2", + "-lineinfo" + ] + + generator_flag + + cc_flag + ), + }, + include_dirs=[ + Path(this_dir) / 'csrc' / 'flash_attn', + Path(this_dir) / 'csrc' / 'flash_attn' / 'src', + Path(this_dir) / 'csrc' / 'cutlass' / 'include', + ], + ) + ) + + +def get_package_version(): + with open(Path(this_dir) / "flash_attn" / "__init__.py", "r") as f: + version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) + public_version = ast.literal_eval(version_match.group(1)) + local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION") + if local_version: + return f"{public_version}+{local_version}" + else: + return str(public_version) + + +class CachedWheelsCommand(_bdist_wheel): + """ + The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot + find an existing wheel (which is currently the case for all flash attention installs). We use + the environment parameters to detect whether there is already a pre-built version of a compatible + wheel available and short-circuits the standard full build pipeline. + """ + def run(self): + if FORCE_BUILD: + return super().run() + + # Determine the version numbers that will be used to determine the correct wheel + # We're using the CUDA version used to build paddle, not the one currently installed + # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) + paddle_cuda_version = parse(paddle.version.cuda) + paddle_version_raw = parse(paddle.__version__) + python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" + platform_name = get_platform() + flash_version = get_package_version() + # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" + cuda_version = f"{paddle_cuda_version.major}{paddle_cuda_version.minor}" + paddle_version = f"{paddle_version_raw.major}.{paddle_version_raw.minor}" + cxx11_abi = str(paddle._C._GLIBCXX_USE_CXX11_ABI).upper() + + # Determine wheel URL based on CUDA version, paddle version, python version and OS + wheel_filename = f'{PACKAGE_NAME}-{flash_version}+cu{cuda_version}paddle{paddle_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl' + wheel_url = BASE_WHEEL_URL.format( + tag_name=f"v{flash_version}", + wheel_name=wheel_filename + ) + print("Guessing wheel URL: ", wheel_url) + + try: + urllib.request.urlretrieve(wheel_url, wheel_filename) + + # Make the archive + # Lifted from the root wheel processing command + # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85 + if not os.path.exists(self.dist_dir): + os.makedirs(self.dist_dir) + + impl_tag, abi_tag, plat_tag = self.get_tag() + archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" + + wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") + print("Raw wheel path", wheel_path) + os.rename(wheel_filename, wheel_path) + except urllib.error.HTTPError: + print("Precompiled wheel not found. Building from source...") + # If the wheel could not be downloaded, build from source + super().run() + + +setup( + name=PACKAGE_NAME, + version=get_package_version(), + packages=find_packages( + exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",) + ), + author="Tri Dao", + author_email="trid@cs.stanford.edu", + description="Flash Attention: Fast and Memory-Efficient Exact Attention", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/Dao-AILab/flash-attention", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: BSD License", + "Operating System :: Unix", + ], + #ext_modules=ext_modules, + cmdclass={ + 'bdist_wheel': CachedWheelsCommand, + "build_ext": BuildExtension + } if ext_modules else { + 'bdist_wheel': CachedWheelsCommand, + }, + python_requires=">=3.7", + install_requires=[ + "paddle", + "einops", + "packaging", + "ninja", + ], +) From 78080ddb5dfdec3a0154698ea81ccc3a45826c39 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 13:02:53 +0800 Subject: [PATCH 04/36] update --- csrc/mp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/mp.py b/csrc/mp.py index c379c9c4e..ab8435e8e 100644 --- a/csrc/mp.py +++ b/csrc/mp.py @@ -239,7 +239,7 @@ def _is_cuda_available(): def get_package_version(): - with open(Path(this_dir) / "flash_attn" / "__init__.py", "r") as f: + with open(Path(this_dir) / "../flash_attn" / "__init__.py", "r") as f: version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) public_version = ast.literal_eval(version_match.group(1)) local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION") From 0d6766ef6c134f8a550db884409238c949c32ef8 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 14:19:48 +0800 Subject: [PATCH 05/36] update --- csrc/setup.py | 255 +++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 244 insertions(+), 11 deletions(-) diff --git a/csrc/setup.py b/csrc/setup.py index 92a9a3f16..33bc82050 100644 --- a/csrc/setup.py +++ b/csrc/setup.py @@ -1,15 +1,248 @@ +# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py +import sys +import warnings +import os +import re +import ast +from pathlib import Path +from packaging.version import parse, Version +import platform + from setuptools import setup, find_packages -from setuptools import setup, find_namespace_packages +import subprocess + +import urllib.request +import urllib.error +from wheel.bdist_wheel import bdist_wheel as _bdist_wheel + +import paddle +from paddle.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension +from paddle.utils.cpp_extension.extension_utils import find_cuda_home + +with open("README.md", "r", encoding="utf-8") as fh: + long_description = fh.read() + + +# ninja build does not work unless include_dirs are abs path +this_dir = os.path.dirname(os.path.abspath(__file__)) +CUDA_HOME = find_cuda_home() +PACKAGE_NAME = "flash_attn" + +BASE_WHEEL_URL = "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}" + +# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels +# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation +FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE" +SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE" +# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI +FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE" +# For CI, we want the option to not add "--threads 4" to nvcc, since the runner can OOM +FORCE_SINGLE_THREAD = os.getenv("FLASH_ATTENTION_FORCE_SINGLE_THREAD", "FALSE") == "TRUE" + + +def get_platform(): + """ + Returns the platform name as used in wheel filenames. + """ + if sys.platform.startswith('linux'): + return 'linux_x86_64' + elif sys.platform == 'darwin': + mac_version = '.'.join(platform.mac_ver()[0].split('.')[:2]) + return f'macosx_{mac_version}_x86_64' + elif sys.platform == 'win32': + return 'win_amd64' + else: + raise ValueError('Unsupported platform: {}'.format(sys.platform)) + + +def get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + bare_metal_version = parse(output[release_idx].split(",")[0]) + + return raw_output, bare_metal_version + + +def check_cuda_paddle_binary_vs_bare_metal(cuda_dir): + raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir) + paddle_binary_version = parse(paddle.version.cuda()) + + print("\nCompiling cuda extensions with") + print(raw_output + "from " + cuda_dir + "/bin\n") + + if (bare_metal_version != paddle_binary_version): + raise RuntimeError( + "Cuda extensions are being compiled with a version of Cuda that does " + "not match the version used to compile Pypaddle binaries. " + "Pypaddle binaries were compiled with Cuda {}.\n".format(paddle.version.cuda) + + "In some cases, a minor-version mismatch will not cause later errors: " + "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " + "You can try commenting out this check (at your own risk)." + ) + + +def raise_if_cuda_home_none(global_option: str) -> None: + if CUDA_HOME is not None: + return + raise RuntimeError( + f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " + "If you're installing within a container from https://hub.docker.com/r/pypaddle/pypaddle, " + "only images whose names contain 'devel' will provide nvcc." + ) + + +def append_nvcc_threads(nvcc_extra_args): + if not FORCE_SINGLE_THREAD: + return nvcc_extra_args + ["--threads", "4"] + return nvcc_extra_args + +def _is_cuda_available(): + """ + Check whether CUDA is available. + """ + try: + assert len(paddle.static.cuda_places()) > 0 + return True + except Exception as e: + logging.warning( + "You are using GPU version PaddlePaddle, but there is no GPU " + "detected on your machine. Maybe CUDA devices is not set properly." + f"\n Original Error is {e}" + ) + return False + +if paddle.is_compiled_with_cuda() and _is_cuda_available(): + # https://github.com/NVIDIA/apex/issues/486 + # Extension builds after https://github.com/pypaddle/pypaddle/pull/23408 attempt to query paddle.cuda.get_device_capability(), + # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). + print( + "\nWarning: Torch did not find available GPUs on this system.\n", + "If your intention is to cross-compile, this is not an error.\n" + "By default, FlashAttention will cross-compile for Ampere (compute capability 8.0, 8.6, " + "8.9), and, if the CUDA version is >= 11.8, Hopper (compute capability 9.0).\n" + "If you wish to cross-compile for a single specific architecture,\n" + 'export PADDLE_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', + ) + if os.environ.get("PADDLE_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version >= Version("11.8"): + os.environ["PADDLE_CUDA_ARCH_LIST"] = "8.0;8.6;9.0" + elif bare_metal_version >= Version("11.4"): + os.environ["PADDLE_CUDA_ARCH_LIST"] = "8.0;8.6" + else: + os.environ["PADDLE_CUDA_ARCH_LIST"] = "8.0;8.6" + +cmdclass = {} +ext_modules = [] + +def get_package_version(): + with open(Path(this_dir) / "../flash_attn" / "__init__.py", "r") as f: + version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) + public_version = ast.literal_eval(version_match.group(1)) + local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION") + if local_version: + return f"{public_version}+{local_version}" + else: + return str(public_version) + +def get_data_files(): + data_files = [] + + # Assuming 'libflashattn.so' is located in the same directory as setup.py + source_lib_path = 'libflashattn.so' + + # Specify the destination directory within the package + destination_lib_path = os.path.join(PACKAGE_NAME, 'libflashattn.so') + + data_files.append((os.path.join(PACKAGE_NAME, 'libflashattn.so'), [source_lib_path])) + return data_files + + +class CachedWheelsCommand(_bdist_wheel): + """ + The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot + find an existing wheel (which is currently the case for all flash attention installs). We use + the environment parameters to detect whether there is already a pre-built version of a compatible + wheel available and short-circuits the standard full build pipeline. + """ + def run(self): + if FORCE_BUILD: + return super().run() + + # Determine the version numbers that will be used to determine the correct wheel + # We're using the CUDA version used to build paddle, not the one currently installed + # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) + paddle_cuda_version = parse(paddle.version.cuda) + paddle_version_raw = parse(paddle.__version__) + python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" + platform_name = get_platform() + flash_version = get_package_version() + # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" + cuda_version = f"{paddle_cuda_version.major}{paddle_cuda_version.minor}" + paddle_version = f"{paddle_version_raw.major}.{paddle_version_raw.minor}" + cxx11_abi = str(paddle._C._GLIBCXX_USE_CXX11_ABI).upper() + + # Determine wheel URL based on CUDA version, paddle version, python version and OS + wheel_filename = f'{PACKAGE_NAME}-{flash_version}+cu{cuda_version}paddle{paddle_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl' + wheel_url = BASE_WHEEL_URL.format( + tag_name=f"v{flash_version}", + wheel_name=wheel_filename + ) + print("Guessing wheel URL: ", wheel_url) + + try: + urllib.request.urlretrieve(wheel_url, wheel_filename) + + # Make the archive + # Lifted from the root wheel processing command + # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85 + if not os.path.exists(self.dist_dir): + os.makedirs(self.dist_dir) + + impl_tag, abi_tag, plat_tag = self.get_tag() + archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" + + wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") + print("Raw wheel path", wheel_path) + os.rename(wheel_filename, wheel_path) + except urllib.error.HTTPError: + print("Precompiled wheel not found. Building from source...") + # If the wheel could not be downloaded, build from source + super().run() + setup( - packages=find_packages(where="src"), - package_dir={"": "src"}, - package_data={"": ["*.so"]}, - exclude_package_data={"flash_attn_with_bias_and_mask": ["*"]}, - include_package_data=True, - #packages=find_namespace_packages(where="src"), - #package_dir={"": "src"}, - #package_data={ - # "": ["*.so"], - #} + name=PACKAGE_NAME, + version=get_package_version(), + packages=find_packages( + #exclude=("build") + #, "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",) + ), + data_files=get_data_files(), + package_data={PACKAGE_NAME: ['build/libflashattn.so']}, + author_email="Paddle-better@baidu.com", + description="Flash Attention: Fast and Memory-Efficient Exact Attention", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/PaddlePaddle/flash-attention", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: BSD License", + "Operating System :: Unix", + ], + ext_modules=ext_modules, + cmdclass={ + 'bdist_wheel': CachedWheelsCommand, + "build_ext": BuildExtension + } if ext_modules else { + 'bdist_wheel': CachedWheelsCommand, + }, + python_requires=">=3.7", + install_requires=[ + "paddle", + "einops", + "packaging", + "ninja", + ], ) From 17b89c9047017af7cb127bf5c1eec349e9045b5a Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 16:18:14 +0800 Subject: [PATCH 06/36] updat --- csrc/CMakeLists.txt | 2 +- csrc/setup.py | 287 ++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 281 insertions(+), 8 deletions(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 639d8b1eb..3b84b5388 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -138,7 +138,7 @@ INSTALL(TARGETS flashattn INSTALL(FILES capi/flash_attn.h DESTINATION "include") add_custom_target(run_my_executable - COMMAND ${CMAKE_COMMAND} -E env python ${CMAKE_SOURCE_DIR}/tp.py sdist bdist_wheel + COMMAND ${CMAKE_COMMAND} -E env python ${CMAKE_SOURCE_DIR}/setup.py sdist bdist_wheel WORKING_DIRECTORY ${CMAKE_BINARY_DIR} DEPENDS flashattn COMMENT "Running my_executable" diff --git a/csrc/setup.py b/csrc/setup.py index 33bc82050..47d6a9555 100644 --- a/csrc/setup.py +++ b/csrc/setup.py @@ -8,8 +8,135 @@ from packaging.version import parse, Version import platform +from setuptools import Command, Extension, setup +from setuptools.command.develop import develop as DevelopCommandBase +from setuptools.command.egg_info import egg_info +from setuptools.command.install import install as InstallCommandBase +from setuptools.command.install_lib import install_lib +from setuptools.dist import Distribution from setuptools import setup, find_packages import subprocess +python_version = platform.python_version() +version_detail = sys.version_info +version = version_detail[0] + version_detail[1] / 10 +env_version = os.getenv("PY_VERSION") + +if version < 3.7: + raise RuntimeError( + f"Paddle only supports Python version >= 3.7 now," + f"you are using Python {python_version}" + ) +elif env_version is None: + print(f"export PY_VERSION = { python_version }") + os.environ["PY_VERSION"] = python_version + +elif env_version != version: + warnings.warn( + f"You set PY_VERSION={env_version}, but" + f"your current python environment is {version}" + f"we will use your current python version to execute" + ) + os.environ["PY_VERSION"] = python_version + + +global env_dict # noqa: F811 +env_dict={ + 'PADDLE_SOURCE_DIR':'@PADDLE_SOURCE_DIR@', + 'PADDLE_VERSION':'@PADDLE_VERSION@', + 'PADDLE_BINARY_DIR':'@PADDLE_BINARY_DIR@', + 'TAG_VERSION_REGEX':'@TAG_VERSION_REGEX@', + 'WITH_GPU':'@WITH_GPU@', + 'CUDNN_MAJOR_VERSION':'@CUDNN_MAJOR_VERSION@', + 'CUDNN_MINOR_VERSION':'@CUDNN_MINOR_VERSION@', + 'CUDNN_PATCHLEVEL_VERSION':'@CUDNN_PATCHLEVEL_VERSION@', + 'CUDA_VERSION':'@CUDA_VERSION@', + 'WITH_PSLI':'@WITH_PSLI@', + 'FLUID_CORE_NAME':'@FLUID_CORE_NAME@', + 'PHI_LIB':'@PHI_LIB@', + 'PHI_NAME':'@PHI_NAME@', + 'WITH_SHARED_PHI':'@WITH_SHARED_PHI@', + 'IR_LIB':'@IR_LIB@', + 'IR_NAME':'@IR_NAME@', + 'WITH_SHARED_IR':'@WITH_SHARED_IR@', + 'COMMON_LIB':'@COMMON_LIB@', + 'COMMON_NAME':'@COMMON_NAME@', + 'WARPCTC_LIBRARIES':'@WARPCTC_LIBRARIES@', + 'WARPRNNT_LIBRARIES':'@WARPRNNT_LIBRARIES@', + 'FLASHATTN_LIBRARIES':'@FLASHATTN_LIBRARIES@', + 'LAPACK_LIB':'@LAPACK_LIB@', + 'GFORTRAN_LIB':'@GFORTRAN_LIB@', + 'GNU_RT_LIB_1':'@GNU_RT_LIB_1@', + 'WITH_CUDNN_DSO':'@WITH_CUDNN_DSO@', + 'CUDNN_LIBRARY':'@CUDNN_LIBRARY@', + 'GNU_RT_LIB_2':'@GNU_RT_LIB_2@', + 'WITH_MKL':'@WITH_MKL@', + 'MKLML_SHARED_LIB':'@MKLML_SHARED_LIB@', + 'MKLML_SHARED_IOMP_LIB':'@MKLML_SHARED_IOMP_LIB@', + 'OPENBLAS_SHARED_LIB':'@OPENBLAS_SHARED_LIB@', + 'OPENBLAS_LIB':'@OPENBLAS_LIB@', + 'BLAS_LIB':'@BLAS_LIB@', + 'WITH_LITE':'@WITH_LITE@', + 'LITE_SHARED_LIB':'@LITE_SHARED_LIB@', + 'LITE_WITH_NNADAPTER':'@LITE_WITH_NNADAPTER@', + 'LITE_NNADAPTER_LIB':'@LITE_NNADAPTER_LIB@', + 'NNADAPTER_WITH_HUAWEI_ASCEND_NPU':'@NNADAPTER_WITH_HUAWEI_ASCEND_NPU@', + 'LITE_NNADAPTER_NPU_LIB':'@LITE_NNADAPTER_NPU_LIB@', + 'WITH_CINN':'@WITH_CINN@', + 'CINN_LIB_LOCATION':'@CINN_LIB_LOCATION@', + 'CINN_LIB_NAME':'@CINN_LIB_NAME@', + 'CINN_INCLUDE_DIR':'@CINN_INCLUDE_DIR@', + 'CMAKE_BUILD_TYPE':'@CMAKE_BUILD_TYPE@', + 'PSLIB_LIB':'@PSLIB_LIB@', + 'JVM_LIB':'@JVM_LIB@', + 'PSLIB_VERSION_PY':'@PSLIB_VERSION_PY@', + 'WITH_MKLDNN':'@WITH_MKLDNN@', + 'MKLDNN_SHARED_LIB':'@MKLDNN_SHARED_LIB@', + 'MKLDNN_INSTALL_DIR':'@MKLDNN_INSTALL_DIR@', + 'WITH_ONNXRUNTIME':'@WITH_ONNXRUNTIME@', + 'ONNXRUNTIME_SHARED_LIB':'@ONNXRUNTIME_SHARED_LIB@', + 'PADDLE2ONNX_LIB':'@PADDLE2ONNX_LIB@', + 'PADDLE2ONNX_LIB_NAME':'@PADDLE2ONNX_LIB_NAME@', + 'ONNXRUNTIME_LIB_NAME':'@ONNXRUNTIME_LIB_NAME@', + 'WITH_XPU':'@WITH_XPU@', + 'XPU_API_LIB':'@XPU_API_LIB@', + 'XPU_API_LIB_NAME':'@XPU_API_LIB_NAME@', + 'XPU_RT_LIB':'@XPU_RT_LIB@', + 'XPU_RT_LIB_NAME':'@XPU_RT_LIB_NAME@', + 'WITH_XPU_BKCL':'@WITH_XPU_BKCL@', + 'XPU_BKCL_LIB':'@XPU_BKCL_LIB@', + 'XPU_BKCL_LIB_NAME':'@XPU_BKCL_LIB_NAME@', + 'WITH_XPU_XFT':'@WITH_XPU_XFT@', + 'XPU_XFT_LIB':'@XPU_XFT_LIB@', + 'XPU_XFT_LIB_NAME':'@XPU_XFT_LIB_NAME@', + 'THIRD_PARTY_PATH':'@THIRD_PARTY_PATH@', + 'SETUP_LOG_FILE':'@SETUP_LOG_FILE@', + 'WITH_STRIP':'@WITH_STRIP@', + 'PACKAGE_NAME':'@PACKAGE_NAME@', + 'PADDLE_VERSION':'@PADDLE_VERSION@', + 'APPLE':'@APPLE@', + 'externalError_INCLUDE_DIR':'@externalError_INCLUDE_DIR@', + 'WITH_ROCM':'@WITH_ROCM@', + 'ORIGIN':'@ORIGIN@', + 'WIN32':'@WIN32@', + 'JIT_RELEASE_WHL':'@JIT_RELEASE_WHL@', + 'WITH_PSLIB':'@WITH_PSLIB@', + 'PYBIND_INCLUDE_DIR':'@PYBIND_INCLUDE_DIR@', + 'WITH_PYTHON':'@WITH_PYTHON@', + 'WITH_CINN':'@WITH_CINN@', + 'CINN_SOURCE_DIR':'@CINN_SOURCE_DIR@', + 'WITH_CPP_DIST':'@WITH_CPP_DIST@', + 'PADDLE_INSTALL_DIR':'@PADDLE_INSTALL_DIR@', + 'PADDLE_LIB_TEST_DIR':'@PADDLE_LIB_TEST_DIR@' +} + +global paddle_binary_dir, paddle_source_dir + +paddle_binary_dir = env_dict.get("PADDLE_BINARY_DIR") +paddle_source_dir = env_dict.get("PADDLE_SOURCE_DIR") + +# preparing parameters for setup() +paddle_version = env_dict.get("PADDLE_VERSION") +package_name = env_dict.get("PACKAGE_NAME") import urllib.request import urllib.error @@ -19,7 +146,7 @@ from paddle.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension from paddle.utils.cpp_extension.extension_utils import find_cuda_home -with open("README.md", "r", encoding="utf-8") as fh: +with open("../README.md", "r", encoding="utf-8") as fh: long_description = fh.read() @@ -153,7 +280,7 @@ def get_data_files(): source_lib_path = 'libflashattn.so' # Specify the destination directory within the package - destination_lib_path = os.path.join(PACKAGE_NAME, 'libflashattn.so') + destination_lib_path = os.path.join(PACKAGE_NAME, 'build/libflashattn.so') data_files.append((os.path.join(PACKAGE_NAME, 'libflashattn.so'), [source_lib_path])) return data_files @@ -211,6 +338,145 @@ def run(self): # If the wheel could not be downloaded, build from source super().run() +class InstallHeaders(Command): + """Override how headers are copied.""" + + description = 'install C/C++ header files' + + user_options = [ + ('install-dir=', 'd', 'directory to install header files to'), + ('force', 'f', 'force installation (overwrite existing files)'), + ] + + boolean_options = ['force'] + + def initialize_options(self): + self.install_dir = None + self.force = 0 + self.outfiles = [] + + def finalize_options(self): + self.set_undefined_options( + 'install', ('install_headers', 'install_dir'), ('force', 'force') + ) + + def run(self): + hdrs = self.distribution.headers + if not hdrs: + return + self.mkpath(self.install_dir) + for header in hdrs: + install_dir = get_header_install_dir(header) + install_dir = os.path.join( + self.install_dir, os.path.dirname(install_dir) + ) + if not os.path.exists(install_dir): + self.mkpath(install_dir) + (out, _) = self.copy_file(header, install_dir) + self.outfiles.append(out) + # (out, _) = self.mkdir_and_copy_file(header) + # self.outfiles.append(out) + + def get_inputs(self): + return self.distribution.headers or [] + + def get_outputs(self): + return self.outfiles + + +class InstallCommand(InstallCommandBase): + def finalize_options(self): + ret = InstallCommandBase.finalize_options(self) + self.install_lib = self.install_platlib + + self.install_headers = os.path.join( + self.install_platlib, 'paddle', 'include' + ) + return ret + + +class DevelopCommand(DevelopCommandBase): + def run(self): + # copy proto and .so to python_source_dir + fluid_proto_binary_path = ( + paddle_binary_dir + '/python/paddle/base/proto/' + ) + fluid_proto_source_path = ( + paddle_source_dir + '/python/paddle/base/proto/' + ) + distributed_proto_binary_path = ( + paddle_binary_dir + '/python/paddle/distributed/fleet/proto/' + ) + distributed_proto_source_path = ( + paddle_source_dir + '/python/paddle/distributed/fleet/proto/' + ) + os.system(f"rm -rf {fluid_proto_source_path}") + shutil.copytree(fluid_proto_binary_path, fluid_proto_source_path) + os.system(f"rm -rf {distributed_proto_source_path}") + shutil.copytree( + distributed_proto_binary_path, distributed_proto_source_path + ) + shutil.copy( + paddle_binary_dir + '/python/paddle/base/libpaddle.so', + paddle_source_dir + '/python/paddle/base/', + ) + dynamic_library_binary_path = paddle_binary_dir + '/python/paddle/libs/' + dynamic_library_source_path = paddle_source_dir + '/python/paddle/libs/' + for lib_so in os.listdir(dynamic_library_binary_path): + shutil.copy( + dynamic_library_binary_path + lib_so, + dynamic_library_source_path, + ) + # write version.py and cuda_env_config_py to python_source_dir + write_version_py( + filename=f'{paddle_source_dir}/python/paddle/version/__init__.py' + ) + write_cuda_env_config_py( + filename=f'{paddle_source_dir}/python/paddle/cuda_env.py' + ) + write_parameter_server_version_py( + filename='{}/python/paddle/incubate/distributed/fleet/parameter_server/version.py'.format( + paddle_source_dir + ) + ) + DevelopCommandBase.run(self) + + +class EggInfo(egg_info): + """Copy license file into `.dist-info` folder.""" + + def run(self): + # don't duplicate license into `.dist-info` when building a distribution + if not self.distribution.have_run.get('install', True): + self.mkpath(self.egg_info) + #self.copy_file( + # env_dict.get("PADDLE_SOURCE_DIR") + "/LICENSE", self.egg_info + #) + + egg_info.run(self) + + +# class Installlib is rewritten to add header files to .egg/paddle +class InstallLib(install_lib): + def run(self): + self.build() + outfiles = self.install() + hrds = self.distribution.headers + if not hrds: + return + for header in hrds: + install_dir = get_header_install_dir(header) + install_dir = os.path.join( + self.install_dir, 'paddle/include', os.path.dirname(install_dir) + ) + if not os.path.exists(install_dir): + self.mkpath(install_dir) + self.copy_file(header, install_dir) + if outfiles is not None: + # always compile, in case we have any extension stubs to deal with + self.byte_compile(outfiles) + + setup( name=PACKAGE_NAME, @@ -227,17 +493,24 @@ def run(self): long_description_content_type="text/markdown", url="https://github.com/PaddlePaddle/flash-attention", classifiers=[ - "Programming Language :: Python :: 3", + "Programming Language :: Python :: 37", "License :: OSI Approved :: BSD License", "Operating System :: Unix", ], ext_modules=ext_modules, cmdclass={ - 'bdist_wheel': CachedWheelsCommand, - "build_ext": BuildExtension - } if ext_modules else { - 'bdist_wheel': CachedWheelsCommand, + 'install_headers': InstallHeaders, + 'install': InstallCommand, + 'egg_info': EggInfo, + 'install_lib': InstallLib, + 'develop': DevelopCommand, }, + #cmdclass={ + # "bdist_wheel": CachedWheelsCommand, + # "build_ext": BuildExtension + #} if ext_modules else { + # "bdist_wheel": CachedWheelsCommand, + #}, python_requires=">=3.7", install_requires=[ "paddle", From 1355060bb29ad6c703f003d8022fb786cc158a52 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 17:10:56 +0800 Subject: [PATCH 07/36] update --- csrc/setup.py | 198 +++++++++----------------------------------------- 1 file changed, 35 insertions(+), 163 deletions(-) diff --git a/csrc/setup.py b/csrc/setup.py index 47d6a9555..f2d7357ce 100644 --- a/csrc/setup.py +++ b/csrc/setup.py @@ -7,7 +7,7 @@ from pathlib import Path from packaging.version import parse, Version import platform - +import shutil from setuptools import Command, Extension, setup from setuptools.command.develop import develop as DevelopCommandBase from setuptools.command.egg_info import egg_info @@ -15,9 +15,17 @@ from setuptools.command.install_lib import install_lib from setuptools.dist import Distribution from setuptools import setup, find_packages +import urllib.request +import urllib.error +from wheel.bdist_wheel import bdist_wheel as _bdist_wheel + +import paddle +from paddle.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension +from paddle.utils.cpp_extension.extension_utils import find_cuda_home import subprocess -python_version = platform.python_version() + version_detail = sys.version_info +python_version = platform.python_version() version = version_detail[0] + version_detail[1] / 10 env_version = os.getenv("PY_VERSION") @@ -38,113 +46,16 @@ ) os.environ["PY_VERSION"] = python_version +paddle_include_path = paddle.sysconfig.get_include() +paddle_lib_path = paddle.sysconfig.get_lib() -global env_dict # noqa: F811 -env_dict={ - 'PADDLE_SOURCE_DIR':'@PADDLE_SOURCE_DIR@', - 'PADDLE_VERSION':'@PADDLE_VERSION@', - 'PADDLE_BINARY_DIR':'@PADDLE_BINARY_DIR@', - 'TAG_VERSION_REGEX':'@TAG_VERSION_REGEX@', - 'WITH_GPU':'@WITH_GPU@', - 'CUDNN_MAJOR_VERSION':'@CUDNN_MAJOR_VERSION@', - 'CUDNN_MINOR_VERSION':'@CUDNN_MINOR_VERSION@', - 'CUDNN_PATCHLEVEL_VERSION':'@CUDNN_PATCHLEVEL_VERSION@', - 'CUDA_VERSION':'@CUDA_VERSION@', - 'WITH_PSLI':'@WITH_PSLI@', - 'FLUID_CORE_NAME':'@FLUID_CORE_NAME@', - 'PHI_LIB':'@PHI_LIB@', - 'PHI_NAME':'@PHI_NAME@', - 'WITH_SHARED_PHI':'@WITH_SHARED_PHI@', - 'IR_LIB':'@IR_LIB@', - 'IR_NAME':'@IR_NAME@', - 'WITH_SHARED_IR':'@WITH_SHARED_IR@', - 'COMMON_LIB':'@COMMON_LIB@', - 'COMMON_NAME':'@COMMON_NAME@', - 'WARPCTC_LIBRARIES':'@WARPCTC_LIBRARIES@', - 'WARPRNNT_LIBRARIES':'@WARPRNNT_LIBRARIES@', - 'FLASHATTN_LIBRARIES':'@FLASHATTN_LIBRARIES@', - 'LAPACK_LIB':'@LAPACK_LIB@', - 'GFORTRAN_LIB':'@GFORTRAN_LIB@', - 'GNU_RT_LIB_1':'@GNU_RT_LIB_1@', - 'WITH_CUDNN_DSO':'@WITH_CUDNN_DSO@', - 'CUDNN_LIBRARY':'@CUDNN_LIBRARY@', - 'GNU_RT_LIB_2':'@GNU_RT_LIB_2@', - 'WITH_MKL':'@WITH_MKL@', - 'MKLML_SHARED_LIB':'@MKLML_SHARED_LIB@', - 'MKLML_SHARED_IOMP_LIB':'@MKLML_SHARED_IOMP_LIB@', - 'OPENBLAS_SHARED_LIB':'@OPENBLAS_SHARED_LIB@', - 'OPENBLAS_LIB':'@OPENBLAS_LIB@', - 'BLAS_LIB':'@BLAS_LIB@', - 'WITH_LITE':'@WITH_LITE@', - 'LITE_SHARED_LIB':'@LITE_SHARED_LIB@', - 'LITE_WITH_NNADAPTER':'@LITE_WITH_NNADAPTER@', - 'LITE_NNADAPTER_LIB':'@LITE_NNADAPTER_LIB@', - 'NNADAPTER_WITH_HUAWEI_ASCEND_NPU':'@NNADAPTER_WITH_HUAWEI_ASCEND_NPU@', - 'LITE_NNADAPTER_NPU_LIB':'@LITE_NNADAPTER_NPU_LIB@', - 'WITH_CINN':'@WITH_CINN@', - 'CINN_LIB_LOCATION':'@CINN_LIB_LOCATION@', - 'CINN_LIB_NAME':'@CINN_LIB_NAME@', - 'CINN_INCLUDE_DIR':'@CINN_INCLUDE_DIR@', - 'CMAKE_BUILD_TYPE':'@CMAKE_BUILD_TYPE@', - 'PSLIB_LIB':'@PSLIB_LIB@', - 'JVM_LIB':'@JVM_LIB@', - 'PSLIB_VERSION_PY':'@PSLIB_VERSION_PY@', - 'WITH_MKLDNN':'@WITH_MKLDNN@', - 'MKLDNN_SHARED_LIB':'@MKLDNN_SHARED_LIB@', - 'MKLDNN_INSTALL_DIR':'@MKLDNN_INSTALL_DIR@', - 'WITH_ONNXRUNTIME':'@WITH_ONNXRUNTIME@', - 'ONNXRUNTIME_SHARED_LIB':'@ONNXRUNTIME_SHARED_LIB@', - 'PADDLE2ONNX_LIB':'@PADDLE2ONNX_LIB@', - 'PADDLE2ONNX_LIB_NAME':'@PADDLE2ONNX_LIB_NAME@', - 'ONNXRUNTIME_LIB_NAME':'@ONNXRUNTIME_LIB_NAME@', - 'WITH_XPU':'@WITH_XPU@', - 'XPU_API_LIB':'@XPU_API_LIB@', - 'XPU_API_LIB_NAME':'@XPU_API_LIB_NAME@', - 'XPU_RT_LIB':'@XPU_RT_LIB@', - 'XPU_RT_LIB_NAME':'@XPU_RT_LIB_NAME@', - 'WITH_XPU_BKCL':'@WITH_XPU_BKCL@', - 'XPU_BKCL_LIB':'@XPU_BKCL_LIB@', - 'XPU_BKCL_LIB_NAME':'@XPU_BKCL_LIB_NAME@', - 'WITH_XPU_XFT':'@WITH_XPU_XFT@', - 'XPU_XFT_LIB':'@XPU_XFT_LIB@', - 'XPU_XFT_LIB_NAME':'@XPU_XFT_LIB_NAME@', - 'THIRD_PARTY_PATH':'@THIRD_PARTY_PATH@', - 'SETUP_LOG_FILE':'@SETUP_LOG_FILE@', - 'WITH_STRIP':'@WITH_STRIP@', - 'PACKAGE_NAME':'@PACKAGE_NAME@', - 'PADDLE_VERSION':'@PADDLE_VERSION@', - 'APPLE':'@APPLE@', - 'externalError_INCLUDE_DIR':'@externalError_INCLUDE_DIR@', - 'WITH_ROCM':'@WITH_ROCM@', - 'ORIGIN':'@ORIGIN@', - 'WIN32':'@WIN32@', - 'JIT_RELEASE_WHL':'@JIT_RELEASE_WHL@', - 'WITH_PSLIB':'@WITH_PSLIB@', - 'PYBIND_INCLUDE_DIR':'@PYBIND_INCLUDE_DIR@', - 'WITH_PYTHON':'@WITH_PYTHON@', - 'WITH_CINN':'@WITH_CINN@', - 'CINN_SOURCE_DIR':'@CINN_SOURCE_DIR@', - 'WITH_CPP_DIST':'@WITH_CPP_DIST@', - 'PADDLE_INSTALL_DIR':'@PADDLE_INSTALL_DIR@', - 'PADDLE_LIB_TEST_DIR':'@PADDLE_LIB_TEST_DIR@' -} - -global paddle_binary_dir, paddle_source_dir - -paddle_binary_dir = env_dict.get("PADDLE_BINARY_DIR") -paddle_source_dir = env_dict.get("PADDLE_SOURCE_DIR") +print("Paddle Include Path:", paddle_include_path) +print("Paddle Lib Path:", paddle_lib_path) # preparing parameters for setup() -paddle_version = env_dict.get("PADDLE_VERSION") -package_name = env_dict.get("PACKAGE_NAME") +paddle_version = paddle.version.full_version +cuda_version= paddle.version.cuda_version -import urllib.request -import urllib.error -from wheel.bdist_wheel import bdist_wheel as _bdist_wheel - -import paddle -from paddle.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension -from paddle.utils.cpp_extension.extension_utils import find_cuda_home with open("../README.md", "r", encoding="utf-8") as fh: long_description = fh.read() @@ -153,14 +64,11 @@ # ninja build does not work unless include_dirs are abs path this_dir = os.path.dirname(os.path.abspath(__file__)) CUDA_HOME = find_cuda_home() -PACKAGE_NAME = "flash_attn" +PACKAGE_NAME = "paddle_flash_attn" -BASE_WHEEL_URL = "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}" +BASE_WHEEL_URL = "https://github.com/PaddlePaddle/flash-attention/releases/download/{tag_name}/{wheel_name}" -# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels -# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE" -SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE" # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE" # For CI, we want the option to not add "--threads 4" to nvcc, since the runner can OOM @@ -218,12 +126,6 @@ def raise_if_cuda_home_none(global_option: str) -> None: "only images whose names contain 'devel' will provide nvcc." ) - -def append_nvcc_threads(nvcc_extra_args): - if not FORCE_SINGLE_THREAD: - return nvcc_extra_args + ["--threads", "4"] - return nvcc_extra_args - def _is_cuda_available(): """ Check whether CUDA is available. @@ -261,7 +163,6 @@ def _is_cuda_available(): os.environ["PADDLE_CUDA_ARCH_LIST"] = "8.0;8.6" cmdclass = {} -ext_modules = [] def get_package_version(): with open(Path(this_dir) / "../flash_attn" / "__init__.py", "r") as f: @@ -294,49 +195,32 @@ class CachedWheelsCommand(_bdist_wheel): wheel available and short-circuits the standard full build pipeline. """ def run(self): - if FORCE_BUILD: - return super().run() - + print("88888888888888888888888888888") + # if FORCE_BUILD: + # return super().run() + self.run_command('build_ext') + super().run() # Determine the version numbers that will be used to determine the correct wheel # We're using the CUDA version used to build paddle, not the one currently installed # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) - paddle_cuda_version = parse(paddle.version.cuda) + paddle_cuda_version = "234" #parse(paddle.version.cuda) paddle_version_raw = parse(paddle.__version__) python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" platform_name = get_platform() flash_version = get_package_version() - # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" - cuda_version = f"{paddle_cuda_version.major}{paddle_cuda_version.minor}" - paddle_version = f"{paddle_version_raw.major}.{paddle_version_raw.minor}" - cxx11_abi = str(paddle._C._GLIBCXX_USE_CXX11_ABI).upper() + cxx11_abi ="" # str(paddle._C.-D_GLIBCXX_USE_CXX11_ABI).upper() # Determine wheel URL based on CUDA version, paddle version, python version and OS - wheel_filename = f'{PACKAGE_NAME}-{flash_version}+cu{cuda_version}paddle{paddle_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl' - wheel_url = BASE_WHEEL_URL.format( - tag_name=f"v{flash_version}", - wheel_name=wheel_filename - ) - print("Guessing wheel URL: ", wheel_url) - - try: - urllib.request.urlretrieve(wheel_url, wheel_filename) - - # Make the archive - # Lifted from the root wheel processing command - # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85 - if not os.path.exists(self.dist_dir): - os.makedirs(self.dist_dir) - - impl_tag, abi_tag, plat_tag = self.get_tag() - archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" + wheel_filename = f'{PACKAGE_NAME}-{flash_version}-cu{cuda_version}-paddle{paddle_version}-{python_version}-{python_version}-{platform_name}.whl' + impl_tag, abi_tag, plat_tag = self.get_tag() + original_wheel_name = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" - wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") - print("Raw wheel path", wheel_path) - os.rename(wheel_filename, wheel_path) - except urllib.error.HTTPError: - print("Precompiled wheel not found. Building from source...") - # If the wheel could not be downloaded, build from source - super().run() + new_wheel_name = wheel_filename + print("self.asdfasdfsdfasdfasdfasdf", self.get_tag()) + shutil.move( + f"{self.dist_dir}/{original_wheel_name}.whl", + f"{self.dist_dir}/{new_wheel_name}" + ) class InstallHeaders(Command): """Override how headers are copied.""" @@ -497,20 +381,8 @@ def run(self): "License :: OSI Approved :: BSD License", "Operating System :: Unix", ], - ext_modules=ext_modules, cmdclass={ - 'install_headers': InstallHeaders, - 'install': InstallCommand, - 'egg_info': EggInfo, - 'install_lib': InstallLib, - 'develop': DevelopCommand, - }, - #cmdclass={ - # "bdist_wheel": CachedWheelsCommand, - # "build_ext": BuildExtension - #} if ext_modules else { - # "bdist_wheel": CachedWheelsCommand, - #}, + "bdist_wheel": CachedWheelsCommand,}, python_requires=">=3.7", install_requires=[ "paddle", From d536119e240ac62eb6b8b71c872115a49200c11e Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 18:53:36 +0800 Subject: [PATCH 08/36] update --- csrc/setup.py | 178 ++------------------------------------------------ 1 file changed, 7 insertions(+), 171 deletions(-) diff --git a/csrc/setup.py b/csrc/setup.py index f2d7357ce..dcfc4e691 100644 --- a/csrc/setup.py +++ b/csrc/setup.py @@ -141,27 +141,6 @@ def _is_cuda_available(): ) return False -if paddle.is_compiled_with_cuda() and _is_cuda_available(): - # https://github.com/NVIDIA/apex/issues/486 - # Extension builds after https://github.com/pypaddle/pypaddle/pull/23408 attempt to query paddle.cuda.get_device_capability(), - # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). - print( - "\nWarning: Torch did not find available GPUs on this system.\n", - "If your intention is to cross-compile, this is not an error.\n" - "By default, FlashAttention will cross-compile for Ampere (compute capability 8.0, 8.6, " - "8.9), and, if the CUDA version is >= 11.8, Hopper (compute capability 9.0).\n" - "If you wish to cross-compile for a single specific architecture,\n" - 'export PADDLE_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', - ) - if os.environ.get("PADDLE_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version >= Version("11.8"): - os.environ["PADDLE_CUDA_ARCH_LIST"] = "8.0;8.6;9.0" - elif bare_metal_version >= Version("11.4"): - os.environ["PADDLE_CUDA_ARCH_LIST"] = "8.0;8.6" - else: - os.environ["PADDLE_CUDA_ARCH_LIST"] = "8.0;8.6" - cmdclass = {} def get_package_version(): @@ -183,7 +162,7 @@ def get_data_files(): # Specify the destination directory within the package destination_lib_path = os.path.join(PACKAGE_NAME, 'build/libflashattn.so') - data_files.append((os.path.join(PACKAGE_NAME, 'libflashattn.so'), [source_lib_path])) + data_files.append((paddle_lib_path, [source_lib_path])) return data_files @@ -215,160 +194,17 @@ def run(self): impl_tag, abi_tag, plat_tag = self.get_tag() original_wheel_name = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" - new_wheel_name = wheel_filename - print("self.asdfasdfsdfasdfasdfasdf", self.get_tag()) - shutil.move( - f"{self.dist_dir}/{original_wheel_name}.whl", - f"{self.dist_dir}/{new_wheel_name}" - ) - -class InstallHeaders(Command): - """Override how headers are copied.""" - - description = 'install C/C++ header files' - - user_options = [ - ('install-dir=', 'd', 'directory to install header files to'), - ('force', 'f', 'force installation (overwrite existing files)'), - ] - - boolean_options = ['force'] - - def initialize_options(self): - self.install_dir = None - self.force = 0 - self.outfiles = [] - - def finalize_options(self): - self.set_undefined_options( - 'install', ('install_headers', 'install_dir'), ('force', 'force') - ) - - def run(self): - hdrs = self.distribution.headers - if not hdrs: - return - self.mkpath(self.install_dir) - for header in hdrs: - install_dir = get_header_install_dir(header) - install_dir = os.path.join( - self.install_dir, os.path.dirname(install_dir) - ) - if not os.path.exists(install_dir): - self.mkpath(install_dir) - (out, _) = self.copy_file(header, install_dir) - self.outfiles.append(out) - # (out, _) = self.mkdir_and_copy_file(header) - # self.outfiles.append(out) - - def get_inputs(self): - return self.distribution.headers or [] - - def get_outputs(self): - return self.outfiles - - -class InstallCommand(InstallCommandBase): - def finalize_options(self): - ret = InstallCommandBase.finalize_options(self) - self.install_lib = self.install_platlib - - self.install_headers = os.path.join( - self.install_platlib, 'paddle', 'include' - ) - return ret - - -class DevelopCommand(DevelopCommandBase): - def run(self): - # copy proto and .so to python_source_dir - fluid_proto_binary_path = ( - paddle_binary_dir + '/python/paddle/base/proto/' - ) - fluid_proto_source_path = ( - paddle_source_dir + '/python/paddle/base/proto/' - ) - distributed_proto_binary_path = ( - paddle_binary_dir + '/python/paddle/distributed/fleet/proto/' - ) - distributed_proto_source_path = ( - paddle_source_dir + '/python/paddle/distributed/fleet/proto/' - ) - os.system(f"rm -rf {fluid_proto_source_path}") - shutil.copytree(fluid_proto_binary_path, fluid_proto_source_path) - os.system(f"rm -rf {distributed_proto_source_path}") - shutil.copytree( - distributed_proto_binary_path, distributed_proto_source_path - ) - shutil.copy( - paddle_binary_dir + '/python/paddle/base/libpaddle.so', - paddle_source_dir + '/python/paddle/base/', - ) - dynamic_library_binary_path = paddle_binary_dir + '/python/paddle/libs/' - dynamic_library_source_path = paddle_source_dir + '/python/paddle/libs/' - for lib_so in os.listdir(dynamic_library_binary_path): - shutil.copy( - dynamic_library_binary_path + lib_so, - dynamic_library_source_path, - ) - # write version.py and cuda_env_config_py to python_source_dir - write_version_py( - filename=f'{paddle_source_dir}/python/paddle/version/__init__.py' - ) - write_cuda_env_config_py( - filename=f'{paddle_source_dir}/python/paddle/cuda_env.py' - ) - write_parameter_server_version_py( - filename='{}/python/paddle/incubate/distributed/fleet/parameter_server/version.py'.format( - paddle_source_dir - ) - ) - DevelopCommandBase.run(self) - - -class EggInfo(egg_info): - """Copy license file into `.dist-info` folder.""" - - def run(self): - # don't duplicate license into `.dist-info` when building a distribution - if not self.distribution.have_run.get('install', True): - self.mkpath(self.egg_info) - #self.copy_file( - # env_dict.get("PADDLE_SOURCE_DIR") + "/LICENSE", self.egg_info - #) - - egg_info.run(self) - - -# class Installlib is rewritten to add header files to .egg/paddle -class InstallLib(install_lib): - def run(self): - self.build() - outfiles = self.install() - hrds = self.distribution.headers - if not hrds: - return - for header in hrds: - install_dir = get_header_install_dir(header) - install_dir = os.path.join( - self.install_dir, 'paddle/include', os.path.dirname(install_dir) - ) - if not os.path.exists(install_dir): - self.mkpath(install_dir) - self.copy_file(header, install_dir) - if outfiles is not None: - # always compile, in case we have any extension stubs to deal with - self.byte_compile(outfiles) - + new_wheel_name ='asdfsdf.whl' # wheel_filename + #shutil.move( + # f"{self.dist_dir}/{original_wheel_name}.whl", + # f"{self.dist_dir}/{new_wheel_name}" + #) setup( name=PACKAGE_NAME, version=get_package_version(), - packages=find_packages( - #exclude=("build") - #, "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",) - ), + packages=find_packages(), data_files=get_data_files(), package_data={PACKAGE_NAME: ['build/libflashattn.so']}, author_email="Paddle-better@baidu.com", From 66fc8a754ceee446a05964bbf99746536fd82c3f Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 19:00:56 +0800 Subject: [PATCH 09/36] update --- csrc/CMakeLists.txt | 5 +- csrc/README.md | 234 ------------ .../src/flash_bwd_launch_template.h | 17 + .../src/flash_fwd_launch_template.h | 22 ++ csrc/mp.py | 336 ------------------ csrc/tp.py | 77 ---- 6 files changed, 43 insertions(+), 648 deletions(-) delete mode 100644 csrc/README.md delete mode 100644 csrc/mp.py delete mode 100644 csrc/tp.py diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 3b84b5388..9195ca565 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -6,7 +6,10 @@ find_package(Git QUIET REQUIRED) execute_process(COMMAND ${GIT_EXECUTABLE} submodule update --init --recursive WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} RESULT_VARIABLE GIT_SUBMOD_RESULT) - +#cmake -DWITH_ADVANCED +if (WITH_ADVANCED) + add_compile_definitions(PADDLE_WITH_ADVANCED)cu +endif() add_definitions("-DFLASH_ATTN_WITH_TORCH=0") set(CUTLASS_3_DIR ${CMAKE_CURRENT_SOURCE_DIR}/cutlass) diff --git a/csrc/README.md b/csrc/README.md deleted file mode 100644 index 79d334530..000000000 --- a/csrc/README.md +++ /dev/null @@ -1,234 +0,0 @@ -# FlashAttention -This repository provides the official implementation of FlashAttention and -FlashAttention-2 from the -following papers. - -**FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness** -Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré -Paper: https://arxiv.org/abs/2205.14135 -IEEE Spectrum [article](https://spectrum.ieee.org/mlperf-rankings-2022) about our submission to the MLPerf 2.0 benchmark using FlashAttention. -![FlashAttention](assets/flashattn_banner.jpg) - -**FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning** -Tri Dao - -Paper: https://tridao.me/publications/flash2/flash2.pdf - -![FlashAttention-2](assets/flashattention_logo.png) - - -## Usage - -We've been very happy to see FlashAttention being widely adopted in such a short -time after its release. This [page](https://github.com/Dao-AILab/flash-attention/blob/main/usage.md) -contains a partial list of places where FlashAttention is being used. - -FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE). -Please cite and credit FlashAttention if you use it. - -## Installation and features - -Requirements: -- CUDA 11.4 and above. -- PyTorch 1.12 and above. - -We recommend the -[Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) -container from Nvidia, which has all the required tools to install FlashAttention. - -To install: -1. Make sure that PyTorch is installed. -2. Make sure that `packaging` is installed (`pip install packaging`) -3. Make sure that `ninja` is installed and that it works correctly (e.g. `ninja ---version` then `echo $?` should return exit code 0). If not (sometimes `ninja ---version` then `echo $?` returns a nonzero exit code), uninstall then reinstall -`ninja` (`pip uninstall -y ninja && pip install ninja`). Without `ninja`, -compiling can take a very long time (2h) since it does not use multiple CPU -cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine. -4. Then: -```sh -pip install flash-attn --no-build-isolation -``` -Alternatively you can compile from source: -```sh -python setup.py install -``` - -If your machine has less than 96GB of RAM and lots of CPU cores, `ninja` might -run too many parallel compilation jobs that could exhaust the amount of RAM. To -limit the number of parallel compilation jobs, you can set the environment -variable `MAX_JOBS`: -```sh -MAX_JOBS=4 pip install flash-attn --no-build-isolation -``` - -Interface: `src/flash_attention_interface.py` - -FlashAttention-2 currently supports: -1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing - GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing - GPUs for now. -2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs). -3. All head dimensions up to 256. Head dim > 192 backward requires A100/A800 or H100/H800. - - -## How to use FlashAttention - -The main functions implement scaled dot product attention (softmax(Q @ K^T * -softmax_scale) @ V): -```python -from flash_attn import flash_attn_qkvpacked_func, flash_attn_func -``` - -```python -flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False): -"""dropout_p should be set to 0.0 during evaluation -If Q, K, V are already stacked into 1 tensor, this function will be faster than -calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation -of the gradients of Q, K, V. -Arguments: - qkv: (batch_size, seqlen, 3, nheads, headdim) - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). -Return: - out: (batch_size, seqlen, nheads, headdim). -""" -``` - -```python -flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False): -"""dropout_p should be set to 0.0 during evaluation -Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads -than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. -For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head -0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - -Arguments: - q: (batch_size, seqlen, nheads, headdim) - k: (batch_size, seqlen, nheads_k, headdim) - v: (batch_size, seqlen, nheads_k, headdim) - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). -Return: - out: (batch_size, seqlen, nheads, headdim). -""" -``` - -To see how these functions are used in a multi-head attention layer (which -includes QKV projection, output projection), see the MHA [implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py). - -## Upgrading from FlashAttention (1.x) to FlashAttention-2 - -These functions have been renamed: -- `flash_attn_unpadded_func` -> `flash_attn_varlen_func` -- `flash_attn_unpadded_qkvpacked_func` -> `flash_attn_varlen_qkvpacked_func` -- `flash_attn_unpadded_kvpacked_func` -> `flash_attn_varlen_kvpacked_func` - -If the inputs have the same sequence lengths in the same batch, it is simpler -and faster to use these functions: -```python -flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False) -``` -```python -flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False) -``` - -## Performance - -We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory). - -We currently have benchmarks for these GPUs: -* [A100](#a100) -* [H100](#h100) - - - -### A100 - -We display FlashAttention speedup using these parameters: -* Head dimension 64 or 128, hidden dimension 2048 (i.e. either 32 or 16 heads). -* Sequence length 512, 1k, 2k, 4k, 8k, 16k. -* Batch size set to 16k / seqlen. - -#### Speedup - -![FlashAttention speedup on A100 80GB SXM5 with FP16/BF16](assets/flash2_a100_fwd_bwd_benchmark.png) - -#### Memory - -![FlashAttention memory](assets/flashattn_memory.jpg) - -We show memory savings in this graph (note that memory footprint is the same no matter if you use dropout or masking). -Memory savings are proportional to sequence length -- since standard attention has memory quadratic in sequence length, whereas FlashAttention has memory linear in sequence length. -We see 10X memory savings at sequence length 2K, and 20X at 4K. -As a result, FlashAttention can scale to much longer sequence lengths. - -### H100 - -![FlashAttention speedup on H100 SXM5 with FP16/BF16](assets/flash2_h100_fwd_bwd_benchmark.png) - -## Full model code and training script - -We have released the full GPT model -[implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/gpt.py). -We also provide optimized implementations of other layers (e.g., MLP, LayerNorm, -cross-entropy loss, rotary embedding). Overall this speeds up training by 3-5x -compared to the baseline implementation from Huggingface, reaching up to 225 -TFLOPs/sec per A100, equivalent to 72% model FLOPs utilization (we don't need -any activation checkpointing). - -We also include a training -[script](https://github.com/Dao-AILab/flash-attention/tree/main/training) to -train GPT2 on Openwebtext and GPT3 on The Pile. - -## Triton implementation of FlashAttention - -Phil Tillet (OpenAI) has an experimental implementation of FlashAttention in Triton: -https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py - -As Triton is a higher-level language than CUDA, it might be easier to understand -and experiment with. The notations in the Triton implementation are also closer -to what's used in our paper. - -We also have an experimental implementation in Triton that support attention -bias (e.g. ALiBi): -https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py - - -## Tests -We test that FlashAttention produces the same output and gradient as a reference -implementation, up to some numerical tolerance. In particular, we check that the -maximum numerical error of FlashAttention is at most twice the numerical error -of a baseline implementation in Pytorch (for different head dimensions, input -dtype, sequence length, causal / non-causal). - -To run the tests: -```sh -pytest -q -s tests/test_flash_attn.py -``` -## When you encounter issues - -This new release of FlashAttention-2 has been tested on several GPT-style -models, mostly on A100 GPUs. - -If you encounter bugs, please open a GitHub Issue! - -## Citation -If you use this codebase, or otherwise found our work valuable, please cite: -``` -@inproceedings{dao2022flashattention, - title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness}, - author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher}, - booktitle={Advances in Neural Information Processing Systems}, - year={2022} -} -@article{dao2023flashattention2, - title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning, - author={Dao, Tri}, - year={2023} -} -``` diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index 2c62e6c57..f3eb8850f 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -64,6 +64,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, const bool is_attn_mask = params.attn_mask_ptr != nullptr; const bool is_deterministic = params.num_splits == 1; // printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv); +#ifdef PADDLE_WITH_ADVANCED BOOL_SWITCH(params.is_causal, IsCausalConst, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { @@ -82,6 +83,22 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, }); }); }); +#else + BOOL_SWITCH(params.is_causal, IsCausalConst, [&] { + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + if (smem_size_dq_dk_dv >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); +#endif auto kernel_dq = &flash_bwd_convert_dq_kernel; if (Kernel_traits::kSmemdQSize >= 48 * 1024) { diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index 6c6382617..5090605cb 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -36,6 +36,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { const bool return_softmax = params.p_ptr != nullptr; const bool is_attn_mask = params.attn_mask_ptr != nullptr; const bool is_equal_qk = (params.seqlen_q == params.seqlen_k) && (Is_causal) && (!is_attn_mask); +#ifdef PADDLE_WITH_ADVANCED BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { @@ -59,6 +60,27 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { }); }); }); +#else + BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { + // Will only return softmax if dropout, to reduce compilation time. + auto kernel = &flash_fwd_kernel; + // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); +#endif } template diff --git a/csrc/mp.py b/csrc/mp.py deleted file mode 100644 index ab8435e8e..000000000 --- a/csrc/mp.py +++ /dev/null @@ -1,336 +0,0 @@ -# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py -import sys -import warnings -import os -import re -import ast -from pathlib import Path -from packaging.version import parse, Version -import platform - -from setuptools import setup, find_packages -import subprocess - -import urllib.request -import urllib.error -from wheel.bdist_wheel import bdist_wheel as _bdist_wheel - -import paddle -from paddle.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension -from paddle.utils.cpp_extension.extension_utils import find_cuda_home - -with open("README.md", "r", encoding="utf-8") as fh: - long_description = fh.read() - - -# ninja build does not work unless include_dirs are abs path -this_dir = os.path.dirname(os.path.abspath(__file__)) -CUDA_HOME = find_cuda_home() -PACKAGE_NAME = "flash_attn" - -BASE_WHEEL_URL = "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}" - -# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels -# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation -FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE" -SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE" -# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI -FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE" -# For CI, we want the option to not add "--threads 4" to nvcc, since the runner can OOM -FORCE_SINGLE_THREAD = os.getenv("FLASH_ATTENTION_FORCE_SINGLE_THREAD", "FALSE") == "TRUE" - - -def get_platform(): - """ - Returns the platform name as used in wheel filenames. - """ - if sys.platform.startswith('linux'): - return 'linux_x86_64' - elif sys.platform == 'darwin': - mac_version = '.'.join(platform.mac_ver()[0].split('.')[:2]) - return f'macosx_{mac_version}_x86_64' - elif sys.platform == 'win32': - return 'win_amd64' - else: - raise ValueError('Unsupported platform: {}'.format(sys.platform)) - - -def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) - output = raw_output.split() - release_idx = output.index("release") + 1 - bare_metal_version = parse(output[release_idx].split(",")[0]) - - return raw_output, bare_metal_version - - -def check_cuda_paddle_binary_vs_bare_metal(cuda_dir): - raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir) - paddle_binary_version = parse(paddle.version.cuda()) - - print("\nCompiling cuda extensions with") - print(raw_output + "from " + cuda_dir + "/bin\n") - - if (bare_metal_version != paddle_binary_version): - raise RuntimeError( - "Cuda extensions are being compiled with a version of Cuda that does " - "not match the version used to compile Pypaddle binaries. " - "Pypaddle binaries were compiled with Cuda {}.\n".format(paddle.version.cuda) - + "In some cases, a minor-version mismatch will not cause later errors: " - "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " - "You can try commenting out this check (at your own risk)." - ) - - -def raise_if_cuda_home_none(global_option: str) -> None: - if CUDA_HOME is not None: - return - raise RuntimeError( - f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " - "If you're installing within a container from https://hub.docker.com/r/pypaddle/pypaddle, " - "only images whose names contain 'devel' will provide nvcc." - ) - - -def append_nvcc_threads(nvcc_extra_args): - if not FORCE_SINGLE_THREAD: - return nvcc_extra_args + ["--threads", "4"] - return nvcc_extra_args - -def _is_cuda_available(): - """ - Check whether CUDA is available. - """ - try: - assert len(paddle.static.cuda_places()) > 0 - return True - except Exception as e: - logging.warning( - "You are using GPU version PaddlePaddle, but there is no GPU " - "detected on your machine. Maybe CUDA devices is not set properly." - f"\n Original Error is {e}" - ) - return False - -if paddle.is_compiled_with_cuda() and _is_cuda_available(): - # https://github.com/NVIDIA/apex/issues/486 - # Extension builds after https://github.com/pypaddle/pypaddle/pull/23408 attempt to query paddle.cuda.get_device_capability(), - # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). - print( - "\nWarning: Torch did not find available GPUs on this system.\n", - "If your intention is to cross-compile, this is not an error.\n" - "By default, FlashAttention will cross-compile for Ampere (compute capability 8.0, 8.6, " - "8.9), and, if the CUDA version is >= 11.8, Hopper (compute capability 9.0).\n" - "If you wish to cross-compile for a single specific architecture,\n" - 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', - ) - if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version >= Version("11.8"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6;9.0" - elif bare_metal_version >= Version("11.4"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6" - else: - os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6" - -cmdclass = {} -ext_modules = [] - -# We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp -# files included in the source distribution, in case the user compiles from source. -subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) - -if not SKIP_CUDA_BUILD: - print("\n\npaddle.__version__ = {}\n\n".format(paddle.__version__)) - TORCH_MAJOR = int(paddle.__version__.split(".")[0]) - TORCH_MINOR = int(paddle.__version__.split(".")[1]) - - # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h - # See https://github.com/pypaddle/pypaddle/pull/70650 - generator_flag = [] - paddle_dir = paddle.__path__[0] - if os.path.exists(os.path.join(paddle_dir, "include", "ATen", "CUDAGeneratorImpl.h")): - generator_flag = ["-DOLD_GENERATOR_PATH"] - - raise_if_cuda_home_none("flash_attn") - # Check, if CUDA11 is installed for compute capability 8.0 - cc_flag = [] - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version < Version("11.4"): - raise RuntimeError("FlashAttention is only supported on CUDA 11.4 and above") - # cc_flag.append("-gencode") - # cc_flag.append("arch=compute_75,code=sm_75") - cc_flag.append("-gencode") - cc_flag.append("arch=compute_80,ggcode=sm_80") - if bare_metal_version >= Version("11.8"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_90,code=sm_90") - - # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as - # paddle._C._GLIBCXX_USE_CXX11_ABI - # https://github.com/pypaddle/pypaddle/blob/8472c24e3b5b60150096486616d98b7bea01500b/paddle/utils/cpp_extension.py#L920 - if FORCE_CXX11_ABI: - paddle._C._GLIBCXX_USE_CXX11_ABI = True - ext_modules.append( - CUDAExtension( - sources=[ - "csrc/flash_attn/flash_api.cpp", - "csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu", - ], - extra_compile_args={ - "cxx": ["-O3", "-std=c++17"] + generator_flag, - "nvcc": append_nvcc_threads( - [ - "-O3", - "-std=c++17", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_HALF2_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - "--ptxas-options=-v", - # "--ptxas-options=-O2", - "-lineinfo" - ] - + generator_flag - + cc_flag - ), - }, - include_dirs=[ - Path(this_dir) / 'csrc' / 'flash_attn', - Path(this_dir) / 'csrc' / 'flash_attn' / 'src', - Path(this_dir) / 'csrc' / 'cutlass' / 'include', - ], - ) - ) - - -def get_package_version(): - with open(Path(this_dir) / "../flash_attn" / "__init__.py", "r") as f: - version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) - public_version = ast.literal_eval(version_match.group(1)) - local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION") - if local_version: - return f"{public_version}+{local_version}" - else: - return str(public_version) - - -class CachedWheelsCommand(_bdist_wheel): - """ - The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot - find an existing wheel (which is currently the case for all flash attention installs). We use - the environment parameters to detect whether there is already a pre-built version of a compatible - wheel available and short-circuits the standard full build pipeline. - """ - def run(self): - if FORCE_BUILD: - return super().run() - - # Determine the version numbers that will be used to determine the correct wheel - # We're using the CUDA version used to build paddle, not the one currently installed - # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) - paddle_cuda_version = parse(paddle.version.cuda) - paddle_version_raw = parse(paddle.__version__) - python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" - platform_name = get_platform() - flash_version = get_package_version() - # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" - cuda_version = f"{paddle_cuda_version.major}{paddle_cuda_version.minor}" - paddle_version = f"{paddle_version_raw.major}.{paddle_version_raw.minor}" - cxx11_abi = str(paddle._C._GLIBCXX_USE_CXX11_ABI).upper() - - # Determine wheel URL based on CUDA version, paddle version, python version and OS - wheel_filename = f'{PACKAGE_NAME}-{flash_version}+cu{cuda_version}paddle{paddle_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl' - wheel_url = BASE_WHEEL_URL.format( - tag_name=f"v{flash_version}", - wheel_name=wheel_filename - ) - print("Guessing wheel URL: ", wheel_url) - - try: - urllib.request.urlretrieve(wheel_url, wheel_filename) - - # Make the archive - # Lifted from the root wheel processing command - # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85 - if not os.path.exists(self.dist_dir): - os.makedirs(self.dist_dir) - - impl_tag, abi_tag, plat_tag = self.get_tag() - archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" - - wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") - print("Raw wheel path", wheel_path) - os.rename(wheel_filename, wheel_path) - except urllib.error.HTTPError: - print("Precompiled wheel not found. Building from source...") - # If the wheel could not be downloaded, build from source - super().run() - - -setup( - name=PACKAGE_NAME, - version=get_package_version(), - packages=find_packages( - exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",) - ), - author="Tri Dao", - author_email="trid@cs.stanford.edu", - description="Flash Attention: Fast and Memory-Efficient Exact Attention", - long_description=long_description, - long_description_content_type="text/markdown", - url="https://github.com/Dao-AILab/flash-attention", - classifiers=[ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: BSD License", - "Operating System :: Unix", - ], - #ext_modules=ext_modules, - cmdclass={ - 'bdist_wheel': CachedWheelsCommand, - "build_ext": BuildExtension - } if ext_modules else { - 'bdist_wheel': CachedWheelsCommand, - }, - python_requires=">=3.7", - install_requires=[ - "paddle", - "einops", - "packaging", - "ninja", - ], -) diff --git a/csrc/tp.py b/csrc/tp.py deleted file mode 100644 index 56aebee07..000000000 --- a/csrc/tp.py +++ /dev/null @@ -1,77 +0,0 @@ -import paddle -from setuptools import setup, find_packages -import sys -import os -import paddle -paddle_path = paddle.sysconfig.get_lib -print(paddle_path) -python_version = sys.version -print("Installing your_package...") - -# Get the CUDA version from PaddlePaddle -cuda_version = paddle.version.cuda() -fa_version = f"1.0.0.post{cuda_version}" -package_name = 'flash_attention_paddle_gpu' - -def get_data_files(): - data_files = [] - - # Assuming 'libflashattn.so' is located in the same directory as setup.py - source_lib_path = 'libflashattn.so' - - # Specify the destination directory within the package - destination_lib_path = os.path.join(package_name, 'libflashattn.so') - - data_files.append((os.path.join(package_name, 'libflashattn.so'), [source_lib_path])) - print(destination_lib_path, "asdf ****************") - print(data_files) - return data_files - -setup( - name=package_name, - version=fa_version, - data_files=get_data_files(), - description='Flash attention in paddlepaddle', - packages=find_packages(), - package_data={package_name: ['build/libflashattn.so']}, -) -# -#import paddle -#import os -#from setuptools import setup -#import sys -# -#python_version = sys.version -#print("Installing your_package...") -# -## Get the CUDA version from PaddlePaddle -#cuda_version = paddle.version.cuda() -#fa_version = f"1.0.0.post{cuda_version}" -#package_name = 'flash_attention_paddle_gpu' # Adjusted package name -# -#def get_data_files(): -# data_files = [] -# -# # Assuming 'libflashattn.so' is located in the same directory as setup.py -# source_lib_path = os.path.abspath('libflashattn.so') -# -# # Specify the destination directory within the package -# destination_lib_path = os.path.join(package_name, 'libflashattn.so') -# -# data_files.append((os.path.join(package_name, 'libflashattn.so'), [source_lib_path])) -# print(destination_lib_path, "asdf ****************") -# print(data_files) -# return data_files -# -## Create an empty __init__.py file in the package directory -#init_file_path = os.path.join(package_name, '__init__.py') -#with open(init_file_path, 'w') as f: -# pass -# -#setup( -# name=package_name, -# version=fa_version, -# description='Flash attention in paddlepaddle', -# packages=[package_name], -# package_data={package_name: ['libflashattn.so']}, -#) From 41ebd074fb8942a2cefe7f1f318945ab8176475a Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 19:05:05 +0800 Subject: [PATCH 10/36] all --- csrc/CMakeLists.txt | 80 ++++++++++++++++++++++----------------------- 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 9195ca565..a051f0552 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -16,38 +16,38 @@ set(CUTLASS_3_DIR ${CMAKE_CURRENT_SOURCE_DIR}/cutlass) set(FA2_SOURCES_CU flash_attn/src/cuda_utils.cu - #flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu ) add_library(flashattn SHARED @@ -60,13 +60,13 @@ target_include_directories(flashattn PRIVATE set(FA1_SOURCES_CU flash_attn_with_bias_and_mask/flash_attn_with_bias_mask.cu - #flash_attn_with_bias_and_mask/src/cuda_utils.cu - #flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim32.cu - #flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim64.cu - #flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim128.cu - #flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim32.cu - #flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim64.cu - #flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim128.cu + flash_attn_with_bias_and_mask/src/cuda_utils.cu + flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim32.cu + flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim64.cu + flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim128.cu + flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim32.cu + flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim64.cu + flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim128.cu flash_attn_with_bias_and_mask/src/utils.cu) add_library(flashattn_with_bias_mask STATIC @@ -141,7 +141,7 @@ INSTALL(TARGETS flashattn INSTALL(FILES capi/flash_attn.h DESTINATION "include") add_custom_target(run_my_executable - COMMAND ${CMAKE_COMMAND} -E env python ${CMAKE_SOURCE_DIR}/setup.py sdist bdist_wheel + COMMAND ${CMAKE_COMMAND} -E env python ${CMAKE_SOURCE_DIR}/setup.py bdist_wheel WORKING_DIRECTORY ${CMAKE_BINARY_DIR} DEPENDS flashattn COMMENT "Running my_executable" From ad614e0dfcaa8862eedb59506d65b2ed8f6a5a39 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 19:06:10 +0800 Subject: [PATCH 11/36] update --- csrc/yest | 3 --- 1 file changed, 3 deletions(-) delete mode 100644 csrc/yest diff --git a/csrc/yest b/csrc/yest deleted file mode 100644 index b3d4d3cd0..000000000 --- a/csrc/yest +++ /dev/null @@ -1,3 +0,0 @@ -include build/libflashattn.so -include src/libflashattn.so -include ./libflashattn.so From c8d003ab709aec89ada749677d2c3e4c014d860b Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 19:07:16 +0800 Subject: [PATCH 12/36] update --- yes.py | 12 ------------ 1 file changed, 12 deletions(-) delete mode 100644 yes.py diff --git a/yes.py b/yes.py deleted file mode 100644 index 29917c43e..000000000 --- a/yes.py +++ /dev/null @@ -1,12 +0,0 @@ -from setuptools import setup - -package_name = '' #flash-attention-paddle-gpu' -setup( - name=package_name, - version='1.0.0', - description='Flash attention in PaddlePaddle', - packages=[package_name], - include_package_data=True, - package_data={package_name: ['csrc/build/libflashattn.so']}, -) - From 559a47952de47cfd5aeecb66cdc306208757dec8 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 19:08:29 +0800 Subject: [PATCH 13/36] update --- csrc/CMakeLists.txt | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index a051f0552..bac78e5b4 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -134,7 +134,6 @@ target_compile_options(flashattn_with_bias_mask PRIVATE $<$) -# INSTALL(TARGETS flashattn LIBRARY DESTINATION "lib") @@ -147,9 +146,8 @@ add_custom_target(run_my_executable COMMENT "Running my_executable" ) -# 创建一个伪目标作为默认构建目标 add_custom_target(default_target DEPENDS run_my_executable) -# 设置 'default_target' 为默认构建目标 +# set 'default_target' set_property(DIRECTORY PROPERTY DEFAULT_TARGET default_target) From bd670ae650df79b3b60b98d4de6be8fd69f00d41 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 19:10:16 +0800 Subject: [PATCH 14/36] 80 90 --- csrc/CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index bac78e5b4..27ac1fdb6 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -6,10 +6,12 @@ find_package(Git QUIET REQUIRED) execute_process(COMMAND ${GIT_EXECUTABLE} submodule update --init --recursive WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} RESULT_VARIABLE GIT_SUBMOD_RESULT) + #cmake -DWITH_ADVANCED if (WITH_ADVANCED) add_compile_definitions(PADDLE_WITH_ADVANCED)cu endif() + add_definitions("-DFLASH_ATTN_WITH_TORCH=0") set(CUTLASS_3_DIR ${CMAKE_CURRENT_SOURCE_DIR}/cutlass) From e4b500675954bf1a9f48189e512e1c00b73c902a Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 19:22:11 +0800 Subject: [PATCH 15/36] error --- csrc/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 27ac1fdb6..c98cdbb45 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -9,7 +9,7 @@ execute_process(COMMAND ${GIT_EXECUTABLE} submodule update --init --recursive #cmake -DWITH_ADVANCED if (WITH_ADVANCED) - add_compile_definitions(PADDLE_WITH_ADVANCED)cu + add_compile_definitions(PADDLE_WITH_ADVANCED) endif() add_definitions("-DFLASH_ATTN_WITH_TORCH=0") From 4f7f1f0dfa0793fc24fe495bf55a0efa75986442 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 19:28:27 +0800 Subject: [PATCH 16/36] update build ok --- csrc/CMakeLists.txt | 79 +++++++++++++++++++++++---------------------- csrc/setup.py | 2 +- 2 files changed, 41 insertions(+), 40 deletions(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index c98cdbb45..fbf48d492 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -18,38 +18,38 @@ set(CUTLASS_3_DIR ${CMAKE_CURRENT_SOURCE_DIR}/cutlass) set(FA2_SOURCES_CU flash_attn/src/cuda_utils.cu - flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu ) add_library(flashattn SHARED @@ -62,13 +62,13 @@ target_include_directories(flashattn PRIVATE set(FA1_SOURCES_CU flash_attn_with_bias_and_mask/flash_attn_with_bias_mask.cu - flash_attn_with_bias_and_mask/src/cuda_utils.cu - flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim32.cu - flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim64.cu - flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim128.cu - flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim32.cu - flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim64.cu - flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim128.cu + #flash_attn_with_bias_and_mask/src/cuda_utils.cu + #flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim32.cu + #flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim64.cu + #flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim128.cu + #flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim32.cu + #flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim64.cu + #flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim128.cu flash_attn_with_bias_and_mask/src/utils.cu) add_library(flashattn_with_bias_mask STATIC @@ -100,6 +100,7 @@ endif() STRING(REPLACE "-" ";" FA_NVCC_ARCH_BIN ${NVCC_ARCH_BIN}) set(FA_GENCODE_OPTION "SHELL:") + foreach(arch ${FA_NVCC_ARCH_BIN}) if(${arch} GREATER_EQUAL 80) set(FA_GENCODE_OPTION "${FA_GENCODE_OPTION} -gencode arch=compute_${arch},code=sm_${arch}") diff --git a/csrc/setup.py b/csrc/setup.py index dcfc4e691..d26d6136f 100644 --- a/csrc/setup.py +++ b/csrc/setup.py @@ -57,7 +57,7 @@ cuda_version= paddle.version.cuda_version -with open("../README.md", "r", encoding="utf-8") as fh: +with open("../../README.md", "r", encoding="utf-8") as fh: long_description = fh.read() From 8c12f7266657a1627f4331053de776b843e610ba Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 19:39:47 +0800 Subject: [PATCH 17/36] update --- csrc/setup.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/csrc/setup.py b/csrc/setup.py index d26d6136f..b40d72c9a 100644 --- a/csrc/setup.py +++ b/csrc/setup.py @@ -194,11 +194,12 @@ def run(self): impl_tag, abi_tag, plat_tag = self.get_tag() original_wheel_name = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" - new_wheel_name ='asdfsdf.whl' # wheel_filename - #shutil.move( - # f"{self.dist_dir}/{original_wheel_name}.whl", - # f"{self.dist_dir}/{new_wheel_name}" - #) + #new_wheel_name = wheel_filename + new_wheel_name = f"{self.wheel_dist_name}-{python_version}-{abi_tag}-{plat_tag}" + shutil.move( + f"{self.dist_dir}/{original_wheel_name}.whl", + f"{self.dist_dir}/{new_wheel_name}.whl" + ) setup( From 7b257e8ad308b30a5a98bd861b7d14e1f3a30a74 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 20:02:19 +0800 Subject: [PATCH 18/36] update --- csrc/setup.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/csrc/setup.py b/csrc/setup.py index b40d72c9a..39a5e338a 100644 --- a/csrc/setup.py +++ b/csrc/setup.py @@ -18,6 +18,7 @@ import urllib.request import urllib.error from wheel.bdist_wheel import bdist_wheel as _bdist_wheel +from setuptools.command.install import install import paddle from paddle.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension @@ -159,14 +160,11 @@ def get_data_files(): # Assuming 'libflashattn.so' is located in the same directory as setup.py source_lib_path = 'libflashattn.so' - # Specify the destination directory within the package - destination_lib_path = os.path.join(PACKAGE_NAME, 'build/libflashattn.so') - - data_files.append((paddle_lib_path, [source_lib_path])) + data_files.append((".", [source_lib_path])) return data_files -class CachedWheelsCommand(_bdist_wheel): +class CustomWheelsCommand(_bdist_wheel): """ The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot find an existing wheel (which is currently the case for all flash attention installs). We use @@ -174,9 +172,6 @@ class CachedWheelsCommand(_bdist_wheel): wheel available and short-circuits the standard full build pipeline. """ def run(self): - print("88888888888888888888888888888") - # if FORCE_BUILD: - # return super().run() self.run_command('build_ext') super().run() # Determine the version numbers that will be used to determine the correct wheel @@ -202,6 +197,21 @@ def run(self): ) +class CustomInstallCommand(install): + def run(self): + install.run(self) + install_path = self.install_lib + # src + source_lib_path = os.path.abspath('libflashattn.so') + + # 目标链接路径 + destination_lib_path = os.path.join(paddle_lib_path, 'libflashattn.so') + + # 创建软链接 + shutil.move(f"{source_lib_path}", f"{destination_lib_path}") + #os.symlink(source_lib_path, destination_lib_path) + + setup( name=PACKAGE_NAME, version=get_package_version(), @@ -219,7 +229,8 @@ def run(self): "Operating System :: Unix", ], cmdclass={ - "bdist_wheel": CachedWheelsCommand,}, + 'bdist_wheel': CustomWheelsCommand, + 'install': CustomInstallCommand}, python_requires=">=3.7", install_requires=[ "paddle", From 4fd33eaf567f5a837f63bd5b1d2e7383c006d76e Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 20:16:49 +0800 Subject: [PATCH 19/36] updaet --- csrc/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index fbf48d492..fba57f0e2 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -142,14 +142,14 @@ INSTALL(TARGETS flashattn INSTALL(FILES capi/flash_attn.h DESTINATION "include") -add_custom_target(run_my_executable +add_custom_target(build_whl COMMAND ${CMAKE_COMMAND} -E env python ${CMAKE_SOURCE_DIR}/setup.py bdist_wheel WORKING_DIRECTORY ${CMAKE_BINARY_DIR} DEPENDS flashattn COMMENT "Running my_executable" ) -add_custom_target(default_target DEPENDS run_my_executable) +add_custom_target(default_target DEPENDS build_whl) # set 'default_target' set_property(DIRECTORY PROPERTY DEFAULT_TARGET default_target) From f03a1dffd0f3661c80c7d4098f5769a04b3febae Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 20:17:35 +0800 Subject: [PATCH 20/36] updaet --- csrc/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index fba57f0e2..486797127 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -146,7 +146,7 @@ add_custom_target(build_whl COMMAND ${CMAKE_COMMAND} -E env python ${CMAKE_SOURCE_DIR}/setup.py bdist_wheel WORKING_DIRECTORY ${CMAKE_BINARY_DIR} DEPENDS flashattn - COMMENT "Running my_executable" + COMMENT "Running build wheel" ) add_custom_target(default_target DEPENDS build_whl) From d8101084517132fa4f43cd657edd5fc3e28d792f Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 20:21:27 +0800 Subject: [PATCH 21/36] upate --- csrc/setup.py | 41 +---------------------------------------- 1 file changed, 1 insertion(+), 40 deletions(-) diff --git a/csrc/setup.py b/csrc/setup.py index 39a5e338a..39e224df6 100644 --- a/csrc/setup.py +++ b/csrc/setup.py @@ -67,15 +67,6 @@ CUDA_HOME = find_cuda_home() PACKAGE_NAME = "paddle_flash_attn" -BASE_WHEEL_URL = "https://github.com/PaddlePaddle/flash-attention/releases/download/{tag_name}/{wheel_name}" - -FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE" -# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI -FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE" -# For CI, we want the option to not add "--threads 4" to nvcc, since the runner can OOM -FORCE_SINGLE_THREAD = os.getenv("FLASH_ATTENTION_FORCE_SINGLE_THREAD", "FALSE") == "TRUE" - - def get_platform(): """ Returns the platform name as used in wheel filenames. @@ -100,33 +91,6 @@ def get_cuda_bare_metal_version(cuda_dir): return raw_output, bare_metal_version -def check_cuda_paddle_binary_vs_bare_metal(cuda_dir): - raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir) - paddle_binary_version = parse(paddle.version.cuda()) - - print("\nCompiling cuda extensions with") - print(raw_output + "from " + cuda_dir + "/bin\n") - - if (bare_metal_version != paddle_binary_version): - raise RuntimeError( - "Cuda extensions are being compiled with a version of Cuda that does " - "not match the version used to compile Pypaddle binaries. " - "Pypaddle binaries were compiled with Cuda {}.\n".format(paddle.version.cuda) - + "In some cases, a minor-version mismatch will not cause later errors: " - "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " - "You can try commenting out this check (at your own risk)." - ) - - -def raise_if_cuda_home_none(global_option: str) -> None: - if CUDA_HOME is not None: - return - raise RuntimeError( - f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " - "If you're installing within a container from https://hub.docker.com/r/pypaddle/pypaddle, " - "only images whose names contain 'devel' will provide nvcc." - ) - def _is_cuda_available(): """ Check whether CUDA is available. @@ -141,7 +105,7 @@ def _is_cuda_available(): f"\n Original Error is {e}" ) return False - +check = _is_cuda_available() cmdclass = {} def get_package_version(): @@ -156,10 +120,7 @@ def get_package_version(): def get_data_files(): data_files = [] - - # Assuming 'libflashattn.so' is located in the same directory as setup.py source_lib_path = 'libflashattn.so' - data_files.append((".", [source_lib_path])) return data_files From 48eb6479fd7163bd20164f0e36ad9a3b5f7ec510 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 20:38:28 +0800 Subject: [PATCH 22/36] update --- csrc/setup.py | 90 ++++++++++++++++++++++++++++++++------------------- 1 file changed, 56 insertions(+), 34 deletions(-) diff --git a/csrc/setup.py b/csrc/setup.py index 39e224df6..1e7a63293 100644 --- a/csrc/setup.py +++ b/csrc/setup.py @@ -1,29 +1,36 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py -import sys -import warnings -import os -import re import ast -from pathlib import Path -from packaging.version import parse, Version +import logging +import os import platform +import re import shutil -from setuptools import Command, Extension, setup -from setuptools.command.develop import develop as DevelopCommandBase -from setuptools.command.egg_info import egg_info -from setuptools.command.install import install as InstallCommandBase -from setuptools.command.install_lib import install_lib -from setuptools.dist import Distribution -from setuptools import setup, find_packages -import urllib.request -import urllib.error +import subprocess +import sys +import warnings +from pathlib import Path + +from packaging.version import parse +from setuptools import find_packages, setup +from setuptools.command.install import install as _install from wheel.bdist_wheel import bdist_wheel as _bdist_wheel -from setuptools.command.install import install import paddle -from paddle.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension from paddle.utils.cpp_extension.extension_utils import find_cuda_home -import subprocess version_detail = sys.version_info python_version = platform.python_version() @@ -55,7 +62,7 @@ # preparing parameters for setup() paddle_version = paddle.version.full_version -cuda_version= paddle.version.cuda_version +cuda_version = paddle.version.cuda_version with open("../../README.md", "r", encoding="utf-8") as fh: @@ -67,6 +74,7 @@ CUDA_HOME = find_cuda_home() PACKAGE_NAME = "paddle_flash_attn" + def get_platform(): """ Returns the platform name as used in wheel filenames. @@ -79,11 +87,13 @@ def get_platform(): elif sys.platform == 'win32': return 'win_amd64' else: - raise ValueError('Unsupported platform: {}'.format(sys.platform)) + raise ValueError(f'Unsupported platform: {sys.platform}') def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + raw_output = subprocess.check_output( + [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True + ) output = raw_output.split() release_idx = output.index("release") + 1 bare_metal_version = parse(output[release_idx].split(",")[0]) @@ -105,12 +115,17 @@ def _is_cuda_available(): f"\n Original Error is {e}" ) return False + + check = _is_cuda_available() cmdclass = {} + def get_package_version(): with open(Path(this_dir) / "../flash_attn" / "__init__.py", "r") as f: - version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) + version_match = re.search( + r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE + ) public_version = ast.literal_eval(version_match.group(1)) local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION") if local_version: @@ -118,6 +133,7 @@ def get_package_version(): else: return str(public_version) + def get_data_files(): data_files = [] source_lib_path = 'libflashattn.so' @@ -132,35 +148,40 @@ class CustomWheelsCommand(_bdist_wheel): the environment parameters to detect whether there is already a pre-built version of a compatible wheel available and short-circuits the standard full build pipeline. """ + def run(self): self.run_command('build_ext') super().run() # Determine the version numbers that will be used to determine the correct wheel # We're using the CUDA version used to build paddle, not the one currently installed # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) - paddle_cuda_version = "234" #parse(paddle.version.cuda) + paddle_cuda_version = "234" # parse(paddle.version.cuda) paddle_version_raw = parse(paddle.__version__) python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" platform_name = get_platform() flash_version = get_package_version() - cxx11_abi ="" # str(paddle._C.-D_GLIBCXX_USE_CXX11_ABI).upper() + cxx11_abi = "" # str(paddle._C.-D_GLIBCXX_USE_CXX11_ABI).upper() # Determine wheel URL based on CUDA version, paddle version, python version and OS wheel_filename = f'{PACKAGE_NAME}-{flash_version}-cu{cuda_version}-paddle{paddle_version}-{python_version}-{python_version}-{platform_name}.whl' impl_tag, abi_tag, plat_tag = self.get_tag() - original_wheel_name = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" + original_wheel_name = ( + f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" + ) - #new_wheel_name = wheel_filename - new_wheel_name = f"{self.wheel_dist_name}-{python_version}-{abi_tag}-{plat_tag}" + # new_wheel_name = wheel_filename + new_wheel_name = ( + f"{self.wheel_dist_name}-{python_version}-{abi_tag}-{plat_tag}" + ) shutil.move( f"{self.dist_dir}/{original_wheel_name}.whl", - f"{self.dist_dir}/{new_wheel_name}.whl" - ) + f"{self.dist_dir}/{new_wheel_name}.whl", + ) -class CustomInstallCommand(install): +class CustomInstallCommand(_install): def run(self): - install.run(self) + super().run(self) install_path = self.install_lib # src source_lib_path = os.path.abspath('libflashattn.so') @@ -170,7 +191,7 @@ def run(self): # 创建软链接 shutil.move(f"{source_lib_path}", f"{destination_lib_path}") - #os.symlink(source_lib_path, destination_lib_path) + # os.symlink(source_lib_path, destination_lib_path) setup( @@ -190,8 +211,9 @@ def run(self): "Operating System :: Unix", ], cmdclass={ - 'bdist_wheel': CustomWheelsCommand, - 'install': CustomInstallCommand}, + 'bdist_wheel': CustomWheelsCommand, + 'install': CustomInstallCommand, + }, python_requires=">=3.7", install_requires=[ "paddle", From 7bb6f314f0995d2a1bc67aa95e2a0edd573dd906 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 20:40:18 +0800 Subject: [PATCH 23/36] update --- csrc/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/setup.py b/csrc/setup.py index 1e7a63293..004c89133 100644 --- a/csrc/setup.py +++ b/csrc/setup.py @@ -181,7 +181,7 @@ def run(self): class CustomInstallCommand(_install): def run(self): - super().run(self) + _install.run(self) install_path = self.install_lib # src source_lib_path = os.path.abspath('libflashattn.so') From 58563ba8c8172dad5439041fc396d61dc120d449 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 20:41:52 +0800 Subject: [PATCH 24/36] udpate --- csrc/setup.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/csrc/setup.py b/csrc/setup.py index 004c89133..11cbb4dab 100644 --- a/csrc/setup.py +++ b/csrc/setup.py @@ -186,11 +186,9 @@ def run(self): # src source_lib_path = os.path.abspath('libflashattn.so') - # 目标链接路径 destination_lib_path = os.path.join(paddle_lib_path, 'libflashattn.so') - # 创建软链接 - shutil.move(f"{source_lib_path}", f"{destination_lib_path}") + # shutil.move(f"{source_lib_path}", f"{destination_lib_path}") # os.symlink(source_lib_path, destination_lib_path) From e856a057707bb1d101e1e49bf2558646fb11cfd4 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 20:44:23 +0800 Subject: [PATCH 25/36] update --- csrc/CMakeLists.txt | 78 ++++++++++++++++++++++----------------------- 1 file changed, 39 insertions(+), 39 deletions(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 486797127..b571791a0 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -18,38 +18,38 @@ set(CUTLASS_3_DIR ${CMAKE_CURRENT_SOURCE_DIR}/cutlass) set(FA2_SOURCES_CU flash_attn/src/cuda_utils.cu - #flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu ) add_library(flashattn SHARED @@ -62,13 +62,13 @@ target_include_directories(flashattn PRIVATE set(FA1_SOURCES_CU flash_attn_with_bias_and_mask/flash_attn_with_bias_mask.cu - #flash_attn_with_bias_and_mask/src/cuda_utils.cu - #flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim32.cu - #flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim64.cu - #flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim128.cu - #flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim32.cu - #flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim64.cu - #flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim128.cu + flash_attn_with_bias_and_mask/src/cuda_utils.cu + flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim32.cu + flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim64.cu + flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim128.cu + flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim32.cu + flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim64.cu + flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim128.cu flash_attn_with_bias_and_mask/src/utils.cu) add_library(flashattn_with_bias_mask STATIC From 256a3c6b8a9b13923cb1c7f4e9cae6dd2a726361 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Thu, 7 Dec 2023 16:36:06 +0800 Subject: [PATCH 26/36] update --- csrc/CMakeLists.txt | 107 +++++++++++++++++++++++--------------------- csrc/setup.py | 7 ++- 2 files changed, 60 insertions(+), 54 deletions(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index b571791a0..a6b9e6e00 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -7,7 +7,7 @@ execute_process(COMMAND ${GIT_EXECUTABLE} submodule update --init --recursive WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} RESULT_VARIABLE GIT_SUBMOD_RESULT) -#cmake -DWITH_ADVANCED +#cmake -DWITH_ADVANCED=ON if (WITH_ADVANCED) add_compile_definitions(PADDLE_WITH_ADVANCED) endif() @@ -18,38 +18,38 @@ set(CUTLASS_3_DIR ${CMAKE_CURRENT_SOURCE_DIR}/cutlass) set(FA2_SOURCES_CU flash_attn/src/cuda_utils.cu - flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu ) add_library(flashattn SHARED @@ -63,12 +63,12 @@ target_include_directories(flashattn PRIVATE set(FA1_SOURCES_CU flash_attn_with_bias_and_mask/flash_attn_with_bias_mask.cu flash_attn_with_bias_and_mask/src/cuda_utils.cu - flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim32.cu - flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim64.cu - flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim128.cu - flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim32.cu - flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim64.cu - flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim128.cu + #flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim32.cu + #flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim64.cu + #flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim128.cu + #flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim32.cu + #flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim64.cu + #flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim128.cu flash_attn_with_bias_and_mask/src/utils.cu) add_library(flashattn_with_bias_mask STATIC @@ -142,15 +142,22 @@ INSTALL(TARGETS flashattn INSTALL(FILES capi/flash_attn.h DESTINATION "include") -add_custom_target(build_whl - COMMAND ${CMAKE_COMMAND} -E env python ${CMAKE_SOURCE_DIR}/setup.py bdist_wheel - WORKING_DIRECTORY ${CMAKE_BINARY_DIR} - DEPENDS flashattn - COMMENT "Running build wheel" -) - -add_custom_target(default_target DEPENDS build_whl) - -# set 'default_target' -set_property(DIRECTORY PROPERTY DEFAULT_TARGET default_target) +if (WITH_ADVANCED) + set_target_properties(flashattn PROPERTIES + OUTPUT_NAME libflashattn_advanced + PREFIX "" + ) +endif() +if (WITH_ADVANCED) + add_custom_target(build_whl + COMMAND ${CMAKE_COMMAND} -E env python ${CMAKE_SOURCE_DIR}/setup.py bdist_wheel + WORKING_DIRECTORY ${CMAKE_BINARY_DIR} + DEPENDS flashattn + COMMENT "Running build wheel" + ) + + add_custom_target(default_target DEPENDS build_whl) + + set_property(DIRECTORY PROPERTY DEFAULT_TARGET default_target) +endif() diff --git a/csrc/setup.py b/csrc/setup.py index 11cbb4dab..d6e136003 100644 --- a/csrc/setup.py +++ b/csrc/setup.py @@ -136,8 +136,9 @@ def get_package_version(): def get_data_files(): data_files = [] - source_lib_path = 'libflashattn.so' - data_files.append((".", [source_lib_path])) + #source_lib_path = 'libflashattn.so' + #data_files.append((".", [source_lib_path])) + data_files.append((".", ['flashattn_advanced.so'])) return data_files @@ -155,8 +156,6 @@ def run(self): # Determine the version numbers that will be used to determine the correct wheel # We're using the CUDA version used to build paddle, not the one currently installed # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) - paddle_cuda_version = "234" # parse(paddle.version.cuda) - paddle_version_raw = parse(paddle.__version__) python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" platform_name = get_platform() flash_version = get_package_version() From af386bf4e232948e1c713339e72eb2b05943272b Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Thu, 7 Dec 2023 16:40:07 +0800 Subject: [PATCH 27/36] update --- csrc/CMakeLists.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index a6b9e6e00..4c1865911 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -137,6 +137,9 @@ target_compile_options(flashattn_with_bias_mask PRIVATE $<$) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + INSTALL(TARGETS flashattn LIBRARY DESTINATION "lib") From 3aca223f32b4117dc44177c9932eb44712fcae34 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Thu, 7 Dec 2023 18:17:05 +0800 Subject: [PATCH 28/36] Update --- csrc/CMakeLists.txt | 76 ++++++++++++++++++++++----------------------- csrc/setup.py | 16 ++++++---- 2 files changed, 48 insertions(+), 44 deletions(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 4c1865911..9261d7d05 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -18,38 +18,38 @@ set(CUTLASS_3_DIR ${CMAKE_CURRENT_SOURCE_DIR}/cutlass) set(FA2_SOURCES_CU flash_attn/src/cuda_utils.cu - #flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu ) add_library(flashattn SHARED @@ -63,12 +63,12 @@ target_include_directories(flashattn PRIVATE set(FA1_SOURCES_CU flash_attn_with_bias_and_mask/flash_attn_with_bias_mask.cu flash_attn_with_bias_and_mask/src/cuda_utils.cu - #flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim32.cu - #flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim64.cu - #flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim128.cu - #flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim32.cu - #flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim64.cu - #flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim128.cu + flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim32.cu + flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim64.cu + flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim128.cu + flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim32.cu + flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim64.cu + flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim128.cu flash_attn_with_bias_and_mask/src/utils.cu) add_library(flashattn_with_bias_mask STATIC diff --git a/csrc/setup.py b/csrc/setup.py index d6e136003..77c127e46 100644 --- a/csrc/setup.py +++ b/csrc/setup.py @@ -138,7 +138,7 @@ def get_data_files(): data_files = [] #source_lib_path = 'libflashattn.so' #data_files.append((".", [source_lib_path])) - data_files.append((".", ['flashattn_advanced.so'])) + data_files.append((".", ['libflashattn_advanced.so'])) return data_files @@ -213,9 +213,13 @@ def run(self): }, python_requires=">=3.7", install_requires=[ - "paddle", - "einops", - "packaging", - "ninja", - ], + "common", + "dual", + "tight>=0.1.0", + "data", + "prox", + "ninja", # Put ninja before paddle if paddle depends on it + "einops", + "packaging", +], ) From 06edc27972b0a8a08a62b42ea410058679fc25bb Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Fri, 8 Dec 2023 10:23:43 +0800 Subject: [PATCH 29/36] update --- csrc/flash_attn/src/flash_bwd_launch_template.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index f3eb8850f..d611b5dea 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -87,8 +87,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, BOOL_SWITCH(params.is_causal, IsCausalConst, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { - auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; - // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; if (smem_size_dq_dk_dv >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); From 45fcc53e73b3ea9b69098ccfe9878fd26467427d Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Fri, 8 Dec 2023 14:55:17 +0800 Subject: [PATCH 30/36] update --- csrc/CMakeLists.txt | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 9261d7d05..0d5e65ee3 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -88,14 +88,9 @@ target_link_libraries(flashattn flashattn_with_bias_mask) add_dependencies(flashattn flashattn_with_bias_mask) +option(NVCC_ARCH_BIN "Set default compute arch to 80" "80") -if (NOT DEFINED NVCC_ARCH_BIN) - message(FATAL_ERROR "NVCC_ARCH_BIN is not defined.") -endif() - -if (NVCC_ARCH_BIN STREQUAL "") - message(FATAL_ERROR "NVCC_ARCH_BIN is not set.") -endif() +message("NVCC_ARCH_BIN is set to: ${NVCC_ARCH_BIN}") STRING(REPLACE "-" ";" FA_NVCC_ARCH_BIN ${NVCC_ARCH_BIN}) From 940a8ae76bdc72fe79d38f9d083c7d629531c6c0 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Fri, 8 Dec 2023 15:12:50 +0800 Subject: [PATCH 31/36] default --- csrc/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 0d5e65ee3..e52955036 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -88,7 +88,7 @@ target_link_libraries(flashattn flashattn_with_bias_mask) add_dependencies(flashattn flashattn_with_bias_mask) -option(NVCC_ARCH_BIN "Set default compute arch to 80" "80") +set(NVCC_ARCH_BIN 80 CACHE STRING "CUDA architectures") message("NVCC_ARCH_BIN is set to: ${NVCC_ARCH_BIN}") From 6b6c7a88e490056fd47698b83402fb4c471ab109 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Fri, 8 Dec 2023 15:30:20 +0800 Subject: [PATCH 32/36] update --- csrc/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index e52955036..c8c1f142b 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -1,5 +1,7 @@ cmake_minimum_required(VERSION 3.9 FATAL_ERROR) project(flash-attention LANGUAGES CXX CUDA) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) find_package(Git QUIET REQUIRED) @@ -132,8 +134,6 @@ target_compile_options(flashattn_with_bias_mask PRIVATE $<$) -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_CXX_STANDARD_REQUIRED ON) INSTALL(TARGETS flashattn LIBRARY DESTINATION "lib") From 18ae75621289868c096ea63477271bd227ab5805 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Fri, 8 Dec 2023 17:04:01 +0800 Subject: [PATCH 33/36] update equal --- .../src/flash_fwd_launch_template.h | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index 5090605cb..ae707d0e9 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -64,19 +64,21 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { - // Will only return softmax if dropout, to reduce compilation time. - auto kernel = &flash_fwd_kernel; - // auto kernel = &flash_fwd_kernel; - if (smem_size >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); - // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + BOOL_SWITCH(is_equal_qk, Is_equal_seq_qk, [&] { + // Will only return softmax if dropout, to reduce compilation time. + auto kernel = &flash_fwd_kernel; + // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); }); }); From d926c09ca157063b1a8ed673d24fcb358c9a11d4 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Sun, 10 Dec 2023 21:44:29 +0800 Subject: [PATCH 34/36] for so --- csrc/setup.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/csrc/setup.py b/csrc/setup.py index 77c127e46..060268cca 100644 --- a/csrc/setup.py +++ b/csrc/setup.py @@ -182,13 +182,9 @@ class CustomInstallCommand(_install): def run(self): _install.run(self) install_path = self.install_lib - # src - source_lib_path = os.path.abspath('libflashattn.so') - - destination_lib_path = os.path.join(paddle_lib_path, 'libflashattn.so') - - # shutil.move(f"{source_lib_path}", f"{destination_lib_path}") - # os.symlink(source_lib_path, destination_lib_path) + source_lib_path = os.path.abspath('libflashattn_advanced.so') + destination_lib_path = os.path.join(paddle_lib_path, 'libflashattn_advanced.so') + shutil.copy(f"{source_lib_path}", f"{destination_lib_path}") setup( From a2714ebebc2113548bd5d47d6f44fde0b4cc4872 Mon Sep 17 00:00:00 2001 From: niuliling123 <51102941+niuliling123@users.noreply.github.com> Date: Mon, 11 Dec 2023 12:27:15 +0800 Subject: [PATCH 35/36] Update CMakeLists.txt --- csrc/CMakeLists.txt | 3 --- 1 file changed, 3 deletions(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index c8c1f142b..9aa83b4fc 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -145,9 +145,6 @@ if (WITH_ADVANCED) OUTPUT_NAME libflashattn_advanced PREFIX "" ) -endif() - -if (WITH_ADVANCED) add_custom_target(build_whl COMMAND ${CMAKE_COMMAND} -E env python ${CMAKE_SOURCE_DIR}/setup.py bdist_wheel WORKING_DIRECTORY ${CMAKE_BINARY_DIR} From a61e35b2bb4195f98811ed90c06ddabf2390c75e Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Mon, 11 Dec 2023 15:18:05 +0800 Subject: [PATCH 36/36] update fa1 mask --- csrc/CMakeLists.txt | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 9aa83b4fc..a19b92744 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -62,6 +62,7 @@ target_include_directories(flashattn PRIVATE flash_attn ${CUTLASS_3_DIR}/include) +if (WITH_ADVANCED) set(FA1_SOURCES_CU flash_attn_with_bias_and_mask/flash_attn_with_bias_mask.cu flash_attn_with_bias_and_mask/src/cuda_utils.cu @@ -72,6 +73,12 @@ set(FA1_SOURCES_CU flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim64.cu flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim128.cu flash_attn_with_bias_and_mask/src/utils.cu) +else() +set(FA1_SOURCES_CU + flash_attn_with_bias_and_mask/flash_attn_with_bias_mask.cu + flash_attn_with_bias_and_mask/src/cuda_utils.cu + flash_attn_with_bias_and_mask/src/utils.cu) +endif() add_library(flashattn_with_bias_mask STATIC flash_attn_with_bias_and_mask/