Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax][Bugfix] Apply FuseOps to nested DataflowBlock #17033

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 14 additions & 25 deletions src/relax/transform/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,16 @@ class GraphCreator : public ExprVisitor {
static IndexedForwardGraph Create(IRModule mod, support::Arena* arena) {
GraphCreator creator(mod, arena);
for (const auto& it : mod->functions) {
// Only visit Relax function without attr kPrimitive.
// Only visit Relax functions with neither attr::kPrimitive nor
// attr::kCodegen. Relax functions with `attr::kPrimitive` are
// previously fused functions, potentially from a previous use
// of `FuseOps` or `FuseOpsByPattern`. Relax functions with
// `attr::kCodegen` are previously fused functions from
// `FuseOpsByPattern`, when the `annotate_codegen` option is
// true.
const auto* func = it.second.as<FunctionNode>();
if (func == nullptr || func->HasNonzeroAttr(attr::kPrimitive)) {
if (func == nullptr || func->HasNonzeroAttr(attr::kPrimitive) ||
func->GetAttr<String>(attr::kCodegen).defined()) {
continue;
}
creator(GetRef<Function>(func));
Expand Down Expand Up @@ -142,13 +149,6 @@ class GraphCreator : public ExprVisitor {
ExprVisitor::VisitExpr_(func);
}

void VisitBindingBlock(const BindingBlock& block) final {
if (const auto* df_block = block.as<DataflowBlockNode>()) {
VisitBindingBlock_(df_block);
}
// We skip ordinary binding blocks since they might be impure (with side effect or control flow)
}

void VisitBinding_(const MatchCastNode* binding) final {
IndexedForwardGraph::Node* node = CreateNode(binding->var.get());
SetNodePattern(node, OpPatternKind::kOpaque);
Expand Down Expand Up @@ -262,16 +262,11 @@ class GraphCreator : public ExprVisitor {
IndexedForwardGraph::Node* leaf_node = nullptr;
if (it != graph_.node_map.end()) {
leaf_node = it->second;
} else if (leaf_expr->IsInstance<ConstantNode>() || leaf_expr->IsInstance<ShapeExprNode>() ||
leaf_expr->IsInstance<PrimValueNode>() || leaf_expr->IsInstance<StringImmNode>() ||
leaf_expr->IsInstance<DataTypeImmNode>()) {
} else {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tqchen Responding to the comment here, with conversation migrated to the new PR.

Is this intended change? Seems that we should consider sass where constant value(and there may not be a binding pt)

This change is intentional, and maintains the same behavior for constants as before. Prior to this change, the else if condition was entered for five of the six child classes of relax::LeafExprNode, and the else branch with the LOG(FATAL) was only entered for a relax::Var.

With this change, the behavior of all LeafExpr subclasses is the same: If encoutered without being previously assigned to a group, they are treated as a group of size one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This resolves an issue that results from the use of PostOrderVisit here, in VisitUnsupportedNode. It looks like this function is intended to collect all Var/Constant nodes that are part of the RHS of a binding (e.g. extracting A and B from C = R.add(A,B)). However, it doesn't collect any bindings that may be part of a nested expression. As a result, variable bindings that are part of a conditional's body would erroneously trigger the LOG(FATAL) in VisitLeaf.

leaf_node = CreateNode(leaf_expr.get());
// Since we never fuse constants, the pattern of the constant is set to `kOpaque`.
SetNodePattern(leaf_node, OpPatternKind::kOpaque);
AddToPostDFSOrder(leaf_node, leaf_expr.get());
} else {
LOG(FATAL) << "The leaf Expr is supposed to be defined before, but got: " << leaf_expr
<< " used before definition.";
}
AddEdge(leaf_node, binding_var_node, pattern);
}
Expand Down Expand Up @@ -701,8 +696,10 @@ class OperatorFusor : public ExprMutator {
}
for (const auto& gv : entry_functions) {
const auto& func = mod_->Lookup(gv);
// Only visit Relax function without attr kPrimitive.
if (func->IsInstance<relax::FunctionNode>() && !func->HasNonzeroAttr(attr::kPrimitive)) {
// Only visit Relax functions with neither attr::kPrimitive nor
// attr::kCodegen.
if (func->IsInstance<relax::FunctionNode>() && !func->HasNonzeroAttr(attr::kPrimitive) &&
!func->GetAttr<String>(attr::kCodegen).defined()) {
auto updated_func = Downcast<Function>(VisitExpr(func));
builder_->UpdateFunction(gv, updated_func);
}
Expand Down Expand Up @@ -739,14 +736,6 @@ class OperatorFusor : public ExprMutator {
return false;
}

BindingBlock VisitBindingBlock(const BindingBlock& block) final {
if (const auto* df_block = block.as<DataflowBlockNode>()) {
return VisitBindingBlock_(df_block);
}
// We skip ordinary binding blocks since they might be impure (with side effect or control flow)
return block;
}

BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) final {
group2func_.clear();

Expand Down
101 changes: 101 additions & 0 deletions tests/python/relax/test_transform_fuse_ops_by_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,5 +1243,106 @@ def func(
assert "fused_relax_matmul_relax_add_relax_clip" in func_names


def test_dataflow_inside_branch():
"""Fusion may apply within internal dataflow

While relax::DataflowBlock instances may not contain flow control
or impure functions, they may be contained within flow control
structures.

"""

@I.ir_module
class Before:
@R.function
def main(
x: R.Tensor([1024, 1024], "float16"),
w: R.Tensor([1024, 1024], "float16"),
transpose_weights: R.Prim("bool"),
):
if transpose_weights:
with R.dataflow():
w_t = R.permute_dims(w)
out = R.matmul(x, w_t)
R.output(out)
else:
with R.dataflow():
out = R.matmul(x, w)
R.output(out)
return out

@I.ir_module
class Expected:
@R.function
def main(
x: R.Tensor([1024, 1024], "float16"),
w: R.Tensor([1024, 1024], "float16"),
transpose_weights: R.Prim("bool"),
):
cls = Expected
if transpose_weights:
with R.dataflow():
out_then = cls.fused_relax_permute_dims_relax_matmul_cublas(w, x)
R.output(out_then)
out = out_then
else:
with R.dataflow():
out_else = cls.fused_relax_matmul_cublas(x, w)
R.output(out_else)
out = out_else
return out

@R.function
def fused_relax_permute_dims_relax_matmul_cublas(
w: R.Tensor((1024, 1024), dtype="float16"),
x: R.Tensor((1024, 1024), dtype="float16"),
) -> R.Tensor((1024, 1024), dtype="float16"):
R.func_attr({"Codegen": "cublas"})

@R.function
def local_func(
w_1: R.Tensor((1024, 1024), dtype="float16"),
x_1: R.Tensor((1024, 1024), dtype="float16"),
) -> R.Tensor((1024, 1024), dtype="float16"):
R.func_attr({"Composite": "cublas.matmul_transposed"})
with R.dataflow():
w_t = R.permute_dims(w_1)
out = R.matmul(x_1, w_t)
R.output(out)
return out

output = local_func(w, x)
return output

@R.function
def fused_relax_matmul_cublas(
x: R.Tensor((1024, 1024), dtype="float16"),
w: R.Tensor((1024, 1024), dtype="float16"),
) -> R.Tensor((1024, 1024), dtype="float16"):
R.func_attr({"Codegen": "cublas"})

@R.function
def local_func(
x_1: R.Tensor((1024, 1024), dtype="float16"),
w_1: R.Tensor((1024, 1024), dtype="float16"),
) -> R.Tensor((1024, 1024), dtype="float16"):
R.func_attr({"Composite": "cublas.matmul"})
with R.dataflow():
out = R.matmul(x_1, w_1)
R.output(out)
return out

output = local_func(x, w)
return output

patterns = relax.backend.pattern_registry.get_patterns_with_prefix("cublas.matmul")
After = relax.transform.FuseOpsByPattern(
patterns,
bind_constants=False,
annotate_codegen=True,
)(Before)
tvm.ir.assert_structural_equal(Expected, After)


if __name__ == "__main__":
pytest.main([__file__])
Loading