-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AMD] Fix issue with rank=1 in tryFitCvtIntoLDS #5084
[AMD] Fix issue with rank=1 in tryFitCvtIntoLDS #5084
Conversation
More specifically, this is the ttgir which appears to triggers this assert:
|
@@ -68,9 +68,11 @@ Attribute createTmpLayout(Attribute layout, ArrayRef<unsigned> warpsPerCTA) { | |||
ctx, src.getOpIdx(), createTmpLayout(src.getParent(), warpsPerCTA), | |||
src.getKWidth()); | |||
} | |||
if (auto src = dyn_cast<triton::gpu::SliceEncodingAttr>(layout)) | |||
if (auto src = dyn_cast<triton::gpu::SliceEncodingAttr>(layout)) { | |||
auto parentWarpsPerCTA = triton::gpu::getWarpsPerCTA(src.getParent()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that for sliced layouts we should use the parent warpsPerCTA (since the rank of the parent is used in later layout conversions)---but would like to confirm this.
This fixes the pass for rank==1. Without this we hit asserts in the LL conversions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, this code was wrong, because parent layout constructor expects rank + 1 dimensions.
On the other hand, using parent warpsPerCTA
defies the idea of the algorithm above, that tries different temporary tensors: using parent layout warps here will return exactly the same layout no mater, what is passed in warpsPerCTA
argument.
I can not think of an elegant solution right now, so I think this is a good workaround to make things work.
I would suggest to add a TODO here, like // TODO: think of a way to construct slice layouts based on warpsPerCTA argument
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review!
Yeah I considered a few alternate approaches and understand the dilemma here. I couldn't figure out an obvious solution either.
Will update the PR with the feedback in a sec after local testing.
d16e296
to
3a8f97b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SamGinzburg hi! Thanks for the patch, can you also add following test to the bottom of optimize-lds-usage.mlir testsuite?
// Checks that optimization do not crash on 1d tensor
// CHECK-LABEL: convert_1d
// CHECK: triton_gpu.local_alloc
// CHECK-NEXT: triton_gpu.convert_layout
// CHECK-NEXT: triton_gpu.local_load
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @convert_1d(%arg0: tensor<128xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>>, %arg1: tensor<128x128xf32, #mma>) attributes {noinline = false} {
%alloc = triton_gpu.local_alloc %arg1 : (tensor<128x128xf32, #mma>) -> !tt.memdesc<128x128xf32, #shared, #triton_gpu.shared_memory>
%1 = triton_gpu.convert_layout %arg0 : tensor<128xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<128xf32, #blocked>
%load = triton_gpu.local_load %alloc : !tt.memdesc<128x128xf32, #shared, #triton_gpu.shared_memory> -> tensor<128x128xf32, #mma>
tt.return
}
}
P.s. Related closed PR: #4651
@@ -68,9 +68,11 @@ Attribute createTmpLayout(Attribute layout, ArrayRef<unsigned> warpsPerCTA) { | |||
ctx, src.getOpIdx(), createTmpLayout(src.getParent(), warpsPerCTA), | |||
src.getKWidth()); | |||
} | |||
if (auto src = dyn_cast<triton::gpu::SliceEncodingAttr>(layout)) | |||
if (auto src = dyn_cast<triton::gpu::SliceEncodingAttr>(layout)) { | |||
auto parentWarpsPerCTA = triton::gpu::getWarpsPerCTA(src.getParent()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, this code was wrong, because parent layout constructor expects rank + 1 dimensions.
On the other hand, using parent warpsPerCTA
defies the idea of the algorithm above, that tries different temporary tensors: using parent layout warps here will return exactly the same layout no mater, what is passed in warpsPerCTA
argument.
I can not think of an elegant solution right now, so I think this is a good workaround to make things work.
I would suggest to add a TODO here, like // TODO: think of a way to construct slice layouts based on warpsPerCTA argument
.
70d72e8
to
6e49d28
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Co-authored-by: Sam Ginzburg <[email protected]>
Co-authored-by: Sam Ginzburg <[email protected]> (cherry picked from commit 4af6cf5)
Co-authored-by: Sam Ginzburg <[email protected]> (cherry picked from commit 4af6cf5)
Co-authored-by: Sam Ginzburg <[email protected]>
* triton-lang/triton#5277 * triton-lang/triton#5084 ghstack-source-id: 1ea811d7562f0cb8d8dd1f65049518db0e61e39a Pull Request resolved: #143302
Co-authored-by: Sam Ginzburg <[email protected]> (cherry picked from commit f9cdf58)
Co-authored-by: Sam Ginzburg <[email protected]>
…tIntoLDS (#5084) (#5453) Co-authored-by: Sam Ginzburg <[email protected]> (cherry picked from commit f9cdf58) Co-authored-by: Samuel Ginzburg <[email protected]>
This is a draft for a patch which fixes an assertion being hit when the wrong number of pipeline stages is set.
More discussion at: pytorch/pytorch#139621
============================================================
The core Triton is a small number of people, and we receive many PRs (thank
you!). To help us review your code more quickly, if you are a new
contributor (less than 3 PRs merged) we ask that you complete the following
tasks and include the filled-out checklist in your PR description.
Complete the following tasks before sending your PR, and replace
[ ]
with[x]
to indicate you have done them.I am not making a trivial change, such as fixing a typo in a comment.
I have written a PR description following these
rules.
[ x] I have run
pre-commit run --from-ref origin/main --to-ref HEAD
.Select one of the following.
/test
forlit
tests/unittest
for C++ tests/python/test
for end-to-end testsFILL THIS IN
.Select one of the following.
lit
tests.lit
tests I have added follow these best practices,including the "tests should be minimal" section. (Usually running Python code
and using the instructions it generates is not minimal.)