Skip to content

Commit

Permalink
[TVMScript] Use tir::Evaluate if expression is in statement context (#…
Browse files Browse the repository at this point in the history
…13396)

* [TVMScript] Use tir::Evaluate if expression is in statement context

For the previous version of the parser, this was special-cased for
some intrinsic operators.  After the new TVMScript was enabled in
#12496, any `PrimExpr` that appears
in the body of a statement is silently ignored.  This commit updates
the parser to instead wrap the bare `PrimExpr` in a `tir::Evaluate`
node.

This change effectively allows [expression
statements](https://docs.python.org/3/reference/simple_stmts.html#expression-statements)
in TVMScript, which are converted to `tir::Evaluate` nodes during
parsing.

* Update to print T.evaluate() for readability, except for CallNode
  • Loading branch information
Lunderberg authored Nov 16, 2022
1 parent 2bb3382 commit 86a5cee
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 9 deletions.
5 changes: 5 additions & 0 deletions python/tvm/script/parser/tir/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from functools import partial
from typing import Any

import tvm
from tvm.ir import PrimType
from tvm.tir import Buffer, IterVar, PrimExpr, Var

Expand Down Expand Up @@ -411,6 +412,10 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
if isinstance(res, Frame):
res.add_callback(partial(res.__exit__, None, None, None))
res.__enter__()
elif isinstance(res, PrimExpr):
T.evaluate(res)
elif isinstance(res, (int, bool)):
T.evaluate(tvm.tir.const(res))


@dispatch.register(token="tir", type_name="If")
Expand Down
19 changes: 10 additions & 9 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1275,16 +1275,17 @@ Doc TVMScriptPrinter::VisitStmt_(const SeqStmtNode* op) {
}

Doc TVMScriptPrinter::VisitStmt_(const EvaluateNode* op) {
if (auto* call = op->value.as<CallNode>()) {
if (call->op.same_as(builtin::assume())) {
Doc doc;
doc << tir_prefix_ << ".assume(" << Print(call->args[0]) << ")";
return doc;
}
}

// When parsing TVMScript, a PrimExpr that occurs as a statement is
// automatically wrapped in `tir::Evaluate`. Therefore, when
// printing, it's only necessary to print the value. For
// readability, though, we still print T.evaluate() when the
// expression is something other than a call node.
Doc doc;
doc << tir_prefix_ << ".evaluate(" << Print(op->value) << ")";
if (op->value.as<CallNode>()) {
doc << Print(op->value);
} else {
doc << tir_prefix_ << ".evaluate(" << Print(op->value) << ")";
}
return doc;
}

Expand Down
10 changes: 10 additions & 0 deletions tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3458,6 +3458,15 @@ def func() -> None:
return func


def implicit_evaluate():
@T.prim_func
def func(A: T.Buffer[1, "int32"]):
T.evaluate(T.assume(A[0] == 5))
A[0] = 10

return func


ir_generator = tvm.testing.parameter(
opt_gemm_normalize,
opt_gemm_lower,
Expand Down Expand Up @@ -3509,6 +3518,7 @@ def func() -> None:
bool_primitive,
bool_cast,
return_none,
implicit_evaluate,
)


Expand Down
26 changes: 26 additions & 0 deletions tests/python/unittest/test_tvmscript_syntax_sugar.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,5 +402,31 @@ def int64_grid_expanded(
assert_structural_equal(int64_grid, int64_grid_expanded)


def test_implicit_evaluate_assume():
@T.prim_func
def explicit(A: T.Buffer[1, "int32"]):
T.evaluate(T.assume(A[0] == 5))
A[0] = 10

@T.prim_func
def implicit(A: T.Buffer[1, "int32"]):
T.assume(A[0] == 5)
A[0] = 10

assert_structural_equal(implicit, explicit)


def test_implicit_evaluate_call_extern():
@T.prim_func
def explicit(A: T.Buffer[1, "int32"]):
T.evaluate(T.call_extern("extern_func", A.data, dtype="int32"))

@T.prim_func
def implicit(A: T.Buffer[1, "int32"]):
T.call_extern("extern_func", A.data, dtype="int32")

assert_structural_equal(implicit, explicit)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 86a5cee

Please sign in to comment.