diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index d5cc1de5c675..2b89d0e736e8 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -283,11 +283,15 @@ AssertFrame Assert(PrimExpr condition, String message); /*! * \brief The let binding. - * \param var The variable to bind. * \param value The value to be bound. + * \param type_annotation The type annotation of the let binding. + * Usually it is used for fine-grained var typing, + * particularly, PointerType. + * \param var The variable to be bound. If not specified, a new variable will be created. * \return The created LetFrame. */ -LetFrame Let(Var var, PrimExpr value); +LetFrame LetStmt(PrimExpr value, Optional type_annotation = NullOpt, + Optional var = NullOpt); /*! * \brief The realization. diff --git a/python/tvm/script/ir_builder/tir/frame.py b/python/tvm/script/ir_builder/tir/frame.py index a57c878bd929..3e453f2e5183 100644 --- a/python/tvm/script/ir_builder/tir/frame.py +++ b/python/tvm/script/ir_builder/tir/frame.py @@ -57,7 +57,9 @@ class AssertFrame(TIRFrame): @_register_object("script.ir_builder.tir.LetFrame") class LetFrame(TIRFrame): - ... + def __enter__(self) -> Var: + super().__enter__() + return self.var @_register_object("script.ir_builder.tir.RealizeFrame") diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 5f4e9d4f2cf0..62a0aa8f32f7 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -28,6 +28,7 @@ import numpy as np # type: ignore +from tvm import tir from tvm.ir import Range, Type from tvm.ir.base import deprecated from tvm.runtime import convert, ndarray @@ -61,7 +62,6 @@ FloorMod, IntImm, IterVar, - Let, Load, Max, Min, @@ -857,6 +857,47 @@ def Assert(condition: PrimExpr, message: str) -> frame.AssertFrame: # pylint: d return _ffi_api.Assert(condition, message) # type: ignore[attr-defined] # pylint: disable=no-member +def LetStmt( # pylint: disable=invalid-name + value: PrimExpr, + type_annotation: Optional[Type] = None, # pylint: disable=redefined-outer-name + *, + var: Optional[Var] = None, # pylint: disable=redefined-outer-name +) -> frame.LetFrame: + """Create a LetStmt binding + + Parameters + ---------- + value : PrimExpr + The value to be bound. + type_annotation : Optional[Type] = None + The type annotation of the let binding. Usually it is used for fine-grained var typing, + particularly, PointerType. + var : Optional[Var] = None + The variable to bind. If not specified, a new variable will be created. + + Returns + ------- + let_frame : frame.LetFrame + The result LetFrame. + """ + if type_annotation is not None: + if callable(type_annotation): + type_annotation = type_annotation() + if isinstance(type_annotation, Var): + type_annotation = type_annotation.type_annotation + return _ffi_api.LetStmt(value, type_annotation, var) # type: ignore[attr-defined] # pylint: disable=no-member + + +def Let( # pylint: disable=invalid-name + expr: PrimExpr, + where: Dict[Var, PrimExpr], # pylint: disable=redefined-outer-name +) -> PrimExpr: + """Create a Let expression binding""" + assert len(where) == 1, "T.Let only allows `where` to have exactly one element" + var, value = list(where.items())[0] # pylint: disable=redefined-outer-name + return tir.Let(var, value, expr) + + def let( v: Var, value: PrimExpr, @@ -880,9 +921,19 @@ def let( res : frame.LetFrame The result LetFrame. """ + + @deprecated("T.let", "T.Let") + def let_expr(v: Var, value: PrimExpr, body: PrimExpr) -> PrimExpr: + return tir.Let(v, value, body) + + @deprecated("T.let", "T.LetStmt") + def let_stmt(v: Var, value: PrimExpr) -> frame.LetFrame: + return _ffi_api.LegacyLetStmt(v, value) # type: ignore[attr-defined] # pylint: disable=no-member + if body is None: - return _ffi_api.Let(v, value) # type: ignore[attr-defined] # pylint: disable=no-member - return Let(v, value, body) + return let_stmt(v, value) + else: + return let_expr(v, value, body) def realize( @@ -1850,7 +1901,6 @@ def wrapped(*args, **kwargs): "thread_binding", "grid", "Assert", - "let", "realize", "allocate", "allocate_const", @@ -2028,6 +2078,8 @@ def wrapped(*args, **kwargs): "Shuffle", "Call", "CallEffectKind", + "let", + "LetStmt", "Let", "IterVar", "CommReducer", diff --git a/python/tvm/script/parser/tir/parser.py b/python/tvm/script/parser/tir/parser.py index fbef1a969179..5796db40ec06 100644 --- a/python/tvm/script/parser/tir/parser.py +++ b/python/tvm/script/parser/tir/parser.py @@ -143,9 +143,8 @@ def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) - IRBuilder.name(var_name, value) return value elif isinstance(value, PrimExpr): - var = Var("", value.dtype) - IRBuilder.name(var_name, var) - frame = T.let(var, value) + frame = T.LetStmt(value) + var = frame.var frame.add_callback(partial(frame.__exit__, None, None, None)) frame.__enter__() return var @@ -294,7 +293,7 @@ def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None: if not isinstance(ann_var, Var): self.report_error(node.annotation, "Annotation should be Var") self.eval_assign(target=lhs, source=ann_var, bind_value=bind_assign_value) - frame = T.let(ann_var, rhs) + frame = T.LetStmt(rhs, var=ann_var) frame.add_callback(partial(frame.__exit__, None, None, None)) frame.__enter__() diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 30102b687722..a54f3d926fc9 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -381,7 +381,20 @@ AssertFrame Assert(PrimExpr condition, String message) { return AssertFrame(n); } -LetFrame Let(Var var, PrimExpr value) { +LetFrame LetStmt(PrimExpr value, Optional type_annotation, Optional var) { + ObjectPtr n = make_object(); + if (var.defined()) { + n->var = var.value(); + } else if (type_annotation.defined()) { + n->var = Var("v", type_annotation.value()); + } else { + n->var = Var("v", value.dtype()); + } + n->value = value; + return LetFrame(n); +} + +LetFrame LegacyLetStmt(Var var, PrimExpr value) { ObjectPtr n = make_object(); n->var = var; n->value = value; @@ -634,7 +647,8 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.ThreadBinding").set_body_typed(Thread TVM_REGISTER_GLOBAL("script.ir_builder.tir.Grid").set_body_typed(Grid); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Assert").set_body_typed(Assert); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.Let").set_body_typed(Let); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.LetStmt").set_body_typed(LetStmt); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.LegacyLetStmt").set_body_typed(LegacyLetStmt); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Allocate").set_body_typed(Allocate); TVM_REGISTER_GLOBAL("script.ir_builder.tir.AllocateConst").set_body_typed(AllocateConst); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Realize").set_body_typed(Realize); diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc index f1435c487044..02ec269b0e73 100644 --- a/src/script/printer/tir/expr.cc +++ b/src/script/printer/tir/expr.cc @@ -211,11 +211,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::Let let, ObjectPath p, IRDocsifier d) -> Doc { - return TIR(d, "let")->Call({ - d->AsDoc(let->var, p->Attr("var")), - d->AsDoc(let->value, p->Attr("value")), - d->AsDoc(let->body, p->Attr("body")), - }); + DictDoc where({d->AsDoc(let->var, p->Attr("var"))}, + {d->AsDoc(let->value, p->Attr("value"))}); + return TIR(d, "Let")->Call({d->AsDoc(let->body, p->Attr("body"))}, // + {"where"}, {where}); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc index b730dd5606ba..92ad41edc9d5 100644 --- a/src/script/printer/tir/stmt.cc +++ b/src/script/printer/tir/stmt.cc @@ -57,30 +57,35 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::LetStmt stmt, ObjectPath p, IRDocsifier d) -> Doc { bool concise = AllowConciseScoping(d); - if (concise && !d->IsVarDefined(stmt->var)) { - ExprDoc rhs = d->AsDoc(stmt->value, p->Attr("value")); - With f(d, stmt); - ExprDoc lhs = DefineVar(stmt->var, *f, d); - AsDocBody(stmt->body, p->Attr("body"), f->get(), d); - Array* stmts = &(*f)->stmts; - Type type = stmt->var->type_annotation; - Optional type_doc = - d->AsDoc(type, p->Attr("var")->Attr("type_annotation")); - if (const auto* tuple_type = type.as()) { - if (tuple_type->fields.empty()) { - type_doc = NullOpt; - } + // Step 1. Type annotation + Optional type_doc = d->AsDoc(stmt->var->type_annotation, // + p->Attr("var")->Attr("type_annotation")); + if (const auto* tuple_type = stmt->var->type_annotation.as()) { + if (tuple_type->fields.empty()) { + type_doc = NullOpt; } + } + // Step 2. RHS + ExprDoc rhs = d->AsDoc(stmt->value, p->Attr("value")); + // Step 3. LHS and body + With f(d, stmt); + Array* stmts = &(*f)->stmts; + bool var_defined = d->IsVarDefined(stmt->var); + if (!var_defined) { + DefineVar(stmt->var, *f, d); + } + ExprDoc lhs = d->AsDoc(stmt->var, p->Attr("var")); + AsDocBody(stmt->body, p->Attr("body"), f->get(), d); + // Step 4. Dispatch + if (var_defined) { + return ScopeDoc(NullOpt, TIR(d, "LetStmt")->Call({rhs}, {"var"}, {lhs}), *stmts); + } else if (concise) { stmts->insert(stmts->begin(), AssignDoc(lhs, rhs, type_doc)); return StmtBlockDoc(*stmts); + } else if (type_doc.defined() && !stmt->var->type_annotation->IsInstance()) { + return ScopeDoc(lhs, TIR(d, "LetStmt")->Call({rhs, type_doc.value()}), *stmts); } else { - ExprDoc lhs = d->AsDoc(stmt->var, p->Attr("var")); - ExprDoc rhs = d->AsDoc(stmt->value, p->Attr("value")); - With f(d, stmt); - AsDocBody(stmt->body, p->Attr("body"), f->get(), d); - Array* stmts = &(*f)->stmts; - rhs = TIR(d, "let")->Call({lhs, rhs}); - return ScopeDoc(NullOpt, rhs, *stmts); + return ScopeDoc(lhs, TIR(d, "LetStmt")->Call({rhs}), *stmts); } }); diff --git a/tests/python/relay/aot/test_pass_aot_lower_main.py b/tests/python/relay/aot/test_pass_aot_lower_main.py index bc58812cd67c..9667d2093757 100644 --- a/tests/python/relay/aot/test_pass_aot_lower_main.py +++ b/tests/python/relay/aot/test_pass_aot_lower_main.py @@ -187,7 +187,7 @@ def func(a: T.handle, output: T.handle) -> None: tmp_write: T.handle("uint8") = output_buffer.data tmp_write_1 = T.Buffer([T.uint64(140)], dtype="uint8", data=tmp_write) for i in T.serial(140): - tmp_write_1[i] = T.let(tmp_read, a_buffer.data, tmp_read_1[i]) + tmp_write_1[i] = T.Let(tmp_read_1[i], where={tmp_read : a_buffer.data}) # fmt: on _assert_lowered_main(mod, func, CallType.CPacked) diff --git a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py index 1755a66ec9fb..5ba2824e74dd 100644 --- a/tests/python/unittest/test_tir_transform_common_subexpr_elim.py +++ b/tests/python/unittest/test_tir_transform_common_subexpr_elim.py @@ -22,6 +22,7 @@ from tvm.ir.module import IRModule from tvm.script import tir as T + # ----------------------------------------------------- # Basic test for the expected Behavior of the CSE pass # ----------------------------------------------------- @@ -359,8 +360,7 @@ def func_distributivity_expected( i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32 ) -> None: B = T.Buffer((50,), "int32") - cse_var_1 = T.int32() - with T.let(cse_var_1, x * y + x * z): + with T.LetStmt(x * y + x * z) as cse_var_1: B[i1] = cse_var_1 B[i2] = cse_var_1 @@ -377,8 +377,7 @@ def func_associativity_expected( i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32 ) -> None: B = T.Buffer((50,), "int32") - cse_var_1 = T.int32() - with T.let(cse_var_1, (x + y) + z): + with T.LetStmt((x + y) + z) as cse_var_1: B[i1] = cse_var_1 B[i2] = cse_var_1 diff --git a/tests/python/unittest/test_tir_transform_hoist_expression.py b/tests/python/unittest/test_tir_transform_hoist_expression.py index ca37915597a5..a0b624a15c31 100644 --- a/tests/python/unittest/test_tir_transform_hoist_expression.py +++ b/tests/python/unittest/test_tir_transform_hoist_expression.py @@ -15,11 +15,10 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import tir import tvm.testing - +from tvm import tir from tvm.script import tir as T -from tvm.tir.transform import HoistExpression, HoistedConditionals, HoistedLetBindings +from tvm.tir.transform import HoistedConditionals, HoistedLetBindings, HoistExpression class BaseBeforeAfter: @@ -448,7 +447,7 @@ class TestHoistLetExpr(BaseBeforeAfter): def before(A: T.Buffer((4, 4), "float32")): for i, j in T.grid(4, 4): x = T.float32() - A[i, j] = T.Let(x, T.cast(i + 1, "float32"), 5.0 * x + T.cast(j, "float32")) + A[i, j] = T.Let(5.0 * x + T.cast(j, "float32"), where={x: T.cast(i + 1, "float32")}) @T.prim_func def expected(A: T.Buffer((4, 4), "float32")): @@ -467,7 +466,7 @@ class TestSuppressHoistLetExpr(BaseBeforeAfter): def before(A: T.Buffer((4, 4), "float32")): for i, j in T.grid(4, 4): x = T.float32() - A[i, j] = T.Let(x, T.cast(i + 1, "float32"), 5.0 * x + T.cast(j, "float32")) + A[i, j] = T.Let(5.0 * x + T.cast(j, "float32"), where={x: T.cast(i + 1, "float32")}) expected = before diff --git a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py index 58f37f04967d..d0403fcae938 100644 --- a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py +++ b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py @@ -164,7 +164,7 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: slow_memory_7_buffer_var = T.match_buffer(slow_memory_7_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) # body tensor_2_let = T.Buffer([200704], dtype="uint8") - with T.let(tensor_2_let.data, T.address_of(fast_memory_6_buffer_var[0], dtype="handle")): + with T.LetStmt(T.address_of(fast_memory_6_buffer_var[0], dtype="handle"), var=tensor_2_let.data): for ax0_ax1_fused_4, ax2_4 in T.grid(56, 56): for ax3_init in T.serial(0, 64): tensor_2_let[ax0_ax1_fused_4 * 3584 + ax2_4 * 64 + ax3_init] = T.uint8(0) @@ -194,12 +194,12 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholde slow_memory_5_buffer_var = T.match_buffer(slow_memory_5_var, [1418528], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_7_let = T.Buffer([157323], "int16") - with T.let(PaddedInput_7_let.data, T.address_of(slow_memory_5_buffer_var[802816], dtype="handle")): + with T.LetStmt(T.address_of(slow_memory_5_buffer_var[802816], dtype="handle"), var=PaddedInput_7_let.data): for i0_i1_fused_7, i2_7, i3_7 in T.grid(229, 229, 3): PaddedInput_7_let[i0_i1_fused_7 * 687 + i2_7 * 3 + i3_7] = T.if_then_else(2 <= i0_i1_fused_7 and i0_i1_fused_7 < 226 and 2 <= i2_7 and i2_7 < 226, placeholder_65[i0_i1_fused_7 * 672 + i2_7 * 3 + i3_7 - 1350], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): Conv2dOutput_7_let = T.Buffer([64], "int32") - with T.let(Conv2dOutput_7_let.data, T.address_of(fast_memory_4_buffer_var[0], dtype="handle")): + with T.LetStmt(T.address_of(fast_memory_4_buffer_var[0], dtype="handle"), var=Conv2dOutput_7_let.data): for ff_3 in T.serial(0, 64): Conv2dOutput_7_let[ff_3] = 0 for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): @@ -399,12 +399,12 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s global_workspace_5_buffer_var = T.match_buffer(global_workspace_5_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_3_let = T.Buffer([360000], 'int16') - with T.let(PaddedInput_3_let.data, T.address_of(global_workspace_5_buffer_var[6480000], dtype="handle")): + with T.LetStmt(T.address_of(global_workspace_5_buffer_var[6480000], dtype="handle"), var=PaddedInput_3_let.data): for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64): PaddedInput_3_let[i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3] = placeholder_29[i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3] for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 5625): Conv2dOutput_3_let = T.Buffer([64], 'int32') - with T.let(Conv2dOutput_3_let.data, T.address_of(global_workspace_5_buffer_var[7200000], dtype="handle")): + with T.LetStmt(T.address_of(global_workspace_5_buffer_var[7200000], dtype="handle"), var=Conv2dOutput_3_let.data): for ax3_outer_2 in T.serial(0, 4): for ff_3 in T.serial(0, 64): Conv2dOutput_3_let[ff_3] = 0 @@ -422,12 +422,12 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_s global_workspace_4_buffer_var = T.match_buffer(global_workspace_4_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_2_let = T.Buffer([360000], "int16") - with T.let(PaddedInput_2_let.data, T.address_of(global_workspace_4_buffer_var[7200000], dtype="handle")): + with T.LetStmt(T.address_of(global_workspace_4_buffer_var[7200000], dtype="handle"), var=PaddedInput_2_let.data): for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64): PaddedInput_2_let[i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2] = placeholder_19[i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2] for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 5625): Conv2dOutput_2_let = T.Buffer([64], 'int32') - with T.let(Conv2dOutput_2_let.data, T.address_of(global_workspace_4_buffer_var[7920000], dtype="handle")): + with T.LetStmt(T.address_of(global_workspace_4_buffer_var[7920000], dtype="handle"), var=Conv2dOutput_2_let.data): for ax3_outer_1 in T.serial(0, 4): for ff_2 in T.serial(0, 64): Conv2dOutput_2_let[ff_2] = 0 @@ -445,12 +445,12 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place global_workspace_2_buffer_var = T.match_buffer(global_workspace_2_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_let = T.Buffer([360000], "int16") - with T.let(PaddedInput_let.data, T.address_of(global_workspace_2_buffer_var[7200000], dtype="handle")): + with T.LetStmt(T.address_of(global_workspace_2_buffer_var[7200000], dtype="handle"), var=PaddedInput_let.data): for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): PaddedInput_let[i0_i1_fused * 4800 + i2 * 64 + i3] = placeholder_7[i0_i1_fused * 4800 + i2 * 64 + i3] for ax0_ax1_fused_ax2_fused in T.serial(0, 5625): Conv2dOutput_let = T.Buffer([64], "int32") - with T.let(Conv2dOutput_let.data, T.address_of(global_workspace_2_buffer_var[7920000], dtype="handle")): + with T.LetStmt(T.address_of(global_workspace_2_buffer_var[7920000], dtype="handle"), var=Conv2dOutput_let.data): for ff in T.serial(0, 64): Conv2dOutput_let[ff] = 0 for rc in T.serial(0, 64): @@ -467,12 +467,12 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(pla global_workspace_3_buffer_var = T.match_buffer(global_workspace_3_var, [7920256], dtype="uint8", strides=[1], elem_offset=0, align=16) # body PaddedInput_1_let = T.Buffer([379456], "int16") - with T.let(PaddedInput_1_let.data, T.address_of(global_workspace_3_buffer_var[0], dtype="handle")): + with T.LetStmt(T.address_of(global_workspace_3_buffer_var[0], dtype="handle"), var=PaddedInput_1_let.data): for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): PaddedInput_1_let[i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1] = T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, placeholder_13[i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864], T.int16(0), dtype="int16") for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625): Conv2dOutput_1_let = T.Buffer([64], "int32") - with T.let(Conv2dOutput_1_let.data, T.address_of(global_workspace_3_buffer_var[7200000], dtype="handle")): + with T.LetStmt(T.address_of(global_workspace_3_buffer_var[7200000], dtype="handle"), var=Conv2dOutput_1_let.data): for ff_1 in T.serial(0, 64): Conv2dOutput_1_let[ff_1] = 0 for ry, rx, rc_1 in T.grid(3, 3, 64): @@ -562,7 +562,9 @@ def tensor_intrin_primfunc(global_workspace_1_var: T.handle("uint8")) -> None: global_workspace_1_var, [40], dtype="uint8", strides=[1], elem_offset=0, align=16 ) dense_let = T.Buffer([10], "int32") - with T.let(dense_let.data, T.address_of(global_workspace_1_buffer_var[0], dtype="handle")): + with T.LetStmt( + T.address_of(global_workspace_1_buffer_var[0], dtype="handle"), var=dense_let.data + ): T.evaluate( T.call_extern( "intrin_function", diff --git a/tests/python/unittest/test_tvmscript_ir_builder_tir.py b/tests/python/unittest/test_tvmscript_ir_builder_tir.py index 889f0c9eda33..5599d2f7c69a 100644 --- a/tests/python/unittest/test_tvmscript_ir_builder_tir.py +++ b/tests/python/unittest/test_tvmscript_ir_builder_tir.py @@ -283,7 +283,7 @@ def test_ir_builder_tir_assert(): def test_ir_builder_tir_let(): with IRBuilder() as ib: - with T.let(T.int32(), tir.IntImm("int32", 2)): + with T.LetStmt(tir.IntImm("int32", 2)) as v: T.evaluate(0) # the let binding generated by IRBuilder let_actual = ib.get() diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index 13aaacb3b758..e74f69dcae8b 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -248,14 +248,14 @@ def test_for(): def test_let_stmt(): with IRBuilder() as ib: - with T.let(T.float32(), T.float32(10)): + with T.LetStmt(T.float32(10)) as v: + ib.name("v", v) T.evaluate(0) obj = ib.get() _assert_print( obj, """ -v = T.float32() -with T.let(v, T.float32(10)): +with T.LetStmt(T.float32(10)) as v: T.evaluate(0) """, ) @@ -602,7 +602,7 @@ def test_let_expr(): obj, """ x = T.int32() -T.let(x, 1, x + 1) +T.Let(x + 1, where={x: 1}) """, ) diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 48a59994690b..c91b733751d1 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -335,9 +335,9 @@ def mmult( T.attr(0, "compute_scope", "mmult_compute_") T.attr(packedB.data, "storage_scope", "global") T.attr(packedB.data, "storage_alignment", 128) - with T.let( - packedB.data, + with T.LetStmt( T.TVMBackendAllocWorkspace(1, dev_id, T.uint64(4194304), 2, 32, dtype="handle"), + var=packedB.data, ): if T.isnullptr(packedB.data, dtype="bool"): T.evaluate(T.tvm_throw_last_error(dtype="int32")) @@ -349,11 +349,11 @@ def mmult( for x_outer in T.parallel(0, 32): T.attr(C_global.data, "storage_scope", "global") T.attr(C_global.data, "storage_alignment", 128) - with T.let( - C_global.data, + with T.LetStmt( T.TVMBackendAllocWorkspace( 1, dev_id, T.uint64(4096), 2, 32, dtype="handle" ), + var=C_global.data, ): if T.isnullptr(C_global.data, dtype="bool"): T.evaluate(T.tvm_throw_last_error(dtype="int32")) @@ -3317,7 +3317,7 @@ def let_expression(): @T.prim_func def func(): x = T.int32() - T.evaluate(T.let(x, 1, x + 1)) + T.evaluate(T.Let(x + 1, where={x: 1})) return func @@ -3542,10 +3542,8 @@ def func(): def let_stmt_var(): @T.prim_func def func(): - x = T.int32() - y = T.int32() - with T.let(x, 0): - with T.let(y, 0): + with T.LetStmt(0) as x: + with T.LetStmt(0) as y: T.evaluate(0) T.evaluate(0) @@ -3555,10 +3553,9 @@ def func(): def let_stmt_value(): @T.prim_func def func(): - x = T.int32() y = T.int32() - with T.let(x, y): - with T.let(y, 0): + with T.LetStmt(y) as x: + with T.LetStmt(0, var=y): T.evaluate(0) T.evaluate(0)