From 3dae97cc011ca097bd457bbfa5855da86290f631 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Tue, 5 Dec 2023 11:46:30 +0900 Subject: [PATCH] [mlir][bufferization] Fix op dominance bug in rewrite pattern (#74159) Fixes a bug in `SplitDeallocWhenNotAliasingAnyOther`. This pattern used to generate invalid IR (op dominance error). We never noticed this bug in existing test cases because other patterns and/or foldings were applied afterwards and those rewrites "fixed up" the IR again. (The bug is visible when running `mlir-opt -debug`.) Also add additional comments to the implementation and simplify the code a bit. Apart from the fixed dominance error, this change is NFC. Without this change, buffer deallocation tests will fail when running with #74270. --- .../BufferDeallocationSimplification.cpp | 55 +++++++++++-------- 1 file changed, 31 insertions(+), 24 deletions(-) diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp index 7bbdeab3ea1a87..42653517249d66 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp @@ -314,44 +314,51 @@ struct SplitDeallocWhenNotAliasingAnyOther LogicalResult matchAndRewrite(DeallocOp deallocOp, PatternRewriter &rewriter) const override { + Location loc = deallocOp.getLoc(); if (deallocOp.getMemrefs().size() <= 1) return failure(); - SmallVector newMemrefs, newConditions, replacements; - DenseSet exceptedUsers; - replacements = deallocOp.getUpdatedConditions(); + SmallVector remainingMemrefs, remainingConditions; + SmallVector> updatedConditions; for (auto [memref, cond] : llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) { + // Check if `memref` can split off into a separate bufferization.dealloc. if (potentiallyAliasesMemref(aliasAnalysis, deallocOp.getMemrefs(), memref, true)) { - newMemrefs.push_back(memref); - newConditions.push_back(cond); + // `memref` alias with other memrefs, do not split off. + remainingMemrefs.push_back(memref); + remainingConditions.push_back(cond); continue; } - auto newDeallocOp = rewriter.create( - deallocOp.getLoc(), memref, cond, deallocOp.getRetained()); - replacements = SmallVector(llvm::map_range( - llvm::zip(replacements, newDeallocOp.getUpdatedConditions()), - [&](auto replAndNew) -> Value { - auto orOp = rewriter.create(deallocOp.getLoc(), - std::get<0>(replAndNew), - std::get<1>(replAndNew)); - exceptedUsers.insert(orOp); - return orOp.getResult(); - })); + // Create new bufferization.dealloc op for `memref`. + auto newDeallocOp = rewriter.create(loc, memref, cond, + deallocOp.getRetained()); + updatedConditions.push_back( + llvm::to_vector(ValueRange(newDeallocOp.getUpdatedConditions()))); } - if (newMemrefs.size() == deallocOp.getMemrefs().size()) + // Fail if no memref was split off. + if (remainingMemrefs.size() == deallocOp.getMemrefs().size()) return failure(); - rewriter.replaceUsesWithIf(deallocOp.getUpdatedConditions(), replacements, - [&](OpOperand &operand) { - return !exceptedUsers.contains( - operand.getOwner()); - }); - return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions, - rewriter); + // Create bufferization.dealloc op for all remaining memrefs. + auto newDeallocOp = rewriter.create( + loc, remainingMemrefs, remainingConditions, deallocOp.getRetained()); + + // Bit-or all conditions. + SmallVector replacements = + llvm::to_vector(ValueRange(newDeallocOp.getUpdatedConditions())); + for (auto additionalConditions : updatedConditions) { + assert(replacements.size() == additionalConditions.size() && + "expected same number of updated conditions"); + for (int64_t i = 0, e = replacements.size(); i < e; ++i) { + replacements[i] = rewriter.create( + loc, replacements[i], additionalConditions[i]); + } + } + rewriter.replaceOp(deallocOp, replacements); + return success(); } private: