Skip to content

Commit

Permalink
fix performance regression
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesGaessler committed May 9, 2024
1 parent fa81c3a commit 2272765
Showing 1 changed file with 22 additions and 9 deletions.
31 changes: 22 additions & 9 deletions ggml-cuda/fattn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,15 @@ static __global__ void flash_attn_vec_ext_f16(
const int k_start = parallel_blocks == 1 ? 0 : ip*D;
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
// Calculate KQ tile and keep track of new maximum KQ values:
half kqmax_new[ncols];

// For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
// see https://github.com/ggerganov/llama.cpp/pull/7061 .
// Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable).
half kqmax_new = kqmax[0];
half kqmax_new_arr[ncols];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
kqmax_new[j] = kqmax[j];
kqmax_new_arr[j] = kqmax[j];
}

#pragma unroll
Expand Down Expand Up @@ -137,7 +142,13 @@ static __global__ void flash_attn_vec_ext_f16(
sum2[j] = warp_reduce_sum(sum2[j]);
half sum = __low2half(sum2[j]) + __high2half(sum2[j]);
sum += mask ? maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
kqmax_new[j] = ggml_cuda_hmax(kqmax_new[j], sum);

if (ncols == 1) {
kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
} else {
kqmax_new_arr[j] = ggml_cuda_hmax(kqmax_new_arr[j], sum);
}

if (threadIdx.x == 0) {
KQ[j*D + i_KQ] = sum;
}
Expand All @@ -146,21 +157,23 @@ static __global__ void flash_attn_vec_ext_f16(

#pragma unroll
for (int j = 0; j < ncols; ++j) {
kqmax_new[j] = warp_reduce_max(kqmax_new[j]);
half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];

kqmax_new_j = warp_reduce_max(kqmax_new_j);
if (threadIdx.x == 0) {
kqmax_shared[j][threadIdx.y] = kqmax_new[j];
kqmax_shared[j][threadIdx.y] = kqmax_new_j;
}
}

__syncthreads();

#pragma unroll
for (int j = 0; j < ncols; ++j) {
kqmax_new[j] = kqmax_shared[j][threadIdx.x];
kqmax_new[j] = warp_reduce_max(kqmax_new[j]);
half kqmax_new_j = kqmax_shared[j][threadIdx.x];
kqmax_new_j = warp_reduce_max(kqmax_new_j);

const half KQ_max_scale = hexp(kqmax[j] - kqmax_new[j]);
kqmax[j] = kqmax_new[j];
const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
kqmax[j] = kqmax_new_j;

const half val = hexp(KQ[j*D + tid] - kqmax[j]);
kqsum[j] = kqsum[j]*KQ_max_scale + val;
Expand Down

0 comments on commit 2272765

Please sign in to comment.