forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
cherry-pick from (PaddlePaddle#56040): Make flash attn v1 available
- Loading branch information
Showing
21 changed files
with
1,127 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
include(ExternalProject) | ||
|
||
add_definitions(-DPADDLE_WITH_FLASHATTN) | ||
|
||
set(FLASHATTN_V1_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn_v1) | ||
set(FLASHATTN_V1_SOURCE_SUBDIR csrc/flash_attn) | ||
set(FLASHATTN_V1_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn_v1) | ||
set(FLASHATTN_V1_REPOSITORY ${GIT_URL}/PaddlePaddle/flash-attention.git) | ||
set(FLASHATTN_V1_TAG 5ff4bbf56ad066750407c4aef16ac740ebda0717) | ||
|
||
set(FLASHATTN_V1_INCLUDE_DIR | ||
"${FLASHATTN_V1_INSTALL_DIR}/include" | ||
CACHE PATH "flash-attn v1 Directory" FORCE) | ||
set(FLASHATTN_V1_LIB_DIR | ||
"${FLASHATTN_V1_INSTALL_DIR}/lib" | ||
CACHE PATH "flash-attn v1 Library Directory" FORCE) | ||
|
||
if(WIN32) | ||
set(FLASHATTN_V1_OLD_LIBRARIES | ||
"${FLASHATTN_V1_INSTALL_DIR}/bin/flashattn${CMAKE_SHARED_LIBRARY_SUFFIX}" | ||
CACHE FILEPATH "flash-attn v1 Library" FORCE) | ||
set(FLASHATTN_V1_LIBRARIES | ||
"${FLASHATTN_V1_INSTALL_DIR}/bin/flashattn_v1${CMAKE_SHARED_LIBRARY_SUFFIX}" | ||
CACHE FILEPATH "flash-attn v1 Library" FORCE) | ||
else() | ||
set(FLASHATTN_V1_OLD_LIBRARIES | ||
"${FLASHATTN_V1_INSTALL_DIR}/lib/libflashattn${CMAKE_SHARED_LIBRARY_SUFFIX}" | ||
CACHE FILEPATH "flash-attn v1 Library" FORCE) | ||
set(FLASHATTN_V1_LIBRARIES | ||
"${FLASHATTN_V1_INSTALL_DIR}/lib/libflashattn_v1${CMAKE_SHARED_LIBRARY_SUFFIX}" | ||
CACHE FILEPATH "flash-attn v1 Library" FORCE) | ||
endif() | ||
|
||
if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang" | ||
OR CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" | ||
OR WIN32) | ||
set(USE_OMP OFF) | ||
else() | ||
set(USE_OMP ON) | ||
endif() | ||
|
||
if(WIN32) | ||
set(FLASHATTN_V1_C_FLAGS $<FILTER:${CMAKE_C_FLAGS},EXCLUDE,/Zc:inline>) | ||
set(FLASHATTN_V1_C_FLAGS_DEBUG | ||
$<FILTER:${CMAKE_C_FLAGS_DEBUG},EXCLUDE,/Zc:inline>) | ||
set(FLASHATTN_V1_C_FLAGS_RELEASE | ||
$<FILTER:${CMAKE_C_FLAGS_RELEASE},EXCLUDE,/Zc:inline>) | ||
set(FLASHATTN_V1_CXX_FLAGS $<FILTER:${CMAKE_CXX_FLAGS},EXCLUDE,/Zc:inline>) | ||
set(FLASHATTN_V1_CXX_FLAGS_RELEASE | ||
$<FILTER:${CMAKE_CXX_FLAGS_RELEASE},EXCLUDE,/Zc:inline>) | ||
set(FLASHATTN_V1_CXX_FLAGS_DEBUG | ||
$<FILTER:${CMAKE_CXX_FLAGS_DEBUG},EXCLUDE,/Zc:inline>) | ||
else() | ||
set(FLASHATTN_V1_C_FLAGS ${CMAKE_C_FLAGS}) | ||
set(FLASHATTN_V1_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG}) | ||
set(FLASHATTN_V1_C_FLAGS_RELEASE ${CMAKE_C_FLAGS_RELEASE}) | ||
set(FLASHATTN_V1_CXX_FLAGS ${CMAKE_CXX_FLAGS}) | ||
set(FLASHATTN_V1_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE}) | ||
set(FLASHATTN_V1_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG}) | ||
endif() | ||
|
||
ExternalProject_Add( | ||
extern_flashattn_v1 | ||
${EXTERNAL_PROJECT_LOG_ARGS} ${SHALLOW_CLONE} | ||
GIT_REPOSITORY ${FLASHATTN_V1_REPOSITORY} | ||
GIT_TAG ${FLASHATTN_V1_TAG} | ||
PREFIX ${FLASHATTN_V1_PREFIX_DIR} | ||
SOURCE_SUBDIR ${FLASHATTN_V1_SOURCE_SUBDIR} | ||
UPDATE_COMMAND "" | ||
PATCH_COMMAND "" | ||
#BUILD_ALWAYS 1 | ||
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} | ||
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} | ||
-DCMAKE_C_FLAGS=${FLASHATTN_V1_C_FLAGS} | ||
-DCMAKE_C_FLAGS_DEBUG=${FLASHATTN_V1_C_FLAGS_DEBUG} | ||
-DCMAKE_C_FLAGS_RELEASE=${FLASHATTN_V1_C_FLAGS_RELEASE} | ||
-DCMAKE_CXX_FLAGS=${FLASHATTN_V1_CXX_FLAGS} | ||
-DCMAKE_CXX_FLAGS_RELEASE=${FLASHATTN_V1_CXX_FLAGS_RELEASE} | ||
-DCMAKE_CXX_FLAGS_DEBUG=${FLASHATTN_V1_CXX_FLAGS_DEBUG} | ||
-DCMAKE_INSTALL_PREFIX=${FLASHATTN_V1_INSTALL_DIR} | ||
-DWITH_GPU=${WITH_GPU} | ||
-DCMAKE_CUDA_COMPILER=${CMAKE_CUDA_COMPILER} | ||
-DWITH_ROCM=${WITH_ROCM} | ||
-DWITH_OMP=${USE_OMP} | ||
-DBUILD_SHARED=ON | ||
-DCMAKE_POSITION_INDEPENDENT_CODE=ON | ||
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} | ||
${EXTERNAL_OPTIONAL_ARGS} | ||
CMAKE_CACHE_ARGS | ||
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} | ||
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON | ||
-DCMAKE_INSTALL_PREFIX:PATH=${FLASHATTN_V1_INSTALL_DIR} | ||
BUILD_BYPRODUCTS ${FLASHATTN_V1_LIBRARIES}) | ||
|
||
add_custom_target( | ||
extern_flashattn_v1_move_lib | ||
COMMAND ${CMAKE_COMMAND} -E copy ${FLASHATTN_V1_OLD_LIBRARIES} | ||
${FLASHATTN_V1_LIBRARIES}) | ||
|
||
add_dependencies(extern_flashattn_v1_move_lib extern_flashattn_v1) | ||
|
||
message(STATUS "flash-attn v1 library: ${FLASHATTN_V1_LIBRARIES}") | ||
get_filename_component(FLASHATTN_V1_LIBRARY_PATH ${FLASHATTN_V1_LIBRARIES} | ||
DIRECTORY) | ||
include_directories(${FLASHATTN_V1_INCLUDE_DIR}) | ||
|
||
add_library(flashattn_v1 INTERFACE) | ||
#set_property(TARGET flashattn_v1 PROPERTY IMPORTED_LOCATION ${FLASHATTN_V1_LIBRARIES}) | ||
add_dependencies(flashattn_v1 extern_flashattn_v1 extern_flashattn_v1_move_lib) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/backends/dynload/flashattn_v1.h" | ||
|
||
namespace phi { | ||
namespace dynload { | ||
|
||
std::once_flag flashattn_v1_dso_flag; | ||
void* flashattn_v1_dso_handle = nullptr; | ||
|
||
#define DEFINE_WRAP(__name) DynLoad__##__name##__v1 __name##_v1 | ||
|
||
FLASHATTN_V1_ROUTINE_EACH(DEFINE_WRAP); | ||
|
||
} // namespace dynload | ||
} // namespace phi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include <mutex> // NOLINT | ||
|
||
#include "cuda_runtime.h" // NOLINT | ||
#include "paddle/phi/backends/dynload/dynamic_loader.h" | ||
#include "paddle/phi/backends/dynload/port.h" | ||
|
||
namespace phi { | ||
namespace dynload { | ||
|
||
extern std::once_flag flashattn_v1_dso_flag; | ||
extern void *flashattn_v1_dso_handle; | ||
|
||
using flash_attn_fwd_v1_func_t = bool (*)( | ||
const void * /*q*/, // total_q x num_heads x head_size, total_q := | ||
// \sum_{i=0}^{b} s_i | ||
const void * /*k*/, // total_k x num_heads x head_size, total_k := | ||
// \sum_{i=0}^{b} s_i | ||
const void * /*v*/, // total_k x num_heads x head_size, total_k := | ||
// \sum_{i=0}^{b} s_i | ||
void * /*out*/, // total_q x num_heads x head_size, total_k := | ||
// \sum_{i=0}^{b} s_i | ||
const void * /*cu_seqlens_q*/, // int32, batch_size+1, starting offset of | ||
// each sequence | ||
const void * /*cu_seqlens_k*/, // int32, batch_size+1, starting offset of | ||
// each sequence | ||
const int /*total_q*/, | ||
const int /*total_k*/, | ||
const int /*batch_size*/, | ||
const int /*num_heads*/, | ||
const int /*head_size*/, | ||
const int /*max_seqlen_q_*/, | ||
const int /*max_seqlen_k_*/, | ||
const float /*p_dropout*/, | ||
const float /*softmax_scale*/, | ||
const bool /*zero_tensors*/, | ||
const bool /*is_causal*/, | ||
const bool /*is_bf16*/, | ||
const int /*num_splits*/, // SMs per attention matrix, can be 1 | ||
void * /*softmax_lse_ptr*/, // softmax log_sum_exp | ||
void * /*softmax_ptr*/, | ||
void * /*workspace_ptr*/, | ||
uint64_t * /*workspace_size*/, | ||
cudaStream_t /*stream*/, | ||
uint64_t /*seed*/, | ||
uint64_t /*offset*/ | ||
); | ||
|
||
using flash_attn_bwd_v1_func_t = bool (*)( | ||
const void * /*q*/, // total_q x num_heads x head_size, total_q := | ||
// \sum_{i=0}^{b} s_i | ||
const void * /*k*/, // total_k x num_heads x head_size, total_k := | ||
// \sum_{i=0}^{b} s_i | ||
const void * /*v*/, // total_k x num_heads x head_size, total_k := | ||
// \sum_{i=0}^{b} s_i | ||
void * /*dq*/, // total_q x num_heads x head_size, total_q := | ||
// \sum_{i=0}^{b} s_i | ||
void * /*dk*/, // total_k x num_heads x head_size, total_k := | ||
// \sum_{i=0}^{b} s_i | ||
void * /*dv*/, // total_k x num_heads x head_size, total_k := | ||
// \sum_{i=0}^{b} s_i | ||
const void * /*out*/, // total_q x num_heads x head_size, total_k := | ||
// \sum_{i=0}^{b} s_i | ||
const void * /*dout*/, // total_q x num_heads, x head_size | ||
const void * /*cu_seqlens_q*/, // int32, batch_size+1 | ||
const void * /*cu_seqlens_k*/, // int32, batch_size+1 | ||
const int /*total_q*/, | ||
const int /*total_k*/, | ||
const int /*batch_size*/, | ||
const int /*num_heads*/, | ||
const int /*head_size*/, | ||
const int /*max_seqlen_q_*/, | ||
const int /*max_seqlen_k_*/, | ||
const float /*p_dropout*/, | ||
const float /*softmax_scale*/, | ||
const bool /*zero_tensors*/, | ||
const bool /*is_causal*/, | ||
const bool /*is_bf16*/, | ||
const int /*num_splits*/, | ||
void * /*softmax_lse_ptr*/, | ||
void * /*dsoftmax_ptr*/, | ||
void * /*workspace_ptr*/, | ||
uint64_t * /*workspace_size*/, | ||
cudaStream_t /*stream*/, | ||
uint64_t /*seed*/, | ||
uint64_t /*offset*/ | ||
); | ||
|
||
using flash_attn_error_v1_func_t = const char *(*)(); | ||
|
||
#define DYNAMIC_LOAD_FLASHATTN_V1_WRAP(__name) \ | ||
struct DynLoad__##__name##__v1 { \ | ||
template <typename... Args> \ | ||
auto operator()(Args... args) { \ | ||
using flashattnFunc = ::phi::dynload::__name##_v1_func_t; \ | ||
std::call_once(flashattn_v1_dso_flag, []() { \ | ||
flashattn_v1_dso_handle = phi::dynload::GetFlashAttnV1DsoHandle(); \ | ||
}); \ | ||
static void *p_##__name = dlsym(flashattn_v1_dso_handle, #__name); \ | ||
return reinterpret_cast<flashattnFunc>(p_##__name)(args...); \ | ||
} \ | ||
}; \ | ||
extern DynLoad__##__name##__v1 __name##_v1 | ||
|
||
#define DECLARE_DYNAMIC_LOAD_FLASHATTN_V1_WRAP(__name) \ | ||
DYNAMIC_LOAD_FLASHATTN_V1_WRAP(__name) | ||
|
||
#define FLASHATTN_V1_ROUTINE_EACH(__macro) \ | ||
__macro(flash_attn_fwd); \ | ||
__macro(flash_attn_bwd); \ | ||
__macro(flash_attn_error); | ||
|
||
FLASHATTN_V1_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_FLASHATTN_V1_WRAP); | ||
|
||
#undef DYNAMIC_LOAD_FLASHATTN_V1_WRAP | ||
|
||
} // namespace dynload | ||
} // namespace phi |
Oops, something went wrong.