diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu index f7be0d625e2f2..de479cf9adfd2 100644 --- a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -21,12 +21,8 @@ #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/arange_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" -#include "paddle/phi/kernels/reshape_kernel.h" - -#ifdef PADDLE_WITH_FLASHATTN -#include "paddle/phi/backends/dynload/flashattn.h" #include "paddle/phi/kernels/gpu/flash_attn_utils.h" -#endif +#include "paddle/phi/kernels/reshape_kernel.h" DECLARE_bool(cudnn_deterministic); @@ -52,6 +48,7 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx, DenseTensor* dq, DenseTensor* dk, DenseTensor* dv) { +#ifdef PADDLE_WITH_FLASHATTN const cudaStream_t stream = ctx.stream(); auto dims = q.dims(); @@ -77,8 +74,7 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx, const int64_t* seed_offset_data = seed_offset.data(); uint64_t seed = static_cast(seed_offset_data[0]); uint64_t offset = static_cast(seed_offset_data[1]); - VLOG(10) << "FlashAttn bwd seed: " << seed << ", offset: " << offset - << ", num_splits:" << num_splits; + VLOG(10) << "FlashAttn bwd seed: " << seed << ", offset: " << offset; int64_t seqlen_q = ((max_seqlen_q + 16 - 1) / 16) * 16; DenseTensor dsoftmax = Empty(ctx, {batch_size, num_heads, seqlen_q}); @@ -174,6 +170,9 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx, int64_t q_size = total_q * num_heads * head_size; ComputeScaleQ(ctx, q_size, scale, dq->data(), dq->data()); +#else + RaiseNotSupportedError(); +#endif } template diff --git a/paddle/phi/kernels/gpu/flash_attn_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_kernel.cu index 2f3922093eac3..bcf8791d3c17f 100644 --- a/paddle/phi/kernels/gpu/flash_attn_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_kernel.cu @@ -21,12 +21,8 @@ #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/arange_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" -#include "paddle/phi/kernels/reshape_kernel.h" - -#ifdef PADDLE_WITH_FLASHATTN -#include "paddle/phi/backends/dynload/flashattn.h" #include "paddle/phi/kernels/gpu/flash_attn_utils.h" -#endif +#include "paddle/phi/kernels/reshape_kernel.h" DECLARE_bool(cudnn_deterministic); @@ -54,6 +50,7 @@ void FlashAttnWithMaskUnpaddedImpl( DenseTensor* softmax, DenseTensor* softmax_lse, DenseTensor* seed_offset) { +#ifdef PADDLE_WITH_FLASHATTN cudaStream_t stream = ctx.stream(); auto dims = q.dims(); @@ -189,6 +186,9 @@ void FlashAttnWithMaskUnpaddedImpl( mask_dims.data() ? mask_dims.data() : nullptr, nullptr); CheckFlashAttnStatus(succ); +#else + RaiseNotSupportedError(); +#endif } template diff --git a/paddle/phi/kernels/gpu/flash_attn_utils.h b/paddle/phi/kernels/gpu/flash_attn_utils.h index e3988658db51f..00ba036df09ba 100644 --- a/paddle/phi/kernels/gpu/flash_attn_utils.h +++ b/paddle/phi/kernels/gpu/flash_attn_utils.h @@ -24,6 +24,7 @@ namespace phi { +#ifdef PADDLE_WITH_FLASHATTN static std::pair GenerateRNGState( const GPUContext& ctx, const paddle::optional& fixed_seed_offset, @@ -208,12 +209,6 @@ static void CheckFlashAttnStatus(const bool status) { phi::dynload::flash_attn_error())); } -static void RaiseNotSupportedError() { - PADDLE_THROW( - phi::errors::Unimplemented("FlashAttention is unsupported, please check " - "the GPU compability and CUDA Version.")); -} - template __global__ void SimleScaleKernel(const T* input, int64_t numel, @@ -259,5 +254,12 @@ static std::vector GetAttnMaskDims(const DenseTensor* attn_mask) { } return mask_dim_4d; } +#endif + +static void RaiseNotSupportedError() { + PADDLE_THROW( + phi::errors::Unimplemented("FlashAttention is unsupported, please check " + "the GPU compability and CUDA Version.")); +} } // namespace phi