Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity][Analysis] Include impure call in VerifyWellFormed errors #16585

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions include/tvm/relax/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,21 @@ TVM_DLL relay::OpPatternKind AnalyzeOpPatternKind(const tir::PrimFunc& func);
*/
TVM_DLL bool HasReshapePattern(const tir::PrimFunc& func);

/*!
* \brief Check if the given expression (likely a function body) contains any impure calls.
* \param expr The expression to be examined. If expr is a function, we check the body.
* \param own_name (Optional.) If we are checking a recursive function body,
* the caller can pass the function's name so recursive calls
* can be ignored in the check (must be a Var or GlobalVar).
* \return The impure expression, if one exists within the given
* expression. Otherwise, NullOpt.
* \note Relies on StructInfo annotations, so ensure that the module has been normalized first.
* Also, an impure call in a *nested* function does *not* mean that the outer expression contains
* an impure call--it only does if the nested function is *later called*.
*/
TVM_DLL Optional<Expr> FindImpureCall(const Expr& expr,
const Optional<Expr>& own_name = Optional<Expr>(nullptr));

/*!
* \brief Check if the given expression (likely a function body) contains any impure calls.
* \param expr The expression to be examined. If expr is a function, we check the body.
Expand Down
44 changes: 29 additions & 15 deletions src/relax/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,23 @@ tvm::Array<Var> AllVars(const Expr& expr) { return VarVisitor().All(expr); }

tvm::Array<GlobalVar> AllGlobalVars(const Expr& expr) { return VarVisitor().AllGlobalVars(expr); }

bool ContainsImpureCall(const Expr& expr, const Optional<Expr>& own_name) {
Optional<Expr> FindImpureCall(const Expr& expr, const Optional<Expr>& own_name) {
class ImpureCallChecker : public ExprVisitor {
public:
static Optional<Expr> Check(const Expr& expr, const Optional<Expr>& own_name) {
ImpureCallChecker visitor(own_name);
visitor.VisitExpr(expr);
return visitor.impure_expr_;
}

private:
explicit ImpureCallChecker(const Optional<Expr>& own_name) : own_name_(own_name) {}

bool Check(const Expr& expr) {
contains_impure_ = false;
VisitExpr(expr);
return contains_impure_;
void VisitExpr(const Expr& expr) override {
// Early bail-out if we found an impure expression
if (!impure_expr_) {
ExprVisitor::VisitExpr(expr);
}
}

void VisitExpr_(const FunctionNode* func) override {
Expand All @@ -159,28 +167,34 @@ bool ContainsImpureCall(const Expr& expr, const Optional<Expr>& own_name) {

void VisitExpr_(const CallNode* call) override {
// ignore recursive calls if we find one
if (!(own_name_ && own_name_.value().same_as(call->op))) {
if (IsImpureCall(GetRef<Call>(call))) {
contains_impure_ = true;
}
bool is_recursive = (own_name_ && own_name_.value().same_as(call->op));
auto expr = GetRef<Call>(call);
if (!is_recursive && IsImpureCall(expr)) {
impure_expr_ = expr;
} else {
ExprVisitor::VisitExpr_(call);
}
ExprVisitor::VisitExpr_(call);
}

private:
const Optional<Expr>& own_name_;
bool contains_impure_ = false;
Optional<Expr> impure_expr_ = NullOpt;
};

if (own_name) {
ICHECK(own_name.value().as<VarNode>() || own_name.value().as<GlobalVarNode>())
<< "Must pass a Var or GlobalVar for own_name";
}
ImpureCallChecker checker(own_name);
if (auto func = expr.as<FunctionNode>()) {
return checker.Check(func->body);

Expr to_check = expr;
if (auto func = to_check.as<FunctionNode>()) {
to_check = func->body;
}
return checker.Check(expr);
return ImpureCallChecker::Check(to_check, own_name);
}

bool ContainsImpureCall(const Expr& expr, const Optional<Expr>& own_name) {
return FindImpureCall(expr, own_name).defined();
}

TVM_REGISTER_GLOBAL("relax.analysis.free_vars").set_body_typed(FreeVars);
Expand Down
23 changes: 13 additions & 10 deletions src/relax/analysis/well_formed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,13 +261,14 @@ class WellFormedChecker : public relax::ExprVisitor,
// if we are not forcing purity and the function is annotated as pure, it must not contain an
// impure call
if (check_struct_info_ &&
!op->GetAttr<Bool>(relax::attr::kForcePure).value_or(Bool(false))->value && op->is_pure &&
ContainsImpureCall(op->body)) {
Malformed(Diagnostic::Error(op)
<< "Function " << op << " is annotated as pure but contains an impure call; "
<< "please set " << relax::attr::kForcePure << " to true "
<< "or use a pure operator variant (e.g., call_pure_packed) "
<< "if it is necessary to override this judgment.");
!op->GetAttr<Bool>(relax::attr::kForcePure).value_or(Bool(false))->value && op->is_pure) {
if (auto impure = FindImpureCall(op->body)) {
Malformed(Diagnostic::Error(op)
<< "Function " << op << " is annotated as pure but contains an impure call: "
<< impure << ". Please set " << relax::attr::kForcePure << " to true "
<< "or use a pure operator variant (e.g., call_pure_packed) "
<< "if it is necessary to override this judgment.");
}
}

if (auto seq = op->body.as<SeqExprNode>()) {
Expand Down Expand Up @@ -310,9 +311,11 @@ class WellFormedChecker : public relax::ExprVisitor,
}

CheckStructInfo(call);
if (is_dataflow_ && check_struct_info_ && IsImpureCall(GetRef<Call>(call))) {
Malformed(Diagnostic::Error(call)
<< "There cannot be an impure call inside a dataflow block.");
if (is_dataflow_ && check_struct_info_) {
if (auto impure = FindImpureCall(GetRef<Call>(call))) {
Malformed(Diagnostic::Error(call)
<< "Impure function call " << impure << " occurs within a dataflow block.");
}
}

// If the operation has defined a custom normalization function
Expand Down
5 changes: 4 additions & 1 deletion tests/python/relax/test_analysis_well_formed.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ def test_force_pure_improper():
assert not rx.analysis.well_formed(mod)


def test_impure_in_dataflow_block():
def test_impure_in_dataflow_block(capfd):
# even if force_pure is set, an impure operation cannot appear in a dataflow block
x = rx.Var("x", R.Tensor((), dtype="int32"))
y = rx.DataflowVar("y")
Expand All @@ -618,6 +618,9 @@ def test_impure_in_dataflow_block():
mod = rx.transform.Normalize()(tvm.IRModule.from_expr(func))
assert not rx.analysis.well_formed(mod)

_stdout, stderr = capfd.readouterr()
assert "R.print" in stderr


if __name__ == "__main__":
tvm.testing.main()
Loading