Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] Fix block merging #102038

Merged
merged 9 commits into from
Aug 7, 2024
Merged

[mlir] Fix block merging #102038

merged 9 commits into from
Aug 7, 2024

Conversation

giuseros
Copy link
Contributor

@giuseros giuseros commented Aug 5, 2024

With this PR I am trying to address: #63230.

What changed:

  • While merging identical blocks, don't add a block argument if it is "identical" to another block argument. I.e., if the two block arguments refer to the same Value. The operations operands in the block will point to the argument we already inserted. This needs to happen to all the arguments we pass to the different successors of the parent block
  • After merged the blocks, get rid of "unnecessary" arguments. I.e., if all the predecessors pass the same block argument, there is no need to pass it as an argument.
  • This last simplification clashed with BufferDeallocationSimplification. The reason, I think, is that the two simplifications are clashing. I.e., BufferDeallocationSimplification contains an analysis based on the block structure. If we simplify the block structure (by merging and/or dropping block arguments) the analysis is invalid . The solution I found is to do a more prudent simplification when running that pass.

Note-1: I ran all the integration tests (-DMLIR_INCLUDE_INTEGRATION_TESTS=ON) and they passed.
Note-2: I fixed a bug found by @Dinistro in #97697 . The issue was that, when looking for redundant arguments, I was not considering that the block might have already some arguments. So the index (in the block args list) of the i-th newArgument is i+numOfOldArguments.

@giuseros giuseros requested review from Dinistro and removed request for matthias-springer August 5, 2024 18:59
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:linalg mlir mlir:bufferization Bufferization infrastructure labels Aug 5, 2024
@giuseros giuseros requested review from krzysz00 and Mogball August 5, 2024 18:59
@llvmbot
Copy link
Member

llvmbot commented Aug 5, 2024

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-bufferization

@llvm/pr-subscribers-mlir-core

Author: Giuseppe Rossini (giuseros)

Changes

With this PR I am trying to address: #63230.

What changed:

  • While merging identical blocks, don't add a block argument if it is "identical" to another block argument. I.e., if the two block arguments refer to the same Value. The operations operands in the block will point to the argument we already inserted. This needs to happen to all the arguments we pass to the different successors of the parent block
  • After merged the blocks, get rid of "unnecessary" arguments. I.e., if all the predecessors pass the same block argument, there is no need to pass it as an argument.
  • This last simplification clashed with BufferDeallocationSimplification. The reason, I think, is that the two simplifications are clashing. I.e., BufferDeallocationSimplification contains an analysis based on the block structure. If we simplify the block structure (by merging and/or dropping block arguments) the analysis is invalid . The solution I found is to do a more prudent simplification when running that pass.

Note-1: I ran all the integration tests (-DMLIR_INCLUDE_INTEGRATION_TESTS=ON) and they passed.
Note-2: I fixed a bug found by @Dinistro in #97697 . The issue was that, when looking for redundant arguments, I was not considering that the block might have already some arguments. So the index (in the block args list) of the i-th newArgument is i+numOfOldArguments.


Patch is 34.79 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/102038.diff

12 Files Affected:

  • (modified) mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp (+7-2)
  • (modified) mlir/lib/Transforms/Utils/RegionUtils.cpp (+208-2)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir (+12-8)
  • (modified) mlir/test/Dialect/Linalg/detensorize_entry_block.mlir (+3-3)
  • (modified) mlir/test/Dialect/Linalg/detensorize_if.mlir (+29-38)
  • (modified) mlir/test/Dialect/Linalg/detensorize_while.mlir (+6-6)
  • (modified) mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir (+6-6)
  • (modified) mlir/test/Dialect/Linalg/detensorize_while_pure_cf.mlir (+2-2)
  • (modified) mlir/test/Transforms/canonicalize-block-merge.mlir (+3-3)
  • (modified) mlir/test/Transforms/canonicalize-dce.mlir (+4-4)
  • (modified) mlir/test/Transforms/make-isolated-from-above.mlir (+9-9)
  • (added) mlir/test/Transforms/test-canonicalize-merge-large-blocks.mlir (+192)
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
index 954485cfede3d..5227b22653eef 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp
@@ -463,10 +463,15 @@ struct BufferDeallocationSimplificationPass
                  SplitDeallocWhenNotAliasingAnyOther,
                  RetainedMemrefAliasingAlwaysDeallocatedMemref>(&getContext(),
                                                                 analysis);
+    // We don't want that the block structure changes invalidating the
+    // `BufferOriginAnalysis` so we apply the rewrites witha `Normal` level of
+    // region simplification
+    GreedyRewriteConfig config;
+    config.enableRegionSimplification = GreedySimplifyRegionLevel::Normal;
     populateDeallocOpCanonicalizationPatterns(patterns, &getContext());
 
-    if (failed(
-            applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
+    if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
+                                            config)))
       signalPassFailure();
   }
 };
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 4c0f15bafbaba..3e15018bdb765 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -9,6 +9,7 @@
 #include "mlir/Transforms/RegionUtils.h"
 #include "mlir/Analysis/TopologicalSortUtils.h"
 #include "mlir/IR/Block.h"
+#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
@@ -16,11 +17,15 @@
 #include "mlir/IR/Value.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "mlir/Support/LogicalResult.h"
 
 #include "llvm/ADT/DepthFirstIterator.h"
 #include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallSet.h"
 
 #include <deque>
+#include <iterator>
 
 using namespace mlir;
 
@@ -674,6 +679,95 @@ static bool ableToUpdatePredOperands(Block *block) {
   return true;
 }
 
+/// Prunes the redundant list of new arguments. E.g., if we are passing an
+/// argument list like [x, y, z, x] this would return [x, y, z] and it would
+/// update the `block` (to whom the argument are passed to) accordingly. The new
+/// arguments are passed as arguments at the back of the block, hence we need to
+/// know how many `numOldArguments` were before, in order to correctly replace
+/// the new arguments in the block
+static SmallVector<SmallVector<Value, 8>, 2> pruneRedundantArguments(
+    const SmallVector<SmallVector<Value, 8>, 2> &newArguments,
+    RewriterBase &rewriter, unsigned numOldArguments, Block *block) {
+
+  SmallVector<SmallVector<Value, 8>, 2> newArgumentsPruned(
+      newArguments.size(), SmallVector<Value, 8>());
+
+  if (newArguments.empty())
+    return newArguments;
+
+  // `newArguments` is a 2D array of size `numLists` x `numArgs`
+  unsigned numLists = newArguments.size();
+  unsigned numArgs = newArguments[0].size();
+
+  // Map that for each arg index contains the index that we can use in place of
+  // the original index. E.g., if we have newArgs = [x, y, z, x], we will have
+  // idxToReplacement[3] = 0
+  llvm::DenseMap<unsigned, unsigned> idxToReplacement;
+
+  // This is a useful data structure to track the first appearance of a Value
+  // on a given list of arguments
+  DenseMap<Value, unsigned> firstValueToIdx;
+  for (unsigned j = 0; j < numArgs; ++j) {
+    Value newArg = newArguments[0][j];
+    if (!firstValueToIdx.contains(newArg))
+      firstValueToIdx[newArg] = j;
+  }
+
+  // Go through the first list of arguments (list 0).
+  for (unsigned j = 0; j < numArgs; ++j) {
+    bool shouldReplaceJ = false;
+    unsigned replacement = 0;
+    // Look back to see if there are possible redundancies in list 0. Please
+    // note that we are using a map to annotate when an argument was seen first
+    // to avoid a O(N^2) algorithm. This has the drawback that if we have two
+    // lists like:
+    // list0: [%a, %a, %a]
+    // list1: [%c, %b, %b]
+    // We cannot simplify it, because firstVlaueToIdx[%a] = 0, but we cannot
+    // point list1[1](==%b) or list1[2](==%b) to list1[0](==%c).  However, since
+    // the number of arguments can be potentially unbounded we cannot afford a
+    // O(N^2) algorithm (to search to all the possible pairs) and we need to
+    // accept the trade-off.
+    unsigned k = firstValueToIdx[newArguments[0][j]];
+    if (k != j) {
+      shouldReplaceJ = true;
+      replacement = k;
+      // If a possible redundancy is found, then scan the other lists: we
+      // can prune the arguments if and only if they are redundant in every
+      // list.
+      for (unsigned i = 1; i < numLists; ++i)
+        shouldReplaceJ =
+            shouldReplaceJ && (newArguments[i][k] == newArguments[i][j]);
+    }
+    // Save the replacement.
+    if (shouldReplaceJ)
+      idxToReplacement[j] = replacement;
+  }
+
+  // Populate the pruned argument list.
+  for (unsigned i = 0; i < numLists; ++i)
+    for (unsigned j = 0; j < numArgs; ++j)
+      if (!idxToReplacement.contains(j))
+        newArgumentsPruned[i].push_back(newArguments[i][j]);
+
+  // Replace the block's redundant arguments.
+  SmallVector<unsigned> toErase;
+  for (auto [idx, arg] : llvm::enumerate(block->getArguments())) {
+    if (idxToReplacement.contains(idx)) {
+      Value oldArg = block->getArgument(numOldArguments + idx);
+      Value newArg =
+          block->getArgument(numOldArguments + idxToReplacement[idx]);
+      rewriter.replaceAllUsesWith(oldArg, newArg);
+      toErase.push_back(numOldArguments + idx);
+    }
+  }
+
+  // Erase the block's redundant arguments.
+  for (unsigned idxToErase : llvm::reverse(toErase))
+    block->eraseArgument(idxToErase);
+  return newArgumentsPruned;
+}
+
 LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
   // Don't consider clusters that don't have blocks to merge.
   if (blocksToMerge.empty())
@@ -703,6 +797,7 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
         1 + blocksToMerge.size(),
         SmallVector<Value, 8>(operandsToMerge.size()));
     unsigned curOpIndex = 0;
+    unsigned numOldArguments = leaderBlock->getNumArguments();
     for (const auto &it : llvm::enumerate(operandsToMerge)) {
       unsigned nextOpOffset = it.value().first - curOpIndex;
       curOpIndex = it.value().first;
@@ -722,6 +817,11 @@ LogicalResult BlockMergeCluster::merge(RewriterBase &rewriter) {
         }
       }
     }
+
+    // Prune redundant arguments and update the leader block argument list
+    newArguments = pruneRedundantArguments(newArguments, rewriter,
+                                           numOldArguments, leaderBlock);
+
     // Update the predecessors for each of the blocks.
     auto updatePredecessors = [&](Block *block, unsigned clusterIndex) {
       for (auto predIt = block->pred_begin(), predE = block->pred_end();
@@ -818,6 +918,108 @@ static LogicalResult mergeIdenticalBlocks(RewriterBase &rewriter,
   return success(anyChanged);
 }
 
+static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
+                                            Block &block) {
+  SmallVector<size_t> argsToErase;
+
+  // Go through the arguments of the block.
+  for (auto [argIdx, blockOperand] : llvm::enumerate(block.getArguments())) {
+    bool sameArg = true;
+    Value commonValue;
+
+    // Go through the block predecessor and flag if they pass to the block
+    // different values for the same argument.
+    for (auto predIt = block.pred_begin(), predE = block.pred_end();
+         predIt != predE; ++predIt) {
+      auto branch = dyn_cast<BranchOpInterface>((*predIt)->getTerminator());
+      if (!branch) {
+        sameArg = false;
+        break;
+      }
+      unsigned succIndex = predIt.getSuccessorIndex();
+      SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
+      auto branchOperands = succOperands.getForwardedOperands();
+      if (!commonValue) {
+        commonValue = branchOperands[argIdx];
+      } else {
+        if (branchOperands[argIdx] != commonValue) {
+          sameArg = false;
+          break;
+        }
+      }
+    }
+
+    // If they are passing the same value, drop the argument.
+    if (commonValue && sameArg) {
+      argsToErase.push_back(argIdx);
+
+      // Remove the argument from the block.
+      rewriter.replaceAllUsesWith(blockOperand, commonValue);
+    }
+  }
+
+  // Remove the arguments.
+  for (auto argIdx : llvm::reverse(argsToErase)) {
+    block.eraseArgument(argIdx);
+
+    // Remove the argument from the branch ops.
+    for (auto predIt = block.pred_begin(), predE = block.pred_end();
+         predIt != predE; ++predIt) {
+      auto branch = cast<BranchOpInterface>((*predIt)->getTerminator());
+      unsigned succIndex = predIt.getSuccessorIndex();
+      SuccessorOperands succOperands = branch.getSuccessorOperands(succIndex);
+      succOperands.erase(argIdx);
+    }
+  }
+  return success(!argsToErase.empty());
+}
+
+/// This optimization drops redundant argument to blocks. I.e., if a given
+/// argument to a block receives the same value from each of the block
+/// predecessors, we can remove the argument from the block and use directly the
+/// original value. This is a simple example:
+///
+/// %cond = llvm.call @rand() : () -> i1
+/// %val0 = llvm.mlir.constant(1 : i64) : i64
+/// %val1 = llvm.mlir.constant(2 : i64) : i64
+/// %val2 = llvm.mlir.constant(3 : i64) : i64
+/// llvm.cond_br %cond, ^bb1(%val0 : i64, %val1 : i64), ^bb2(%val0 : i64, %val2
+/// : i64)
+///
+/// ^bb1(%arg0 : i64, %arg1 : i64):
+///    llvm.call @foo(%arg0, %arg1)
+///
+/// The previous IR can be rewritten as:
+/// %cond = llvm.call @rand() : () -> i1
+/// %val0 = llvm.mlir.constant(1 : i64) : i64
+/// %val1 = llvm.mlir.constant(2 : i64) : i64
+/// %val2 = llvm.mlir.constant(3 : i64) : i64
+/// llvm.cond_br %cond, ^bb1(%val1 : i64), ^bb2(%val2 : i64)
+///
+/// ^bb1(%arg0 : i64):
+///    llvm.call @foo(%val0, %arg0)
+///
+static LogicalResult dropRedundantArguments(RewriterBase &rewriter,
+                                            MutableArrayRef<Region> regions) {
+  llvm::SmallSetVector<Region *, 1> worklist;
+  for (Region &region : regions)
+    worklist.insert(&region);
+  bool anyChanged = false;
+  while (!worklist.empty()) {
+    Region *region = worklist.pop_back_val();
+
+    // Add any nested regions to the worklist.
+    for (Block &block : *region) {
+      anyChanged = succeeded(dropRedundantArguments(rewriter, block));
+
+      for (Operation &op : block)
+        for (Region &nestedRegion : op.getRegions())
+          worklist.insert(&nestedRegion);
+    }
+  }
+  return success(anyChanged);
+}
+
 //===----------------------------------------------------------------------===//
 // Region Simplification
 //===----------------------------------------------------------------------===//
@@ -832,8 +1034,12 @@ LogicalResult mlir::simplifyRegions(RewriterBase &rewriter,
   bool eliminatedBlocks = succeeded(eraseUnreachableBlocks(rewriter, regions));
   bool eliminatedOpsOrArgs = succeeded(runRegionDCE(rewriter, regions));
   bool mergedIdenticalBlocks = false;
-  if (mergeBlocks)
+  bool droppedRedundantArguments = false;
+  if (mergeBlocks) {
     mergedIdenticalBlocks = succeeded(mergeIdenticalBlocks(rewriter, regions));
+    droppedRedundantArguments =
+        succeeded(dropRedundantArguments(rewriter, regions));
+  }
   return success(eliminatedBlocks || eliminatedOpsOrArgs ||
-                 mergedIdenticalBlocks);
+                 mergedIdenticalBlocks || droppedRedundantArguments);
 }
diff --git a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
index 5e8104f83cc4d..8e14990502143 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation/dealloc-branchop-interface.mlir
@@ -178,7 +178,7 @@ func.func @condBranchDynamicTypeNested(
 //  CHECK-NEXT: ^bb1
 //   CHECK-NOT: bufferization.dealloc
 //   CHECK-NOT: bufferization.clone
-//       CHECK: cf.br ^bb5([[ARG1]], %false{{[0-9_]*}} :
+//       CHECK: cf.br ^bb6([[ARG1]], %false{{[0-9_]*}} :
 //       CHECK: ^bb2([[IDX:%.*]]:{{.*}})
 //       CHECK: [[ALLOC1:%.*]] = memref.alloc([[IDX]])
 //  CHECK-NEXT: test.buffer_based
@@ -186,20 +186,24 @@ func.func @condBranchDynamicTypeNested(
 //  CHECK-NEXT: [[OWN:%.+]] = arith.select [[ARG0]], [[ARG0]], [[NOT_ARG0]]
 //   CHECK-NOT: bufferization.dealloc
 //   CHECK-NOT: bufferization.clone
-//       CHECK: cf.cond_br{{.*}}, ^bb3, ^bb3
+//       CHECK: cf.cond_br{{.*}}, ^bb3, ^bb4
 //  CHECK-NEXT: ^bb3:
 //   CHECK-NOT: bufferization.dealloc
 //   CHECK-NOT: bufferization.clone
-//       CHECK: cf.br ^bb4([[ALLOC1]], [[OWN]]
-//  CHECK-NEXT: ^bb4([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
+//       CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]]
+//  CHECK-NEXT: ^bb4:
 //   CHECK-NOT: bufferization.dealloc
 //   CHECK-NOT: bufferization.clone
-//       CHECK: cf.br ^bb5([[ALLOC2]], [[COND1]]
-//  CHECK-NEXT: ^bb5([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
+//       CHECK: cf.br ^bb5([[ALLOC1]], [[OWN]]
+//  CHECK-NEXT: ^bb5([[ALLOC2:%.*]]:{{.*}}, [[COND1:%.+]]:{{.*}})
+//   CHECK-NOT: bufferization.dealloc
+//   CHECK-NOT: bufferization.clone
+//       CHECK: cf.br ^bb6([[ALLOC2]], [[COND1]]
+//  CHECK-NEXT: ^bb6([[ALLOC4:%.*]]:{{.*}}, [[COND2:%.+]]:{{.*}})
 //  CHECK-NEXT: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC4]]
 //  CHECK-NEXT: [[OWN:%.+]]:2 = bufferization.dealloc ([[BASE]] :{{.*}}) if ([[COND2]]) retain ([[ALLOC4]], [[ARG2]] :
-//       CHECK: cf.br ^bb6([[ALLOC4]], [[OWN]]#0
-//  CHECK-NEXT: ^bb6([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
+//       CHECK: cf.br ^bb7([[ALLOC4]], [[OWN]]#0
+//  CHECK-NEXT: ^bb7([[ALLOC5:%.*]]:{{.*}}, [[COND3:%.+]]:{{.*}})
 //       CHECK: test.copy
 //       CHECK: [[BASE:%[a-zA-Z0-9_]+]]{{.*}} = memref.extract_strided_metadata [[ALLOC5]]
 //  CHECK-NEXT: bufferization.dealloc ([[BASE]] : {{.*}}) if ([[COND3]])
diff --git a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
index d1a89226fdb58..50a2d6bf532aa 100644
--- a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
@@ -15,7 +15,7 @@ func.func @main(%arg0: tensor<f32>) -> tensor<f32> {
 // CHECK-LABEL: @main
 // CHECK-SAME:       (%[[ARG0:.+]]: tensor<f32>) -> tensor<f32>
 // CHECK:   %[[EXTRACTED:.+]] = tensor.extract %[[ARG0]][] : tensor<f32>
-// CHECK: cf.br ^{{.*}}(%[[EXTRACTED]] : f32)
-// CHECK: ^{{.*}}(%[[ARG1:.+]]: f32):
-// CHECK:   %[[ELEMENTS:.+]] = tensor.from_elements %[[ARG1]] : tensor<f32>
+// CHECK: cf.br ^{{.*}}
+// CHECK: ^{{.*}}:
+// CHECK:   %[[ELEMENTS:.+]] = tensor.from_elements %[[EXTRACTED]] : tensor<f32>
 // CHECK:   return %[[ELEMENTS]] : tensor<f32>
diff --git a/mlir/test/Dialect/Linalg/detensorize_if.mlir b/mlir/test/Dialect/Linalg/detensorize_if.mlir
index 8d17763c04b6c..c728ad21d2209 100644
--- a/mlir/test/Dialect/Linalg/detensorize_if.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_if.mlir
@@ -42,18 +42,15 @@ func.func @main() -> (tensor<i32>) attributes {} {
 }
 
 // CHECK-LABEL:  func @main()
-// CHECK-DAG:     arith.constant 0
-// CHECK-DAG:     arith.constant 10
-// CHECK:         cf.br ^[[bb1:.*]](%{{.*}}: i32)
-// CHECK-NEXT:   ^[[bb1]](%{{.*}}: i32):
-// CHECK-NEXT:     arith.cmpi slt, %{{.*}}, %{{.*}}
-// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32)
-// CHECK-NEXT:   ^[[bb2]](%{{.*}}: i32)
-// CHECK-NEXT:     arith.addi %{{.*}}, %{{.*}}
-// CHECK-NEXT:     cf.br ^[[bb3:.*]](%{{.*}} : i32)
-// CHECK-NEXT:   ^[[bb3]](%{{.*}}: i32)
-// CHECK-NEXT:     tensor.from_elements %{{.*}} : tensor<i32>
-// CHECK-NEXT:     return %{{.*}}
+// CHECK-DAG:     %[[cst:.*]] = arith.constant dense<0>
+// CHECK-DAG:     arith.constant true
+// CHECK:         cf.br
+// CHECK-NEXT:   ^[[bb1:.*]]:
+// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb3
+// CHECK-NEXT:   ^[[bb2]]
+// CHECK-NEXT:     cf.br ^[[bb3:.*]]
+// CHECK-NEXT:   ^[[bb3]]
+// CHECK-NEXT:     return %[[cst]]
 // CHECK-NEXT:   }
 
 // -----
@@ -106,20 +103,17 @@ func.func @main() -> (tensor<i32>) attributes {} {
 }
 
 // CHECK-LABEL:  func @main()
-// CHECK-DAG:     arith.constant 0
-// CHECK-DAG:     arith.constant 10
-// CHECK:         cf.br ^[[bb1:.*]](%{{.*}}: i32)
-// CHECK-NEXT:   ^[[bb1]](%{{.*}}: i32):
-// CHECK-NEXT:     arith.cmpi slt, %{{.*}}, %{{.*}}
-// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb3(%{{.*}} : i32)
-// CHECK-NEXT:   ^[[bb2]](%{{.*}}: i32)
-// CHECK-NEXT:     arith.addi %{{.*}}, %{{.*}}
-// CHECK-NEXT:     cf.br ^[[bb3:.*]](%{{.*}} : i32)
-// CHECK-NEXT:   ^[[bb3]](%{{.*}}: i32)
-// CHECK-NEXT:     cf.br ^[[bb4:.*]](%{{.*}} : i32)
-// CHECK-NEXT:   ^[[bb4]](%{{.*}}: i32)
-// CHECK-NEXT:     tensor.from_elements %{{.*}} : tensor<i32>
-// CHECK-NEXT:     return %{{.*}}
+// CHECK-DAG:     %[[cst:.*]] = arith.constant dense<0>
+// CHECK-DAG:     arith.constant true
+// CHECK:         cf.br ^[[bb1:.*]]
+// CHECK-NEXT:   ^[[bb1:.*]]:
+// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb3
+// CHECK-NEXT:   ^[[bb2]]:
+// CHECK-NEXT:     cf.br ^[[bb3:.*]]
+// CHECK-NEXT:   ^[[bb3]]:
+// CHECK-NEXT:     cf.br ^[[bb4:.*]]
+// CHECK-NEXT:   ^[[bb4]]:
+// CHECK-NEXT:     return %[[cst]]
 // CHECK-NEXT:   }
 
 // -----
@@ -171,16 +165,13 @@ func.func @main() -> (tensor<i32>) attributes {} {
 }
 
 // CHECK-LABEL:  func @main()
-// CHECK-DAG:     arith.constant 0
-// CHECK-DAG:     arith.constant 10
-// CHECK:         cf.br ^[[bb1:.*]](%{{.*}}: i32)
-// CHECK-NEXT:   ^[[bb1]](%{{.*}}: i32):
-// CHECK-NEXT:     arith.cmpi slt, %{{.*}}, %{{.*}}
-// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^bb2(%{{.*}} : i32)
-// CHECK-NEXT:   ^[[bb2]](%{{.*}}: i32)
-// CHECK-NEXT:     arith.addi %{{.*}}, %{{.*}}
-// CHECK-NEXT:     cf.br ^[[bb3:.*]](%{{.*}} : i32)
-// CHECK-NEXT:   ^[[bb3]](%{{.*}}: i32)
-// CHECK-NEXT:     tensor.from_elements %{{.*}} : tensor<i32>
-// CHECK-NEXT:     return %{{.*}}
+// CHECK-DAG:     %[[cst:.*]] = arith.constant dense<10>
+// CHECK-DAG:     arith.constant true
+// CHECK:         cf.br ^[[bb1:.*]]
+// CHECK-NEXT:   ^[[bb1]]:
+// CHECK-NEXT:     cf.cond_br %{{.*}}, ^[[bb2:.*]], ^bb2
+// CHECK-NEXT:   ^[[bb2]]
+// CHECK-NEXT:     cf.br ^[[bb3:.*]]
+// CHECK-NEXT:   ^[[bb3]]
+// CHECK-NEXT:     return %[[cst]]
 // CHECK-NEXT:   }
diff --git a/mlir/test/Dialect/Linalg/detensorize_while.mlir b/mlir/test/Dialect/Linalg/detensorize_while.mlir
index aa30900f76a33..580a97d3a851b 100644
--- a/mlir/test/Dialect/Linalg/detensorize_while.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_while.mlir
@@ -46,11 +46,11 @@ func.func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attribu
 // DET-ALL:         cf.br ^[[bb1:.*]](%{{.*}} : i32)
 // DET-ALL:       ^[[bb1]](%{{.*}}: i32)
 // DET-ALL:         arith.cmpi slt, {{.*}}
-// DET-ALL:         cf.cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
-// DET-ALL:       ^[[bb2]](%{{.*}}: i32)
+// DET-ALL:         cf.cond_br {{.*}}, ^[[bb2:.*]], ^[[bb3:.*]]
+// DET-ALL:       ^[[bb2]]
 // DET-ALL:         arith.addi {{.*}}
 // DET-ALL:         cf.br ^[[bb1]](%{{.*}} : i32)
-// DET-ALL:       ^[[bb3]](%{{.*}}: i32)
+// DET-ALL:       ^[[bb3]]:
 // DET-ALL:         tensor.from_elements {{.*}}
 // DET-ALL:         return %{{.*}} : tensor<i32>
 
@@ -62,10 +62,10 @@ func.func @main(%farg0: tensor<i32>, %farg1: tensor<i32>) -> tensor<i32> attribu
 // DET-CF:         cf.br ^[[bb1:.*]](%{{.*}} : i32)
 // DET-CF:       ^[[bb1]](%{{.*}}: i32)
 // DET-CF:         arith.cmpi slt, {{.*}}
-// DET-CF:         cf.cond_br {{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32)
-// DET-CF:       ^[[bb2]](%{{.*}}: i32)
+// DET-CF:        ...
[truncated]

@giuseros
Copy link
Contributor Author

giuseros commented Aug 5, 2024

Hi @Dinistro , it was actually a small bug (unfortunately I was on holiday, otherwise I could have sorted that without going through the revert process).

@krzysz00 , @Mogball , this is the same PR you already review (so the logic is almost unchanged). What basically changed is the logic to prune a new argument within the block. The original code was

Value oldArg = block->getArgument(idx);
Value newArg = block->getArgument(idxToReplacement[idx]);
rewriter.replaceAllUsesWith(oldArg, newArg);
toErase.push_back(idx);

Where idx is the index of the argument we want to map to idxToReplacement[idx]. However, those indices had to be offset by the number of block arguments that were already there.

So, the above code, becomes:

Value oldArg =  block->getArgument(numOldArguments + idx);
Value newArg = block->getArgument(numOldArguments + idxToReplacement[idx]);
rewriter.replaceAllUsesWith(oldArg, newArg);
toErase.push_back(numOldArguments + idx);

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved assuming I'm reading correctly and you've added the code that caused the revert to the tests

Copy link
Contributor

@Dinistro Dinistro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix, LGTM % a few nits.

mlir/lib/Transforms/Utils/RegionUtils.cpp Outdated Show resolved Hide resolved
mlir/lib/Transforms/Utils/RegionUtils.cpp Outdated Show resolved Hide resolved
mlir/lib/Transforms/Utils/RegionUtils.cpp Show resolved Hide resolved
mlir/lib/Transforms/Utils/RegionUtils.cpp Outdated Show resolved Hide resolved
mlir/lib/Transforms/Utils/RegionUtils.cpp Outdated Show resolved Hide resolved
mlir/lib/Transforms/Utils/RegionUtils.cpp Show resolved Hide resolved
mlir/lib/Transforms/Utils/RegionUtils.cpp Outdated Show resolved Hide resolved
@giuseros giuseros force-pushed the improve_block_merging_3 branch from e018af0 to 80951a3 Compare August 6, 2024 10:31
Copy link

github-actions bot commented Aug 6, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@giuseros giuseros force-pushed the improve_block_merging_3 branch from 80951a3 to 074c43e Compare August 6, 2024 10:43
@giuseros giuseros force-pushed the improve_block_merging_3 branch from 074c43e to e576c6b Compare August 6, 2024 10:44
@giuseros
Copy link
Contributor Author

giuseros commented Aug 6, 2024

Hi @Dinistro , I applied the suggested changes! Thanks for the review!

@giuseros giuseros merged commit 441b672 into llvm:main Aug 7, 2024
7 checks passed
@giuseros
Copy link
Contributor Author

giuseros commented Aug 7, 2024

Merging this in, since it's been approved and reviewed

banach-space pushed a commit to banach-space/llvm-project that referenced this pull request Aug 7, 2024
With this PR I am trying to address:
llvm#63230.

What changed:
- While merging identical blocks, don't add a block argument if it is
"identical" to another block argument. I.e., if the two block arguments
refer to the same `Value`. The operations operands in the block will
point to the argument we already inserted. This needs to happen to all
the arguments we pass to the different successors of the parent block
- After merged the blocks, get rid of "unnecessary" arguments. I.e., if
all the predecessors pass the same block argument, there is no need to
pass it as an argument.
- This last simplification clashed with
`BufferDeallocationSimplification`. The reason, I think, is that the two
simplifications are clashing. I.e., `BufferDeallocationSimplification`
contains an analysis based on the block structure. If we simplify the
block structure (by merging and/or dropping block arguments) the
analysis is invalid . The solution I found is to do a more prudent
simplification when running that pass.

**Note-1**: I ran all the integration tests
(`-DMLIR_INCLUDE_INTEGRATION_TESTS=ON`) and they passed.
**Note-2**: I fixed a bug found by @Dinistro in llvm#97697 . The issue was
that, when looking for redundant arguments, I was not considering that
the block might have already some arguments. So the index (in the block
args list) of the i-th `newArgument` is `i+numOfOldArguments`.
TIFitis pushed a commit that referenced this pull request Aug 8, 2024
With this PR I am trying to address:
#63230.

What changed:
- While merging identical blocks, don't add a block argument if it is
"identical" to another block argument. I.e., if the two block arguments
refer to the same `Value`. The operations operands in the block will
point to the argument we already inserted. This needs to happen to all
the arguments we pass to the different successors of the parent block
- After merged the blocks, get rid of "unnecessary" arguments. I.e., if
all the predecessors pass the same block argument, there is no need to
pass it as an argument.
- This last simplification clashed with
`BufferDeallocationSimplification`. The reason, I think, is that the two
simplifications are clashing. I.e., `BufferDeallocationSimplification`
contains an analysis based on the block structure. If we simplify the
block structure (by merging and/or dropping block arguments) the
analysis is invalid . The solution I found is to do a more prudent
simplification when running that pass.

**Note-1**: I ran all the integration tests
(`-DMLIR_INCLUDE_INTEGRATION_TESTS=ON`) and they passed.
**Note-2**: I fixed a bug found by @Dinistro in #97697 . The issue was
that, when looking for redundant arguments, I was not considering that
the block might have already some arguments. So the index (in the block
args list) of the i-th `newArgument` is `i+numOfOldArguments`.

// Add any nested regions to the worklist.
for (Block &block : *region) {
anyChanged = succeeded(dropRedundantArguments(rewriter, block));
Copy link
Contributor

@Hardcode84 Hardcode84 Aug 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

anyChanged is completely overwritten on each iteration.

It should be anyChanged = succeeded(dropRedundantArguments(rewriter, block)) || anyChanged, I believe

giuseros added a commit to giuseros/llvm-project that referenced this pull request Aug 12, 2024
giuseros added a commit to giuseros/llvm-project that referenced this pull request Aug 12, 2024
giuseros added a commit that referenced this pull request Aug 12, 2024
bwendling pushed a commit to bwendling/llvm-project that referenced this pull request Aug 15, 2024
kstoimenov pushed a commit to kstoimenov/llvm-project that referenced this pull request Aug 15, 2024
With this PR I am trying to address:
llvm#63230.

What changed:
- While merging identical blocks, don't add a block argument if it is
"identical" to another block argument. I.e., if the two block arguments
refer to the same `Value`. The operations operands in the block will
point to the argument we already inserted. This needs to happen to all
the arguments we pass to the different successors of the parent block
- After merged the blocks, get rid of "unnecessary" arguments. I.e., if
all the predecessors pass the same block argument, there is no need to
pass it as an argument.
- This last simplification clashed with
`BufferDeallocationSimplification`. The reason, I think, is that the two
simplifications are clashing. I.e., `BufferDeallocationSimplification`
contains an analysis based on the block structure. If we simplify the
block structure (by merging and/or dropping block arguments) the
analysis is invalid . The solution I found is to do a more prudent
simplification when running that pass.

**Note-1**: I ran all the integration tests
(`-DMLIR_INCLUDE_INTEGRATION_TESTS=ON`) and they passed.
**Note-2**: I fixed a bug found by @Dinistro in llvm#97697 . The issue was
that, when looking for redundant arguments, I was not considering that
the block might have already some arguments. So the index (in the block
args list) of the i-th `newArgument` is `i+numOfOldArguments`.
joviliast added a commit to joviliast/triton that referenced this pull request Aug 29, 2024
Updating LLVM repo to get block merging fix introduced here:
llvm/llvm-project#102038

Signed-off-by: Ilya Veselov <[email protected]>
joviliast added a commit to joviliast/triton that referenced this pull request Aug 29, 2024
Updating LLVM repo to get block merging fix introduced here:
llvm/llvm-project#102038

Signed-off-by: Ilya Veselov <[email protected]>
joviliast added a commit to joviliast/triton that referenced this pull request Aug 29, 2024
Updating LLVM repo to ToT to get block merging fix introduced here:
llvm/llvm-project#102038
and masked load support from here:
llvm/llvm-project#104598

Signed-off-by: Ilya Veselov <[email protected]>
antiagainst pushed a commit to triton-lang/triton that referenced this pull request Aug 30, 2024
Updating LLVM repo to ToT to get block merging fix introduced here:
llvm/llvm-project#102038
and masked load support from here:
llvm/llvm-project#104598

Signed-off-by: Ilya Veselov <[email protected]>
antiagainst added a commit to antiagainst/triton that referenced this pull request Aug 30, 2024
This contains the two commit we want to have for AMD backend:
* llvm/llvm-project#102038
* llvm/llvm-project#104598
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:bufferization Bufferization infrastructure mlir:core MLIR Core Infrastructure mlir:linalg mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants