Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMD] Fix issue with rank=1 in tryFitCvtIntoLDS #5084

Merged
merged 4 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions test/TritonGPU/amd/optimize-lds-usage.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,22 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war
tt.return
}
}

// -----

// Checks that optimization do not crash on 1d tensor
// CHECK-LABEL: convert_1d
// CHECK: triton_gpu.local_alloc
// CHECK-NEXT: triton_gpu.convert_layout
// CHECK-NEXT: triton_gpu.local_load
#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @convert_1d(%arg0: tensor<128xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>>, %arg1: tensor<128x128xf32, #mma>) attributes {noinline = false} {
%alloc = triton_gpu.local_alloc %arg1 : (tensor<128x128xf32, #mma>) -> !tt.memdesc<128x128xf32, #shared, #triton_gpu.shared_memory>
%1 = triton_gpu.convert_layout %arg0 : tensor<128xf32, #triton_gpu.slice<{dim = 0, parent = #mma}>> -> tensor<128xf32, #blocked>
%load = triton_gpu.local_load %alloc : !tt.memdesc<128x128xf32, #shared, #triton_gpu.shared_memory> -> tensor<128x128xf32, #mma>
tt.return
}
}
16 changes: 13 additions & 3 deletions third_party/amd/lib/TritonAMDGPUToLLVM/OptimizeLDSUsage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ class OptimizeAMDLDSUsage
auto dstEnc = dstType.getEncoding();

auto ctx = srcEnc.getContext();
auto rank = srcType.getShape().size();
auto rank = srcType.getRank();

unsigned numWarps = triton::gpu::getNumWarpsPerCTA(srcEnc);
auto warpSize = triton::gpu::getWarpSize(srcEnc);

Expand All @@ -109,11 +110,20 @@ class OptimizeAMDLDSUsage
// Create a list of temporary layouts
SmallVector<unsigned> elemsPerThread(rank, 1);
SmallVector<unsigned> threadsPerWarp(rank, 1);
threadsPerWarp[rank - 1] = warpSize / 8;
threadsPerWarp[rank - 2] = warpSize / threadsPerWarp[rank - 1];

// Special case for rank == 1
if (rank == 1) {
threadsPerWarp[0] = warpSize;
} else {
assert(rank > 1);
threadsPerWarp[rank - 1] = warpSize / 8;
threadsPerWarp[rank - 2] = warpSize / threadsPerWarp[rank - 1];
}

auto layoutCTA = triton::gpu::getCTALayout(srcEnc);
auto order = triton::gpu::getOrder(srcEnc);
SmallVector<unsigned> dummyWarpsPerCTA(rank, 1);

auto baseFallbackLayout = triton::gpu::BlockedEncodingAttr::get(
ctx, elemsPerThread, threadsPerWarp, dummyWarpsPerCTA, order,
layoutCTA);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,13 @@ Attribute createTmpLayout(Attribute layout, ArrayRef<unsigned> warpsPerCTA) {
ctx, src.getOpIdx(), createTmpLayout(src.getParent(), warpsPerCTA),
src.getKWidth());
}
if (auto src = dyn_cast<triton::gpu::SliceEncodingAttr>(layout))
if (auto src = dyn_cast<triton::gpu::SliceEncodingAttr>(layout)) {
// TODO: think of a way to construct slice layouts based on warpsPerCTA
// argument
auto parentWarpsPerCTA = triton::gpu::getWarpsPerCTA(src.getParent());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that for sliced layouts we should use the parent warpsPerCTA (since the rank of the parent is used in later layout conversions)---but would like to confirm this.

This fixes the pass for rank==1. Without this we hit asserts in the LL conversions

Copy link
Contributor

@binarman binarman Nov 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, this code was wrong, because parent layout constructor expects rank + 1 dimensions.

On the other hand, using parent warpsPerCTA defies the idea of the algorithm above, that tries different temporary tensors: using parent layout warps here will return exactly the same layout no mater, what is passed in warpsPerCTA argument.

I can not think of an elegant solution right now, so I think this is a good workaround to make things work.
I would suggest to add a TODO here, like // TODO: think of a way to construct slice layouts based on warpsPerCTA argument.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review!

Yeah I considered a few alternate approaches and understand the dilemma here. I couldn't figure out an obvious solution either.

Will update the PR with the feedback in a sec after local testing.

return triton::gpu::SliceEncodingAttr::get(
ctx, src.getDim(), createTmpLayout(src.getParent(), warpsPerCTA));
ctx, src.getDim(), createTmpLayout(src.getParent(), parentWarpsPerCTA));
}
assert("Encountered unsupported layout");
return Attribute();
}
Expand Down
Loading