-
Notifications
You must be signed in to change notification settings - Fork 11k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix flash-attn for AMD #6773
Fix flash-attn for AMD #6773
Conversation
I don't think it work for gfx1100 at least not yet |
Sorry, I forgot: you need to install rocWMMA https://github.com/ROCm/rocWMMA . |
I just tried this PR (ac6ae5d), but could not compile. I installed rocwmma from source and I'm using ROCm 6.0.2. I tried to complie with:
|
Well, this looks like it would be non-trivial to fix. I was hoping it would be possible to just use rocWMMA as a drop-in replacement. But as I said, I don't have an AMD GPU with tensor cores to debug this with. And quite honestly I don't want to invest a lot of time into this either way because as far as I am concerned none of the current AMD GPUs are worth buying anyways. So unless another dev wants to take over the current llama.cpp FlashAttention implementation using tensor cores will be NVIDIA only. |
I just tried to compile with C++17,
It could be something with my setup, as there should be a fill_fragment function in rocwmma. |
nvm, there are still rocwmma compilation issue with c++17:
Something is not right with the types. @JohannesGaessler, indeed does not seem to look like an easy fix. |
I managed to compile the branch.
According to https://rocm.docs.amd.com/projects/HIPIFY/en/amd-staging/tables/CUDA_Device_API_supported_by_HIP.html,
I know next to nothing about HIP/Cuda, so this is probably wrong.
For this error, here is my workaround: --- a/ggml-cuda/fattn.cu (revision ac6ae5daca029e554af08281a3fd839169725c8c)
+++ b/ggml-cuda/fattn.cu (revision 21b0bf477a56122d8302b218579956b034deaa36)
@@ -339,7 +351,7 @@
frag_c_KQ KQ_c[ncols/frag_n];
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
- wmma::fill_fragment(KQ_c[j], 0.0f);
+ wmma::fill_fragment(KQ_c[j], KQ_acc_t{0.0f});
}
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
@@ -470,7 +482,7 @@
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
- wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
+ wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], __half{0.0f});
}
#pragma unroll However, the resulting binary crashes, and it doesn't seem directly related to flash-attn:
|
@jdecourval just tried your suggestions (I also know nothing about HIP/CUDA), but ChatGPT has a similar proposal for hmax2. And for me it compiles and works. I tried:
with the output:
|
Thanks @tbocek, it was indeed something on my side, probably a leftover from another test. Having the code compile is a nice improvement, but it still doesn't work here if I add the
It comes from the fact that However, simply removing the check does make it work! --- a/ggml-cuda/common.cuh (revision 21b0bf477a56122d8302b218579956b034deaa36)
+++ b/ggml-cuda/common.cuh (date 1713632170170)
@@ -364,16 +364,11 @@
}
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
-#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
}
return a;
-#else
- GGML_UNUSED(a);
- NO_DEVICE_CODE;
-#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
}
static __device__ __forceinline__ float warp_reduce_max(float x) { @JohannesGaessler do you know why that check is there? |
That check is there because that is the old check for FlashAttention and I forgot to change it. |
Here are some results. Token generation is about 2-3% faster.
I pushed all my fixes to this branch: https://github.com/jdecourval/llama.cpp/tree/fixflashattn |
If it's only token generation that is faster then this PR is pretty much pointless because the FlashAttention kernel for batch size 1 does not use tensor cores at all (except for Phi-2). Also can you check llama-bench with |
@jdecourval I applied your patches. Here are my results with -p 4096:
|
And here are mine
Updated with more models |
5ac4690
to
f725ca9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JohannesGaessler What do you think given the numbers so far - is it worth it to merge AMD support?
Maybe in the future there can be more progress. I don't know how difficult it would be for people to install the rocWMMA dependency. We can also put it behind a compile flag?
My current stance is that I don't think that a speedup of 2% is large enough to justify adding a dependency, especially when there is no dev with the ability to test and support the implementation themselves. I think right now it makes more sense to just disable FlashAttention for AMD. One of my next goals is to write kernels that don't use tensor cores at all. It may turn out that those are faster on AMD than rocWMMA anyways. |
Alright. Thank you very much for the help. I will update the target branch to disable flash attention when HIP is enabled for now |
I agree that a 2% performance improvement is not much, but the 500-600MB VRAM reduction may be significant: #5021 (comment) |
On Arch Linux at least, I had to manually install the rocWMMA dependency from github, as there is no current library (also not in AUR). |
There's this one: https://aur.archlinux.org/packages/rocwmma |
82b282c
to
ce281b9
Compare
Obsolete now that #5021 has been merged. |
This PR attempts to make the llama.cpp FlashAttention code run on AMD via HIP. I do not have an AMD GPU with tensor cores so I cannot test the code myself.
Edit: this PR needs rocWMMA https://github.com/ROCm/rocWMMA to be installed.
This PR also changes the order of some definitions in
common.cuh
because otherwiseNO_DEVICE_CODE
is not functional on AMD.