Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MultiHeadAttention implementation #2146

Merged
merged 21 commits into from
Mar 11, 2023
Merged

MultiHeadAttention implementation #2146

merged 21 commits into from
Mar 11, 2023

Conversation

CarloLucibello
Copy link
Member

@CarloLucibello CarloLucibello commented Dec 29, 2022

For the time being, just porting the rather basic implementation in Metalhead.jl.
In the future (this PR or future ones), we will want to rely on NeuralAttentionlib.jl for the heavy lifting.
Possibly, we will want to add all the flexibility of Transformers.jl and pytorch implementations.

Implementation of a (cross-)attention layer, inspired by pytorch, Transformers.jl and flax

The implementation relies on the recently added NNlib.dot_product_attention.

cc @chengchingwen @theabhirath @darsnack @mcabbott

Fix FluxML/Metalhead.jl#141

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable

Old Post

At the moment, there are two implementations:

  • one based on Tullio.jl
  • one offloading the computation to NeuralAttentionlib.jl

We will have to pick one. I tested a few cases and they yield the same outputs.
NeuralAttentionlib.jl seems to be more performant, probably thanks to the handwritten rules.

I can add a "native" implementation using permutedims and batched_mul.

@CarloLucibello CarloLucibello marked this pull request as draft December 29, 2022 14:35
@codecov-commenter
Copy link

codecov-commenter commented Dec 29, 2022

Codecov Report

Patch coverage: 93.10% and project coverage change: +3.83 🎉

Comparison is base (484796c) 82.63% compared to head (29afec7) 86.46%.

📣 This organization is not using Codecov’s GitHub App Integration. We recommend you install it so Codecov can continue to function properly for your repositories. Learn more

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2146      +/-   ##
==========================================
+ Coverage   82.63%   86.46%   +3.83%     
==========================================
  Files          23       20       -3     
  Lines        1578     1537      -41     
==========================================
+ Hits         1304     1329      +25     
+ Misses        274      208      -66     
Impacted Files Coverage Δ
src/Flux.jl 0.00% <ø> (ø)
src/layers/attention.jl 93.10% <93.10%> (ø)

... and 7 files with indirect coverage changes

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report in Codecov by Sentry.
📢 Do you have feedback about the report comment? Let us know in this issue.

@CarloLucibello CarloLucibello mentioned this pull request Dec 29, 2022
92 tasks
@mcabbott
Copy link
Member

mcabbott commented Jan 2, 2023

Thanks for making a start. Some questions, perhaps not all immediate, are:

  1. How complicated a combined thing should be called a layer here? This one is a much bigger block than anything else, where you stack dropout etc. by hand.
  2. Is the goal to provide the basic idea, in a readable way, or the fastest and most flexible possible implementation?
  3. What deps should Flux (or perhaps NNlib) accept? We finally extricated ourselves from some of them...

Re implementations & performance, well I'm obviously pleased how well Tullio does on the CPU, compared to permutedims etc, at least at this size. It is easy to fuse the scaling of q into multiplication & save a copy. It might be possible to fuse more. It might be optimal to provide gradients rather than use its own. Although, see point 3.

Most of the time is spent in softmax. Partly because maximum is surprisingly slow (due to careful NaN handling?). @turbo exp may help a lot (but needs care as it doesn't like size-1 dimensions).

On GPU, I presume that it will be very hard to beat cudnnMultiHeadAttnForward . But have timed nothing. (Nor have I looked closely at what cases that does & doesn't handle.)

julia> dim, len, batch_size, num_heads = 128, 8, 128, 32;
julia> mha = MultiHeadAttention(dim, num_heads);
julia> x = rand(Float32, (dim, len, batch_size));
julia> mha(x, impl=:native)  mha(x, impl=:tullio)
true

# Timing the layer

julia> @btime mha(x, impl=:native, proj=false);
  min 7.193 ms, mean 8.043 ms (152 allocations, 8.26 MiB)

julia> @btime mha(x, impl=:tullio, proj=false);
  min 3.317 ms, mean 3.701 ms (161 allocations, 6.26 MiB)
  min 3.196 ms, mean 3.721 ms (167 allocations, 5.76 MiB)  with scale absorbed

julia> @btime mha(x, impl=:nalib, proj=false);
  min 7.070 ms, mean 7.926 ms (139 allocations, 6.76 MiB)

# Timing pieces

julia> @btime copy($x);  # so each copy is 1/2 MB
  min 13.083 μs, mean 48.775 μs (2 allocations, 512.06 KiB)

julia> mha.qkv_proj   # this step is always done, could be fused?
QKVProj(Dense(128 => 128; bias=false), Dense(128 => 128; bias=false), Dense(128 => 128; bias=false))

julia> @btime mha.qkv_proj(x,x,x);  # half the memory, but not much time.
  min 200.375 μs, mean 500.531 μs (26 allocations, 3.00 MiB)

julia> @btime mha.attn_drop(x);  # dropout appears to be cheap
  min 132.096 ns, mean 139.195 ns (1 allocation, 48 bytes)

julia> z = rand(Float32, 8, 8, 32, 128);  # size of input to softmax

julia> @btime softmax($z);
  min 2.754 ms, mean 2.883 ms (12 allocations, 1.25 MiB)

julia> @btime maximum($z; dims=1);  # this is a fairly big component
  min 767.625 μs, mean 801.111 μs (6 allocations, 128.25 KiB)

julia> @btime @fastmath reduce(max, $z; dims=1, init=0.0f0);
  min 78.125 μs, mean 97.439 μs (3 allocations, 128.11 KiB)

julia> @btime _softmax($z);  # an attempt...
  min 2.785 ms, mean 2.964 ms (9 allocations, 1.25 MiB)  # with reduce(max), why not better?
  min 659.542 μs, mean 825.112 μs (9 allocations, 1.25 MiB)  # with @turbo, in fact wrong

# Without softmax nor qkv_proj

julia> 7.193 - 2.754 - 0.2  # native
4.239

julia> 3.317 - 2.754 - 0.2 # tullio
0.36300000000000016

julia> 7.070 - 2.754 - 0.2  # nalib
4.1160000000000005

@CarloLucibello
Copy link
Member Author

CarloLucibello commented Jan 2, 2023

  1. How complicated a combined thing should be called a layer here? This one is a much bigger block than anything else, where you stack dropout etc. by hand.

The paper "Attention is all you need" and all deep learning frameworks are quite consistent in what MultiHeadAttention is, all those components are needed. I can remove the dropout on the outputs, I don't remember where I got that from, it is absent in pytorch, keras, flax.

Pytorch goes also much further and implements the whole transformer model

  1. Is the goal to provide the basic idea, in a readable way, or the fastest and most flexible possible implementation?

for me, it would be flexibility > readability > performance

  1. What deps should Flux (or perhaps NNlib) accept? We finally extricated ourselves from some of them...

Being the overall performance for the layer not so different (thanks to softmax) I think we can go with native and avoid carrying on additional dependencies

@CarloLucibello
Copy link
Member Author

I can keep only the "native" implementation, and leave the tullio contractions as comments that help understand what is going on and also as pointers for implementing more performant similar operations.

I will also try to hook cudnn.

@chengchingwen
Copy link
Member

chengchingwen commented Jan 2, 2023

The goal of NeuralAttenionlib is to provide flexible attention variants without sacrificing the performance (though I'm mainly focusing on GPU performance, no CPU).

On GPU, I presume that it will be very hard to beat cudnnMultiHeadAttnForward . But have timed nothing. (Nor have I looked closely at what cases that does & doesn't handle.)

If anyone is interested in beating cudnnMultiHeadAttnForward, there is a FlashAttention that fuse almost all operation in attention into one cuda kernel, which use the WMMA api. It might be possible to follow their strategy and implement it with CUDA.jl

Also, it important to check the performance of the backward function / gradient computation.

How complicated a combined thing should be called a layer here? This one is a much bigger block than anything else, where you stack dropout etc. by hand.

IMPO I think Flux really need to setup a paradigm for building model. Like in Pytorch, all the predefined layer object is not needed. The user can always inheritance the torch.nn.Module and use torch.functional to define the forward method. While in Flux, we are more restrict to the predefined Layers and hence the question (what should be viewed as a layer).

@mcabbott
Copy link
Member

mcabbott commented Jan 2, 2023

Re functional things, some discussion here but no PRs yet.

For the attention story, I think what you're asking for is more like what we have for conv right? i.e. some selfattention function which lives in NNlib and does the work (and is overloaded to call the CUDA implementation) while the Flux layer handles initialisation etc. (We should do that to batchnorm too.)

For CPU implementations, note that softmax is the majority of the time only in the Tullio case (and I did not time gradients). FluxML/NNlib.jl#450 makes it a little faster. Now that we have package extensions, having using LoopVectorization load a faster path for softmax!(::Array, ...) is probably something we should do. Perhaps it should similarly load a faster overload for the functional selfattention(::Array...).

Thanks for the link to FlashAttention, from a very quick skim, interesting they can avoid memory quadratic in sequence length. And surprising how much time they say pytorch spends on mask/softmax/dropout.

@ToucheSir
Copy link
Member

The user can always inheritance the torch.nn.Module and use torch.functional to define the forward method.

Where do you feel the current approach of @functortorch.nn.Module and NNlib ≈ torch.nn.functional falls short here? If anything I would've thought we do better than PyTorch because one doesn't have to rely on the entirety of Flux to write custom layers, but maybe there are gaps in functionality/documentation I'm missing.

@chengchingwen
Copy link
Member

I'm not sure if we really need a "factory", but I do think we doesn't separate functionality thoroughly enough. There are many stuff that should be defined in NNlib but currently live in Flux (e.g. dropout, normalization, etc.). Ideally (IMPO) every forward definition of a layer should be almost a single call of a function from NNlib, so yes, just like what we have for conv.

The attention story is more complicated in NAlib. It is not only designed with the functional style in mind, but also tries to be flexible enough to support different attention variants. For example:

generic_multihead_qkv_attention(
    weighted_sum_mixing,
    normalized_score(softmax) $
    masked_score(GenericMaskOp(), mask) $
    scalar_relative_position_embedding(t5_bucketed_position_id(n_bucket, max_distance), position_bias) $
    dot_product_score,
    head, q, k, v
)

This is an implementation of a multi-head attention with relative position embedding. By given different kinds of mask, it can also become other variant like a causal or local one.

@ToucheSir
Copy link
Member

ToucheSir commented Jan 3, 2023

There are many stuff that should be defined in NNlib but currently live in Flux (e.g. dropout, normalization, etc.). Ideally (IMPO) every forward definition of a layer should be almost a single call of a function from NNlib, so yes, just like what we have for conv.

100% agreed. FluxML/NNlib.jl#452 is a big step in that direction. I think we can set a de-facto policy that any future Flux layers should follow the criteria you've outlined as well.

On the attention op composition example, the ideal would be that a user can mix and match bits like you've shown, but end up with something close to the performance of Flash Attention (e.g. memory efficient, fusion where applicable). It's not clear to me what that would look like, but it certainly seems within the realm of possibility.

How to represent this in a layer API is just as interesting. Do we stick with a fixed "default" MHA layer as mentioned above, or can we allow some composition at that level too?

@chengchingwen
Copy link
Member

On the attention op composition example, the ideal would be that a user can mix and match bits like you've shown, but end up with something close to the performance of Flash Attention (e.g. memory efficient, fusion where applicable). It's not clear to me what that would look like, but it certainly seems within the realm of possibility.

We can always dispatch to optimized specialization of a composition, but I'm not sure if it's possible to have a composable flash attention since that might require composability at the kernel level.

How to represent this in a layer API is just as interesting. Do we stick with a fixed "default" MHA layer as mentioned above, or can we allow some composition at that level too?

I store an attention op inside the attention layer and define types for each composition.

@ToucheSir
Copy link
Member

I don't think it's impossible, but it would be tricky. Separate host-side functions should always be the fallback, but I feel we could eventually come up with something like Base's broadcast types which allows for kernel composition (probably would involve some generated functions).

@chengchingwen
Copy link
Member

Another issue is the gradient computation. Even the gradient of the broadcast fusion is hard to compute, I don't think it would easier for attention fusion. I was actually hoping Enzyme could help with this, but we need to have the fused kernel beforehand.

@ToucheSir
Copy link
Member

I'd be pleasantly surprised if Enzyme could synthesize a fast gradient for such complex kernels. Definitely been playing around with it for fused broadcasting however. Speculatively, a halfway solution could be to define something like https://github.com/JuliaGPU/GemmKernels.jl where we create gradient building blocks and have the library combine them together. Anyhow, none of that ought to hold up this PR, just a nice brainstorm :)

@chengchingwen
Copy link
Member

Anyhow, none of that ought to hold up this PR, just a nice brainstorm :)

Sure, I have no opinion on this PR since Transformers.jl would be using its own implementation.

@CarloLucibello CarloLucibello marked this pull request as ready for review March 5, 2023 11:38
@CarloLucibello
Copy link
Member Author

any feedback? This should be essentially ready

src/layers/attention.jl Outdated Show resolved Hide resolved
test/test_utils.jl Outdated Show resolved Hide resolved
Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't look like there's much left to say here, shall we do this?

@CarloLucibello
Copy link
Member Author

letsdoit!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Move MHAttention layer to Flux
5 participants