Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] Make overloads of SymbolTable::replaceAllSymbolUses consistent. #68320

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions mlir/lib/IR/SymbolTable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -655,12 +655,20 @@ static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
scopes.back().limit = limit;
return scopes;
}
template <typename IRUnit>
static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol,
IRUnit *limit) {
Region *limit) {
return {{SymbolRefAttr::get(symbol), limit}};
}

static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol,
Operation *limit) {
SmallVector<SymbolScope, 1> scopes;
auto symbolRef = SymbolRefAttr::get(symbol);
for (auto &region : limit->getRegions())
scopes.push_back({symbolRef, &region});
return scopes;
}

/// Returns true if the given reference 'SubRef' is a sub reference of the
/// reference 'ref', i.e. 'ref' is a further qualified reference.
static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) {
Expand Down
15 changes: 13 additions & 2 deletions mlir/test/python/ir/symbol_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ def testSymbolTableRAUW():
"""
)
foo, bar = list(m.operation.regions[0].blocks[0].operations)[0:2]

# Do renaming just within `foo`.
SymbolTable.set_symbol_name(bar, "bam")
# Note that module.operation counts as a "nested symbol table" which won't
# be traversed into, so it is necessary to traverse its children.
SymbolTable.replace_all_symbol_uses("bar", "bam", foo)
# CHECK: call @bam()
# CHECK: func private @bam
Expand All @@ -118,6 +118,17 @@ def testSymbolTableRAUW():
print(f"Foo symbol: {repr(SymbolTable.get_symbol_name(foo))}")
print(f"Bar symbol: {repr(SymbolTable.get_symbol_name(bar))}")

# Do renaming within the module.
SymbolTable.set_symbol_name(bar, "baz")
SymbolTable.replace_all_symbol_uses("bam", "baz", m.operation)
# CHECK: call @baz()
# CHECK: func private @baz
print(m)
# CHECK: Foo symbol: StringAttr("foo")
# CHECK: Bar symbol: StringAttr("baz")
print(f"Foo symbol: {repr(SymbolTable.get_symbol_name(foo))}")
print(f"Bar symbol: {repr(SymbolTable.get_symbol_name(bar))}")


# CHECK-LABEL: testSymbolTableVisibility
@run
Expand Down
1 change: 1 addition & 0 deletions mlir/unittests/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_mlir_unittest(MLIRIRTests
OperationSupportTest.cpp
PatternMatchTest.cpp
ShapedTypeTest.cpp
SymbolTableTest.cpp
TypeTest.cpp
OpPropertiesTest.cpp

Expand Down
136 changes: 136 additions & 0 deletions mlir/unittests/IR/SymbolTableTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
//===- SymbolTableTest.cpp - SymbolTable unit tests -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Parser/Parser.h"

#include "gtest/gtest.h"

using namespace mlir;

namespace test {
void registerTestDialect(DialectRegistry &);
} // namespace test

class ReplaceAllSymbolUsesTest : public ::testing::Test {
protected:
using ReplaceFnType = llvm::function_ref<LogicalResult(
SymbolTable, ModuleOp, Operation *, Operation *)>;

void SetUp() override {
::test::registerTestDialect(registry);
context = std::make_unique<MLIRContext>(registry);
}

void testReplaceAllSymbolUses(ReplaceFnType replaceFn) {
// Set up IR and find func ops.
OwningOpRef<ModuleOp> module =
parseSourceString<ModuleOp>(kInput, context.get());
SymbolTable symbolTable(module.get());
auto opIterator = module->getBody(0)->getOperations().begin();
auto fooOp = cast<FunctionOpInterface>(opIterator++);
auto barOp = cast<FunctionOpInterface>(opIterator++);
ASSERT_EQ(fooOp.getNameAttr(), "foo");
ASSERT_EQ(barOp.getNameAttr(), "bar");

// Call test function that does symbol replacement.
LogicalResult res = replaceFn(symbolTable, module.get(), fooOp, barOp);
ASSERT_TRUE(succeeded(res));
ASSERT_TRUE(succeeded(verify(module.get())));

// Check that it got renamed.
bool calleeFound = false;
fooOp->walk([&](CallOpInterface callOp) {
StringAttr callee = callOp.getCallableForCallee()
.dyn_cast<SymbolRefAttr>()
.getLeafReference();
EXPECT_EQ(callee, "baz");
calleeFound = true;
});
EXPECT_TRUE(calleeFound);
}

std::unique_ptr<MLIRContext> context;

private:
constexpr static llvm::StringLiteral kInput = R"MLIR(
module {
test.conversion_func_op private @foo() {
"test.conversion_call_op"() { callee=@bar } : () -> ()
"test.return"() : () -> ()
}
test.conversion_func_op private @bar()
}
)MLIR";

DialectRegistry registry;
};

namespace {

TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleOp) {
// Symbol as `Operation *`, rename within module.
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
auto barOp) -> LogicalResult {
return symbolTable.replaceAllSymbolUses(
barOp, StringAttr::get(context.get(), "baz"), module);
});
}

TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleOp) {
// Symbol as `StringAttr`, rename within module.
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
auto barOp) -> LogicalResult {
return symbolTable.replaceAllSymbolUses(
StringAttr::get(context.get(), "bar"),
StringAttr::get(context.get(), "baz"), module);
});
}

TEST_F(ReplaceAllSymbolUsesTest, OperationInModuleBody) {
// Symbol as `Operation *`, rename within module body.
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
auto barOp) -> LogicalResult {
return symbolTable.replaceAllSymbolUses(
barOp, StringAttr::get(context.get(), "baz"), &module->getRegion(0));
});
}

TEST_F(ReplaceAllSymbolUsesTest, StringAttrInModuleBody) {
// Symbol as `StringAttr`, rename within module body.
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
auto barOp) -> LogicalResult {
return symbolTable.replaceAllSymbolUses(
StringAttr::get(context.get(), "bar"),
StringAttr::get(context.get(), "baz"), &module->getRegion(0));
});
}

TEST_F(ReplaceAllSymbolUsesTest, OperationInFuncOp) {
// Symbol as `Operation *`, rename within function.
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
auto barOp) -> LogicalResult {
return symbolTable.replaceAllSymbolUses(
barOp, StringAttr::get(context.get(), "baz"), fooOp);
});
}

TEST_F(ReplaceAllSymbolUsesTest, StringAttrInFuncOp) {
// Symbol as `StringAttr`, rename within function.
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
auto barOp) -> LogicalResult {
return symbolTable.replaceAllSymbolUses(
StringAttr::get(context.get(), "bar"),
StringAttr::get(context.get(), "baz"), fooOp);
});
}

} // namespace