Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax][Bugfix] Apply FuseOps to nested DataflowBlock #17022

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 14 additions & 25 deletions src/relax/transform/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,16 @@ class GraphCreator : public ExprVisitor {
static IndexedForwardGraph Create(IRModule mod, support::Arena* arena) {
GraphCreator creator(mod, arena);
for (const auto& it : mod->functions) {
// Only visit Relax function without attr kPrimitive.
// Only visit Relax functions with neither attr::kPrimitive nor
// attr::kCodegen. Relax functions with `attr::kPrimitive` are
// previously fused functions, potentially from a previous use
// of `FuseOps` or `FuseOpsByPattern`. Relax functions with
// `attr::kCodegen` are previously fused functions from
// `FuseOpsByPattern`, when the `annotate_codegen` option is
// true.
const auto* func = it.second.as<FunctionNode>();
if (func == nullptr || func->HasNonzeroAttr(attr::kPrimitive)) {
if (func == nullptr || func->HasNonzeroAttr(attr::kPrimitive) ||
func->GetAttr<String>(attr::kCodegen).defined()) {
continue;
}
creator(GetRef<Function>(func));
Expand Down Expand Up @@ -142,13 +149,6 @@ class GraphCreator : public ExprVisitor {
ExprVisitor::VisitExpr_(func);
}

void VisitBindingBlock(const BindingBlock& block) final {
if (const auto* df_block = block.as<DataflowBlockNode>()) {
VisitBindingBlock_(df_block);
}
// We skip ordinary binding blocks since they might be impure (with side effect or control flow)
}

void VisitBinding_(const MatchCastNode* binding) final {
IndexedForwardGraph::Node* node = CreateNode(binding->var.get());
SetNodePattern(node, OpPatternKind::kOpaque);
Expand Down Expand Up @@ -262,16 +262,11 @@ class GraphCreator : public ExprVisitor {
IndexedForwardGraph::Node* leaf_node = nullptr;
if (it != graph_.node_map.end()) {
leaf_node = it->second;
} else if (leaf_expr->IsInstance<ConstantNode>() || leaf_expr->IsInstance<ShapeExprNode>() ||
leaf_expr->IsInstance<PrimValueNode>() || leaf_expr->IsInstance<StringImmNode>() ||
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intended change? Seems that we should consider sass where constant value(and there may not be a binding pt)

leaf_expr->IsInstance<DataTypeImmNode>()) {
} else {
leaf_node = CreateNode(leaf_expr.get());
// Since we never fuse constants, the pattern of the constant is set to `kOpaque`.
SetNodePattern(leaf_node, OpPatternKind::kOpaque);
AddToPostDFSOrder(leaf_node, leaf_expr.get());
} else {
LOG(FATAL) << "The leaf Expr is supposed to be defined before, but got: " << leaf_expr
<< " used before definition.";
}
AddEdge(leaf_node, binding_var_node, pattern);
}
Expand Down Expand Up @@ -701,8 +696,10 @@ class OperatorFusor : public ExprMutator {
}
for (const auto& gv : entry_functions) {
const auto& func = mod_->Lookup(gv);
// Only visit Relax function without attr kPrimitive.
if (func->IsInstance<relax::FunctionNode>() && !func->HasNonzeroAttr(attr::kPrimitive)) {
// Only visit Relax functions with neither attr::kPrimitive nor
// attr::kCodegen.
if (func->IsInstance<relax::FunctionNode>() && !func->HasNonzeroAttr(attr::kPrimitive) &&
!func->GetAttr<String>(attr::kCodegen).defined()) {
auto updated_func = Downcast<Function>(VisitExpr(func));
builder_->UpdateFunction(gv, updated_func);
}
Expand Down Expand Up @@ -739,14 +736,6 @@ class OperatorFusor : public ExprMutator {
return false;
}

BindingBlock VisitBindingBlock(const BindingBlock& block) final {
if (const auto* df_block = block.as<DataflowBlockNode>()) {
return VisitBindingBlock_(df_block);
}
// We skip ordinary binding blocks since they might be impure (with side effect or control flow)
return block;
}

BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) final {
group2func_.clear();

Expand Down
101 changes: 101 additions & 0 deletions tests/python/relax/test_transform_fuse_ops_by_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,5 +1243,106 @@ def func(
assert "fused_relax_matmul_relax_add_relax_clip" in func_names


def test_dataflow_inside_branch():
"""Fusion may apply within internal dataflow

While relax::DataflowBlock instances may not contain flow control
or impure functions, they may be contained within flow control
structures.

"""

@I.ir_module
class Before:
@R.function
def main(
x: R.Tensor([1024, 1024], "float16"),
w: R.Tensor([1024, 1024], "float16"),
transpose_weights: R.Prim("bool"),
):
if transpose_weights:
with R.dataflow():
w_t = R.permute_dims(w)
out = R.matmul(x, w_t)
R.output(out)
else:
with R.dataflow():
out = R.matmul(x, w)
R.output(out)
return out

@I.ir_module
class Expected:
@R.function
def main(
x: R.Tensor([1024, 1024], "float16"),
w: R.Tensor([1024, 1024], "float16"),
transpose_weights: R.Prim("bool"),
):
cls = Expected
if transpose_weights:
with R.dataflow():
out_then = cls.fused_relax_permute_dims_relax_matmul_cublas(w, x)
R.output(out_then)
out = out_then
else:
with R.dataflow():
out_else = cls.fused_relax_matmul_cublas(x, w)
R.output(out_else)
out = out_else
return out

@R.function
def fused_relax_permute_dims_relax_matmul_cublas(
w: R.Tensor((1024, 1024), dtype="float16"),
x: R.Tensor((1024, 1024), dtype="float16"),
) -> R.Tensor((1024, 1024), dtype="float16"):
R.func_attr({"Codegen": "cublas"})

@R.function
def local_func(
w_1: R.Tensor((1024, 1024), dtype="float16"),
x_1: R.Tensor((1024, 1024), dtype="float16"),
) -> R.Tensor((1024, 1024), dtype="float16"):
R.func_attr({"Composite": "cublas.matmul_transposed"})
with R.dataflow():
w_t = R.permute_dims(w_1)
out = R.matmul(x_1, w_t)
R.output(out)
return out

output = local_func(w, x)
return output

@R.function
def fused_relax_matmul_cublas(
x: R.Tensor((1024, 1024), dtype="float16"),
w: R.Tensor((1024, 1024), dtype="float16"),
) -> R.Tensor((1024, 1024), dtype="float16"):
R.func_attr({"Codegen": "cublas"})

@R.function
def local_func(
x_1: R.Tensor((1024, 1024), dtype="float16"),
w_1: R.Tensor((1024, 1024), dtype="float16"),
) -> R.Tensor((1024, 1024), dtype="float16"):
R.func_attr({"Composite": "cublas.matmul"})
with R.dataflow():
out = R.matmul(x_1, w_1)
R.output(out)
return out

output = local_func(x, w)
return output

patterns = relax.backend.pattern_registry.get_patterns_with_prefix("cublas.matmul")
After = relax.transform.FuseOpsByPattern(
patterns,
bind_constants=False,
annotate_codegen=True,
)(Before)
tvm.ir.assert_structural_equal(Expected, After)


if __name__ == "__main__":
pytest.main([__file__])
Loading