-
Notifications
You must be signed in to change notification settings - Fork 12.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[mlir][nvgpu] Improve
WarpgroupAccumulator
type to simplify IR (#68728
) `WarpgroupAccumulator` (or `!nvgpu.warpgroup.accumulator`) is a type that keeps the accumulator matrix that is used by warp-group level matrix multiplication. It is handy to have a special type for that as the matrix is distributed among the threads of the warp-group. However, current transformations requires to create and use multiple `WarpgroupAccumulator` if the shape of GEMM is larger than the supported shape of `wgmma.mma_async` instruction. This makes IR looks dense. This PR improves the transformation of `WarpgroupAccumulator` type in every nvgpu Op that uses it. **Example: Current GEMM in NVGPU-IR** ``` // Init %m1, %m2 = nvgpu.warpgroup.mma.init.accumulator -> !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> // GEMM %r1, %r2 = nvgpu.warpgroup.mma %descA, %descB, %m1, %m2 {transposeB}: !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> -> !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> // Epilogue nvgpu.warpgroup.mma.store [%r1, %r2] to %sharedMemoryBuffer : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> into memref<128x128xf32,3> ``` **Example: This PR simplifies the IR as below:** ``` // Init %m = nvgpu.warpgroup.mma.init.accumulator -> !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>> // GEMM %r1 = nvgpu.warpgroup.mma %descA, %descB, %m1 {transposeB}: !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>> -> !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>> // Epilogue nvgpu.warpgroup.mma.store [%matrixD1, %matrixD2] to %sharedMemoryBuffer : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> into memref<128x128xf32,3> ```
- Loading branch information
Showing
7 changed files
with
177 additions
and
158 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.