Skip to content

Commit

Permalink
[AMD] Disable block merging to avoid block argument explosion (triton…
Browse files Browse the repository at this point in the history
…-lang#4176)

This PR disable block merging when running
`convert-builtin-func-to-llvm`.

The reason behind this is that for now block merging can double the
arguments of the blocks. This means that after a while we can start
witnessing a block argument "explosion" which hangs the compiler.

I am working on this ticket:
llvm/llvm-project#63230 to make block merging
better, but in the meantime, we should stop merging blocks to avoid
compiler hangs.

I added the minimal test to reproduce the explosion. The test for now is
checking that we don't try to merge blocks.
  • Loading branch information
giuseros authored Jun 21, 2024
1 parent 7ca6d12 commit cf2ad02
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 1 deletion.
63 changes: 63 additions & 0 deletions test/Conversion/amd/amd-convert-builtin-func.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// RUN: triton-opt --convert-builtin-func-to-llvm %s | FileCheck %s

// Trying to merge those blocks will cause a lot of duplication in the block arguments, which will cause
// an exponential growth of the argument length. Make sure we don't try to merge those blocks.
module {
llvm.func @rand() -> i1
llvm.func @"__predicated_store_!llvm.void_!llvm.ptr<1>_i32_i1_"(!llvm.ptr<1>, i32, i1) attributes {libname = "", libpath = ""}

llvm.func @top(%arg0: i64, %1 : !llvm.ptr<1>, %2 : !llvm.ptr<1>, %3 : !llvm.ptr<1>, %4 : !llvm.ptr<1>) {
%0 = llvm.mlir.constant(0 : i64) : i64
%10 = llvm.icmp "eq" %arg0, %0 : i64
%true = llvm.mlir.constant(1 : i1) : i1
%c = llvm.mlir.constant(1 : i32) : i32
// CHECK: llvm.cond_br {{.*}}, ^bb{{.*}}, ^bb{{.*}}
llvm.cond_br %10, ^bb1, ^bb14
^bb1: // pred: ^bb0
%11 = llvm.call @rand() : () -> i1
// CHECK: llvm.cond_br {{.*}}, ^bb{{.*}}, ^bb{{.*}}
llvm.cond_br %11, ^bb2, ^bb3
^bb2: // pred: ^bb1
llvm.call @"__predicated_store_!llvm.void_!llvm.ptr<1>_i32_i1_"(%1, %c, %true) : (!llvm.ptr<1>, i32, i1) -> ()
llvm.br ^bb4
^bb3: // pred: ^bb1
llvm.call @"__predicated_store_!llvm.void_!llvm.ptr<1>_i32_i1_"(%2, %c, %true) : (!llvm.ptr<1>, i32, i1) -> ()
llvm.br ^bb4
^bb4: // 2 preds: ^bb2, ^bb3
%14 = llvm.call @rand() : () -> i1
// CHECK: llvm.cond_br {{.*}}, ^bb{{.*}}, ^bb{{.*}}
llvm.cond_br %14, ^bb5, ^bb6
^bb5: // pred: ^bb4
llvm.call @"__predicated_store_!llvm.void_!llvm.ptr<1>_i32_i1_"(%3, %c, %true) : (!llvm.ptr<1>, i32, i1) -> ()
llvm.br ^bb13
^bb6: // pred: ^bb4
llvm.call @"__predicated_store_!llvm.void_!llvm.ptr<1>_i32_i1_"(%4, %c, %true) : (!llvm.ptr<1>, i32, i1) -> ()
llvm.br ^bb13
^bb13: // 2 preds: ^bb11, ^bb12
llvm.br ^bb27
^bb14: // pred: ^bb0
%23 = llvm.call @rand() : () -> i1
// CHECK: llvm.cond_br {{.*}}, ^bb{{.*}}, ^bb{{.*}}
llvm.cond_br %23, ^bb15, ^bb16
^bb15: // pred: ^bb14
llvm.call @"__predicated_store_!llvm.void_!llvm.ptr<1>_i32_i1_"(%4, %c, %true) : (!llvm.ptr<1>, i32, i1) -> ()
llvm.br ^bb17
^bb16: // pred: ^bb14
llvm.call @"__predicated_store_!llvm.void_!llvm.ptr<1>_i32_i1_"(%3, %c, %true) : (!llvm.ptr<1>, i32, i1) -> ()
llvm.br ^bb17
^bb17: // 2 preds: ^bb15, ^bb16
%26 = llvm.call @rand() : () -> i1
// CHECK: llvm.cond_br {{.*}}, ^bb{{.*}}, ^bb{{.*}}
llvm.cond_br %26, ^bb18, ^bb19
^bb18: // pred: ^bb17
llvm.call @"__predicated_store_!llvm.void_!llvm.ptr<1>_i32_i1_"(%2, %c, %true) : (!llvm.ptr<1>, i32, i1) -> ()
llvm.br ^bb26
^bb19: // pred: ^bb17
llvm.call @"__predicated_store_!llvm.void_!llvm.ptr<1>_i32_i1_"(%1, %c, %true) : (!llvm.ptr<1>, i32, i1) -> ()
llvm.br ^bb26
^bb26: // 2 preds: ^bb24, ^bb25
llvm.br ^bb27
^bb27: // 2 preds: ^bb13, ^bb26
llvm.return
}
}
9 changes: 8 additions & 1 deletion third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,17 @@ struct ConvertBuiltinFuncToLLVM
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();

// Disable block merging because of:
// https://github.com/llvm/llvm-project/issues/63230
// TODO(giuseros): enable block merging once the above ticket is completed
GreedyRewriteConfig config;
config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal;

RewritePatternSet patterns(context);
patterns.add<CallOpConversion>(context);

if (mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns)).failed()) {
if (mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns), config)
.failed()) {
signalPassFailure();
}
}
Expand Down

0 comments on commit cf2ad02

Please sign in to comment.