Skip to content

Commit

Permalink
cherry-pick from (PaddlePaddle#56040): Make flash attn v1 available
Browse files Browse the repository at this point in the history
  • Loading branch information
sneaxiy authored and wentaoyu committed Oct 26, 2023
1 parent e86c435 commit 8e90e92
Show file tree
Hide file tree
Showing 21 changed files with 1,127 additions and 37 deletions.
123 changes: 123 additions & 0 deletions cmake/external/flashattn_v1.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

include(ExternalProject)

add_definitions(-DPADDLE_WITH_FLASHATTN)

set(FLASHATTN_V1_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn_v1)
set(FLASHATTN_V1_SOURCE_SUBDIR csrc/flash_attn)
set(FLASHATTN_V1_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn_v1)
set(FLASHATTN_V1_REPOSITORY ${GIT_URL}/PaddlePaddle/flash-attention.git)
set(FLASHATTN_V1_TAG 5ff4bbf56ad066750407c4aef16ac740ebda0717)

set(FLASHATTN_V1_INCLUDE_DIR
"${FLASHATTN_V1_INSTALL_DIR}/include"
CACHE PATH "flash-attn v1 Directory" FORCE)
set(FLASHATTN_V1_LIB_DIR
"${FLASHATTN_V1_INSTALL_DIR}/lib"
CACHE PATH "flash-attn v1 Library Directory" FORCE)

if(WIN32)
set(FLASHATTN_V1_OLD_LIBRARIES
"${FLASHATTN_V1_INSTALL_DIR}/bin/flashattn${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "flash-attn v1 Library" FORCE)
set(FLASHATTN_V1_LIBRARIES
"${FLASHATTN_V1_INSTALL_DIR}/bin/flashattn_v1${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "flash-attn v1 Library" FORCE)
else()
set(FLASHATTN_V1_OLD_LIBRARIES
"${FLASHATTN_V1_INSTALL_DIR}/lib/libflashattn${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "flash-attn v1 Library" FORCE)
set(FLASHATTN_V1_LIBRARIES
"${FLASHATTN_V1_INSTALL_DIR}/lib/libflashattn_v1${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "flash-attn v1 Library" FORCE)
endif()

if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang"
OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang"
OR WIN32)
set(USE_OMP OFF)
else()
set(USE_OMP ON)
endif()

if(WIN32)
set(FLASHATTN_V1_C_FLAGS $<FILTER:${CMAKE_C_FLAGS},EXCLUDE,/Zc:inline>)
set(FLASHATTN_V1_C_FLAGS_DEBUG
$<FILTER:${CMAKE_C_FLAGS_DEBUG},EXCLUDE,/Zc:inline>)
set(FLASHATTN_V1_C_FLAGS_RELEASE
$<FILTER:${CMAKE_C_FLAGS_RELEASE},EXCLUDE,/Zc:inline>)
set(FLASHATTN_V1_CXX_FLAGS $<FILTER:${CMAKE_CXX_FLAGS},EXCLUDE,/Zc:inline>)
set(FLASHATTN_V1_CXX_FLAGS_RELEASE
$<FILTER:${CMAKE_CXX_FLAGS_RELEASE},EXCLUDE,/Zc:inline>)
set(FLASHATTN_V1_CXX_FLAGS_DEBUG
$<FILTER:${CMAKE_CXX_FLAGS_DEBUG},EXCLUDE,/Zc:inline>)
else()
set(FLASHATTN_V1_C_FLAGS ${CMAKE_C_FLAGS})
set(FLASHATTN_V1_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG})
set(FLASHATTN_V1_C_FLAGS_RELEASE ${CMAKE_C_FLAGS_RELEASE})
set(FLASHATTN_V1_CXX_FLAGS ${CMAKE_CXX_FLAGS})
set(FLASHATTN_V1_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
set(FLASHATTN_V1_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
endif()

ExternalProject_Add(
extern_flashattn_v1
${EXTERNAL_PROJECT_LOG_ARGS} ${SHALLOW_CLONE}
GIT_REPOSITORY ${FLASHATTN_V1_REPOSITORY}
GIT_TAG ${FLASHATTN_V1_TAG}
PREFIX ${FLASHATTN_V1_PREFIX_DIR}
SOURCE_SUBDIR ${FLASHATTN_V1_SOURCE_SUBDIR}
UPDATE_COMMAND ""
PATCH_COMMAND ""
#BUILD_ALWAYS 1
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_C_FLAGS=${FLASHATTN_V1_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${FLASHATTN_V1_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${FLASHATTN_V1_C_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS=${FLASHATTN_V1_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${FLASHATTN_V1_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${FLASHATTN_V1_CXX_FLAGS_DEBUG}
-DCMAKE_INSTALL_PREFIX=${FLASHATTN_V1_INSTALL_DIR}
-DWITH_GPU=${WITH_GPU}
-DCMAKE_CUDA_COMPILER=${CMAKE_CUDA_COMPILER}
-DWITH_ROCM=${WITH_ROCM}
-DWITH_OMP=${USE_OMP}
-DBUILD_SHARED=ON
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_INSTALL_PREFIX:PATH=${FLASHATTN_V1_INSTALL_DIR}
BUILD_BYPRODUCTS ${FLASHATTN_V1_LIBRARIES})

add_custom_target(
extern_flashattn_v1_move_lib
COMMAND ${CMAKE_COMMAND} -E copy ${FLASHATTN_V1_OLD_LIBRARIES}
${FLASHATTN_V1_LIBRARIES})

add_dependencies(extern_flashattn_v1_move_lib extern_flashattn_v1)

message(STATUS "flash-attn v1 library: ${FLASHATTN_V1_LIBRARIES}")
get_filename_component(FLASHATTN_V1_LIBRARY_PATH ${FLASHATTN_V1_LIBRARIES}
DIRECTORY)
include_directories(${FLASHATTN_V1_INCLUDE_DIR})

add_library(flashattn_v1 INTERFACE)
#set_property(TARGET flashattn_v1 PROPERTY IMPORTED_LOCATION ${FLASHATTN_V1_LIBRARIES})
add_dependencies(flashattn_v1 extern_flashattn_v1 extern_flashattn_v1_move_lib)
3 changes: 2 additions & 1 deletion cmake/third_party.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,8 @@ if(WITH_GPU
foreach(arch ${NVCC_ARCH_BIN})
if(${arch} GREATER_EQUAL 80)
include(external/flashattn)
list(APPEND third_party_deps extern_flashattn)
include(external/flashattn_v1)
list(APPEND third_party_deps extern_flashattn extern_flashattn_v1)
set(WITH_FLASHATTN ON)
break()
endif()
Expand Down
22 changes: 22 additions & 0 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,28 @@
func : flash_attn_unpadded_grad
data_type: q

- backward_op : flash_attn_v1_grad
forward : flash_attn_v1 (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta :
func : FlashAttnV1GradInferMeta
param : [q, k, v]
kernel :
func : flash_attn_v1_grad
data_type: q

- backward_op : flash_attn_v1_unpadded_grad
forward : flash_attn_v1_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false)
output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad)
infer_meta :
func : FlashAttnV1GradInferMeta
param : [q, k, v]
kernel :
func : flash_attn_v1_unpadded_grad
data_type: q

- backward_op : flatten_grad
forward : flatten(Tensor x, int start_axis = 1, int stop_axis = 1) -> Tensor(out), Tensor(xshape)
args : (Tensor xshape, Tensor out_grad)
Expand Down
24 changes: 24 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -941,6 +941,30 @@
intermediate : softmax_lse, seed_offset
backward : flash_attn_unpadded_grad

- op : flash_attn_v1
args : (Tensor q, Tensor k, Tensor v, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false)
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
infer_meta :
func : FlashAttnV1InferMeta
param : [q, k, v]
kernel :
func : flash_attn_v1
data_type : q
intermediate : softmax_lse, seed_offset
backward : flash_attn_v1_grad

- op : flash_attn_v1_unpadded
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false)
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
infer_meta :
func : FlashAttnV1InferMeta
param : [q, k, v]
kernel :
func : flash_attn_v1_unpadded
data_type : q
intermediate : softmax_lse, seed_offset
backward : flash_attn_v1_unpadded_grad

- op : flatten
args : (Tensor x, int start_axis = 1, int stop_axis = 1)
output : Tensor(out), Tensor(xshape)
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/backends/dynload/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ if(WITH_XPU)
endif()

if(WITH_FLASHATTN)
list(APPEND DYNLOAD_COMMON_SRCS flashattn.cc)
list(APPEND DYNLOAD_COMMON_SRCS flashattn.cc flashattn_v1.cc)
endif()

if(MKL_FOUND AND WITH_ONEMKL)
Expand Down
14 changes: 14 additions & 0 deletions paddle/phi/backends/dynload/dynamic_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,20 @@ void* GetFlashAttnDsoHandle() {
#endif
}

void* GetFlashAttnV1DsoHandle() {
std::string flashattn_dir = "";
if (!s_py_site_pkg_path.path.empty()) {
flashattn_dir = s_py_site_pkg_path.path;
}
#if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(flashattn_dir, "libflashattn_v1.dylib");
#elif defined(_WIN32)
return GetDsoHandleFromSearchPath(flashattn_dir, "flashattn_v1.dll");
#else
return GetDsoHandleFromSearchPath(flashattn_dir, "libflashattn_v1.so");
#endif
}

void* GetNCCLDsoHandle() {
#ifdef PADDLE_WITH_HIP
std::string warning_msg(
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/backends/dynload/dynamic_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ void* GetCUDADsoHandle();
void* GetWarpCTCDsoHandle();
void* GetWarpRNNTDsoHandle();
void* GetFlashAttnDsoHandle();
void* GetFlashAttnV1DsoHandle();
void* GetNCCLDsoHandle();
void* GetTensorRtDsoHandle();
void* GetMKLMLDsoHandle();
Expand Down
28 changes: 28 additions & 0 deletions paddle/phi/backends/dynload/flashattn_v1.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/backends/dynload/flashattn_v1.h"

namespace phi {
namespace dynload {

std::once_flag flashattn_v1_dso_flag;
void* flashattn_v1_dso_handle = nullptr;

#define DEFINE_WRAP(__name) DynLoad__##__name##__v1 __name##_v1

FLASHATTN_V1_ROUTINE_EACH(DEFINE_WRAP);

} // namespace dynload
} // namespace phi
133 changes: 133 additions & 0 deletions paddle/phi/backends/dynload/flashattn_v1.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <mutex> // NOLINT

#include "cuda_runtime.h" // NOLINT
#include "paddle/phi/backends/dynload/dynamic_loader.h"
#include "paddle/phi/backends/dynload/port.h"

namespace phi {
namespace dynload {

extern std::once_flag flashattn_v1_dso_flag;
extern void *flashattn_v1_dso_handle;

using flash_attn_fwd_v1_func_t = bool (*)(
const void * /*q*/, // total_q x num_heads x head_size, total_q :=
// \sum_{i=0}^{b} s_i
const void * /*k*/, // total_k x num_heads x head_size, total_k :=
// \sum_{i=0}^{b} s_i
const void * /*v*/, // total_k x num_heads x head_size, total_k :=
// \sum_{i=0}^{b} s_i
void * /*out*/, // total_q x num_heads x head_size, total_k :=
// \sum_{i=0}^{b} s_i
const void * /*cu_seqlens_q*/, // int32, batch_size+1, starting offset of
// each sequence
const void * /*cu_seqlens_k*/, // int32, batch_size+1, starting offset of
// each sequence
const int /*total_q*/,
const int /*total_k*/,
const int /*batch_size*/,
const int /*num_heads*/,
const int /*head_size*/,
const int /*max_seqlen_q_*/,
const int /*max_seqlen_k_*/,
const float /*p_dropout*/,
const float /*softmax_scale*/,
const bool /*zero_tensors*/,
const bool /*is_causal*/,
const bool /*is_bf16*/,
const int /*num_splits*/, // SMs per attention matrix, can be 1
void * /*softmax_lse_ptr*/, // softmax log_sum_exp
void * /*softmax_ptr*/,
void * /*workspace_ptr*/,
uint64_t * /*workspace_size*/,
cudaStream_t /*stream*/,
uint64_t /*seed*/,
uint64_t /*offset*/
);

using flash_attn_bwd_v1_func_t = bool (*)(
const void * /*q*/, // total_q x num_heads x head_size, total_q :=
// \sum_{i=0}^{b} s_i
const void * /*k*/, // total_k x num_heads x head_size, total_k :=
// \sum_{i=0}^{b} s_i
const void * /*v*/, // total_k x num_heads x head_size, total_k :=
// \sum_{i=0}^{b} s_i
void * /*dq*/, // total_q x num_heads x head_size, total_q :=
// \sum_{i=0}^{b} s_i
void * /*dk*/, // total_k x num_heads x head_size, total_k :=
// \sum_{i=0}^{b} s_i
void * /*dv*/, // total_k x num_heads x head_size, total_k :=
// \sum_{i=0}^{b} s_i
const void * /*out*/, // total_q x num_heads x head_size, total_k :=
// \sum_{i=0}^{b} s_i
const void * /*dout*/, // total_q x num_heads, x head_size
const void * /*cu_seqlens_q*/, // int32, batch_size+1
const void * /*cu_seqlens_k*/, // int32, batch_size+1
const int /*total_q*/,
const int /*total_k*/,
const int /*batch_size*/,
const int /*num_heads*/,
const int /*head_size*/,
const int /*max_seqlen_q_*/,
const int /*max_seqlen_k_*/,
const float /*p_dropout*/,
const float /*softmax_scale*/,
const bool /*zero_tensors*/,
const bool /*is_causal*/,
const bool /*is_bf16*/,
const int /*num_splits*/,
void * /*softmax_lse_ptr*/,
void * /*dsoftmax_ptr*/,
void * /*workspace_ptr*/,
uint64_t * /*workspace_size*/,
cudaStream_t /*stream*/,
uint64_t /*seed*/,
uint64_t /*offset*/
);

using flash_attn_error_v1_func_t = const char *(*)();

#define DYNAMIC_LOAD_FLASHATTN_V1_WRAP(__name) \
struct DynLoad__##__name##__v1 { \
template <typename... Args> \
auto operator()(Args... args) { \
using flashattnFunc = ::phi::dynload::__name##_v1_func_t; \
std::call_once(flashattn_v1_dso_flag, []() { \
flashattn_v1_dso_handle = phi::dynload::GetFlashAttnV1DsoHandle(); \
}); \
static void *p_##__name = dlsym(flashattn_v1_dso_handle, #__name); \
return reinterpret_cast<flashattnFunc>(p_##__name)(args...); \
} \
}; \
extern DynLoad__##__name##__v1 __name##_v1

#define DECLARE_DYNAMIC_LOAD_FLASHATTN_V1_WRAP(__name) \
DYNAMIC_LOAD_FLASHATTN_V1_WRAP(__name)

#define FLASHATTN_V1_ROUTINE_EACH(__macro) \
__macro(flash_attn_fwd); \
__macro(flash_attn_bwd); \
__macro(flash_attn_error);

FLASHATTN_V1_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_FLASHATTN_V1_WRAP);

#undef DYNAMIC_LOAD_FLASHATTN_V1_WRAP

} // namespace dynload
} // namespace phi
Loading

0 comments on commit 8e90e92

Please sign in to comment.