Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batched multiplication support for ndims > 3 #391

Open
theabhirath opened this issue Feb 20, 2022 · 3 comments · Fixed by #455
Open

Batched multiplication support for ndims > 3 #391

theabhirath opened this issue Feb 20, 2022 · 3 comments · Fixed by #455

Comments

@theabhirath
Copy link
Member

Currently NNlib.batchmul works with arrays with upto 3 dimensions. It would be nice if this could be upgraded to function similar to numpy.matmul or torch.matmul - this would help in a lot of models, especially some of the attention-based ones I'm working on 😅

@mcabbott
Copy link
Member

The first question is what rules it should obey. batched_mul is always matrix-matrix multiplication, and does this:

julia> rand(3,4,7)  rand(4,5,7) |> size
(3, 5, 7)

julia> rand(3,4,1)  rand(4,5,7) |> size
(3, 5, 7)

julia> rand(3,4,7)  rand(4,5,1) |> size
(3, 5, 7)

julia> rand(3,4,7)  rand(4,5) |> size  # same, as size(B,3)==1
(3, 5, 7)

The easy extension would be this one, but should anything else be allowed?

rand(3,4,7,9)  rand(4,5,7,9) |> size # (3, 5, 7, 9)

I think the 3D cases match what CUBLAS.gemm_strided_batched! handles. But more exotic things could be done with a loop of course.

@theabhirath
Copy link
Member Author

This is what torch and numpy have to say:

If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the last two indexes and broadcast accordingly.

which is exactly what you've suggested (of course Julia has the matrices in the first two dimensions). I think this is a straightforward enough way to go about it given other frameworks do it the same way so it is expected behaviour in a sense.

@mcabbott
Copy link
Member

mcabbott commented Feb 20, 2022

The rule you suggest is, I think, this:

rand(3,4,7,9)  rand(4,5,7,9) |> size # (3, 5, 7, 9)  # trivial reshape
rand(3,4,7,1)  rand(4,5,7,9) |> size # (3, 5, 7, 9)
rand(3,4,1,9)  rand(4,5,7,9) |> size # (3, 5, 7, 9)
rand(3,4,1,9)  rand(4,5,7,1) |> size # (3, 5, 7, 9)

rand(3,4,7,9,11)  rand(4,5,7,9,11) |> size # (3, 5, 7, 9, 11)  # trivial reshape
rand(3,4,1,9,1)  rand(4,5,7,1,11) |> size  # (3, 5, 7, 9, 11)

Notice that there are two levels of difficulty here, all but the last still have a regular stride across batches. Whether that's sufficient for gemm_strided_batched! I don't recall:

julia> rand(3,4,1,9) |> strides  # batch stride 12
(1, 3, 12, 12)

julia> rand(3,4,7,1) |> strides  # batch stride 12 only
(1, 3, 12, 84)

julia> rand(3,4,7,9,11) |> strides  # batch stride 12 only
(1, 3, 12, 84, 756)

julia> rand(3,4,7,1,11) |> strides  # irregular batch stride, 12 and 84
(1, 3, 12, 84, 84)

The completely general case is of course not so hard to write as a loop over gemm!, which is all the CPU implementation is anyway. Probably you can hack the broadcast machinery to do the indexing for you.

But on the GPU, at least for ndims=3 the fused strided_batched routine is much quicker than a loop. So ndims>3 cases which can't be written as one call, probably want to be written as a loop over gemm_strided_batched! calls, not over gemm! calls?

Edit:gemm_strided_batched! is here:
https://github.com/JuliaGPU/CUDA.jl/blob/f81cdf7484842889f5adfcaeb60436f5ebfb513a/lib/cublas/wrappers.jl#L1036

But there is also gemm_batched! here:
https://github.com/JuliaGPU/CUDA.jl/blob/f81cdf7484842889f5adfcaeb60436f5ebfb513a/lib/cublas/wrappers.jl#L974

Possibly using that for cases which don't fit the strided_batched is better than inventing something? Some discussion of these options here:
https://developer.nvidia.com/blog/cublas-strided-batched-matrix-multiply/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants