From 36f4805a92f3110f05022c5d98ad346420d68d35 Mon Sep 17 00:00:00 2001 From: Eldritch Cheese Date: Mon, 30 Jan 2023 15:43:38 -0600 Subject: [PATCH 1/2] [TVMScript] Preserve LetStmt of constants Prior to this commit, the `bind_assign_value` implementation for TIR would treat assignment of constants (e.g. `j = 42`) as meta-variables to be inserted at their point of use. This commit updates the parsing to treat `j = 42` as a shorthand for the TIR `LetStmt`, similar to the update made in https://github.com/apache/tvm/pull/14320. This behavior is more consistent with the other uses of `bind_assign_value` as assignment, with `j = T.meta_var(42)` used to represent meta-variables. --- python/tvm/script/parser/tir/parser.py | 4 ++-- .../unittest/test_tvmscript_syntax_sugar.py | 15 +++++++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index 8a067267a352..cb26399d2efb 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -142,14 +142,14 @@ def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) - ): IRBuilder.name(var_name, value) return value - elif isinstance(value, PrimExpr): + else: + value = tvm.runtime.convert(value) frame = T.LetStmt(value) var = frame.var IRBuilder.name(var_name, var) frame.add_callback(partial(frame.__exit__, None, None, None)) frame.__enter__() return var - return value @dispatch.register(token="tir", type_name="For") diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py index ac1262b9b517..26c9b85a5172 100644 --- a/tests/python/unittest/test_tvmscript_syntax_sugar.py +++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py @@ -414,6 +414,21 @@ def implicit(i: T.int32): assert_structural_equal(implicit, explicit) +def test_preserve_trivial_let_binding_of_value(): + @T.prim_func + def explicit(i: T.int32): + j = T.int32() + T.LetStmt(42, var=j) + T.evaluate(j) + + @T.prim_func + def implicit(i: T.int32): + j = 42 + T.evaluate(j) + + assert_structural_equal(implicit, explicit) + + def test_preserve_parameter_name(): @T.prim_func def func(i: T.int32): From 62e25bd51e2b0656815fac5dcc3de200a243680b Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 7 Apr 2023 14:16:11 -0500 Subject: [PATCH 2/2] Updated tests that relied on implicit meta_var --- tests/python/unittest/test_arith_domain_touched.py | 2 +- tests/python/unittest/test_tvmscript_syntax_sugar.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_arith_domain_touched.py b/tests/python/unittest/test_arith_domain_touched.py index e19991b3b83a..1553aabd4e4c 100644 --- a/tests/python/unittest/test_arith_domain_touched.py +++ b/tests/python/unittest/test_arith_domain_touched.py @@ -22,7 +22,7 @@ @T.prim_func def scalar_func(a: T.handle, b: T.handle): m = T.int32() - n = 100 + n = T.meta_var(100) A = T.match_buffer(a, (n, m)) B = T.match_buffer(b, (n, m)) diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py index 26c9b85a5172..1ff5be80cabc 100644 --- a/tests/python/unittest/test_tvmscript_syntax_sugar.py +++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py @@ -253,8 +253,8 @@ def func_without_type_annotation(A: T.Buffer((1,), "int32")): def test_letstmt_bind_with_constant(): @T.prim_func def constant_binds(): - x = 1 - y = 42.0 + x = T.meta_var(1) + y = T.meta_var(42.0) T.evaluate(T.cast(x, "float32") + y) @T.prim_func