Skip to content

Commit

Permalink
[TVMScript] Distinguish LetStmt and Let expression
Browse files Browse the repository at this point in the history
As the legacy behavior of previous generation of TVMScript printer,
LetStmt and Let expression in TIR are printed in the same syntax, i.e.
`T.let`. This could be confusing and misleading at times.

This PR introduces a new printer behavior in a backward compatible way.
While ensuring all legacy `T.let` can be parsed properly, we introduce
different syntax for each of them.

For LetStmt in TIR without concise scoping, the new syntax is:

```python
with T.LetStmt(value) as var:
  ...
```

which was:

```python
var = T.int32()
with T.let(var, value):
  ...
```

For let expression in TIR PrimExpr, the new syntax becomes:

```python
x = T.int32()
T.Let(x + 1, where={x : 1})
```

which was:

```python
x = T.int32()
T.let(x, 1, x + 1)
```
  • Loading branch information
Ubuntu authored and junrushao committed Mar 6, 2023
1 parent befdc4e commit 488eaed
Show file tree
Hide file tree
Showing 14 changed files with 149 additions and 77 deletions.
8 changes: 6 additions & 2 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -283,11 +283,15 @@ AssertFrame Assert(PrimExpr condition, String message);

/*!
* \brief The let binding.
* \param var The variable to bind.
* \param value The value to be bound.
* \param type_annotation The type annotation of the let binding.
* Usually it is used for fine-grained var typing,
* particularly, PointerType.
* \param var The variable to be bound. If not specified, a new variable will be created.
* \return The created LetFrame.
*/
LetFrame Let(Var var, PrimExpr value);
LetFrame LetStmt(PrimExpr value, Optional<Type> type_annotation = NullOpt,
Optional<Var> var = NullOpt);

/*!
* \brief The realization.
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/script/ir_builder/tir/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ class AssertFrame(TIRFrame):

@_register_object("script.ir_builder.tir.LetFrame")
class LetFrame(TIRFrame):
...
def __enter__(self) -> Var:
super().__enter__()
return self.var


@_register_object("script.ir_builder.tir.RealizeFrame")
Expand Down
60 changes: 56 additions & 4 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import numpy as np # type: ignore

from tvm import tir
from tvm.ir import Range, Type
from tvm.ir.base import deprecated
from tvm.runtime import convert, ndarray
Expand Down Expand Up @@ -61,7 +62,6 @@
FloorMod,
IntImm,
IterVar,
Let,
Load,
Max,
Min,
Expand Down Expand Up @@ -857,6 +857,47 @@ def Assert(condition: PrimExpr, message: str) -> frame.AssertFrame: # pylint: d
return _ffi_api.Assert(condition, message) # type: ignore[attr-defined] # pylint: disable=no-member


def LetStmt( # pylint: disable=invalid-name
value: PrimExpr,
type_annotation: Optional[Type] = None, # pylint: disable=redefined-outer-name
*,
var: Optional[Var] = None, # pylint: disable=redefined-outer-name
) -> frame.LetFrame:
"""Create a LetStmt binding
Parameters
----------
value : PrimExpr
The value to be bound.
type_annotation : Optional[Type] = None
The type annotation of the let binding. Usually it is used for fine-grained var typing,
particularly, PointerType.
var : Optional[Var] = None
The variable to bind. If not specified, a new variable will be created.
Returns
-------
let_frame : frame.LetFrame
The result LetFrame.
"""
if type_annotation is not None:
if callable(type_annotation):
type_annotation = type_annotation()
if isinstance(type_annotation, Var):
type_annotation = type_annotation.type_annotation
return _ffi_api.LetStmt(value, type_annotation, var) # type: ignore[attr-defined] # pylint: disable=no-member


def Let( # pylint: disable=invalid-name
expr: PrimExpr,
where: Dict[Var, PrimExpr], # pylint: disable=redefined-outer-name
) -> PrimExpr:
"""Create a Let expression binding"""
assert len(where) == 1, "T.Let only allows `where` to have exactly one element"
var, value = list(where.items())[0] # pylint: disable=redefined-outer-name
return tir.Let(var, value, expr)


def let(
v: Var,
value: PrimExpr,
Expand All @@ -880,9 +921,19 @@ def let(
res : frame.LetFrame
The result LetFrame.
"""

@deprecated("T.let", "T.Let")
def let_expr(v: Var, value: PrimExpr, body: PrimExpr) -> PrimExpr:
return tir.Let(v, value, body)

@deprecated("T.let", "T.LetStmt")
def let_stmt(v: Var, value: PrimExpr) -> frame.LetFrame:
return _ffi_api.LegacyLetStmt(v, value) # type: ignore[attr-defined] # pylint: disable=no-member

if body is None:
return _ffi_api.Let(v, value) # type: ignore[attr-defined] # pylint: disable=no-member
return Let(v, value, body)
return let_stmt(v, value)
else:
return let_expr(v, value, body)


def realize(
Expand Down Expand Up @@ -1850,7 +1901,6 @@ def wrapped(*args, **kwargs):
"thread_binding",
"grid",
"Assert",
"let",
"realize",
"allocate",
"allocate_const",
Expand Down Expand Up @@ -2028,6 +2078,8 @@ def wrapped(*args, **kwargs):
"Shuffle",
"Call",
"CallEffectKind",
"let",
"LetStmt",
"Let",
"IterVar",
"CommReducer",
Expand Down
7 changes: 3 additions & 4 deletions python/tvm/script/parser/tir/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,8 @@ def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) -
IRBuilder.name(var_name, value)
return value
elif isinstance(value, PrimExpr):
var = Var("", value.dtype)
IRBuilder.name(var_name, var)
frame = T.let(var, value)
frame = T.LetStmt(value)
var = frame.var
frame.add_callback(partial(frame.__exit__, None, None, None))
frame.__enter__()
return var
Expand Down Expand Up @@ -294,7 +293,7 @@ def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None:
if not isinstance(ann_var, Var):
self.report_error(node.annotation, "Annotation should be Var")
self.eval_assign(target=lhs, source=ann_var, bind_value=bind_assign_value)
frame = T.let(ann_var, rhs)
frame = T.LetStmt(rhs, var=ann_var)
frame.add_callback(partial(frame.__exit__, None, None, None))
frame.__enter__()

Expand Down
18 changes: 16 additions & 2 deletions src/script/ir_builder/tir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,20 @@ AssertFrame Assert(PrimExpr condition, String message) {
return AssertFrame(n);
}

LetFrame Let(Var var, PrimExpr value) {
LetFrame LetStmt(PrimExpr value, Optional<Type> type_annotation, Optional<Var> var) {
ObjectPtr<LetFrameNode> n = make_object<LetFrameNode>();
if (var.defined()) {
n->var = var.value();
} else if (type_annotation.defined()) {
n->var = Var("v", type_annotation.value());
} else {
n->var = Var("v", value.dtype());
}
n->value = value;
return LetFrame(n);
}

LetFrame LegacyLetStmt(Var var, PrimExpr value) {
ObjectPtr<LetFrameNode> n = make_object<LetFrameNode>();
n->var = var;
n->value = value;
Expand Down Expand Up @@ -634,7 +647,8 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.ThreadBinding").set_body_typed(Thread
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Grid").set_body_typed(Grid);

TVM_REGISTER_GLOBAL("script.ir_builder.tir.Assert").set_body_typed(Assert);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Let").set_body_typed(Let);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.LetStmt").set_body_typed(LetStmt);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.LegacyLetStmt").set_body_typed(LegacyLetStmt);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Allocate").set_body_typed(Allocate);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.AllocateConst").set_body_typed(AllocateConst);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Realize").set_body_typed(Realize);
Expand Down
9 changes: 4 additions & 5 deletions src/script/printer/tir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,11 +211,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::Let>("", [](tir::Let let, ObjectPath p, IRDocsifier d) -> Doc {
return TIR(d, "let")->Call({
d->AsDoc<ExprDoc>(let->var, p->Attr("var")),
d->AsDoc<ExprDoc>(let->value, p->Attr("value")),
d->AsDoc<ExprDoc>(let->body, p->Attr("body")),
});
DictDoc where({d->AsDoc<ExprDoc>(let->var, p->Attr("var"))},
{d->AsDoc<ExprDoc>(let->value, p->Attr("value"))});
return TIR(d, "Let")->Call({d->AsDoc<ExprDoc>(let->body, p->Attr("body"))}, //
{"where"}, {where});
});

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
Expand Down
45 changes: 25 additions & 20 deletions src/script/printer/tir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,30 +57,35 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::LetStmt>("", [](tir::LetStmt stmt, ObjectPath p, IRDocsifier d) -> Doc {
bool concise = AllowConciseScoping(d);
if (concise && !d->IsVarDefined(stmt->var)) {
ExprDoc rhs = d->AsDoc<ExprDoc>(stmt->value, p->Attr("value"));
With<TIRFrame> f(d, stmt);
ExprDoc lhs = DefineVar(stmt->var, *f, d);
AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
Array<StmtDoc>* stmts = &(*f)->stmts;
Type type = stmt->var->type_annotation;
Optional<ExprDoc> type_doc =
d->AsDoc<ExprDoc>(type, p->Attr("var")->Attr("type_annotation"));
if (const auto* tuple_type = type.as<TupleTypeNode>()) {
if (tuple_type->fields.empty()) {
type_doc = NullOpt;
}
// Step 1. Type annotation
Optional<ExprDoc> type_doc = d->AsDoc<ExprDoc>(stmt->var->type_annotation, //
p->Attr("var")->Attr("type_annotation"));
if (const auto* tuple_type = stmt->var->type_annotation.as<TupleTypeNode>()) {
if (tuple_type->fields.empty()) {
type_doc = NullOpt;
}
}
// Step 2. RHS
ExprDoc rhs = d->AsDoc<ExprDoc>(stmt->value, p->Attr("value"));
// Step 3. LHS and body
With<TIRFrame> f(d, stmt);
Array<StmtDoc>* stmts = &(*f)->stmts;
bool var_defined = d->IsVarDefined(stmt->var);
if (!var_defined) {
DefineVar(stmt->var, *f, d);
}
ExprDoc lhs = d->AsDoc<ExprDoc>(stmt->var, p->Attr("var"));
AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
// Step 4. Dispatch
if (var_defined) {
return ScopeDoc(NullOpt, TIR(d, "LetStmt")->Call({rhs}, {"var"}, {lhs}), *stmts);
} else if (concise) {
stmts->insert(stmts->begin(), AssignDoc(lhs, rhs, type_doc));
return StmtBlockDoc(*stmts);
} else if (type_doc.defined() && !stmt->var->type_annotation->IsInstance<PrimTypeNode>()) {
return ScopeDoc(lhs, TIR(d, "LetStmt")->Call({rhs, type_doc.value()}), *stmts);
} else {
ExprDoc lhs = d->AsDoc<ExprDoc>(stmt->var, p->Attr("var"));
ExprDoc rhs = d->AsDoc<ExprDoc>(stmt->value, p->Attr("value"));
With<TIRFrame> f(d, stmt);
AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
Array<StmtDoc>* stmts = &(*f)->stmts;
rhs = TIR(d, "let")->Call({lhs, rhs});
return ScopeDoc(NullOpt, rhs, *stmts);
return ScopeDoc(lhs, TIR(d, "LetStmt")->Call({rhs}), *stmts);
}
});

Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/aot/test_pass_aot_lower_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def func(a: T.handle, output: T.handle) -> None:
tmp_write: T.handle("uint8") = output_buffer.data
tmp_write_1 = T.Buffer([T.uint64(140)], dtype="uint8", data=tmp_write)
for i in T.serial(140):
tmp_write_1[i] = T.let(tmp_read, a_buffer.data, tmp_read_1[i])
tmp_write_1[i] = T.Let(tmp_read_1[i], where={tmp_read : a_buffer.data})
# fmt: on

_assert_lowered_main(mod, func, CallType.CPacked)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from tvm.ir.module import IRModule
from tvm.script import tir as T


# -----------------------------------------------------
# Basic test for the expected Behavior of the CSE pass
# -----------------------------------------------------
Expand Down Expand Up @@ -359,8 +360,7 @@ def func_distributivity_expected(
i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32
) -> None:
B = T.Buffer((50,), "int32")
cse_var_1 = T.int32()
with T.let(cse_var_1, x * y + x * z):
with T.LetStmt(x * y + x * z) as cse_var_1:
B[i1] = cse_var_1
B[i2] = cse_var_1

Expand All @@ -377,8 +377,7 @@ def func_associativity_expected(
i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32
) -> None:
B = T.Buffer((50,), "int32")
cse_var_1 = T.int32()
with T.let(cse_var_1, (x + y) + z):
with T.LetStmt((x + y) + z) as cse_var_1:
B[i1] = cse_var_1
B[i2] = cse_var_1

Expand Down
9 changes: 4 additions & 5 deletions tests/python/unittest/test_tir_transform_hoist_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import tir
import tvm.testing

from tvm import tir
from tvm.script import tir as T
from tvm.tir.transform import HoistExpression, HoistedConditionals, HoistedLetBindings
from tvm.tir.transform import HoistedConditionals, HoistedLetBindings, HoistExpression


class BaseBeforeAfter:
Expand Down Expand Up @@ -448,7 +447,7 @@ class TestHoistLetExpr(BaseBeforeAfter):
def before(A: T.Buffer((4, 4), "float32")):
for i, j in T.grid(4, 4):
x = T.float32()
A[i, j] = T.Let(x, T.cast(i + 1, "float32"), 5.0 * x + T.cast(j, "float32"))
A[i, j] = T.Let(5.0 * x + T.cast(j, "float32"), where={x: T.cast(i + 1, "float32")})

@T.prim_func
def expected(A: T.Buffer((4, 4), "float32")):
Expand All @@ -467,7 +466,7 @@ class TestSuppressHoistLetExpr(BaseBeforeAfter):
def before(A: T.Buffer((4, 4), "float32")):
for i, j in T.grid(4, 4):
x = T.float32()
A[i, j] = T.Let(x, T.cast(i + 1, "float32"), 5.0 * x + T.cast(j, "float32"))
A[i, j] = T.Let(5.0 * x + T.cast(j, "float32"), where={x: T.cast(i + 1, "float32")})

expected = before

Expand Down
Loading

0 comments on commit 488eaed

Please sign in to comment.