Skip to content

Commit

Permalink
fixup! CUDA: add FP32 FlashAttention vector kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed May 11, 2024
1 parent bbeb952 commit 41f5f3a
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 8 deletions.
33 changes: 29 additions & 4 deletions ggml-cuda/fattn-vec-f16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ static __global__ void flash_attn_vec_ext_f16(
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
const float max_bias,
const float m0,
const float m1,
const uint32_t n_head_log2,
const int ne00,
const int ne01,
const int ne02,
Expand Down Expand Up @@ -49,6 +53,18 @@ static __global__ void flash_attn_vec_ext_f16(
const int stride_KV = nb11 / sizeof(half);
const int stride_KV2 = nb11 / sizeof(half2);

half slopeh = __float2half(1.0f);

// ALiBi
if (max_bias > 0.0f) {
const int h = blockIdx.y;

const float base = h < n_head_log2 ? m0 : m1;
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;

slopeh = __float2half(powf(base, exph));
}

static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
constexpr int nwarps = D / WARP_SIZE;
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
Expand Down Expand Up @@ -132,7 +148,7 @@ static __global__ void flash_attn_vec_ext_f16(
for (int j = 0; j < ncols; ++j) {
sum2[j] = warp_reduce_sum(sum2[j]);
half sum = __low2half(sum2[j]) + __high2half(sum2[j]);
sum += mask ? maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);

if (ncols == 1) {
kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
Expand Down Expand Up @@ -244,8 +260,17 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
const int shmem = 0;

float scale;
memcpy(&scale, KQV->op_params, sizeof(float));
float scale = 1.0f;
float max_bias = 0.0f;

memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));

const uint32_t n_head = Q->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));

const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);

flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>
<<<blocks_num, block_dim, shmem, main_stream>>> (
Expand All @@ -254,7 +279,7 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_
(const char *) V->data,
mask ? ((const char *) mask->data) : nullptr,
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale,
scale, max_bias, m0, m1, n_head_log2,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Expand Down
33 changes: 29 additions & 4 deletions ggml-cuda/fattn-vec-f32.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ static __global__ void flash_attn_vec_ext_f32(
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
const float max_bias,
const float m0,
const float m1,
const uint32_t n_head_log2,
const int ne00,
const int ne01,
const int ne02,
Expand Down Expand Up @@ -48,6 +52,18 @@ static __global__ void flash_attn_vec_ext_f32(
const int stride_KV = nb11 / sizeof(half);
const int stride_KV2 = nb11 / sizeof(half2);

float slope = 1.0f;

// ALiBi
if (max_bias > 0.0f) {
const int h = blockIdx.y;

const float base = h < n_head_log2 ? m0 : m1;
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;

slope = powf(base, exph);
}

static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
constexpr int nwarps = D / WARP_SIZE;
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
Expand Down Expand Up @@ -127,7 +143,7 @@ static __global__ void flash_attn_vec_ext_f32(
#pragma unroll
for (int j = 0; j < ncols; ++j) {
sum[j] = warp_reduce_sum(sum[j]);
sum[j] += mask ? __half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
sum[j] += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;

kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum[j]);

Expand Down Expand Up @@ -230,8 +246,17 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
const int shmem = 0;

float scale;
memcpy(&scale, KQV->op_params, sizeof(float));
float scale = 1.0f;
float max_bias = 0.0f;

memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));

const uint32_t n_head = Q->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));

const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);

flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks>
<<<blocks_num, block_dim, shmem, main_stream>>> (
Expand All @@ -240,7 +265,7 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_
(const char *) V->data,
mask ? ((const char *) mask->data) : nullptr,
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale,
scale, max_bias, m0, m1, n_head_log2,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Expand Down

0 comments on commit 41f5f3a

Please sign in to comment.