Skip to content

Commit

Permalink
[Arc] Dedup: Fix use after free (#6568)
Browse files Browse the repository at this point in the history
  • Loading branch information
hovind authored Jan 12, 2024
1 parent c07347a commit 7d0bc38
Showing 1 changed file with 17 additions and 12 deletions.
29 changes: 17 additions & 12 deletions lib/Dialect/Arc/Transforms/Dedup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,11 +330,11 @@ struct StructuralEquivalence {
} // namespace

static void addCallSiteOperands(
MutableArrayRef<mlir::CallOpInterface> callSites,
SmallSetVector<mlir::CallOpInterface, 1> &callSites,
ArrayRef<std::variant<Operation *, unsigned>> operandMappings) {
SmallDenseMap<Operation *, Operation *> clonedOps;
SmallVector<Value> newOperands;
for (auto &callOp : callSites) {
for (auto callOp : callSites) {
OpBuilder builder(callOp);
newOperands.clear();
clonedOps.clear();
Expand Down Expand Up @@ -362,12 +362,13 @@ static bool isOutlinable(OpOperand &operand) {
namespace {
struct DedupPass : public arc::impl::DedupBase<DedupPass> {
void runOnOperation() override;
void replaceArcWith(DefineOp oldArc, DefineOp newArc);
void replaceArcWith(DefineOp oldArc, DefineOp newArc,
SymbolTableCollection &symbolTable);

/// A mapping from arc names to arc definitions.
DenseMap<StringAttr, DefineOp> arcByName;
/// A mapping from arc definitions to call sites.
DenseMap<DefineOp, SmallVector<mlir::CallOpInterface, 1>> callSites;
DenseMap<DefineOp, SmallSetVector<mlir::CallOpInterface, 1>> callSites;
};

struct ArcHash {
Expand Down Expand Up @@ -396,10 +397,7 @@ void DedupPass::runOnOperation() {
getOperation().walk([&](mlir::CallOpInterface callOp) {
if (auto defOp =
dyn_cast_or_null<DefineOp>(callOp.resolveCallable(&symbolTable)))
callSites[arcByName.lookup(callOp.getCallableForCallee()
.get<mlir::SymbolRefAttr>()
.getLeafReference())]
.push_back(callOp);
callSites[defOp].insert(callOp);
});

// Sort the arcs by hash such that arcs with the same hash are next to each
Expand Down Expand Up @@ -436,7 +434,7 @@ void DedupPass::runOnOperation() {
LLVM_DEBUG(llvm::dbgs()
<< "- Merge " << defineOp.getSymNameAttr() << " <- "
<< otherDefineOp.getSymNameAttr() << "\n");
replaceArcWith(otherDefineOp, defineOp);
replaceArcWith(otherDefineOp, defineOp, symbolTable);
arcHashes[otherIdx].defineOp = {};
}
}
Expand Down Expand Up @@ -713,13 +711,14 @@ void DedupPass::runOnOperation() {
<< " - Merged " << defineOp.getSymNameAttr() << " <- "
<< otherDefineOp.getSymNameAttr() << "\n");
addCallSiteOperands(callSites[otherDefineOp], newOperands);
replaceArcWith(otherDefineOp, defineOp);
replaceArcWith(otherDefineOp, defineOp, symbolTable);
arcHashes[otherIdx].defineOp = {};
}
}
}

void DedupPass::replaceArcWith(DefineOp oldArc, DefineOp newArc) {
void DedupPass::replaceArcWith(DefineOp oldArc, DefineOp newArc,
SymbolTableCollection &symbolTable) {
++dedupPassNumArcsDeduped;
auto oldArcOps = oldArc.getOps();
dedupPassTotalOps += std::distance(oldArcOps.begin(), oldArcOps.end());
Expand All @@ -728,8 +727,14 @@ void DedupPass::replaceArcWith(DefineOp oldArc, DefineOp newArc) {
auto newArcName = SymbolRefAttr::get(newArc.getSymNameAttr());
for (auto callOp : oldUses) {
callOp.setCalleeFromCallable(newArcName);
newUses.push_back(callOp);
newUses.insert(callOp);
}

oldArc.walk([&](mlir::CallOpInterface callOp) {
if (auto defOp =
dyn_cast_or_null<DefineOp>(callOp.resolveCallable(&symbolTable)))
callSites[defOp].remove(callOp);
});
callSites.erase(oldArc);
arcByName.erase(oldArc.getSymNameAttr());
oldArc->erase();
Expand Down

0 comments on commit 7d0bc38

Please sign in to comment.