Skip to content

Commit

Permalink
[Relax] Expose name_hint field for BlockBuilder.match_cast (#16600)
Browse files Browse the repository at this point in the history
* [Relax] Expose name_hint field for BlockBuilder.match_cast

Prior to this commit, while a `relax.VarBinding` created using
`BlockBuilder.emit` could have its name explicitly specified by the
user, a `relax.MatchCast` created using `BlockBuilder.match_cast`
could not.  This commit updates `BlockBuilder.match_cast` to accept an
optional `name_hint` parameter, which is then provided to the C++
`BlockBuilder::EmitMatchCast` method.

* Fix lint error
  • Loading branch information
Lunderberg authored Feb 19, 2024
1 parent 36ebcd0 commit 1e6482c
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 5 deletions.
12 changes: 10 additions & 2 deletions python/tvm/relax/block_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def rx_func(x: Tensor((n,), "float32"), y: Tensor(((n + 1),), "float32"))
name_hint = kwargs.pop("name_hint", "")
return self.emit(self.call_te(func, *args, **kwargs), name_hint=name_hint)

def match_cast(self, value: Expr, struct_info: StructInfo) -> Var:
def match_cast(self, value: Expr, struct_info: StructInfo, name_hint: str = "") -> Var:
"""Emit a MatchCast.
Parameters
Expand All @@ -545,12 +545,20 @@ def match_cast(self, value: Expr, struct_info: StructInfo) -> Var:
struct_info : StructInfo
The struct info to be matched.
name_hint : str
The name of the match cast
Returns
-------
ret : tvm.relax.Var
A newly created variable that get bounds to be the casted result.
"""
return _ffi_api.BlockBuilderEmitMatchCast(self, value, struct_info) # type: ignore
return _ffi_api.BlockBuilderEmitMatchCast(
self,
value,
struct_info,
name_hint,
) # type: ignore

def emit_output(self, output: Union[Expr, Tuple, List[Expr]], name_hint: str = "") -> Var:
"""Emit output for the current dataflow block or function.
Expand Down
4 changes: 2 additions & 2 deletions src/relax/ir/block_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1015,8 +1015,8 @@ TVM_REGISTER_GLOBAL("relax.BlockBuilderEmit")
});

TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitMatchCast")
.set_body_typed([](BlockBuilder builder, Expr value, StructInfo struct_info) {
return builder->EmitMatchCast(value, struct_info);
.set_body_typed([](BlockBuilder builder, Expr value, StructInfo struct_info, String name_hint) {
return builder->EmitMatchCast(value, struct_info, name_hint);
});

TVM_REGISTER_GLOBAL("relax.BlockBuilderEmitOutput")
Expand Down
3 changes: 2 additions & 1 deletion tests/python/relax/test_blockbuilder_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def test_emit_match_cast():
assert_structural_equal(lv0.struct_info, rx.TensorStructInfo([m, n], "float32"))

# lv1: Shape = match_cast(shape, rx.ShapeStructInfo([m, n]))
lv1 = bb.match_cast(y, rx.ShapeStructInfo([m, n]))
lv1 = bb.match_cast(y, rx.ShapeStructInfo([m, n]), "var_name")
assert lv1.struct_info == rx.ShapeStructInfo([m, n])
gv0 = bb.emit_output(lv1)

Expand All @@ -244,6 +244,7 @@ def test_emit_match_cast():
assert b1.value == y
assert b1.struct_info == rx.ShapeStructInfo([m, n])
assert b1.var == lv1
assert b1.var.name_hint == "var_name"


def test_emit_match_cast_binding_in_dataflow_block():
Expand Down

0 comments on commit 1e6482c

Please sign in to comment.