Skip to content

Commit

Permalink
[IR] Expose ReplaceGlobalVars utility in the Python API (#17361)
Browse files Browse the repository at this point in the history
* [IR] Expose ReplaceGlobalVars utility in the Python API

This is a follow-up PR to #17202,
which added a general utility to replace `GlobalVar` instances across
all TVM IR types.  This PR exposes this new utility through the Python
API, and explicitly tests its functionality.

* Lint fix
  • Loading branch information
Lunderberg authored Sep 12, 2024
1 parent bd11e19 commit b8b5fb6
Show file tree
Hide file tree
Showing 7 changed files with 418 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
*/

/*!
* \file tvm/ir/replace_global_var.h
* \file tvm/ir/replace_global_vars.h
*
* \brief A utility to replace GlobalVar instances across all TVM IR
* types in an IRMdoule.
*/
#ifndef TVM_IR_REPLACE_GLOBAL_VAR_H_
#define TVM_IR_REPLACE_GLOBAL_VAR_H_
#ifndef TVM_IR_REPLACE_GLOBAL_VARS_H_
#define TVM_IR_REPLACE_GLOBAL_VARS_H_

#include <tvm/ir/module.h>

Expand All @@ -41,7 +41,7 @@ namespace transform {
*
* \return The updated IRModule
*/
TVM_DLL IRModule ReplaceGlobalVar(IRModule mod, Map<GlobalVar, GlobalVar> replacements);
TVM_DLL IRModule ReplaceGlobalVars(IRModule mod, Map<GlobalVar, GlobalVar> replacements);

struct GlobalVarReplacer {
using FType = NodeFunctor<BaseFunc(const ObjectRef&, Map<GlobalVar, GlobalVar>)>;
Expand All @@ -54,4 +54,4 @@ struct GlobalVarReplacer {
} // namespace transform
} // namespace tvm

#endif // TVM_IR_REPLACE_GLOBAL_VAR_H_
#endif // TVM_IR_REPLACE_GLOBAL_VARS_H_
28 changes: 28 additions & 0 deletions python/tvm/ir/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""IRModule that holds the functions and type definitions."""

from __future__ import annotations

from typing import Dict, Union
Expand Down Expand Up @@ -216,6 +217,33 @@ def get_global_vars(self):
"""
return _ffi_api.Module_GetGlobalVars(self)

def replace_global_vars(
self,
replacements: Dict[Union[str, _expr.GlobalVar], Union[str, _expr.GlobalVar]],
) -> "IRModule":
"""Replace GlobalVar instances within the module
Replace GlobalVars within the IRModule. Since the IRModule
may contain internal references to a GlobalVar, either in TIR
or in Relax, this method should be used whenever replacing or
renaming a GlobalVar.
Parameters
----------
replacements: Dict[Union[str, _expr.GlobalVar], Union[str, _expr.GlobalVar]]
A dictionary where each key is a GlobalVar to be replaced,
and the corresponding value is the GlobalVar with which to
replace it.
Returns
-------
IRModule
The updated module
"""
return _ffi_api.Module_ReplaceGlobalVars(self, replacements)

def get_global_type_vars(self):
"""Collect all global type vars defined in this module.
Expand Down
43 changes: 39 additions & 4 deletions src/ir/replace_global_var.cc → src/ir/replace_global_vars.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,22 @@
*/

/*!
* \file src/ir/replace_global_var.cc
* \file src/ir/replace_global_vars.cc
* \brief IRModule transform to replace GlobalVar instances across any IR type.
*/

#include <tvm/ir/replace_global_var.h>
#include <tvm/ir/replace_global_vars.h>

#include <vector>

namespace tvm {
namespace transform {

IRModule ReplaceGlobalVar(IRModule mod, Map<GlobalVar, GlobalVar> replacements) {
IRModule ReplaceGlobalVars(IRModule mod, Map<GlobalVar, GlobalVar> replacements) {
if (replacements.empty()) {
return mod;
}

std::vector<GlobalVar> to_remove;
IRModule updates;

Expand Down Expand Up @@ -57,7 +61,38 @@ IRModule ReplaceGlobalVar(IRModule mod, Map<GlobalVar, GlobalVar> replacements)
return mod;
}

TVM_REGISTER_GLOBAL("transform.ReplaceGlobalVar").set_body_typed(ReplaceGlobalVar);
TVM_REGISTER_GLOBAL("transform.ReplaceGlobalVars").set_body_typed(ReplaceGlobalVars);

IRModule ModuleReplaceGlobalVars(
IRModule mod, Map<Variant<String, GlobalVar>, Variant<String, GlobalVar>> replacements) {
Map<GlobalVar, GlobalVar> gvar_replacements;
for (const auto& [before, after] : replacements) {
GlobalVar gvar_before;
if (auto gvar = before.as<GlobalVar>()) {
gvar_before = gvar.value();
} else if (auto str = before.as<String>()) {
gvar_before = mod->GetGlobalVar(str.value());
} else {
LOG(FATAL) << "Variant<String,GlobalVar> must contain either String or GlobalVar";
}

GlobalVar gvar_after;
if (auto gvar = after.as<GlobalVar>()) {
gvar_after = gvar.value();
} else if (auto str = after.as<String>()) {
gvar_after = gvar_before;
gvar_after.CopyOnWrite()->name_hint = str.value();
} else {
LOG(FATAL) << "Variant<String,GlobalVar> must contain either String or GlobalVar";
}

gvar_replacements.Set(gvar_before, gvar_after);
}

return ReplaceGlobalVars(mod, gvar_replacements);
}

TVM_REGISTER_GLOBAL("ir.Module_ReplaceGlobalVars").set_body_typed(ModuleReplaceGlobalVars);

} // namespace transform
} // namespace tvm
4 changes: 2 additions & 2 deletions src/relax/transform/attach_global_symbol.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
*/

#include <tvm/ir/module.h>
#include <tvm/ir/replace_global_var.h>
#include <tvm/ir/replace_global_vars.h>
#include <tvm/relax/struct_info.h>
#include <tvm/relax/transform.h>
#include <tvm/tir/function.h>
Expand Down Expand Up @@ -72,7 +72,7 @@ Pass AttachGlobalSymbol() {
mod.CopyOnWrite()->Update(updates);

if (gvar_updates.size()) {
mod = tvm::transform::ReplaceGlobalVar(mod, gvar_updates);
mod = tvm::transform::ReplaceGlobalVars(mod, gvar_updates);
}
}
return mod;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@

/*!
*
* \file src/relax/transform/replace_global_var.cc
* \file src/relax/transform/replace_global_vars.cc
*
* \brief GlobalVar replacement across IR types
*/

#include <tvm/ir/analysis.h>
#include <tvm/ir/replace_global_var.h>
#include <tvm/ir/replace_global_vars.h>
#include <tvm/relax/analysis.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/tir/expr_functor.h>
Expand Down Expand Up @@ -53,7 +53,24 @@ TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable)
.set_dispatch<relax::FunctionNode>([](const ObjectRef& func,
Map<GlobalVar, GlobalVar> replacements) -> BaseFunc {
Mutator mutator(replacements);
return Downcast<BaseFunc>(mutator(Downcast<Function>(func)));
auto new_func = Downcast<Function>(mutator(Downcast<Function>(func)));

// If the function is externally exposed, and is being replaced
// by a GlobalVar with a new name, then the function's
// kGlobalSymbol must be updated to match.
if (auto opt = new_func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
auto name = opt.value();
for (const auto& [before, after] : replacements) {
if (before->name_hint == name) {
if (after->name_hint != name) {
new_func = WithAttr(new_func, tvm::attr::kGlobalSymbol, after->name_hint);
}
break;
}
}
}

return new_func;
});

TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@

/*!
*
* \file src/tir/transforms/replace_global_var.cc
* \file src/tir/transforms/replace_global_vars.cc
*
* \brief GlobalVar replacement across IR types
*/

#include <tvm/ir/replace_global_var.h>
#include <tvm/ir/replace_global_vars.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt_functor.h>

Expand Down Expand Up @@ -61,6 +61,22 @@ TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable)
if (!new_body.same_as(func->body)) {
func.CopyOnWrite()->body = new_body;
}

// If the function is externally exposed, and is being replaced
// by a GlobalVar with a new name, then the function's
// kGlobalSymbol must be updated to match.
if (auto opt = func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
auto name = opt.value();
for (const auto& [before, after] : replacements) {
if (before->name_hint == name) {
if (after->name_hint != name) {
func = WithAttr(func, tvm::attr::kGlobalSymbol, after->name_hint);
}
break;
}
}
}

return func;
});

Expand Down
Loading

0 comments on commit b8b5fb6

Please sign in to comment.