Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMD] Fix issue with rank=1 in tryFitCvtIntoLDS #5084

Merged
merged 4 commits into from
Nov 7, 2024

Conversation

SamGinzburg
Copy link
Contributor

@SamGinzburg SamGinzburg commented Nov 6, 2024

This is a draft for a patch which fixes an assertion being hit when the wrong number of pipeline stages is set.

More discussion at: pytorch/pytorch#139621

============================================================
The core Triton is a small number of people, and we receive many PRs (thank
you!). To help us review your code more quickly, if you are a new
contributor (less than 3 PRs merged) we ask that you complete the following
tasks and include the filled-out checklist in your PR description.

Complete the following tasks before sending your PR, and replace [ ] with
[x] to indicate you have done them.

  • I am not making a trivial change, such as fixing a typo in a comment.

  • I have written a PR description following these
    rules.

  • [ x] I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • [x ] I have added tests.
      • /test for lit tests
      • /unittest for C++ tests
      • /python/test for end-to-end tests
    • This PR does not need a test because FILL THIS IN.
  • Select one of the following.

    • I have not added any lit tests.
    • The lit tests I have added follow these best practices,
      including the "tests should be minimal" section. (Usually running Python code
      and using the instructions it generates is not minimal.)

@SamGinzburg
Copy link
Contributor Author

More specifically, this is the ttgir which appears to triggers this assert:

#blocked3 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>

%162 = triton_gpu.convert_layout %161 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128xf32, #blocked3>
tt.store %159, %162 : tensor<128x!tt.ptr<f32>, #blocked3>

@@ -68,9 +68,11 @@ Attribute createTmpLayout(Attribute layout, ArrayRef<unsigned> warpsPerCTA) {
ctx, src.getOpIdx(), createTmpLayout(src.getParent(), warpsPerCTA),
src.getKWidth());
}
if (auto src = dyn_cast<triton::gpu::SliceEncodingAttr>(layout))
if (auto src = dyn_cast<triton::gpu::SliceEncodingAttr>(layout)) {
auto parentWarpsPerCTA = triton::gpu::getWarpsPerCTA(src.getParent());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that for sliced layouts we should use the parent warpsPerCTA (since the rank of the parent is used in later layout conversions)---but would like to confirm this.

This fixes the pass for rank==1. Without this we hit asserts in the LL conversions

Copy link
Contributor

@binarman binarman Nov 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, this code was wrong, because parent layout constructor expects rank + 1 dimensions.

On the other hand, using parent warpsPerCTA defies the idea of the algorithm above, that tries different temporary tensors: using parent layout warps here will return exactly the same layout no mater, what is passed in warpsPerCTA argument.

I can not think of an elegant solution right now, so I think this is a good workaround to make things work.
I would suggest to add a TODO here, like // TODO: think of a way to construct slice layouts based on warpsPerCTA argument.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review!

Yeah I considered a few alternate approaches and understand the dilemma here. I couldn't figure out an obvious solution either.

Will update the PR with the feedback in a sec after local testing.

@SamGinzburg SamGinzburg force-pushed the PR-AMDLdsOptRankCheck branch from d16e296 to 3a8f97b Compare November 6, 2024 17:02
Copy link
Contributor

@binarman binarman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SamGinzburg hi! Thanks for the patch, can you also add following test to the bottom of optimize-lds-usage.mlir testsuite?

// Checks that optimization do not crash on 1d tensor
// CHECK-LABEL: convert_1d
// CHECK: triton_gpu.local_alloc
// CHECK-NEXT: triton_gpu.convert_layout
// CHECK-NEXT: triton_gpu.local_load
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
  tt.func public @convert_1d(%arg0: tensor<128xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>>, %arg1: tensor<128x128xf32, #mma>) attributes {noinline = false} {
    %alloc = triton_gpu.local_alloc %arg1 : (tensor<128x128xf32, #mma>) -> !tt.memdesc<128x128xf32, #shared, #triton_gpu.shared_memory>
    %1 = triton_gpu.convert_layout %arg0 : tensor<128xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<128xf32, #blocked>
    %load = triton_gpu.local_load %alloc : !tt.memdesc<128x128xf32, #shared, #triton_gpu.shared_memory> -> tensor<128x128xf32, #mma>
    tt.return
  }
}

P.s. Related closed PR: #4651

@@ -68,9 +68,11 @@ Attribute createTmpLayout(Attribute layout, ArrayRef<unsigned> warpsPerCTA) {
ctx, src.getOpIdx(), createTmpLayout(src.getParent(), warpsPerCTA),
src.getKWidth());
}
if (auto src = dyn_cast<triton::gpu::SliceEncodingAttr>(layout))
if (auto src = dyn_cast<triton::gpu::SliceEncodingAttr>(layout)) {
auto parentWarpsPerCTA = triton::gpu::getWarpsPerCTA(src.getParent());
Copy link
Contributor

@binarman binarman Nov 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, this code was wrong, because parent layout constructor expects rank + 1 dimensions.

On the other hand, using parent warpsPerCTA defies the idea of the algorithm above, that tries different temporary tensors: using parent layout warps here will return exactly the same layout no mater, what is passed in warpsPerCTA argument.

I can not think of an elegant solution right now, so I think this is a good workaround to make things work.
I would suggest to add a TODO here, like // TODO: think of a way to construct slice layouts based on warpsPerCTA argument.

@SamGinzburg SamGinzburg force-pushed the PR-AMDLdsOptRankCheck branch from 70d72e8 to 6e49d28 Compare November 7, 2024 00:08
@SamGinzburg SamGinzburg marked this pull request as ready for review November 7, 2024 18:35
Copy link
Contributor

@binarman binarman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@antiagainst antiagainst merged commit 4af6cf5 into triton-lang:main Nov 7, 2024
7 checks passed
Luosuu pushed a commit to Luosuu/triton that referenced this pull request Nov 13, 2024
jataylo pushed a commit to jataylo/triton that referenced this pull request Nov 18, 2024
jataylo pushed a commit to jataylo/triton that referenced this pull request Dec 12, 2024
bertmaher pushed a commit that referenced this pull request Dec 16, 2024
bertmaher added a commit to pytorch/pytorch that referenced this pull request Dec 16, 2024
bertmaher added a commit to pytorch/pytorch that referenced this pull request Dec 16, 2024
* triton-lang/triton#5277
* triton-lang/triton#5084

ghstack-source-id: 1ea811d7562f0cb8d8dd1f65049518db0e61e39a
Pull Request resolved: #143302
jataylo pushed a commit to jataylo/triton that referenced this pull request Dec 18, 2024
bertmaher pushed a commit that referenced this pull request Dec 19, 2024
atalman pushed a commit that referenced this pull request Dec 23, 2024
…tIntoLDS (#5084) (#5453)

Co-authored-by: Sam Ginzburg <[email protected]>
(cherry picked from commit f9cdf58)

Co-authored-by: Samuel Ginzburg <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants