-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AMD] Reland sinking the 2nd tt.load after local_load's #4935
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
|
zhanglx13
force-pushed
the
robust_sched_load
branch
3 times, most recently
from
October 29, 2024 02:13
c341a5f
to
748b853
Compare
zhanglx13
force-pushed
the
robust_sched_load
branch
2 times, most recently
from
October 29, 2024 03:47
2d6cd3d
to
d0485f1
Compare
antiagainst
requested changes
Oct 29, 2024
third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp
Outdated
Show resolved
Hide resolved
third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp
Outdated
Show resolved
Hide resolved
antiagainst
approved these changes
Oct 29, 2024
antiagainst
changed the title
[AMD] Robust sched-2nd-load
[AMD] Reland sinking the 2nd tt.load after local_load's
Oct 29, 2024
This helps backend to interleave global load and mfma instructions and can reduce global load issue latency.
There are two issues with this function: 1. It returns true when there is no for loop 2. It cannot detect for loop when there is any. The 2nd bullet is because "getOps<OpTy>() is useful to iterate on some Operations immediately listed inside a single block (or a single region)", therefore, moduleOp.getOps<scf::ForOp> will always return nothing. Instead, we use walker here to find scf.forOp in a nested fashion.
This test expects scheduleGlobalLoadLocalStore to move load earlier. However, we only apply scheduleGlobalLoadLocalStore for pureMatmulProblem. And this test is too long. If we decide to apply the optimization for attn like kernel in the future, we will add a smaller test.
zhanglx13
force-pushed
the
robust_sched_load
branch
from
October 30, 2024 16:42
f9501b7
to
ca22491
Compare
Luosuu
pushed a commit
to Luosuu/triton
that referenced
this pull request
Nov 13, 2024
…4935) This PR adds more restrictions about when should we apply the sched-load optimizations and un-revert triton-lang#4823. We will only apply the optimization when all of the following is satisfied: 1. pureMatmulProblem, i.e. 1 `tt.dot` in the main loop 2. two `tt.load`s in the main loop 3. 2nd `tt.load` is ahead of the `tt.dot` 4. 1st user of 2nd `tt.load` is after the `tt.dot` 5. tile size is large enough, i.e. nonKDim >= 128 and kDim >= 64
guacamoleo
pushed a commit
to guacamoleo/triton
that referenced
this pull request
Nov 14, 2024
…4935) This PR adds more restrictions about when should we apply the sched-load optimizations and un-revert triton-lang#4823. We will only apply the optimization when all of the following is satisfied: 1. pureMatmulProblem, i.e. 1 `tt.dot` in the main loop 2. two `tt.load`s in the main loop 3. 2nd `tt.load` is ahead of the `tt.dot` 4. 1st user of 2nd `tt.load` is after the `tt.dot` 5. tile size is large enough, i.e. nonKDim >= 128 and kDim >= 64
jataylo
pushed a commit
to jataylo/triton
that referenced
this pull request
Nov 18, 2024
…4935) This PR adds more restrictions about when should we apply the sched-load optimizations and un-revert triton-lang#4823. We will only apply the optimization when all of the following is satisfied: 1. pureMatmulProblem, i.e. 1 `tt.dot` in the main loop 2. two `tt.load`s in the main loop 3. 2nd `tt.load` is ahead of the `tt.dot` 4. 1st user of 2nd `tt.load` is after the `tt.dot` 5. tile size is large enough, i.e. nonKDim >= 128 and kDim >= 64 (cherry picked from commit 4f6f768)
jataylo
pushed a commit
to jataylo/triton
that referenced
this pull request
Nov 18, 2024
…4935) This PR adds more restrictions about when should we apply the sched-load optimizations and un-revert triton-lang#4823. We will only apply the optimization when all of the following is satisfied: 1. pureMatmulProblem, i.e. 1 `tt.dot` in the main loop 2. two `tt.load`s in the main loop 3. 2nd `tt.load` is ahead of the `tt.dot` 4. 1st user of 2nd `tt.load` is after the `tt.dot` 5. tile size is large enough, i.e. nonKDim >= 128 and kDim >= 64 (cherry picked from commit 4f6f768)
antiagainst
added a commit
that referenced
this pull request
Dec 4, 2024
Cherry pick list: - #4925 - #5053 - #5019 - #5002 - #4935 - required additional cherry picks #4991 and #4951 - #4998 - #4925 - #5281 - #5308 - All previous LLVM hash PRs before #5308 --------- Co-authored-by: Ilya V <[email protected]> Co-authored-by: Lei Zhang <[email protected]> Co-authored-by: Lixun Zhang <[email protected]> Co-authored-by: Keren Zhou <[email protected]> Co-authored-by: Alexander Efimov <[email protected]> Co-authored-by: Kyle Wang <[email protected]> Co-authored-by: Jungwook Park <[email protected]> Co-authored-by: peterbell10 <[email protected]> Co-authored-by: Hongtao Yu <[email protected]>
jataylo
pushed a commit
to jataylo/triton
that referenced
this pull request
Dec 12, 2024
…4935) This PR adds more restrictions about when should we apply the sched-load optimizations and un-revert triton-lang#4823. We will only apply the optimization when all of the following is satisfied: 1. pureMatmulProblem, i.e. 1 `tt.dot` in the main loop 2. two `tt.load`s in the main loop 3. 2nd `tt.load` is ahead of the `tt.dot` 4. 1st user of 2nd `tt.load` is after the `tt.dot` 5. tile size is large enough, i.e. nonKDim >= 128 and kDim >= 64 (cherry picked from commit 4f6f768)
jataylo
pushed a commit
to jataylo/triton
that referenced
this pull request
Dec 13, 2024
…4935) This PR adds more restrictions about when should we apply the sched-load optimizations and un-revert triton-lang#4823. We will only apply the optimization when all of the following is satisfied: 1. pureMatmulProblem, i.e. 1 `tt.dot` in the main loop 2. two `tt.load`s in the main loop 3. 2nd `tt.load` is ahead of the `tt.dot` 4. 1st user of 2nd `tt.load` is after the `tt.dot` 5. tile size is large enough, i.e. nonKDim >= 128 and kDim >= 64 (cherry picked from commit 4f6f768)
jataylo
added a commit
to ROCm/triton
that referenced
this pull request
Dec 13, 2024
* [AMD] Emit vectorized 16-bit float LLVM atomic ops (triton-lang#4925) In the case of 16 bit floats operands for tt::AtomicRMWOp, construct only one LLVM::AtomicRMWOp but use vector of elements. Such approach allows to generate packed intrinsics and process 2 elements at once. Added a lit test for f16 vectorized case. (cherry picked from commit 78c8054) * [AMD] Restructure ReorderInstructions pass (triton-lang#4998) (cherry picked from commit 86a2ac7) * [AMD] Support warp-level reduction with DPP (triton-lang#5019) This commit adds support for warp-level reduction with DPP instructions, which can improve performance. See https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/ (cherry picked from commit 21119e3) * [AMD] Add missing dependency to TritonAMDGPUIR (triton-lang#5053) TritonAMDGPUTransforms now depends on it. (cherry picked from commit 0b443ce) * [AMD] Support warp-level reduction with DPP (triton-lang#5019) This commit adds support for warp-level reduction with DPP instructions, which can improve performance. See https://gpuopen.com/learn/amd-gcn-assembly-cross-lane-operations/ (cherry picked from commit 21119e3) * [AMD] Use DPP to accelerate 16-bit floats (triton-lang#5072) In the case of unpaired f16 elements utilize DPP instructions to accelerate atomics. Here is an algorithm of lowering `tt::atomicRmwOp(%ptr, %val, %mask)`: 0. Group thread by pairs. Master thread is (tid % 2 == 0); 1. All the threads send `%val` to `(tid - 1)` thread via `dppUpdateOp shl`, so all the masters recieve value from secondary threads; 2. Take into account parity in the `%mask` value, build CF structures according to it; 3. Generate `llvm::atomicRmwOp` in the threads enabled by `%mask` value; 4. All the threads send result of generated operation to `(tid + 1)` thread via `dppUpdateOp shl`, so all secondary thread also recieve their result. DPP approach has ~5% perf improvment so use this one in the case target arch supports DPP. Signed-off-by: Ilya Veselov <[email protected]> (cherry picked from commit bab3470) * [AMD] Reland sinking the 2nd tt.load after local_load's (triton-lang#4935) This PR adds more restrictions about when should we apply the sched-load optimizations and un-revert triton-lang#4823. We will only apply the optimization when all of the following is satisfied: 1. pureMatmulProblem, i.e. 1 `tt.dot` in the main loop 2. two `tt.load`s in the main loop 3. 2nd `tt.load` is ahead of the `tt.dot` 4. 1st user of 2nd `tt.load` is after the `tt.dot` 5. tile size is large enough, i.e. nonKDim >= 128 and kDim >= 64 (cherry picked from commit 4f6f768) --------- Co-authored-by: Ilya V <[email protected]> Co-authored-by: Lei Zhang <[email protected]> Co-authored-by: Kyle Wang <[email protected]> Co-authored-by: Lixun Zhang <[email protected]>
jataylo
added a commit
to jataylo/triton
that referenced
this pull request
Dec 18, 2024
Cherry pick list: - triton-lang#4925 - triton-lang#5053 - triton-lang#5019 - triton-lang#5002 - triton-lang#4935 - required additional cherry picks triton-lang#4991 and triton-lang#4951 - triton-lang#4998 - triton-lang#4925 - triton-lang#5281 - triton-lang#5308 - All previous LLVM hash PRs before triton-lang#5308 --------- Co-authored-by: Ilya V <[email protected]> Co-authored-by: Lei Zhang <[email protected]> Co-authored-by: Lixun Zhang <[email protected]> Co-authored-by: Keren Zhou <[email protected]> Co-authored-by: Alexander Efimov <[email protected]> Co-authored-by: Kyle Wang <[email protected]> Co-authored-by: Jungwook Park <[email protected]> Co-authored-by: peterbell10 <[email protected]> Co-authored-by: Hongtao Yu <[email protected]> (cherry picked from commit 2d8093c)
bertmaher
pushed a commit
that referenced
this pull request
Dec 19, 2024
Cherry pick list: - #4925 - #5053 - #5019 - #5002 - #4935 - required additional cherry picks #4991 and #4951 - #4998 - #4925 - #5281 - #5308 - All previous LLVM hash PRs before #5308 --------- Co-authored-by: Ilya V <[email protected]> Co-authored-by: Lei Zhang <[email protected]> Co-authored-by: Lixun Zhang <[email protected]> Co-authored-by: Keren Zhou <[email protected]> Co-authored-by: Alexander Efimov <[email protected]> Co-authored-by: Kyle Wang <[email protected]> Co-authored-by: Jungwook Park <[email protected]> Co-authored-by: peterbell10 <[email protected]> Co-authored-by: Hongtao Yu <[email protected]>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR adds more restrictions about when should we apply the sched-load optimizations and un-revert #4823.
We will only apply the optimization when all of the following is satisfied:
tt.dot
in the main looptt.load
s in the main looptt.load
is ahead of thett.dot
tt.load
is after thett.dot