Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax] Handle presence of R.call_tir in MergeCompositeFunctions #17220

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions src/relax/transform/merge_composite_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,19 +234,35 @@ class CompositeGroupsBuilder : public MemoizedExprTranslator<Group*> {
void UpdateGroupDependencies(Group* group, const Array<Expr>& args) {
Group* group_root = group->FindRoot();

for (const auto& arg : args) {
auto arg_group_root = memo_[arg]->FindRoot();
std::function<void(Expr)> visit_expr = [&](Expr expr) {
if (expr.as<GlobalVarNode>()) return;
if (auto tuple = expr.as<TupleNode>()) {
for (const auto& field : tuple->fields) {
visit_expr(field);
}
return;
}

ICHECK(memo_.count(expr)) << "Could not find memo-ized group for expression of type "
<< expr->GetTypeKey();
auto arg_group_root = memo_[expr]->FindRoot();

if (arg_group_root == group_root) {
// If arg and the current node are in the same group,
// there is nothing to update.
continue;
return;
}

// Add the group of arg as dependency
group_deps_[group_root].insert(arg_group_root);
// Propagate dependencies of arg
for (auto dep : group_deps_[arg_group_root]) {
group_deps_[group_root].insert(dep);
}
};

for (const auto& arg : args) {
visit_expr(arg);
}
}

Expand Down
119 changes: 119 additions & 0 deletions tests/python/relax/test_transform_merge_composite_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tvm import relax
from tvm.script import relax as R
from tvm.script import ir as I
from tvm.script import tir as T


@tvm.script.ir_module
Expand Down Expand Up @@ -1106,5 +1107,123 @@ def main(
check(Module, Expected)


def test_handle_existence_of_call_tir():
"""MergeCompositeFunctions should accept R.call_tir as input

No merging is required in this case, since the two composite
functions have `R.call_tir` between them. This is a regression
test, as previously the `Tuple` used to express of `R.call_tir`
caused a segfault.

"""

@I.ir_module
class Before:
@R.function
def main(A: R.Tensor([10], dtype="float32")) -> R.Tensor([10], dtype="float32"):
cls = Before
with R.dataflow():
B = cls.fused_relax_nn_relu(A)
C = R.call_tir(cls.relu, (B,), out_sinfo=R.Tensor([10], dtype="float32"))
D = cls.fused_relax_nn_gelu(C)
R.output(D)
return D

@R.function(private=True)
def fused_relax_nn_relu(
Input: R.Tensor([10], dtype="float32")
) -> R.Tensor([10], dtype="float32"):
R.func_attr({"Composite": "compiler_A.relu", "Primitive": 1})
with R.dataflow():
Output = R.nn.relu(Input)
R.output(Output)
return Output

@T.prim_func(private=True)
def relu(
Input: T.Buffer(T.int64(10), "float32"),
Output: T.Buffer(T.int64(10), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
for i in range(T.int64(10)):
with T.block("compute"):
vi = T.axis.remap("S", [i])
Output[vi] = T.max(Input[vi], T.float32(0))

@R.function(private=True)
def fused_relax_nn_gelu(
Input: R.Tensor([10], dtype="float32")
) -> R.Tensor([10], dtype="float32"):
R.func_attr({"Composite": "compiler_A.gelu", "Primitive": 1})
with R.dataflow():
Output = R.nn.gelu(Input)
R.output(Output)
return Output

@I.ir_module
class Expected:
@R.function
def main(A: R.Tensor([10], dtype="float32")) -> R.Tensor([10], dtype="float32"):
cls = Expected
with R.dataflow():
B = cls.fused_relax_nn_relu1_compiler_A(A)
C = R.call_tir(cls.relu, (B,), out_sinfo=R.Tensor([10], dtype="float32"))
D = cls.fused_relax_nn_gelu1_compiler_A(C)
R.output(D)
return D

@R.function
def fused_relax_nn_relu1_compiler_A(
Input: R.Tensor([10], dtype="float32")
) -> R.Tensor([10], dtype="float32"):
R.func_attr({"Codegen": "compiler_A"})

@R.function
def composite_lambda(
Input: R.Tensor([10], dtype="float32")
) -> R.Tensor([10], dtype="float32"):
R.func_attr({"Composite": "compiler_A.relu"})
with R.dataflow():
Output = R.nn.relu(Input)
R.output(Output)
return Output

Output = composite_lambda(Input)
return Output

@T.prim_func(private=True)
def relu(
Input: T.Buffer(T.int64(10), "float32"),
Output: T.Buffer(T.int64(10), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True)})
for i in range(T.int64(10)):
with T.block("compute"):
vi = T.axis.remap("S", [i])
Output[vi] = T.max(Input[vi], T.float32(0))

@R.function
def fused_relax_nn_gelu1_compiler_A(
Input: R.Tensor([10], dtype="float32")
) -> R.Tensor([10], dtype="float32"):
R.func_attr({"Codegen": "compiler_A"})

@R.function
def composite_lambda(
Input: R.Tensor([10], dtype="float32")
) -> R.Tensor([10], dtype="float32"):
R.func_attr({"Composite": "compiler_A.gelu"})
with R.dataflow():
Output = R.nn.gelu(Input)
R.output(Output)
return Output

Output = composite_lambda(Input)
return Output

After = relax.transform.MergeCompositeFunctions()(Before)
tvm.ir.assert_structural_equal(Expected, After)


if __name__ == "__main__":
pytest.main([__file__])
Loading