Skip to content

Commit

Permalink
[Analysis] Allow calls to GlobalVar in @R.function (#16778)
Browse files Browse the repository at this point in the history
* [Analysis] Allow calls to GlobalVar in @R.function

Prior to this commit, the post-parsing well-formed check performed by
TVMScript allowed a call to `GlobalVar` in a `@R.function`, but only
if it occurred within the context of a `@I.ir_module`.  If
`@R.function` appeared on its own, calls to a `GlobalVar` would be
treated as calls to an undefined function.

* Use approrpirate well-formed checks TIR/Relax functions

* Lint fix

* Import order fix
  • Loading branch information
Lunderberg authored Mar 26, 2024
1 parent ae7b8d9 commit 72f0326
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 36 deletions.
6 changes: 3 additions & 3 deletions include/tvm/relax/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -547,15 +547,15 @@ TVM_DLL bool ContainsImpureCall(const Expr& expr,
/*!
* \brief Check if the IRModule is well formed.
*
* \param m the IRModule to check.
* \param obj The IRModule or relax::Function to check.
* \param check_struct_info A boolean flag indicating if the property "every Expr
* must have defined structure info" will be checked.
* \return true if the IRModule is well formed, false if not.
* \return true if the object is well formed, false if not.
* \note By default the structure info is always checked. It is only in test cases
* where `check_struct_info` might be false, so that other well-formed requirements
* will be well tested and will not be blocked by not having structure info.
*/
TVM_DLL bool WellFormed(IRModule m, bool check_struct_info = true);
TVM_DLL bool WellFormed(Variant<IRModule, Function> obj, bool check_struct_info = true);

/*!
* \brief Using the layout transforms on the outputs, suggest layout transformation on the blocks
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/relax/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,13 +434,13 @@ def remove_all_unused(func: Function) -> Function:
return _ffi_api.remove_all_unused(func) # type: ignore


def well_formed(mod: IRModule, check_struct_info: bool = True) -> bool:
def well_formed(obj: Union[IRModule, Function], check_struct_info: bool = True) -> bool:
"""Check if the IRModule is well formed.
Parameters
----------
mod : tvm.IRModule
The input IRModule.
obj : Union[tvm.IRModule, Function]
The input IRModule or relax.Function.
check_struct_info : bool
A boolean flag indicating if the property "every Expr must
Expand All @@ -457,7 +457,7 @@ def well_formed(mod: IRModule, check_struct_info: bool = True) -> bool:
where `check_struct_info` might be false, so that other well-formed requirements
will be well tested and will not be blocked by not having structure info.
"""
return _ffi_api.well_formed(mod, check_struct_info) # type: ignore
return _ffi_api.well_formed(obj, check_struct_info) # type: ignore


def _get_prim_func_default_dtype(func: PrimFunc):
Expand Down
26 changes: 17 additions & 9 deletions python/tvm/script/parser/core/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import inspect
from typing import Any, Dict, Union

import tvm
from ....ir.module import IRModule
from ...ir_builder import IRBuilder
from . import doc
Expand All @@ -34,12 +35,19 @@


def _default_globals() -> Dict[str, Any]:
import tvm # pylint: disable=import-outside-toplevel
from tvm.script.parser import ir # pylint: disable=import-outside-toplevel
from tvm.script.parser import relax # pylint: disable=import-outside-toplevel
from tvm.script.parser import tir # pylint: disable=import-outside-toplevel

extra_vars = {"tvm": tvm, "I": ir, "ir": ir, "T": tir, "tir": tir, "R": relax, "relax": relax}
extra_vars = {
"tvm": tvm,
"I": ir,
"ir": ir,
"T": tir,
"tir": tir,
"R": relax,
"relax": relax,
}
return extra_vars


Expand Down Expand Up @@ -95,19 +103,19 @@ def parse(
ret = builder.get()
# check well-formedness in both Relax and TIR
if check_well_formed:
# (C0415 = import-outside-toplevel. It is necessary here to avoid a circular dependency,
# since importing Relax imports a dependency on the parser)
from ....relax.analysis import well_formed as relax_well_formed # pylint: disable=C0415
from ....tir.analysis import verify_well_formed as tir_well_formed # pylint: disable=C0415

check_ret = ret
if not isinstance(check_ret, IRModule):
check_ret = IRModule.from_expr(ret)

source_ast = source.as_ast()
if not relax_well_formed(check_ret):

if isinstance(ret, (IRModule, tvm.relax.Function)) and not tvm.relax.analysis.well_formed(
ret
):
parser.report_error(source_ast, err=WELL_FORMED_ERROR_MESSAGE)

try:
tir_well_formed(check_ret)
tvm.tir.analysis.verify_well_formed(check_ret)
except Exception as err: # pylint: disable=broad-exception-caught
parser.report_error(
source_ast,
Expand Down
47 changes: 27 additions & 20 deletions src/relax/analysis/well_formed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,22 +85,30 @@ class WellFormedChecker : public relax::ExprVisitor,
public relax::StructInfoVisitor,
public tir::ExprVisitor {
public:
static bool Check(IRModule mod, bool check_struct_info) {
WellFormedChecker well_formed_checker = WellFormedChecker(mod, check_struct_info);

for (const auto& it : mod->functions) {
// visit relax.Function
if (auto* n = it.second.as<FunctionNode>()) {
Function func = GetRef<Function>(n);
well_formed_checker.CheckGlobalVarAndGsymbolConsistency(it.first, func);
well_formed_checker.VisitExpr(func);
static bool Check(Variant<IRModule, Function> obj, bool check_struct_info) {
WellFormedChecker well_formed_checker =
WellFormedChecker(obj.as<IRModule>(), check_struct_info);

if (const auto* mod = obj.as<IRModuleNode>()) {
for (const auto& it : mod->functions) {
// visit relax.Function
if (auto* n = it.second.as<FunctionNode>()) {
Function func = GetRef<Function>(n);
well_formed_checker.CheckGlobalVarAndGsymbolConsistency(it.first, func);
well_formed_checker.VisitExpr(func);
}
}
} else if (const auto* func = obj.as<FunctionNode>()) {
well_formed_checker.VisitExpr(GetRef<Expr>(func));
} else {
LOG(FATAL) << "Unreachable, "
<< "variant did not contain any of the allowed types";
}
return well_formed_checker.well_formed_;
}

private:
explicit WellFormedChecker(IRModule mod, bool check_struct_info)
WellFormedChecker(Optional<IRModule> mod, bool check_struct_info)
: mod_(std::move(mod)), check_struct_info_(check_struct_info), cur_visited_func_(nullptr) {}

using relax::ExprVisitor::VisitExpr_;
Expand Down Expand Up @@ -147,9 +155,11 @@ class WellFormedChecker : public relax::ExprVisitor,

void VisitExpr_(const GlobalVarNode* op) final {
GlobalVar var = GetRef<GlobalVar>(op);
if (!(mod_->ContainGlobalVar(var->name_hint) &&
mod_->GetGlobalVar(var->name_hint).same_as(var))) {
Malformed(Diagnostic::Error(var) << "GlobalVar " << GetRef<Expr>(op) << " is not defined.");
if (mod_.defined()) {
if (!(mod_.value()->ContainGlobalVar(var->name_hint) &&
mod_.value()->GetGlobalVar(var->name_hint).same_as(var))) {
Malformed(Diagnostic::Error(var) << "GlobalVar " << GetRef<Expr>(op) << " is not defined.");
}
}

if (op->checked_type_.defined()) {
Expand Down Expand Up @@ -556,7 +566,7 @@ class WellFormedChecker : public relax::ExprVisitor,
std::swap(mode_, mode);
}

IRModule mod_;
Optional<IRModule> mod_;
const bool check_struct_info_;
bool well_formed_ = true;
bool is_dataflow_;
Expand All @@ -576,14 +586,11 @@ class WellFormedChecker : public relax::ExprVisitor,
tvm::OpAttrMap<FNormalize> op_map_normalize_ = Op::GetAttrMap<FNormalize>("FNormalize");
};

bool WellFormed(IRModule m, bool check_struct_info) {
return WellFormedChecker::Check(std::move(m), check_struct_info);
bool WellFormed(Variant<IRModule, Function> obj, bool check_struct_info) {
return WellFormedChecker::Check(obj, check_struct_info);
}

TVM_REGISTER_GLOBAL(("relax.analysis.well_formed"))
.set_body_typed([](IRModule m, bool check_struct_info) {
return WellFormed(m, check_struct_info);
});
TVM_REGISTER_GLOBAL(("relax.analysis.well_formed")).set_body_typed(WellFormed);

} // namespace relax
} // namespace tvm
34 changes: 34 additions & 0 deletions tests/python/relax/test_analysis_well_formed.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tvm import relax as rx
from tvm import tir
from tvm.script import relax as R
from tvm.script import ir as I
from tvm.script import tir as T

m = tir.Var("m", "int64")
Expand Down Expand Up @@ -622,5 +623,38 @@ def test_impure_in_dataflow_block(capfd):
assert "R.print" in stderr


def test_well_formed_function():
"""Relax's well-formed check can be applied on a function"""

@R.function
def func(A: R.Tensor([16, 32], "float32"), B: R.Tensor([32, 64], "float32")):
return R.matmul(A, B)

assert rx.analysis.well_formed(func)


def test_well_formed_function_referencing_global_var():
"""GlobalVar may refer to other functions in the module
If validating that a IRModule is well-formed, the GlobalVar must
have a definition. If validating that a relax.Function is
well-formed, no GlobalVar definitions are available.
"""

@I.ir_module
class Module:
@R.function
def main(A: R.Tensor([16, 32], "float32"), B: R.Tensor([32, 64], "float32")):
return Module.subroutine(A, B)

@R.function(private=True)
def subroutine(A: R.Tensor([16, 32], "float32"), B: R.Tensor([32, 64], "float32")):
return R.matmul(A, B)

assert rx.analysis.well_formed(Module)
assert rx.analysis.well_formed(Module["main"])
assert rx.analysis.well_formed(Module["subroutine"])


if __name__ == "__main__":
tvm.testing.main()
37 changes: 37 additions & 0 deletions tests/python/relax/test_tvmscript_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2091,5 +2091,42 @@ def func(a: R.Tensor((10, 10))) -> R.Tensor((10, 10)):
_check(parsed_module, expected)


def test_define_relax_function_using_global_var():
"""A @R.function may call a GlobalVar
When parsing a @R.function, the function's body may reference
GlobalVar instances available in the calling python scope. The
resulting function should pass TVMScript's well-formed check, as
the GlobalVar may be available in the IRModule for which the
function is being defined.
"""

@I.ir_module
class DefinedAllAtOnce:
@R.function
def main(A: R.Tensor, B: R.Tensor):
return DefinedAllAtOnce.subroutine(A, B)

@R.function(private=True)
def subroutine(A: R.Tensor, B: R.Tensor) -> R.Tensor:
return R.matmul(A, B)

@I.ir_module
class MainDefinedLater:
@R.function(private=True)
def subroutine(A: R.Tensor, B: R.Tensor) -> R.Tensor:
return R.matmul(A, B)

subroutine_gvar = MainDefinedLater.get_global_var("subroutine")

@R.function
def main(A: R.Tensor, B: R.Tensor):
return subroutine_gvar(A, B)

MainDefinedLater["main"] = main

tvm.ir.assert_structural_equal(DefinedAllAtOnce, MainDefinedLater)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 72f0326

Please sign in to comment.