diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index 63171f672289..ea26c4740a46 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -143,14 +143,14 @@ def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) - ): IRBuilder.name(var_name, value) return value - elif isinstance(value, PrimExpr): + else: + value = tvm.runtime.convert(value) frame = T.LetStmt(value) var = frame.var IRBuilder.name(var_name, var) frame.add_callback(partial(frame.__exit__, None, None, None)) frame.__enter__() return var - return value @dispatch.register(token="tir", type_name="For") diff --git a/tests/python/unittest/test_arith_domain_touched.py b/tests/python/unittest/test_arith_domain_touched.py index e19991b3b83a..1553aabd4e4c 100644 --- a/tests/python/unittest/test_arith_domain_touched.py +++ b/tests/python/unittest/test_arith_domain_touched.py @@ -22,7 +22,7 @@ @T.prim_func def scalar_func(a: T.handle, b: T.handle): m = T.int32() - n = 100 + n = T.meta_var(100) A = T.match_buffer(a, (n, m)) B = T.match_buffer(b, (n, m)) diff --git a/tests/python/unittest/test_tvmscript_syntax_sugar.py b/tests/python/unittest/test_tvmscript_syntax_sugar.py index ac1262b9b517..1ff5be80cabc 100644 --- a/tests/python/unittest/test_tvmscript_syntax_sugar.py +++ b/tests/python/unittest/test_tvmscript_syntax_sugar.py @@ -253,8 +253,8 @@ def func_without_type_annotation(A: T.Buffer((1,), "int32")): def test_letstmt_bind_with_constant(): @T.prim_func def constant_binds(): - x = 1 - y = 42.0 + x = T.meta_var(1) + y = T.meta_var(42.0) T.evaluate(T.cast(x, "float32") + y) @T.prim_func @@ -414,6 +414,21 @@ def implicit(i: T.int32): assert_structural_equal(implicit, explicit) +def test_preserve_trivial_let_binding_of_value(): + @T.prim_func + def explicit(i: T.int32): + j = T.int32() + T.LetStmt(42, var=j) + T.evaluate(j) + + @T.prim_func + def implicit(i: T.int32): + j = 42 + T.evaluate(j) + + assert_structural_equal(implicit, explicit) + + def test_preserve_parameter_name(): @T.prim_func def func(i: T.int32):