From 86a5ceec271f241451b641d10b4c27e0cdeb1e89 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 15 Nov 2022 19:07:20 -0600 Subject: [PATCH] [TVMScript] Use tir::Evaluate if expression is in statement context (#13396) * [TVMScript] Use tir::Evaluate if expression is in statement context For the previous version of the parser, this was special-cased for some intrinsic operators. After the new TVMScript was enabled in https://github.com/apache/tvm/pull/12496, any `PrimExpr` that appears in the body of a statement is silently ignored. This commit updates the parser to instead wrap the bare `PrimExpr` in a `tir::Evaluate` node. This change effectively allows [expression statements](https://docs.python.org/3/reference/simple_stmts.html#expression-statements) in TVMScript, which are converted to `tir::Evaluate` nodes during parsing. * Update to print T.evaluate() for readability, except for CallNode --- python/tvm/script/parser/tir/parser.py | 5 ++++ src/printer/tvmscript_printer.cc | 19 +++++++------- .../unittest/test_tvmscript_roundtrip.py | 10 +++++++ .../unittest/test_tvmscript_syntax_sugar.py | 26 +++++++++++++++++++ 4 files changed, 51 insertions(+), 9 deletions(-) diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index 1370758f5a5b..0e74114ba29c 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -20,6 +20,7 @@ from functools import partial from typing import Any +import tvm from tvm.ir import PrimType from tvm.tir import Buffer, IterVar, PrimExpr, Var @@ -411,6 +412,10 @@ def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: if isinstance(res, Frame): res.add_callback(partial(res.__exit__, None, None, None)) res.__enter__() + elif isinstance(res, PrimExpr): + T.evaluate(res) + elif isinstance(res, (int, bool)): + T.evaluate(tvm.tir.const(res)) @dispatch.register(token="tir", type_name="If") diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index d7a3a406e352..f1d68ee43845 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -1275,16 +1275,17 @@ Doc TVMScriptPrinter::VisitStmt_(const SeqStmtNode* op) { } Doc TVMScriptPrinter::VisitStmt_(const EvaluateNode* op) { - if (auto* call = op->value.as()) { - if (call->op.same_as(builtin::assume())) { - Doc doc; - doc << tir_prefix_ << ".assume(" << Print(call->args[0]) << ")"; - return doc; - } - } - + // When parsing TVMScript, a PrimExpr that occurs as a statement is + // automatically wrapped in `tir::Evaluate`. Therefore, when + // printing, it's only necessary to print the value. For + // readability, though, we still print T.evaluate() when the + // expression is something other than a call node. Doc doc; - doc << tir_prefix_ << ".evaluate(" << Print(op->value) << ")"; + if (op->value.as()) { + doc << Print(op->value); + } else { + doc << tir_prefix_ << ".evaluate(" << Print(op->value) << ")"; + } return doc; } diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index f22e61e1838d..b8c8379c8a16 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3458,6 +3458,15 @@ def func() -> None: return func +def implicit_evaluate(): + @T.prim_func + def func(A: T.Buffer[1, "int32"]): + T.evaluate(T.assume(A[0] == 5)) + A[0] = 10 + + return func + + ir_generator = tvm.testing.parameter( opt_gemm_normalize, opt_gemm_lower, @@ -3509,6 +3518,7 @@ def func() -> None: bool_primitive, bool_cast, return_none, + implicit_evaluate, ) diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py index 16f1cb04945a..a39354b9552a 100644 --- a/tests/python/unittest/test_tvmscript_syntax_sugar.py +++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py @@ -402,5 +402,31 @@ def int64_grid_expanded( assert_structural_equal(int64_grid, int64_grid_expanded) +def test_implicit_evaluate_assume(): + @T.prim_func + def explicit(A: T.Buffer[1, "int32"]): + T.evaluate(T.assume(A[0] == 5)) + A[0] = 10 + + @T.prim_func + def implicit(A: T.Buffer[1, "int32"]): + T.assume(A[0] == 5) + A[0] = 10 + + assert_structural_equal(implicit, explicit) + + +def test_implicit_evaluate_call_extern(): + @T.prim_func + def explicit(A: T.Buffer[1, "int32"]): + T.evaluate(T.call_extern("extern_func", A.data, dtype="int32")) + + @T.prim_func + def implicit(A: T.Buffer[1, "int32"]): + T.call_extern("extern_func", A.data, dtype="int32") + + assert_structural_equal(implicit, explicit) + + if __name__ == "__main__": tvm.testing.main()