Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix][TVMScript] Handle R.match_cast as last binding in if/else #16562

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
[Bugfix][TVMScript] Handle R.match_cast as last binding in if/else
Prior to this commit, using `R.match_cast` as the last binding would
produce a segfault, as `var_binding->value` was used instead of
`match_cast->value`.  In addition, because the last binding of each
branch was removed, any changes to the struct info resulting from the
match cast were silently discarded.

This commit updates the TVMScript parsing of if/else statements to
remove the segfault and maintain the struct info changes produced by
the `R.match_cast`.
Lunderberg committed Feb 14, 2024
commit af56392d7852ab1eadc4d80cd60866060eca1a25
4 changes: 3 additions & 1 deletion src/script/ir_builder/relax/frame.cc
Original file line number Diff line number Diff line change
@@ -263,7 +263,9 @@ void ElseFrameNode::ExitWithScope() {
IfFrame frame = FindIfFrame("R.Else");
frame->else_expr = output;
CHECK(frame->var_name == var_name)
<< "This last binding of both branches must have the same variable.";
<< "This last binding of both branches must provide the same variable. "
<< "However, the R.Then branch provides variable " << frame->var_name
<< ", while the R.Else branch provides variable " << var_name;
}

TVM_REGISTER_NODE_TYPE(FunctionFrameNode);
52 changes: 36 additions & 16 deletions src/script/ir_builder/relax/utils.h
Original file line number Diff line number Diff line change
@@ -70,10 +70,13 @@ inline BlockFrame CheckBlockFrameExistAndUnended() {
inline tvm::relax::SeqExpr GetSeqExprForBranch(const SeqExprFrame& frame, String* var_name) {
// Step 0. Check frame type
std::string method;
std::string output_var_suffix;
if (frame->IsInstance<ThenFrameNode>()) {
method = "R.Then";
output_var_suffix = "_then";
} else if (frame->IsInstance<ElseFrameNode>()) {
method = "R.Else";
output_var_suffix = "_else";
} else {
ICHECK(false) << "TypeError: Unsupported frame type: " << frame->GetTypeKey();
}
@@ -84,29 +87,46 @@ inline tvm::relax::SeqExpr GetSeqExprForBranch(const SeqExprFrame& frame, String
const tvm::relax::BindingBlock& last_block = frame->binding_blocks.back();
CHECK(!last_block->bindings.empty()) << "Blocks are expected to be non-empty.";

// Step 2. Collect body from the last binding.
// Step 2. Update the last binding of each branch. While we could
// use the last bound value of each branch as a SeqExpr body, the
// Normalizer would pull it back out into a `gv#` binding anyways.
// Generating a new variable in each branch provides a more readable
// variable name.

tvm::relax::Binding last_binding = last_block->bindings.back();
CHECK(!last_binding->var->IsInstance<tvm::relax::DataflowVarNode>())
<< "A non-dataflow var is expected in the last binding of '" << method << "'.";

*var_name = last_binding->var->name_hint();

// Step 3. Re-collect binding blocks to replace the last binding.
Array<tvm::relax::BindingBlock> new_blocks(frame->binding_blocks.begin(),
frame->binding_blocks.end() - 1);
Array<tvm::relax::Binding> last_block_bindings(last_block->bindings.begin(),
last_block->bindings.end() - 1);

tvm::relax::Var new_var = tvm::relax::Var(last_binding->var->name_hint() + output_var_suffix,
GetStructInfo(last_binding->var));
tvm::relax::Expr body;
const tvm::relax::Binding& last_binding = last_block->bindings.back();
if (const auto* var_binding = last_binding.as<tvm::relax::VarBindingNode>()) {
CHECK(!var_binding->var->IsInstance<tvm::relax::DataflowVarNode>())
<< "A non-dataflow var is expected in the last binding of '" << method << "'.";

if (const auto* var_binding = last_binding.as<tvm::relax::VarBindingNode>();
var_binding && var_binding->value->IsInstance<tvm::relax::VarNode>()) {
body = var_binding->value;
*var_name = var_binding->var->name_hint();
} else if (const auto* var_binding = last_binding.as<tvm::relax::VarBindingNode>()) {
last_block_bindings.push_back(last_binding =
tvm::relax::VarBinding(new_var, var_binding->value));
body = new_var;
} else if (const auto* match_cast = last_binding.as<tvm::relax::MatchCastNode>()) {
CHECK(!match_cast->var->IsInstance<tvm::relax::DataflowVarNode>())
<< "A non-dataflow var is expected in the last binding of '" << method << "'.";
body = var_binding->value;
*var_name = match_cast->var->name_hint();
last_block_bindings.push_back(
tvm::relax::MatchCast(new_var, match_cast->value, match_cast->struct_info));
body = new_var;
} else {
ICHECK(false) << "TypeError: Unsupported binding type: " << last_binding->GetTypeKey();
}

// Step 3. Re-collect binding blocks to remove the last binding.
Array<tvm::relax::BindingBlock> new_blocks(frame->binding_blocks.begin(),
frame->binding_blocks.end() - 1);
Array<tvm::relax::Binding> last_block_bindings(last_block->bindings.begin(),
last_block->bindings.end() - 1);
new_blocks.push_back(tvm::relax::BindingBlock(last_block_bindings));
new_blocks.push_back(last_block->IsInstance<tvm::relax::DataflowBlockNode>()
? tvm::relax::DataflowBlock(last_block_bindings)
: tvm::relax::BindingBlock(last_block_bindings));

return tvm::relax::SeqExpr(new_blocks, body);
}
41 changes: 41 additions & 0 deletions tests/python/relax/test_tvmscript_parser.py
Original file line number Diff line number Diff line change
@@ -1176,6 +1176,47 @@ def check_call(call, op, args):
check_call(y_bind.value, "relax.add", [w_bind.var, w_bind.var])


def test_if_branch_with_match_cast():
"""The last branch of a relax::If node may be a MatchCast

This is a regression test. In previous implementations, using
R.match_cast as the last binding would cause a segfault while
parsing.
"""

@R.function
def func(A: R.Tensor([16, 16]), is_bfloat16: R.Prim("bool")):
if is_bfloat16:
A = R.match_cast(A, R.Tensor([16, 16], "bfloat16"))
B = A.astype("float16")
else:
B = R.match_cast(A, R.Tensor([16, 16], "float16"))
return B

A, is_bfloat16 = func.params
(block,) = func.body.blocks
(B_binding,) = block.bindings

B_var = B_binding.var
assert isinstance(B_var, relax.Var)
assert B_var.name_hint == "B"

if_then_else = B_binding.value
assert isinstance(if_then_else, relax.If)
assert isinstance(if_then_else.true_branch, relax.SeqExpr)
assert isinstance(if_then_else.false_branch, relax.SeqExpr)

else_branch = if_then_else.false_branch
(else_block,) = else_branch.blocks

assert isinstance(else_block.bindings[-1], relax.MatchCast)

# If the `R.match_cast` were removed, the function would infer the
# return value as `R.Tensor([16,16])`, with an unknown dtype.
# With the `R.match_cast` retained, the output dtype is known.
tvm.ir.assert_structural_equal(func.ret_struct_info, R.Tensor([16, 16], "float16"))


def test_if_inside_dataflow():
with pytest.raises(tvm.error.DiagnosticError):