diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp index 727bd6b1839618..7180ea432ea057 100644 --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -729,12 +729,20 @@ static SmallVector collectSymbolScopes(Operation *symbol, scopes.back().limit = limit; return scopes; } -template static SmallVector collectSymbolScopes(StringAttr symbol, - IRUnit *limit) { + Region *limit) { return {{SymbolRefAttr::get(symbol), limit}}; } +static SmallVector collectSymbolScopes(StringAttr symbol, + Operation *limit) { + SmallVector scopes; + auto symbolRef = SymbolRefAttr::get(symbol); + for (auto ®ion : limit->getRegions()) + scopes.push_back({symbolRef, ®ion}); + return scopes; +} + /// Returns true if the given reference 'SubRef' is a sub reference of the /// reference 'ref', i.e. 'ref' is a further qualified reference. static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) { diff --git a/mlir/test/python/ir/symbol_table.py b/mlir/test/python/ir/symbol_table.py index 3264cfcf9a1049..577721ab2111f5 100644 --- a/mlir/test/python/ir/symbol_table.py +++ b/mlir/test/python/ir/symbol_table.py @@ -106,9 +106,9 @@ def testSymbolTableRAUW(): """ ) foo, bar = list(m.operation.regions[0].blocks[0].operations)[0:2] + + # Do renaming just within `foo`. SymbolTable.set_symbol_name(bar, "bam") - # Note that module.operation counts as a "nested symbol table" which won't - # be traversed into, so it is necessary to traverse its children. SymbolTable.replace_all_symbol_uses("bar", "bam", foo) # CHECK: call @bam() # CHECK: func private @bam @@ -118,6 +118,17 @@ def testSymbolTableRAUW(): print(f"Foo symbol: {repr(SymbolTable.get_symbol_name(foo))}") print(f"Bar symbol: {repr(SymbolTable.get_symbol_name(bar))}") + # Do renaming within the module. + SymbolTable.set_symbol_name(bar, "baz") + SymbolTable.replace_all_symbol_uses("bam", "baz", m.operation) + # CHECK: call @baz() + # CHECK: func private @baz + print(m) + # CHECK: Foo symbol: StringAttr("foo") + # CHECK: Bar symbol: StringAttr("baz") + print(f"Foo symbol: {repr(SymbolTable.get_symbol_name(foo))}") + print(f"Bar symbol: {repr(SymbolTable.get_symbol_name(bar))}") + # CHECK-LABEL: testSymbolTableVisibility @run diff --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt index 8a74a590962892..6d05af193dfae0 100644 --- a/mlir/unittests/IR/CMakeLists.txt +++ b/mlir/unittests/IR/CMakeLists.txt @@ -8,6 +8,7 @@ add_mlir_unittest(MLIRIRTests OperationSupportTest.cpp PatternMatchTest.cpp ShapedTypeTest.cpp + SymbolTableTest.cpp TypeTest.cpp OpPropertiesTest.cpp diff --git a/mlir/unittests/IR/SymbolTableTest.cpp b/mlir/unittests/IR/SymbolTableTest.cpp new file mode 100644 index 00000000000000..5dcec749f0f425 --- /dev/null +++ b/mlir/unittests/IR/SymbolTableTest.cpp @@ -0,0 +1,136 @@ +//===- SymbolTableTest.cpp - SymbolTable unit tests -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "mlir/IR/SymbolTable.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "mlir/Parser/Parser.h" + +#include "gtest/gtest.h" + +using namespace mlir; + +namespace test { +void registerTestDialect(DialectRegistry &); +} // namespace test + +class ReplaceAllSymbolUsesTest : public ::testing::Test { +protected: + using ReplaceFnType = llvm::function_ref; + + void SetUp() override { + ::test::registerTestDialect(registry); + context = std::make_unique(registry); + } + + void testReplaceAllSymbolUses(ReplaceFnType replaceFn) { + // Set up IR and find func ops. + OwningOpRef module = + parseSourceString(kInput, context.get()); + SymbolTable symbolTable(module.get()); + auto opIterator = module->getBody(0)->getOperations().begin(); + auto fooOp = cast(opIterator++); + auto barOp = cast(opIterator++); + ASSERT_EQ(fooOp.getNameAttr(), "foo"); + ASSERT_EQ(barOp.getNameAttr(), "bar"); + + // Call test function that does symbol replacement. + LogicalResult res = replaceFn(symbolTable, module.get(), fooOp, barOp); + ASSERT_TRUE(succeeded(res)); + ASSERT_TRUE(succeeded(verify(module.get()))); + + // Check that it got renamed. + bool calleeFound = false; + fooOp->walk([&](CallOpInterface callOp) { + StringAttr callee = callOp.getCallableForCallee() + .dyn_cast() + .getLeafReference(); + EXPECT_EQ(callee, "baz"); + calleeFound = true; + }); + EXPECT_TRUE(calleeFound); + } + + std::unique_ptr context; + +private: + constexpr static llvm::StringLiteral kInput = R"MLIR( + module { + test.conversion_func_op private @foo() { + "test.conversion_call_op"() { callee=@bar } : () -> () + "test.return"() : () -> () + } + test.conversion_func_op private @bar() + } + )MLIR"; + + DialectRegistry registry; +}; + +namespace { + +TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleOp) { + // Symbol as `Operation *`, rename within module. + testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, + auto barOp) -> LogicalResult { + return symbolTable.replaceAllSymbolUses( + barOp, StringAttr::get(context.get(), "baz"), module); + }); +} + +TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleOp) { + // Symbol as `StringAttr`, rename within module. + testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, + auto barOp) -> LogicalResult { + return symbolTable.replaceAllSymbolUses( + StringAttr::get(context.get(), "bar"), + StringAttr::get(context.get(), "baz"), module); + }); +} + +TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleBody) { + // Symbol as `Operation *`, rename within module body. + testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, + auto barOp) -> LogicalResult { + return symbolTable.replaceAllSymbolUses( + barOp, StringAttr::get(context.get(), "baz"), &module->getRegion(0)); + }); +} + +TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleBody) { + // Symbol as `StringAttr`, rename within module body. + testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, + auto barOp) -> LogicalResult { + return symbolTable.replaceAllSymbolUses( + StringAttr::get(context.get(), "bar"), + StringAttr::get(context.get(), "baz"), &module->getRegion(0)); + }); +} + +TEST_F(ReplaceAllSymbolUsesTest, OperationInFuncOp) { + // Symbol as `Operation *`, rename within function. + testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, + auto barOp) -> LogicalResult { + return symbolTable.replaceAllSymbolUses( + barOp, StringAttr::get(context.get(), "baz"), fooOp); + }); +} + +TEST_F(ReplaceAllSymbolUsesTest, StringAttrInFuncOp) { + // Symbol as `StringAttr`, rename within function. + testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp, + auto barOp) -> LogicalResult { + return symbolTable.replaceAllSymbolUses( + StringAttr::get(context.get(), "bar"), + StringAttr::get(context.get(), "baz"), fooOp); + }); +} + +} // namespace