From af56392d7852ab1eadc4d80cd60866060eca1a25 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 12 Feb 2024 17:03:47 +0000 Subject: [PATCH] [Bugfix][TVMScript] Handle R.match_cast as last binding in if/else Prior to this commit, using `R.match_cast` as the last binding would produce a segfault, as `var_binding->value` was used instead of `match_cast->value`. In addition, because the last binding of each branch was removed, any changes to the struct info resulting from the match cast were silently discarded. This commit updates the TVMScript parsing of if/else statements to remove the segfault and maintain the struct info changes produced by the `R.match_cast`. --- src/script/ir_builder/relax/frame.cc | 4 +- src/script/ir_builder/relax/utils.h | 52 ++++++++++++++------- tests/python/relax/test_tvmscript_parser.py | 41 ++++++++++++++++ 3 files changed, 80 insertions(+), 17 deletions(-) diff --git a/src/script/ir_builder/relax/frame.cc b/src/script/ir_builder/relax/frame.cc index 966af809c9b4..b95db57a881b 100644 --- a/src/script/ir_builder/relax/frame.cc +++ b/src/script/ir_builder/relax/frame.cc @@ -263,7 +263,9 @@ void ElseFrameNode::ExitWithScope() { IfFrame frame = FindIfFrame("R.Else"); frame->else_expr = output; CHECK(frame->var_name == var_name) - << "This last binding of both branches must have the same variable."; + << "This last binding of both branches must provide the same variable. " + << "However, the R.Then branch provides variable " << frame->var_name + << ", while the R.Else branch provides variable " << var_name; } TVM_REGISTER_NODE_TYPE(FunctionFrameNode); diff --git a/src/script/ir_builder/relax/utils.h b/src/script/ir_builder/relax/utils.h index ae91d05769bd..395e027bce57 100644 --- a/src/script/ir_builder/relax/utils.h +++ b/src/script/ir_builder/relax/utils.h @@ -70,10 +70,13 @@ inline BlockFrame CheckBlockFrameExistAndUnended() { inline tvm::relax::SeqExpr GetSeqExprForBranch(const SeqExprFrame& frame, String* var_name) { // Step 0. Check frame type std::string method; + std::string output_var_suffix; if (frame->IsInstance()) { method = "R.Then"; + output_var_suffix = "_then"; } else if (frame->IsInstance()) { method = "R.Else"; + output_var_suffix = "_else"; } else { ICHECK(false) << "TypeError: Unsupported frame type: " << frame->GetTypeKey(); } @@ -84,29 +87,46 @@ inline tvm::relax::SeqExpr GetSeqExprForBranch(const SeqExprFrame& frame, String const tvm::relax::BindingBlock& last_block = frame->binding_blocks.back(); CHECK(!last_block->bindings.empty()) << "Blocks are expected to be non-empty."; - // Step 2. Collect body from the last binding. + // Step 2. Update the last binding of each branch. While we could + // use the last bound value of each branch as a SeqExpr body, the + // Normalizer would pull it back out into a `gv#` binding anyways. + // Generating a new variable in each branch provides a more readable + // variable name. + + tvm::relax::Binding last_binding = last_block->bindings.back(); + CHECK(!last_binding->var->IsInstance()) + << "A non-dataflow var is expected in the last binding of '" << method << "'."; + + *var_name = last_binding->var->name_hint(); + + // Step 3. Re-collect binding blocks to replace the last binding. + Array new_blocks(frame->binding_blocks.begin(), + frame->binding_blocks.end() - 1); + Array last_block_bindings(last_block->bindings.begin(), + last_block->bindings.end() - 1); + + tvm::relax::Var new_var = tvm::relax::Var(last_binding->var->name_hint() + output_var_suffix, + GetStructInfo(last_binding->var)); tvm::relax::Expr body; - const tvm::relax::Binding& last_binding = last_block->bindings.back(); - if (const auto* var_binding = last_binding.as()) { - CHECK(!var_binding->var->IsInstance()) - << "A non-dataflow var is expected in the last binding of '" << method << "'."; + + if (const auto* var_binding = last_binding.as(); + var_binding && var_binding->value->IsInstance()) { body = var_binding->value; - *var_name = var_binding->var->name_hint(); + } else if (const auto* var_binding = last_binding.as()) { + last_block_bindings.push_back(last_binding = + tvm::relax::VarBinding(new_var, var_binding->value)); + body = new_var; } else if (const auto* match_cast = last_binding.as()) { - CHECK(!match_cast->var->IsInstance()) - << "A non-dataflow var is expected in the last binding of '" << method << "'."; - body = var_binding->value; - *var_name = match_cast->var->name_hint(); + last_block_bindings.push_back( + tvm::relax::MatchCast(new_var, match_cast->value, match_cast->struct_info)); + body = new_var; } else { ICHECK(false) << "TypeError: Unsupported binding type: " << last_binding->GetTypeKey(); } - // Step 3. Re-collect binding blocks to remove the last binding. - Array new_blocks(frame->binding_blocks.begin(), - frame->binding_blocks.end() - 1); - Array last_block_bindings(last_block->bindings.begin(), - last_block->bindings.end() - 1); - new_blocks.push_back(tvm::relax::BindingBlock(last_block_bindings)); + new_blocks.push_back(last_block->IsInstance() + ? tvm::relax::DataflowBlock(last_block_bindings) + : tvm::relax::BindingBlock(last_block_bindings)); return tvm::relax::SeqExpr(new_blocks, body); } diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 71970ad965b6..fa047dbc3a3d 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -1176,6 +1176,47 @@ def check_call(call, op, args): check_call(y_bind.value, "relax.add", [w_bind.var, w_bind.var]) +def test_if_branch_with_match_cast(): + """The last branch of a relax::If node may be a MatchCast + + This is a regression test. In previous implementations, using + R.match_cast as the last binding would cause a segfault while + parsing. + """ + + @R.function + def func(A: R.Tensor([16, 16]), is_bfloat16: R.Prim("bool")): + if is_bfloat16: + A = R.match_cast(A, R.Tensor([16, 16], "bfloat16")) + B = A.astype("float16") + else: + B = R.match_cast(A, R.Tensor([16, 16], "float16")) + return B + + A, is_bfloat16 = func.params + (block,) = func.body.blocks + (B_binding,) = block.bindings + + B_var = B_binding.var + assert isinstance(B_var, relax.Var) + assert B_var.name_hint == "B" + + if_then_else = B_binding.value + assert isinstance(if_then_else, relax.If) + assert isinstance(if_then_else.true_branch, relax.SeqExpr) + assert isinstance(if_then_else.false_branch, relax.SeqExpr) + + else_branch = if_then_else.false_branch + (else_block,) = else_branch.blocks + + assert isinstance(else_block.bindings[-1], relax.MatchCast) + + # If the `R.match_cast` were removed, the function would infer the + # return value as `R.Tensor([16,16])`, with an unknown dtype. + # With the `R.match_cast` retained, the output dtype is known. + tvm.ir.assert_structural_equal(func.ret_struct_info, R.Tensor([16, 16], "float16")) + + def test_if_inside_dataflow(): with pytest.raises(tvm.error.DiagnosticError):