Skip to content

Commit

Permalink
[TVMScript] Preserve LetStmt of constants (#14531)
Browse files Browse the repository at this point in the history
* [TVMScript] Preserve LetStmt of constants

Prior to this commit, the `bind_assign_value` implementation for TIR
would treat assignment of constants (e.g. `j = 42`) as meta-variables
to be inserted at their point of use.  This commit updates the parsing
to treat `j = 42` as a shorthand for the TIR `LetStmt`, similar to the
update made in #14320.  This
behavior is more consistent with the other uses of `bind_assign_value`
as assignment, with `j = T.meta_var(42)` used to represent
meta-variables.

* Updated tests that relied on implicit meta_var
  • Loading branch information
Lunderberg authored May 5, 2023
1 parent 1294926 commit ddd2e81
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 5 deletions.
4 changes: 2 additions & 2 deletions python/tvm/script/parser/tir/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,14 +143,14 @@ def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) -
):
IRBuilder.name(var_name, value)
return value
elif isinstance(value, PrimExpr):
else:
value = tvm.runtime.convert(value)
frame = T.LetStmt(value)
var = frame.var
IRBuilder.name(var_name, var)
frame.add_callback(partial(frame.__exit__, None, None, None))
frame.__enter__()
return var
return value


@dispatch.register(token="tir", type_name="For")
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_arith_domain_touched.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
@T.prim_func
def scalar_func(a: T.handle, b: T.handle):
m = T.int32()
n = 100
n = T.meta_var(100)
A = T.match_buffer(a, (n, m))
B = T.match_buffer(b, (n, m))

Expand Down
19 changes: 17 additions & 2 deletions tests/python/unittest/test_tvmscript_syntax_sugar.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,8 @@ def func_without_type_annotation(A: T.Buffer((1,), "int32")):
def test_letstmt_bind_with_constant():
@T.prim_func
def constant_binds():
x = 1
y = 42.0
x = T.meta_var(1)
y = T.meta_var(42.0)
T.evaluate(T.cast(x, "float32") + y)

@T.prim_func
Expand Down Expand Up @@ -414,6 +414,21 @@ def implicit(i: T.int32):
assert_structural_equal(implicit, explicit)


def test_preserve_trivial_let_binding_of_value():
@T.prim_func
def explicit(i: T.int32):
j = T.int32()
T.LetStmt(42, var=j)
T.evaluate(j)

@T.prim_func
def implicit(i: T.int32):
j = 42
T.evaluate(j)

assert_structural_equal(implicit, explicit)


def test_preserve_parameter_name():
@T.prim_func
def func(i: T.int32):
Expand Down

0 comments on commit ddd2e81

Please sign in to comment.