Skip to content

Commit

Permalink
Fix the compiling error for non flash-attn case.
Browse files Browse the repository at this point in the history
  • Loading branch information
Xreki committed Aug 7, 2023
1 parent 28a40ab commit ec4bcc4
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 18 deletions.
13 changes: 6 additions & 7 deletions paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,8 @@
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/arange_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h"

#ifdef PADDLE_WITH_FLASHATTN
#include "paddle/phi/backends/dynload/flashattn.h"
#include "paddle/phi/kernels/gpu/flash_attn_utils.h"
#endif
#include "paddle/phi/kernels/reshape_kernel.h"

DECLARE_bool(cudnn_deterministic);

Expand All @@ -52,6 +48,7 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx,
DenseTensor* dq,
DenseTensor* dk,
DenseTensor* dv) {
#ifdef PADDLE_WITH_FLASHATTN
const cudaStream_t stream = ctx.stream();

auto dims = q.dims();
Expand All @@ -77,8 +74,7 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx,
const int64_t* seed_offset_data = seed_offset.data<int64_t>();
uint64_t seed = static_cast<uint64_t>(seed_offset_data[0]);
uint64_t offset = static_cast<uint64_t>(seed_offset_data[1]);
VLOG(10) << "FlashAttn bwd seed: " << seed << ", offset: " << offset
<< ", num_splits:" << num_splits;
VLOG(10) << "FlashAttn bwd seed: " << seed << ", offset: " << offset;

int64_t seqlen_q = ((max_seqlen_q + 16 - 1) / 16) * 16;
DenseTensor dsoftmax = Empty<float>(ctx, {batch_size, num_heads, seqlen_q});
Expand Down Expand Up @@ -174,6 +170,9 @@ void FlashAttnUnpaddedGradImpl(const Context& ctx,

int64_t q_size = total_q * num_heads * head_size;
ComputeScaleQ(ctx, q_size, scale, dq->data<T>(), dq->data<T>());
#else
RaiseNotSupportedError();
#endif
}

template <typename T, typename Context>
Expand Down
10 changes: 5 additions & 5 deletions paddle/phi/kernels/gpu/flash_attn_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,8 @@
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/arange_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h"

#ifdef PADDLE_WITH_FLASHATTN
#include "paddle/phi/backends/dynload/flashattn.h"
#include "paddle/phi/kernels/gpu/flash_attn_utils.h"
#endif
#include "paddle/phi/kernels/reshape_kernel.h"

DECLARE_bool(cudnn_deterministic);

Expand Down Expand Up @@ -54,6 +50,7 @@ void FlashAttnWithMaskUnpaddedImpl(
DenseTensor* softmax,
DenseTensor* softmax_lse,
DenseTensor* seed_offset) {
#ifdef PADDLE_WITH_FLASHATTN
cudaStream_t stream = ctx.stream();

auto dims = q.dims();
Expand Down Expand Up @@ -189,6 +186,9 @@ void FlashAttnWithMaskUnpaddedImpl(
mask_dims.data() ? mask_dims.data() : nullptr,
nullptr);
CheckFlashAttnStatus(succ);
#else
RaiseNotSupportedError();
#endif
}

template <typename T, typename Context>
Expand Down
14 changes: 8 additions & 6 deletions paddle/phi/kernels/gpu/flash_attn_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

namespace phi {

#ifdef PADDLE_WITH_FLASHATTN
static std::pair<uint64_t, uint64_t> GenerateRNGState(
const GPUContext& ctx,
const paddle::optional<DenseTensor>& fixed_seed_offset,
Expand Down Expand Up @@ -208,12 +209,6 @@ static void CheckFlashAttnStatus(const bool status) {
phi::dynload::flash_attn_error()));
}

static void RaiseNotSupportedError() {
PADDLE_THROW(
phi::errors::Unimplemented("FlashAttention is unsupported, please check "
"the GPU compability and CUDA Version."));
}

template <typename T>
__global__ void SimleScaleKernel(const T* input,
int64_t numel,
Expand Down Expand Up @@ -259,5 +254,12 @@ static std::vector<int64_t> GetAttnMaskDims(const DenseTensor* attn_mask) {
}
return mask_dim_4d;
}
#endif

static void RaiseNotSupportedError() {
PADDLE_THROW(
phi::errors::Unimplemented("FlashAttention is unsupported, please check "
"the GPU compability and CUDA Version."));
}

} // namespace phi

0 comments on commit ec4bcc4

Please sign in to comment.