forked from triton-lang/triton
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[BACKEND][NVIDIA] Add DotOp Hoisting Pass for WGMMA and Add Lowering …
…for SMEM-to-MMAv3 DotOp Copy (triton-lang#5003) Hopper has two kinds of WGMMAs, "SS" (both operands in shmem) and "RS" (LHS operand A in registers). In cases where we apply elementwise operations on A before WGMMA, Triton previously will copy A from global memory (GMEM) into registers (RF), perform the elementwise ops, and then copy to shared memory (SMEM) to perform SS WGMMA. This PR adds an optimization for the case above to use RS GEMM. This requires the following changes: - In TritonGPU OptimizeDotOperands pass, add optimizations to change SS GEMM into RS GEMM. - Add TritonGPU -> LLVM lowering for copying from SMEM to RF in MMA v3 dotOperand layout. NOTE: This may not see perf gain, and may even see perf loss, for certain shapes (e.g. small-K), and additional optimizations are in a separate [PR](openxla#19) (still more optimizations are WIP). Please advise on the merging strategy.
- Loading branch information
1 parent
bd483c5
commit 6b092ae
Showing
9 changed files
with
512 additions
and
69 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.