Skip to content

Commit

Permalink
[mlir] Make overloads of SymbolTable::replaceAllSymbolUses consistent. (
Browse files Browse the repository at this point in the history
llvm#68320)

This function has several overloads that allow to specify the symbol
that should be renamed and the scope for that renaming in different
ways. The overloads were inconsistent in the following way (quoted
strings are `StringAttr`s, other variables are `Operation *`):

* `replaceAllSymbolUses(symbolOp, "new_symbol", scopeOp)` would traverse
into the nested regions of `scopeOp` and hence rename the symbol inside
of `scopeOp`.
* `replaceAllSymbolUses("symbol", "new_symbol", scopeOp)` would *not*
traverse into the nested regions of `scopeOp` and hence *not* rename the
symbol.

The underlying behavior was spread over different places and is somewhat
hard to understand. The two overloads above mainly differed by what
`collectSymbolScopes` computed, which is itself overloaded. If `scopeOp`
is a top-level module, then the overload on `(Operation *, Operation
*)`, which is used in the first of the above cases, computes a scope
where the body region of the module is the `limit`; however, the
overload on `(StringAttr, Operation *)` computed the module op itself as
the `limit`. Later, `walkSymbolTable` would walk the body of the module
if it was given as a region but it would *not* enter the regions of the
module op because that op has a symbol table (which was assumed to be a
*different* scope).

The fix in this commit is change the behavior of `collectSymbolScopes`
such that the `(StringAttr, Operation *)` overload returns a scope for
each region in the `limit` argument.
  • Loading branch information
ingomueller-net authored Oct 10, 2023
1 parent 0d0f219 commit 4790578
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 4 deletions.
12 changes: 10 additions & 2 deletions mlir/lib/IR/SymbolTable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -729,12 +729,20 @@ static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
scopes.back().limit = limit;
return scopes;
}
template <typename IRUnit>
static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol,
IRUnit *limit) {
Region *limit) {
return {{SymbolRefAttr::get(symbol), limit}};
}

static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol,
Operation *limit) {
SmallVector<SymbolScope, 1> scopes;
auto symbolRef = SymbolRefAttr::get(symbol);
for (auto &region : limit->getRegions())
scopes.push_back({symbolRef, &region});
return scopes;
}

/// Returns true if the given reference 'SubRef' is a sub reference of the
/// reference 'ref', i.e. 'ref' is a further qualified reference.
static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) {
Expand Down
15 changes: 13 additions & 2 deletions mlir/test/python/ir/symbol_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ def testSymbolTableRAUW():
"""
)
foo, bar = list(m.operation.regions[0].blocks[0].operations)[0:2]

# Do renaming just within `foo`.
SymbolTable.set_symbol_name(bar, "bam")
# Note that module.operation counts as a "nested symbol table" which won't
# be traversed into, so it is necessary to traverse its children.
SymbolTable.replace_all_symbol_uses("bar", "bam", foo)
# CHECK: call @bam()
# CHECK: func private @bam
Expand All @@ -118,6 +118,17 @@ def testSymbolTableRAUW():
print(f"Foo symbol: {repr(SymbolTable.get_symbol_name(foo))}")
print(f"Bar symbol: {repr(SymbolTable.get_symbol_name(bar))}")

# Do renaming within the module.
SymbolTable.set_symbol_name(bar, "baz")
SymbolTable.replace_all_symbol_uses("bam", "baz", m.operation)
# CHECK: call @baz()
# CHECK: func private @baz
print(m)
# CHECK: Foo symbol: StringAttr("foo")
# CHECK: Bar symbol: StringAttr("baz")
print(f"Foo symbol: {repr(SymbolTable.get_symbol_name(foo))}")
print(f"Bar symbol: {repr(SymbolTable.get_symbol_name(bar))}")


# CHECK-LABEL: testSymbolTableVisibility
@run
Expand Down
1 change: 1 addition & 0 deletions mlir/unittests/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_mlir_unittest(MLIRIRTests
OperationSupportTest.cpp
PatternMatchTest.cpp
ShapedTypeTest.cpp
SymbolTableTest.cpp
TypeTest.cpp
OpPropertiesTest.cpp

Expand Down
136 changes: 136 additions & 0 deletions mlir/unittests/IR/SymbolTableTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
//===- SymbolTableTest.cpp - SymbolTable unit tests -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Parser/Parser.h"

#include "gtest/gtest.h"

using namespace mlir;

namespace test {
void registerTestDialect(DialectRegistry &);
} // namespace test

class ReplaceAllSymbolUsesTest : public ::testing::Test {
protected:
using ReplaceFnType = llvm::function_ref<LogicalResult(
SymbolTable, ModuleOp, Operation *, Operation *)>;

void SetUp() override {
::test::registerTestDialect(registry);
context = std::make_unique<MLIRContext>(registry);
}

void testReplaceAllSymbolUses(ReplaceFnType replaceFn) {
// Set up IR and find func ops.
OwningOpRef<ModuleOp> module =
parseSourceString<ModuleOp>(kInput, context.get());
SymbolTable symbolTable(module.get());
auto opIterator = module->getBody(0)->getOperations().begin();
auto fooOp = cast<FunctionOpInterface>(opIterator++);
auto barOp = cast<FunctionOpInterface>(opIterator++);
ASSERT_EQ(fooOp.getNameAttr(), "foo");
ASSERT_EQ(barOp.getNameAttr(), "bar");

// Call test function that does symbol replacement.
LogicalResult res = replaceFn(symbolTable, module.get(), fooOp, barOp);
ASSERT_TRUE(succeeded(res));
ASSERT_TRUE(succeeded(verify(module.get())));

// Check that it got renamed.
bool calleeFound = false;
fooOp->walk([&](CallOpInterface callOp) {
StringAttr callee = callOp.getCallableForCallee()
.dyn_cast<SymbolRefAttr>()
.getLeafReference();
EXPECT_EQ(callee, "baz");
calleeFound = true;
});
EXPECT_TRUE(calleeFound);
}

std::unique_ptr<MLIRContext> context;

private:
constexpr static llvm::StringLiteral kInput = R"MLIR(
module {
test.conversion_func_op private @foo() {
"test.conversion_call_op"() { callee=@bar } : () -> ()
"test.return"() : () -> ()
}
test.conversion_func_op private @bar()
}
)MLIR";

DialectRegistry registry;
};

namespace {

TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleOp) {
// Symbol as `Operation *`, rename within module.
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
auto barOp) -> LogicalResult {
return symbolTable.replaceAllSymbolUses(
barOp, StringAttr::get(context.get(), "baz"), module);
});
}

TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleOp) {
// Symbol as `StringAttr`, rename within module.
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
auto barOp) -> LogicalResult {
return symbolTable.replaceAllSymbolUses(
StringAttr::get(context.get(), "bar"),
StringAttr::get(context.get(), "baz"), module);
});
}

TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleBody) {
// Symbol as `Operation *`, rename within module body.
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
auto barOp) -> LogicalResult {
return symbolTable.replaceAllSymbolUses(
barOp, StringAttr::get(context.get(), "baz"), &module->getRegion(0));
});
}

TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleBody) {
// Symbol as `StringAttr`, rename within module body.
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
auto barOp) -> LogicalResult {
return symbolTable.replaceAllSymbolUses(
StringAttr::get(context.get(), "bar"),
StringAttr::get(context.get(), "baz"), &module->getRegion(0));
});
}

TEST_F(ReplaceAllSymbolUsesTest, OperationInFuncOp) {
// Symbol as `Operation *`, rename within function.
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
auto barOp) -> LogicalResult {
return symbolTable.replaceAllSymbolUses(
barOp, StringAttr::get(context.get(), "baz"), fooOp);
});
}

TEST_F(ReplaceAllSymbolUsesTest, StringAttrInFuncOp) {
// Symbol as `StringAttr`, rename within function.
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
auto barOp) -> LogicalResult {
return symbolTable.replaceAllSymbolUses(
StringAttr::get(context.get(), "bar"),
StringAttr::get(context.get(), "baz"), fooOp);
});
}

} // namespace

0 comments on commit 4790578

Please sign in to comment.