-
Notifications
You must be signed in to change notification settings - Fork 11k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CUDA: generalize FP16 fattn vec kernel #7061
CUDA: generalize FP16 fattn vec kernel #7061
Conversation
This Pr
Main branch For reference
|
There is no ROCm FlashAttention implementation on master. The only reason the test is passing is that |
I get an error if I try
|
The issue was that the logic for determining when to compile the kernel was incorrect. Does it work now? |
Now it runs, but the output is broken when FlashAttention is enabled and the --batch-size 16
--batch-size 8
|
Thank you for the bug report, the issue was that I wrote back the data in the wrong order for batch sizes 2-8 and I didn't notice because those batch sizes were not being used when I tested for correctness. It should now be fixed. |
Can confirm this pr is quite a bit slower for my workload. I have another question. Would there be a workaround for GPU like P40 i.e a kernel for FP32 or it's just not worth it. |
This PR works perfect on my P40. |
but it's gonna slow right? P40 is pretty slow on FP16 as it only support int8/fp32. |
The current PR will produce correct results but for usable performance P40s will need a dedicated FP32 kernel which I will also add. |
At that time, I did not try to quantify the difference in quality because comparing the two versions side by side seemed to produce very similar results. If what you said is true for me, it should translate in a higher perplexity, right? Here what I get: My PR: Same result. This is with a random Q6_K model over wikitext. About performance, here is what I get. My PR:
This PR:
|
so with FA, it got better PPL consistently . |
e4643c9
to
57bde8c
Compare
@jdecourval thank you for the data. I'll still try to write a kernel without any tensor cores that is more optimized for large batch sizes first but given this data I think that rocWMMA is a viable option. @sorasoras I don't think that enabling/disabling FlashAttention should significantly affect results since the differences should just be due to floating point rounding error. |
RDNA3 is gonna slowdown, Any there any plan for this?
|
Just don't use FlashAttention until it's faster? |
Is there something in this PR that could cause a slowdown in TG?
There are also some warnings:
|
Thank you for pointing this out, I did not think to test this but the new kernel is for whatever reason massively slower than the one on master for a batch size of 1. According to NSight Compute the kernel runtime that I used for testing increased from 70 µs to 92 µs. I really don't understand why this is happening. The compiler should be able to unroll all loops over Should we just keep both kernel versions? |
Personally I am not interested about the hardware that benefits from this change, and I don't expect to make any changes to this code, so if you want to maintain another copy of the flash attn kernel that's entirely up to you. It would be nice to keep the same performance with tensor cores as in master, but IMO the difference in overall performance is small enough that it could be ignored. More results with the latest commit:
|
If you start reverting the |
Intuitively I would think the issue has to do with the data types (array vs. regular variables) rather than the loops. |
FWIW, I tried having a quick look at the performance on my 7900xtx, was able to bring it back to the level of your previous kernel, by removing most checks for AMD, which meant re-enabling rocwmma. Compare with above:
It looks like this (pile of hacks): jdecourval@b82649d |
I can confirm that the performance regression is caused by In any case, I think this is worth asking my contact at NVIDIA about. This could be a legitimate compiler bug. |
78ee06e
to
fece1fe
Compare
From what I can tell the issue is caused specifically |
I think so.
|
This PR adds a FlashAttention kernel that only uses regular FP16 arithmetic and not any tensor core operations by generalizing the kernel I wrote for batch size 1 to larger batch sizes. It works reasonably well for batch sizes <= 8. The target hardware is the NVIDIA P100 and AMD RX 5000/6000 GPUs although the code also makes it possible to run FlashAttention on AMD RX 7000 (with likely much worse performance) and on other Pascal GPUs (with effectively unusable performance). On my RX 6800 the performance changes as follows:
This PR also rearranges the order of some definitions in
ggml-cuda/common.cuh
. This is because on masterNO_DEVICE_CODE
is broken on AMD due to the definitions being in the wrong order. I also implemented some warp reduce functions that didn't work on AMD.@jdecourval when you worked on #6773 , did you check that the code actually produces correct results? I encountered an issue where due to missing implementations in
ggml-cuda/common.cuh
the compiler would optimize out part of the kernel which resulted in significantly faster but useless code.