Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] Make overloads of SymbolTable::replaceAllSymbolUses consistent. #68320

Merged

Conversation

ingomueller-net
Copy link
Contributor

@ingomueller-net ingomueller-net commented Oct 5, 2023

This function has several overloads that allow to specify the symbol that should be renamed and the scope for that renaming in different ways. The overloads were inconsistent in the following way (quoted strings are StringAttrs, other variables are Operation *):

  • replaceAllSymbolUses(symbolOp, "new_symbol", scopeOp) would traverse into the nested regions of scopeOp and hence rename the symbol inside of scopeOp.
  • replaceAllSymbolUses("symbol", "new_symbol", scopeOp) would not traverse into the nested regions of scopeOp and hence not rename the symbol.

The underlying behavior was spread over different places and is somewhat hard to understand. The two overloads above mainly differed by what collectSymbolScopes computed, which is itself overloaded. If scopeOp is a top-level module, then the overload on (Operation *, Operation *), which is used in the first of the above cases, computes a scope where the body region of the module is the limit; however, the overload on (StringAttr, Operation *) computed the module op itself as the limit. Later, walkSymbolTable would walk the body of the module if it was given as a region but it would not enter the regions of the module op because that op has a symbol table (which was assumed to be a different scope).

The fix in this commit is change the behavior of collectSymbolScopes such that the (StringAttr, Operation *) overload returns a scope for each region in the limit argument.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Oct 5, 2023
@llvmbot
Copy link
Member

llvmbot commented Oct 5, 2023

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Changes

This function has several overloads that allow to specify the symbol that should be renamed and the scope for that renaming in different ways. The overloads were inconsistent in the following way (quoted strings are StringAttrs, other variables are Operation *):

  • replaceAllSymbolUses(symbolOp, "new_symbol", scopeOp) would traverse into the symbol table of scopeOp.
  • replaceAllSymbolUses("symbol", "new_symbol", scopeOp) would not traverse into the symbol table of scopeOp.

The underlying behavior was spread over different places and is somewhat hard to understand. The two overloads above mainly differed by what collectSymbolScopes computed, which is itself overloaded. If scopeOp is a top-level module, then the overload on
(Operation *, Operation *), which is used in the first of the above cases, computes a scope where the body region of the module is the limit; however, the overload on (StringAttr, Operation *) computed the module op itself as the limit. Later, walkSymbolTable would walk the body of the module if it was given as a region but it would not enter the regions of the module op because that op has a symbol table (which was assumed to be a different scope).

The fix in this commit is change the behavior of collectSymbolScopes such that the (StringAttr, Operation *) overload returns a scope for each region in the limit argument.


Full diff: https://github.com/llvm/llvm-project/pull/68320.diff

2 Files Affected:

  • (modified) mlir/lib/IR/SymbolTable.cpp (+10-2)
  • (modified) mlir/test/python/ir/symbol_table.py (+13-2)
diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp
index 2494cb7086f0d7d..b69f230f5108f62 100644
--- a/mlir/lib/IR/SymbolTable.cpp
+++ b/mlir/lib/IR/SymbolTable.cpp
@@ -655,12 +655,20 @@ static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
     scopes.back().limit = limit;
   return scopes;
 }
-template <typename IRUnit>
 static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol,
-                                                       IRUnit *limit) {
+                                                       Region *limit) {
   return {{SymbolRefAttr::get(symbol), limit}};
 }
 
+static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol,
+                                                       Operation *limit) {
+  SmallVector<SymbolScope, 1> scopes;
+  auto symbolRef = SymbolRefAttr::get(symbol);
+  for (auto &region : limit->getRegions())
+    scopes.push_back({symbolRef, &region});
+  return scopes;
+}
+
 /// Returns true if the given reference 'SubRef' is a sub reference of the
 /// reference 'ref', i.e. 'ref' is a further qualified reference.
 static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) {
diff --git a/mlir/test/python/ir/symbol_table.py b/mlir/test/python/ir/symbol_table.py
index 3264cfcf9a10495..577721ab2111f55 100644
--- a/mlir/test/python/ir/symbol_table.py
+++ b/mlir/test/python/ir/symbol_table.py
@@ -106,9 +106,9 @@ def testSymbolTableRAUW():
       """
         )
         foo, bar = list(m.operation.regions[0].blocks[0].operations)[0:2]
+
+        # Do renaming just within `foo`.
         SymbolTable.set_symbol_name(bar, "bam")
-        # Note that module.operation counts as a "nested symbol table" which won't
-        # be traversed into, so it is necessary to traverse its children.
         SymbolTable.replace_all_symbol_uses("bar", "bam", foo)
         # CHECK: call @bam()
         # CHECK: func private @bam
@@ -118,6 +118,17 @@ def testSymbolTableRAUW():
         print(f"Foo symbol: {repr(SymbolTable.get_symbol_name(foo))}")
         print(f"Bar symbol: {repr(SymbolTable.get_symbol_name(bar))}")
 
+        # Do renaming within the module.
+        SymbolTable.set_symbol_name(bar, "baz")
+        SymbolTable.replace_all_symbol_uses("bam", "baz", m.operation)
+        # CHECK: call @baz()
+        # CHECK: func private @baz
+        print(m)
+        # CHECK: Foo symbol: StringAttr("foo")
+        # CHECK: Bar symbol: StringAttr("baz")
+        print(f"Foo symbol: {repr(SymbolTable.get_symbol_name(foo))}")
+        print(f"Bar symbol: {repr(SymbolTable.get_symbol_name(bar))}")
+
 
 # CHECK-LABEL: testSymbolTableVisibility
 @run

Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's fun how the only place directly testing this is Python bindings.

@ingomueller-net
Copy link
Contributor Author

It's fun how the only place directly testing this is Python bindings.

Indeed! And only one of the overloads...

@ftynse
Copy link
Member

ftynse commented Oct 6, 2023

It's fun how the only place directly testing this is Python bindings.

Indeed! And only one of the overloads...

If you feel motivated, you can add a unittest for the API in mlir/unittests.

This function has several overloads that allow to specify the symbol
that should be renamed and the scope for that renaming in different
ways. The overloads were inconsistent in the following way (quoted
strings are `StringAttr`s, other variables are `Operation *`):

* `replaceAllSymbolUses(symbolOp, "new_symbol", scopeOp)` would traverse
  into the symbol table of `scopeOp`.
* `replaceAllSymbolUses("symbol", "new_symbol", scopeOp)` would *not*
  traverse into the symbol table of `scopeOp`.

The underlying behavior was spread over different places and is somewhat
hard to understand. The two overloads above mainly differed by what
`collectSymbolScopes` computed, which is itself overloaded. If `scopeOp`
is a top-level module, then the overload on
`(Operation *, Operation *)`, which is used in the first of the above
cases, computes a scope where the body region of the module is the
`limit`; however, the overload on `(StringAttr, Operation *)` computed
the module op itself as the `limit`. Later, `walkSymbolTable` would walk
the body of the module if it was given as a region but it would *not*
enter the regions of the module op because that op has a symbol table
(which was assumed to be a *different* scope).

The fix in this commit is change the behavior of `collectSymbolScopes`
such that the `(StringAttr, Operation *)` overload returns a scope for
each region in the `limit` argument.
@ingomueller-net ingomueller-net force-pushed the fix-replace-all-symbol-uses branch from 6370e1a to 24a6403 Compare October 6, 2023 09:04
@ingomueller-net ingomueller-net requested a review from ftynse October 6, 2023 09:07
@ingomueller-net
Copy link
Contributor Author

If you feel motivated, you can add a unittest for the API in mlir/unittests.

Done. Please check.

mlir/unittests/IR/SymbolTableTest.cpp Outdated Show resolved Hide resolved
mlir/unittests/IR/SymbolTableTest.cpp Outdated Show resolved Hide resolved
mlir/unittests/IR/SymbolTableTest.cpp Outdated Show resolved Hide resolved
mlir/unittests/IR/SymbolTableTest.cpp Outdated Show resolved Hide resolved
mlir/unittests/IR/SymbolTableTest.cpp Outdated Show resolved Hide resolved
@ingomueller-net ingomueller-net requested a review from ftynse October 9, 2023 08:16
@ingomueller-net ingomueller-net merged commit 4790578 into llvm:main Oct 10, 2023
@ingomueller-net ingomueller-net deleted the fix-replace-all-symbol-uses branch October 10, 2023 05:47
ingomueller-net added a commit to ingomueller-net/llvm-project that referenced this pull request Oct 11, 2023
This is a follow-up commit for 4790578@llvm/llvm-project (llvm#68320)
that adds more tests. In particular, the tests now check that the
`limit` op itself is not traversed, i.e., symbols in attributes in of
the `limit` op are not renamed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants