From cf2ad02324fc253970c3ab2666e775406405f213 Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Fri, 21 Jun 2024 16:08:08 +0100 Subject: [PATCH] [AMD] Disable block merging to avoid block argument explosion (#4176) This PR disable block merging when running `convert-builtin-func-to-llvm`. The reason behind this is that for now block merging can double the arguments of the blocks. This means that after a while we can start witnessing a block argument "explosion" which hangs the compiler. I am working on this ticket: https://github.com/llvm/llvm-project/issues/63230 to make block merging better, but in the meantime, we should stop merging blocks to avoid compiler hangs. I added the minimal test to reproduce the explosion. The test for now is checking that we don't try to merge blocks. --- .../amd/amd-convert-builtin-func.mlir | 63 +++++++++++++++++++ .../TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp | 9 ++- 2 files changed, 71 insertions(+), 1 deletion(-) create mode 100644 test/Conversion/amd/amd-convert-builtin-func.mlir diff --git a/test/Conversion/amd/amd-convert-builtin-func.mlir b/test/Conversion/amd/amd-convert-builtin-func.mlir new file mode 100644 index 000000000000..9df0059a52a5 --- /dev/null +++ b/test/Conversion/amd/amd-convert-builtin-func.mlir @@ -0,0 +1,63 @@ +// RUN: triton-opt --convert-builtin-func-to-llvm %s | FileCheck %s + +// Trying to merge those blocks will cause a lot of duplication in the block arguments, which will cause +// an exponential growth of the argument length. Make sure we don't try to merge those blocks. +module { + llvm.func @rand() -> i1 + llvm.func @"__predicated_store_!llvm.void_!llvm.ptr<1>_i32_i1_"(!llvm.ptr<1>, i32, i1) attributes {libname = "", libpath = ""} + + llvm.func @top(%arg0: i64, %1 : !llvm.ptr<1>, %2 : !llvm.ptr<1>, %3 : !llvm.ptr<1>, %4 : !llvm.ptr<1>) { + %0 = llvm.mlir.constant(0 : i64) : i64 + %10 = llvm.icmp "eq" %arg0, %0 : i64 + %true = llvm.mlir.constant(1 : i1) : i1 + %c = llvm.mlir.constant(1 : i32) : i32 + // CHECK: llvm.cond_br {{.*}}, ^bb{{.*}}, ^bb{{.*}} + llvm.cond_br %10, ^bb1, ^bb14 + ^bb1: // pred: ^bb0 + %11 = llvm.call @rand() : () -> i1 + // CHECK: llvm.cond_br {{.*}}, ^bb{{.*}}, ^bb{{.*}} + llvm.cond_br %11, ^bb2, ^bb3 + ^bb2: // pred: ^bb1 + llvm.call @"__predicated_store_!llvm.void_!llvm.ptr<1>_i32_i1_"(%1, %c, %true) : (!llvm.ptr<1>, i32, i1) -> () + llvm.br ^bb4 + ^bb3: // pred: ^bb1 + llvm.call @"__predicated_store_!llvm.void_!llvm.ptr<1>_i32_i1_"(%2, %c, %true) : (!llvm.ptr<1>, i32, i1) -> () + llvm.br ^bb4 + ^bb4: // 2 preds: ^bb2, ^bb3 + %14 = llvm.call @rand() : () -> i1 + // CHECK: llvm.cond_br {{.*}}, ^bb{{.*}}, ^bb{{.*}} + llvm.cond_br %14, ^bb5, ^bb6 + ^bb5: // pred: ^bb4 + llvm.call @"__predicated_store_!llvm.void_!llvm.ptr<1>_i32_i1_"(%3, %c, %true) : (!llvm.ptr<1>, i32, i1) -> () + llvm.br ^bb13 + ^bb6: // pred: ^bb4 + llvm.call @"__predicated_store_!llvm.void_!llvm.ptr<1>_i32_i1_"(%4, %c, %true) : (!llvm.ptr<1>, i32, i1) -> () + llvm.br ^bb13 + ^bb13: // 2 preds: ^bb11, ^bb12 + llvm.br ^bb27 + ^bb14: // pred: ^bb0 + %23 = llvm.call @rand() : () -> i1 + // CHECK: llvm.cond_br {{.*}}, ^bb{{.*}}, ^bb{{.*}} + llvm.cond_br %23, ^bb15, ^bb16 + ^bb15: // pred: ^bb14 + llvm.call @"__predicated_store_!llvm.void_!llvm.ptr<1>_i32_i1_"(%4, %c, %true) : (!llvm.ptr<1>, i32, i1) -> () + llvm.br ^bb17 + ^bb16: // pred: ^bb14 + llvm.call @"__predicated_store_!llvm.void_!llvm.ptr<1>_i32_i1_"(%3, %c, %true) : (!llvm.ptr<1>, i32, i1) -> () + llvm.br ^bb17 + ^bb17: // 2 preds: ^bb15, ^bb16 + %26 = llvm.call @rand() : () -> i1 + // CHECK: llvm.cond_br {{.*}}, ^bb{{.*}}, ^bb{{.*}} + llvm.cond_br %26, ^bb18, ^bb19 + ^bb18: // pred: ^bb17 + llvm.call @"__predicated_store_!llvm.void_!llvm.ptr<1>_i32_i1_"(%2, %c, %true) : (!llvm.ptr<1>, i32, i1) -> () + llvm.br ^bb26 + ^bb19: // pred: ^bb17 + llvm.call @"__predicated_store_!llvm.void_!llvm.ptr<1>_i32_i1_"(%1, %c, %true) : (!llvm.ptr<1>, i32, i1) -> () + llvm.br ^bb26 + ^bb26: // 2 preds: ^bb24, ^bb25 + llvm.br ^bb27 + ^bb27: // 2 preds: ^bb13, ^bb26 + llvm.return + } +} diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp index 9bf8e89d3ef1..b46b0fc82741 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp @@ -165,10 +165,17 @@ struct ConvertBuiltinFuncToLLVM MLIRContext *context = &getContext(); ModuleOp mod = getOperation(); + // Disable block merging because of: + // https://github.com/llvm/llvm-project/issues/63230 + // TODO(giuseros): enable block merging once the above ticket is completed + GreedyRewriteConfig config; + config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal; + RewritePatternSet patterns(context); patterns.add(context); - if (mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns)).failed()) { + if (mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns), config) + .failed()) { signalPassFailure(); } }