-
-
Notifications
You must be signed in to change notification settings - Fork 124
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batched multiplication support for ndims > 3 #391
Comments
The first question is what rules it should obey. julia> rand(3,4,7) ⊠ rand(4,5,7) |> size
(3, 5, 7)
julia> rand(3,4,1) ⊠ rand(4,5,7) |> size
(3, 5, 7)
julia> rand(3,4,7) ⊠ rand(4,5,1) |> size
(3, 5, 7)
julia> rand(3,4,7) ⊠ rand(4,5) |> size # same, as size(B,3)==1
(3, 5, 7) The easy extension would be this one, but should anything else be allowed? rand(3,4,7,9) ⊠ rand(4,5,7,9) |> size # (3, 5, 7, 9) I think the 3D cases match what |
This is what torch and numpy have to say:
which is exactly what you've suggested (of course Julia has the matrices in the first two dimensions). I think this is a straightforward enough way to go about it given other frameworks do it the same way so it is expected behaviour in a sense. |
The rule you suggest is, I think, this: rand(3,4,7,9) ⊠ rand(4,5,7,9) |> size # (3, 5, 7, 9) # trivial reshape
rand(3,4,7,1) ⊠ rand(4,5,7,9) |> size # (3, 5, 7, 9)
rand(3,4,1,9) ⊠ rand(4,5,7,9) |> size # (3, 5, 7, 9)
rand(3,4,1,9) ⊠ rand(4,5,7,1) |> size # (3, 5, 7, 9)
rand(3,4,7,9,11) ⊠ rand(4,5,7,9,11) |> size # (3, 5, 7, 9, 11) # trivial reshape
rand(3,4,1,9,1) ⊠ rand(4,5,7,1,11) |> size # (3, 5, 7, 9, 11) Notice that there are two levels of difficulty here, all but the last still have a regular stride across batches. Whether that's sufficient for julia> rand(3,4,1,9) |> strides # batch stride 12
(1, 3, 12, 12)
julia> rand(3,4,7,1) |> strides # batch stride 12 only
(1, 3, 12, 84)
julia> rand(3,4,7,9,11) |> strides # batch stride 12 only
(1, 3, 12, 84, 756)
julia> rand(3,4,7,1,11) |> strides # irregular batch stride, 12 and 84
(1, 3, 12, 84, 84) The completely general case is of course not so hard to write as a loop over But on the GPU, at least for ndims=3 the fused strided_batched routine is much quicker than a loop. So ndims>3 cases which can't be written as one call, probably want to be written as a loop over Edit: But there is also Possibly using that for cases which don't fit the strided_batched is better than inventing something? Some discussion of these options here: |
Currently NNlib.batchmul works with arrays with upto 3 dimensions. It would be nice if this could be upgraded to function similar to
numpy.matmul
ortorch.matmul
- this would help in a lot of models, especially some of the attention-based ones I'm working on 😅The text was updated successfully, but these errors were encountered: