Skip to content

Commit

Permalink
[Relax][Transform] Handle tuple return in RemoveUnusedOutputs (#17253)
Browse files Browse the repository at this point in the history
* [Relax][Transform] Handle tuple return in RemoveUnusedOutputs

Prior to this commit, the `relax.transform.RemoveUnusedOutputs` pass
only marked a tuple element as used if it occurred in a `TupleGetItem`
node.  This ignored use cases where a tuple is used as an aggregate
object, such as returning a tuple from a function.  This would collect
incorrect results for a Relax function that calls a subroutine,
receives a tuple as the return value of the subroutine, then returns
that tuple.

This commit updates `RemoveUnusedOutputs` to look for usage of a tuple
object, not just for usage in `TupleGetItem`.

Closes #17247
  • Loading branch information
Lunderberg authored Sep 6, 2024
1 parent ec28b67 commit ff884b6
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 20 deletions.
59 changes: 39 additions & 20 deletions src/relax/transform/remove_unused_outputs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,29 +92,48 @@ class PartialTupleUsageCollector : ExprVisitor {
}

void VisitExpr_(const TupleGetItemNode* op) override {
Expr tuple = UnwrapBindings(op->tuple);

if (auto call = tuple.as<CallNode>()) {
if (auto opt_callee = call->op.as<GlobalVar>()) {
auto callee = opt_callee.value();
if (auto it = output_usage_mask_.find(callee); it != output_usage_mask_.end()) {
auto& used_indices = it->second;

CHECK_GE(op->index, 0) << "IndexError: "
<< "Indices for TupleGetItem must be non-negative, "
<< "but expression " << GetRef<Expr>(op)
<< " uses a tuple index of " << op->index;
size_t index = op->index;

CHECK_LT(index, used_indices.size())
<< "IndexError: "
<< "Indices for TupleGetItem must be less than the size of the tuple, "
<< "but expression " << GetRef<Expr>(op) << " uses a tuple index of " << op->index
<< " for a tuple of size " << used_indices.size();
used_indices[index] = true;
if (auto* usage_mask_ptr = GetCalleeUsageMask(op->tuple)) {
auto& used_indices = *usage_mask_ptr;

CHECK_GE(op->index, 0) << "IndexError: "
<< "Indices for TupleGetItem must be non-negative, "
<< "but expression " << GetRef<Expr>(op) << " uses a tuple index of "
<< op->index;
size_t index = op->index;

CHECK_LT(index, used_indices.size())
<< "IndexError: "
<< "Indices for TupleGetItem must be less than the size of the tuple, "
<< "but expression " << GetRef<Expr>(op) << " uses a tuple index of " << op->index
<< " for a tuple of size " << used_indices.size();
used_indices[index] = true;
}
}

void VisitExpr_(const VarNode* op) override {
if (auto* usage_mask_ptr = GetCalleeUsageMask(GetRef<Var>(op))) {
auto& usage_mask = *usage_mask_ptr;
for (size_t i = 0; i < usage_mask.size(); i++) {
usage_mask[i] = true;
}
}
}

std::vector<bool>* GetCalleeUsageMask(Expr expr) {
if (!expr->struct_info_.as<TupleStructInfoNode>()) {
return nullptr;
}

expr = UnwrapBindings(expr);
if (auto call = expr.as<CallNode>()) {
if (auto callee = call->op.as<GlobalVar>()) {
if (auto it = output_usage_mask_.find(callee.value()); it != output_usage_mask_.end()) {
return &it->second;
}
}
}

return nullptr;
}

Expr UnwrapBindings(Expr expr) const {
Expand Down
20 changes: 20 additions & 0 deletions tests/python/relax/test_transform_remove_unused_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,5 +119,25 @@ def func() -> R.Tuple([R.Tensor([16, 16], "int32"), R.Tensor([32, 32], "int32")]
return (A, C)


class TestReturnTuple(BaseCompare):
@I.ir_module
class Before:
@R.function
def main(A: R.Tensor([16, 16], "int32")):
B = R.add(A, A)
out_tuple = Before.func(B)
return out_tuple

@R.function(private=True)
def func(
B: R.Tensor([16, 16], "int32")
) -> R.Tuple(R.Tensor([16, 16], "int32"), R.Tensor([16, 16], "int32")):
C = R.multiply(B, B)
D = R.add(B, B)
return (C, D)

Expected = Before


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit ff884b6

Please sign in to comment.