From 569a00a59a0f6a8a93dbb3be231a01e38567dc54 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 23 May 2024 10:15:18 -0500 Subject: [PATCH] [Relax][Bugfix] Apply FuseOps to nested DataflowBlock While it is ill-formed for control-flow to occur within a `DataflowBlock`, it is legal for a `DataflowBlock` to be contained within a control-flow. Prior to this commit, the `FuseOps` and `FuseOpsByPattern` transforms erroneously skipped `DataflowBlock` instances that were contained within a `relax::If` node. This commit updates `FuseOps` to apply operator fusion to any dataflow block, regardless of whether it is found at the top level of a a Relax function. Co-authored-by: Chris Sullivan --- src/relax/transform/fuse_ops.cc | 39 +++---- .../test_transform_fuse_ops_by_pattern.py | 101 ++++++++++++++++++ 2 files changed, 115 insertions(+), 25 deletions(-) diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index e89c5e44454f..c4bd52eff18e 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -108,9 +108,16 @@ class GraphCreator : public ExprVisitor { static IndexedForwardGraph Create(IRModule mod, support::Arena* arena) { GraphCreator creator(mod, arena); for (const auto& it : mod->functions) { - // Only visit Relax function without attr kPrimitive. + // Only visit Relax functions with neither attr::kPrimitive nor + // attr::kCodegen. Relax functions with `attr::kPrimitive` are + // previously fused functions, potentially from a previous use + // of `FuseOps` or `FuseOpsByPattern`. Relax functions with + // `attr::kCodegen` are previously fused functions from + // `FuseOpsByPattern`, when the `annotate_codegen` option is + // true. const auto* func = it.second.as(); - if (func == nullptr || func->HasNonzeroAttr(attr::kPrimitive)) { + if (func == nullptr || func->HasNonzeroAttr(attr::kPrimitive) || + func->GetAttr(attr::kCodegen).defined()) { continue; } creator(GetRef(func)); @@ -142,13 +149,6 @@ class GraphCreator : public ExprVisitor { ExprVisitor::VisitExpr_(func); } - void VisitBindingBlock(const BindingBlock& block) final { - if (const auto* df_block = block.as()) { - VisitBindingBlock_(df_block); - } - // We skip ordinary binding blocks since they might be impure (with side effect or control flow) - } - void VisitBinding_(const MatchCastNode* binding) final { IndexedForwardGraph::Node* node = CreateNode(binding->var.get()); SetNodePattern(node, OpPatternKind::kOpaque); @@ -262,16 +262,11 @@ class GraphCreator : public ExprVisitor { IndexedForwardGraph::Node* leaf_node = nullptr; if (it != graph_.node_map.end()) { leaf_node = it->second; - } else if (leaf_expr->IsInstance() || leaf_expr->IsInstance() || - leaf_expr->IsInstance() || leaf_expr->IsInstance() || - leaf_expr->IsInstance()) { + } else { leaf_node = CreateNode(leaf_expr.get()); // Since we never fuse constants, the pattern of the constant is set to `kOpaque`. SetNodePattern(leaf_node, OpPatternKind::kOpaque); AddToPostDFSOrder(leaf_node, leaf_expr.get()); - } else { - LOG(FATAL) << "The leaf Expr is supposed to be defined before, but got: " << leaf_expr - << " used before definition."; } AddEdge(leaf_node, binding_var_node, pattern); } @@ -701,8 +696,10 @@ class OperatorFusor : public ExprMutator { } for (const auto& gv : entry_functions) { const auto& func = mod_->Lookup(gv); - // Only visit Relax function without attr kPrimitive. - if (func->IsInstance() && !func->HasNonzeroAttr(attr::kPrimitive)) { + // Only visit Relax functions with neither attr::kPrimitive nor + // attr::kCodegen. + if (func->IsInstance() && !func->HasNonzeroAttr(attr::kPrimitive) && + !func->GetAttr(attr::kCodegen).defined()) { auto updated_func = Downcast(VisitExpr(func)); builder_->UpdateFunction(gv, updated_func); } @@ -739,14 +736,6 @@ class OperatorFusor : public ExprMutator { return false; } - BindingBlock VisitBindingBlock(const BindingBlock& block) final { - if (const auto* df_block = block.as()) { - return VisitBindingBlock_(df_block); - } - // We skip ordinary binding blocks since they might be impure (with side effect or control flow) - return block; - } - BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) final { group2func_.clear(); diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py b/tests/python/relax/test_transform_fuse_ops_by_pattern.py index f5905f764351..1582526042f1 100644 --- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py +++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py @@ -1243,5 +1243,106 @@ def func( assert "fused_relax_matmul_relax_add_relax_clip" in func_names +def test_dataflow_inside_branch(): + """Fusion may apply within internal dataflow + + While relax::DataflowBlock instances may not contain flow control + or impure functions, they may be contained within flow control + structures. + + """ + + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor([1024, 1024], "float16"), + w: R.Tensor([1024, 1024], "float16"), + transpose_weights: R.Prim("bool"), + ): + if transpose_weights: + with R.dataflow(): + w_t = R.permute_dims(w) + out = R.matmul(x, w_t) + R.output(out) + else: + with R.dataflow(): + out = R.matmul(x, w) + R.output(out) + return out + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor([1024, 1024], "float16"), + w: R.Tensor([1024, 1024], "float16"), + transpose_weights: R.Prim("bool"), + ): + cls = Expected + if transpose_weights: + with R.dataflow(): + out_then = cls.fused_relax_permute_dims_relax_matmul_cublas(w, x) + R.output(out_then) + out = out_then + else: + with R.dataflow(): + out_else = cls.fused_relax_matmul_cublas(x, w) + R.output(out_else) + out = out_else + return out + + @R.function + def fused_relax_permute_dims_relax_matmul_cublas( + w: R.Tensor((1024, 1024), dtype="float16"), + x: R.Tensor((1024, 1024), dtype="float16"), + ) -> R.Tensor((1024, 1024), dtype="float16"): + R.func_attr({"Codegen": "cublas"}) + + @R.function + def local_func( + w_1: R.Tensor((1024, 1024), dtype="float16"), + x_1: R.Tensor((1024, 1024), dtype="float16"), + ) -> R.Tensor((1024, 1024), dtype="float16"): + R.func_attr({"Composite": "cublas.matmul_transposed"}) + with R.dataflow(): + w_t = R.permute_dims(w_1) + out = R.matmul(x_1, w_t) + R.output(out) + return out + + output = local_func(w, x) + return output + + @R.function + def fused_relax_matmul_cublas( + x: R.Tensor((1024, 1024), dtype="float16"), + w: R.Tensor((1024, 1024), dtype="float16"), + ) -> R.Tensor((1024, 1024), dtype="float16"): + R.func_attr({"Codegen": "cublas"}) + + @R.function + def local_func( + x_1: R.Tensor((1024, 1024), dtype="float16"), + w_1: R.Tensor((1024, 1024), dtype="float16"), + ) -> R.Tensor((1024, 1024), dtype="float16"): + R.func_attr({"Composite": "cublas.matmul"}) + with R.dataflow(): + out = R.matmul(x_1, w_1) + R.output(out) + return out + + output = local_func(x, w) + return output + + patterns = relax.backend.pattern_registry.get_patterns_with_prefix("cublas.matmul") + After = relax.transform.FuseOpsByPattern( + patterns, + bind_constants=False, + annotate_codegen=True, + )(Before) + tvm.ir.assert_structural_equal(Expected, After) + + if __name__ == "__main__": pytest.main([__file__])