Skip to content

Commit

Permalink
address comments and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
SamGinzburg committed Nov 6, 2024
1 parent f1b904c commit 70d72e8
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 2 deletions.
2 changes: 2 additions & 0 deletions test/TritonGPU/amd/optimize-lds-usage.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-war
}
}

// -----

// Checks that optimization do not crash on 1d tensor
// CHECK-LABEL: convert_1d
// CHECK: triton_gpu.local_alloc
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class OptimizeAMDLDSUsage

// Special case for rank == 1
if (rank == 1) {
threadsPerWarp[0] = warpSize;
threadsPerWarp[0] = warpSize / 8;
} else {
assert(rank > 1);
threadsPerWarp[rank - 1] = warpSize / 8;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ Attribute createTmpLayout(Attribute layout, ArrayRef<unsigned> warpsPerCTA) {
src.getKWidth());
}
if (auto src = dyn_cast<triton::gpu::SliceEncodingAttr>(layout)) {
// TODO: think of a way to construct slice layouts based on warpsPerCTA argument
// TODO: think of a way to construct slice layouts based on warpsPerCTA
// argument
auto parentWarpsPerCTA = triton::gpu::getWarpsPerCTA(src.getParent());
return triton::gpu::SliceEncodingAttr::get(
ctx, src.getDim(), createTmpLayout(src.getParent(), parentWarpsPerCTA));
Expand Down

0 comments on commit 70d72e8

Please sign in to comment.