-
-
Notifications
You must be signed in to change notification settings - Fork 609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MultiHeadAttention implementation #2146
Conversation
Codecov ReportPatch coverage:
📣 This organization is not using Codecov’s GitHub App Integration. We recommend you install it so Codecov can continue to function properly for your repositories. Learn more Additional details and impacted files@@ Coverage Diff @@
## master #2146 +/- ##
==========================================
+ Coverage 82.63% 86.46% +3.83%
==========================================
Files 23 20 -3
Lines 1578 1537 -41
==========================================
+ Hits 1304 1329 +25
+ Misses 274 208 -66
... and 7 files with indirect coverage changes Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report in Codecov by Sentry. |
Thanks for making a start. Some questions, perhaps not all immediate, are:
Re implementations & performance, well I'm obviously pleased how well Tullio does on the CPU, compared to permutedims etc, at least at this size. It is easy to fuse the scaling of Most of the time is spent in On GPU, I presume that it will be very hard to beat cudnnMultiHeadAttnForward . But have timed nothing. (Nor have I looked closely at what cases that does & doesn't handle.) julia> dim, len, batch_size, num_heads = 128, 8, 128, 32;
julia> mha = MultiHeadAttention(dim, num_heads);
julia> x = rand(Float32, (dim, len, batch_size));
julia> mha(x, impl=:native) ≈ mha(x, impl=:tullio)
true
# Timing the layer
julia> @btime mha(x, impl=:native, proj=false);
min 7.193 ms, mean 8.043 ms (152 allocations, 8.26 MiB)
julia> @btime mha(x, impl=:tullio, proj=false);
min 3.317 ms, mean 3.701 ms (161 allocations, 6.26 MiB)
min 3.196 ms, mean 3.721 ms (167 allocations, 5.76 MiB) with scale absorbed
julia> @btime mha(x, impl=:nalib, proj=false);
min 7.070 ms, mean 7.926 ms (139 allocations, 6.76 MiB)
# Timing pieces
julia> @btime copy($x); # so each copy is 1/2 MB
min 13.083 μs, mean 48.775 μs (2 allocations, 512.06 KiB)
julia> mha.qkv_proj # this step is always done, could be fused?
QKVProj(Dense(128 => 128; bias=false), Dense(128 => 128; bias=false), Dense(128 => 128; bias=false))
julia> @btime mha.qkv_proj(x,x,x); # half the memory, but not much time.
min 200.375 μs, mean 500.531 μs (26 allocations, 3.00 MiB)
julia> @btime mha.attn_drop(x); # dropout appears to be cheap
min 132.096 ns, mean 139.195 ns (1 allocation, 48 bytes)
julia> z = rand(Float32, 8, 8, 32, 128); # size of input to softmax
julia> @btime softmax($z);
min 2.754 ms, mean 2.883 ms (12 allocations, 1.25 MiB)
julia> @btime maximum($z; dims=1); # this is a fairly big component
min 767.625 μs, mean 801.111 μs (6 allocations, 128.25 KiB)
julia> @btime @fastmath reduce(max, $z; dims=1, init=0.0f0);
min 78.125 μs, mean 97.439 μs (3 allocations, 128.11 KiB)
julia> @btime _softmax($z); # an attempt...
min 2.785 ms, mean 2.964 ms (9 allocations, 1.25 MiB) # with reduce(max), why not better?
min 659.542 μs, mean 825.112 μs (9 allocations, 1.25 MiB) # with @turbo, in fact wrong
# Without softmax nor qkv_proj
julia> 7.193 - 2.754 - 0.2 # native
4.239
julia> 3.317 - 2.754 - 0.2 # tullio
0.36300000000000016
julia> 7.070 - 2.754 - 0.2 # nalib
4.1160000000000005 |
The paper "Attention is all you need" and all deep learning frameworks are quite consistent in what MultiHeadAttention is, all those components are needed. I can remove the dropout on the outputs, I don't remember where I got that from, it is absent in pytorch, keras, flax. Pytorch goes also much further and implements the whole transformer model
for me, it would be flexibility > readability > performance
Being the overall performance for the layer not so different (thanks to softmax) I think we can go with |
I can keep only the "native" implementation, and leave the tullio contractions as comments that help understand what is going on and also as pointers for implementing more performant similar operations. I will also try to hook cudnn. |
The goal of NeuralAttenionlib is to provide flexible attention variants without sacrificing the performance (though I'm mainly focusing on GPU performance, no CPU).
If anyone is interested in beating cudnnMultiHeadAttnForward, there is a FlashAttention that fuse almost all operation in attention into one cuda kernel, which use the WMMA api. It might be possible to follow their strategy and implement it with CUDA.jl Also, it important to check the performance of the backward function / gradient computation.
IMPO I think Flux really need to setup a paradigm for building model. Like in Pytorch, all the predefined layer object is not needed. The user can always inheritance the |
Re functional things, some discussion here but no PRs yet. For the attention story, I think what you're asking for is more like what we have for For CPU implementations, note that Thanks for the link to FlashAttention, from a very quick skim, interesting they can avoid memory quadratic in sequence length. And surprising how much time they say pytorch spends on mask/softmax/dropout. |
Where do you feel the current approach of |
I'm not sure if we really need a "factory", but I do think we doesn't separate functionality thoroughly enough. There are many stuff that should be defined in NNlib but currently live in Flux (e.g. dropout, normalization, etc.). Ideally (IMPO) every forward definition of a layer should be almost a single call of a function from NNlib, so yes, just like what we have for The attention story is more complicated in NAlib. It is not only designed with the functional style in mind, but also tries to be flexible enough to support different attention variants. For example: generic_multihead_qkv_attention(
weighted_sum_mixing,
normalized_score(softmax) $
masked_score(GenericMaskOp(), mask) $
scalar_relative_position_embedding(t5_bucketed_position_id(n_bucket, max_distance), position_bias) $
dot_product_score,
head, q, k, v
) This is an implementation of a multi-head attention with relative position embedding. By given different kinds of |
100% agreed. FluxML/NNlib.jl#452 is a big step in that direction. I think we can set a de-facto policy that any future Flux layers should follow the criteria you've outlined as well. On the attention op composition example, the ideal would be that a user can mix and match bits like you've shown, but end up with something close to the performance of Flash Attention (e.g. memory efficient, fusion where applicable). It's not clear to me what that would look like, but it certainly seems within the realm of possibility. How to represent this in a layer API is just as interesting. Do we stick with a fixed "default" MHA layer as mentioned above, or can we allow some composition at that level too? |
We can always dispatch to optimized specialization of a composition, but I'm not sure if it's possible to have a composable flash attention since that might require composability at the kernel level.
I store an attention op inside the attention layer and define types for each composition. |
I don't think it's impossible, but it would be tricky. Separate host-side functions should always be the fallback, but I feel we could eventually come up with something like Base's broadcast types which allows for kernel composition (probably would involve some generated functions). |
Another issue is the gradient computation. Even the gradient of the broadcast fusion is hard to compute, I don't think it would easier for attention fusion. I was actually hoping Enzyme could help with this, but we need to have the fused kernel beforehand. |
I'd be pleasantly surprised if Enzyme could synthesize a fast gradient for such complex kernels. Definitely been playing around with it for fused broadcasting however. Speculatively, a halfway solution could be to define something like https://github.com/JuliaGPU/GemmKernels.jl where we create gradient building blocks and have the library combine them together. Anyhow, none of that ought to hold up this PR, just a nice brainstorm :) |
Sure, I have no opinion on this PR since Transformers.jl would be using its own implementation. |
0ede58c
to
a1e8365
Compare
any feedback? This should be essentially ready |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't look like there's much left to say here, shall we do this?
letsdoit! |
For the time being, just porting the rather basic implementation in Metalhead.jl.In the future (this PR or future ones), we will want to rely on NeuralAttentionlib.jl for the heavy lifting.Possibly, we will want to add all the flexibility of Transformers.jl and pytorch implementations.Implementation of a (cross-)attention layer, inspired by pytorch, Transformers.jl and flax
The implementation relies on the recently added
NNlib.dot_product_attention
.cc @chengchingwen @theabhirath @darsnack @mcabbott
Fix FluxML/Metalhead.jl#141
PR Checklist
Old Post
At the moment, there are two implementations:
We will have to pick one. I tested a few cases and they yield the same outputs.
NeuralAttentionlib.jl seems to be more performant, probably thanks to the handwritten rules.
I can add a "native" implementation using permutedims and batched_mul.