From 523a815dd4f628d421c326bbf32e45b2417fb46a Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 4 Apr 2022 22:04:24 +0800 Subject: [PATCH] [Relax][Pass] FuseOps (#15) * [Relax][Pass] FuseOps * Fix 2 bugs --- include/tvm/relax/transform.h | 7 +- python/tvm/relax/transform/transform.py | 7 +- src/relax/transform/fuse_ops.cc | 1180 ++++++++++++----------- 3 files changed, 642 insertions(+), 552 deletions(-) diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index 5c43b06a8a..a8891033fc 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -139,7 +139,12 @@ TVM_DLL Pass LayoutRewrite(); TVM_DLL Pass FoldConstant(); /*! - * \brief Fuse operators in an expr to a larger operator according to some rules. + * \brief This pass groups bindings in a dataflow block of Relaxfunctions and generate a new grouped + * Relax function for each group, according to the fusion algorithm described in the pass + * implementation. By grouping bindings into new Relax functions, we substitute the bindings in the + * function being manipulated into function calls to the new grouped function. + * + * A follow-up pass named "FuseTIR" will generate a TIR PrimFunc for each grouped function. * \param fuse_opt_level The level of fuse optimization. * -1 indicates that the level will be inferred from pass context. * \return The Pass. diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index f5f42fd3b6..fe249cf3f6 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -183,7 +183,12 @@ def LayoutRewrite() -> tvm.ir.transform.Pass: def FuseOps(fuse_opt_level=-1) -> tvm.ir.transform.Pass: - """Fuse operators in an expr to a larger operator according to some rules. + """This pass groups bindings in a dataflow block of Relax functions and generate a new grouped + Relax function for each group, according to the fusion algorithm described in the pass + implementation. By grouping bindings into new Relax functions, we substitute the bindings in + the function being manipulated into function calls to the new grouped function. + + A follow-up pass named "FuseTIR" will generate a TIR PrimFunc for each grouped function. Parameters ---------- diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 7939073689..c7bbff0eb9 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -17,8 +17,13 @@ * under the License. */ /*! - * \file src/relax/transform/fma_rewrite.cc - * \brief Perform fused multiply add rewriting in dataflow blocks. + * \file src/relax/transform/fuse_ops.cc + * \brief This file contains a pass which groups bindings in a dataflow block of Relax + * functions and generate a new grouped Relax function for each group, according to the fusion + * algorithm described below. By grouping bindings into new Relax functions, we substitute the + * bindings in the function being manipulated into function calls to the new grouped function. + * + * A follow-up pass named "FuseTIR" will generate a TIR PrimFunc for each grouped function. */ #include @@ -79,670 +84,745 @@ constexpr uint32_t kMaxFusedOps = 256; TVM_REGISTER_PASS_CONFIG_OPTION("relax.FuseOps.max_depth", Integer); -// Creator of post dominator tree of the dataflow -class IndexedForwardGraphCreator : private ExprVisitor { +class GraphCreator : public ExprVisitor { public: - static IndexedForwardGraph Create(const IRModule& mod, support::Arena* arena) { - GlobalVar main_global_var = mod->GetGlobalVar("main"); - Function body = Downcast(mod->Lookup(main_global_var)); - IndexedForwardGraphCreator creator(arena, mod); - for (const auto& kv : mod->functions) { - const Expr& func = kv.second; - if (func->IsInstance()) { - creator.VisitExpr(body); - } - } - // creator.graph_.DebugDump(); + /*! + * \brief Create a IndexedForwardGraph according to the input module. The graph will be used for + * graph partition and operator fusion. + * \param mod The module which the creation accords to + * \param arena The allocator of all the internal node objects + * \return The created IndexedForwardGraph + */ + static IndexedForwardGraph Create(IRModule mod, support::Arena* arena) { + // Since cross-function call is not supported yet, FuseOps only serves the entry function, whose + // name is "main". + auto relax_func = Downcast(mod->Lookup("main")); + GraphCreator creator(mod, arena); + creator(relax_func); + + // The algorithm of the graph creator ensures that each created node will be added to the + // post-dfs order and will be set its op pattern. Thus we check whether all these containers + // have the same size. + size_t n_nodes = creator.graph_.node_map.size(); + ICHECK_EQ(n_nodes, creator.graph_.post_dfs_order.size()); + ICHECK_EQ(n_nodes, creator.initialized_nodes_.size()); + return creator.graph_; } private: - void VisitExpr_(const ConstantNode* op) final { - this->CreateNode(op); - this->AddNode(op); - IndexedForwardGraph::Node* node = graph_.node_map.at(op); - DataType dtype = DataType(op->data->dtype); - // This rule must be consistent with code generator. - bool is_simple_const = - (dtype == DataType::Int(32) || dtype == DataType::Int(64) || dtype == DataType::Float(32) || - dtype == DataType::Float(64) || dtype == DataType::Bool()); - if (op->is_scalar() && is_simple_const) { - node->pattern = OpPatternKind::kElemWise; - } else { - // for now, mark non-scalar constant - // as opaque, we will not choose to fuse it. - node->pattern = OpPatternKind::kOpaque; + explicit GraphCreator(IRModule mod, support::Arena* arena) + : mod_(std::move(mod)), arena_(arena) {} + + void VisitExpr_(const FunctionNode* func) final { + for (const Var& param : func->params) { + IndexedForwardGraph::Node* param_node = CreateNode(param.get()); + LOG(INFO) << "[2]. Create node for param " << param << ". node is " << param_node; + // The parameter is passed in from the outside, and thus it's marked as an external reference, + // and it's pattern is `kOpaque`. + MarkAsExternRef(param_node); + SetNodePattern(param_node, OpPatternKind::kOpaque); + AddToPostDFSOrder(param_node, param.get()); } + ExprVisitor::VisitExpr_(func); } - void VisitExpr_(const FunctionNode* op) final { - for (const Var& param : op->params) { - CreateNode(param.get()); - this->UpdateEdge(param, nullptr, OpPatternKind::kOpaque); + void VisitBindingBlock(const BindingBlock& block) final { + if (const auto* df_block = block.as()) { + VisitBindingBlock_(df_block); } - ExprVisitor::VisitExpr_(op); - } - - void VisitExpr_(const VarNode* op) final { this->AddNode(op); } - - void VisitBindingBlock_(const BindingBlockNode* block) final { - // Skip Binding Block since it's imprue (with side effect or control flow) - return; + // We skip ordinary binding blocks since they might be impure (with side effect or control flow) } - void VisitBinding_(const MatchShapeNode* binding) final { - auto node = CreateNode(binding->var.get()); - this->UpdateEdge(binding->var, node, OpPatternKind::kInjective); - } + // TODO(tvm-team): how to deal with MatchShape binding here void VisitBinding_(const VarBindingNode* binding) final { - // Don't allow recursive var binding. - ICHECK(!cur_binding_var_.defined()); - cur_binding_var_ = binding->var; - CreateNode(binding->var.get()); + CHECK(cur_binding_var_node_ == nullptr) + << "We are visiting a new VarBinding inside an outer binding, which is not allowed"; + cur_binding_var_node_ = CreateNode(binding->var.get()); + LOG(INFO) << "[3]. Create node for binding var " << binding->var->name_hint() << ". node is " + << cur_binding_var_node_ << ", type key is " << binding->var->GetTypeKey(); + + // If the variable is not a dataflow variable, it must be the output variable of this dataflow + // block if (!binding->var->IsInstance()) { - this->UpdateEdge(binding->var, nullptr, OpPatternKind::kOpaque); + this->MarkAsExternRef(cur_binding_var_node_); } ExprVisitor::VisitBinding_(binding); - cur_binding_var_ = NullOpt; + AddToPostDFSOrder(cur_binding_var_node_, binding->var.get()); + cur_binding_var_node_ = nullptr; + } + + void VisitExpr(const Expr& expr) final { + if (cur_binding_var_node_ == nullptr) { + // Case 1. The expression is not under a binding. No action is needed. + } else if (expr->IsInstance() || expr->IsInstance()) { + // Case 2. The type of the expression is supported by fusion (as defined below). No action is + // needed again - we will recurse into this expression and let the visitor deal with such + // expressions. + } else if (!IsLeaf(expr)) { + // Case 3. The type of the expression is not fusion-supported and the expression is not a + // leaf. In this case, we set the pattern of the current binding variable to be `kOpaque`. + ICHECK(cur_pattern_ == OpPatternKind::kOpaque); + SetNodePattern(cur_binding_var_node_, OpPatternKind::kOpaque); + } else { + // Case 4. The expression is a leaf expression, which currently is not fusion-supported. + // - Under such circumstances, if the current binding value is exactly the expression + // itself, the pattern of the current binding variable is not set. + // - Otherwise, the pattern of the current binding variable must have been set. + // - Thus, we set the pattern of the current binding variable to `kOpaque` (since leaf + // expressions are not fusion-supported) if it hasn't been set yet. + if (initialized_nodes_.find(cur_binding_var_node_) == initialized_nodes_.end()) { + ICHECK(cur_pattern_ == OpPatternKind::kOpaque); + SetNodePattern(cur_binding_var_node_, OpPatternKind::kOpaque); + } + } + ExprVisitor::VisitExpr(expr); } - void VisitExpr_(const CallNode* op) final { + /********** Non-Leaf Expression Nodes **********/ + + void VisitExpr_(const CallNode* call) final { + // If the function call is not under a binding, there is no need to recurse into it. + if (cur_binding_var_node_ == nullptr) { + return; + } static const Op& call_tir_op_ = Op::Get("relax.call_tir"); - ICHECK(cur_binding_var_.defined()); - const Var& binding_var = cur_binding_var_.value(); - auto it = graph_.node_map.find(binding_var.get()); - ICHECK(it != graph_.node_map.end()); - IndexedForwardGraph::Node* node = it->second; - - // If the pattern is not annotated we will default to opaque. - OpPatternKind op_pattern = OpPatternKind::kOpaque; - ICHECK(op->op->IsInstance()); - if (op->op == call_tir_op_) { - GlobalVar global_var = Downcast(op->args[0]); + // - If the op being called is a TIR PrimFunc, we get the function op pattern directly from the + // function attribute and visit the arguments one by one. + // - Otherwise, the pattern of the current binding variable node is set to `kOpaque`, and we + // recurse into the call expression. + const auto* op = call->op.as(); + if (op == call_tir_op_.get()) { + const GlobalVar& global_var = Downcast(call->args[0]); tir::PrimFunc func = Downcast(mod_->Lookup(global_var)); - const Tuple& args = Downcast(op->args[1]); - const Expr& shape = op->args[2]; - int func_pattern = func->GetAttr("op_pattern").value_or(OpPatternKind::kOpaque); - // TODO(siyuan): Check the integer data is valid - op_pattern = static_cast(func_pattern); + const Tuple& args = Downcast(call->args[1]); + // TODO(tvm-team): handle the shape argument (args[3]) + Optional opt_pattern = func->GetAttr("op_pattern"); + OpPatternKind pattern; + if (opt_pattern.defined()) { + pattern = static_cast(Downcast(opt_pattern)->value); + } else { + pattern = OpPatternKind::kOpaque; + } + // The pattern of the current binding variable node is set to the pattern of this operator. + SetNodePattern(cur_binding_var_node_, pattern); for (const Expr& arg : args->fields) { - this->VisitExpr(arg); - this->UpdateEdge(arg, node, op_pattern); + // If the operator pattern was detected to be `kBroadcast`, and meanwhile this argument has + // the same shape as the operator output, the relation between this argument and the output + // is actually element-wise. And in this case we change the pattern to `kElemWise` + // temporarily. + if (pattern == OpPatternKind::kBroadcast && structural_equal_(call->shape_, arg->shape_)) { + LOG(INFO) << "[13]. Change current pattern to element-wise"; + cur_pattern_ = OpPatternKind::kElemWise; + } else { + cur_pattern_ = pattern; + } + VisitExpr(arg); } } else { - LOG(FATAL) << "The call op " << op->op << " is not supported in dataflow block for now."; + SetNodePattern(cur_binding_var_node_, OpPatternKind::kOpaque); + ExprVisitor::VisitExpr_(call); } - node->pattern = op_pattern; - this->AddNode(binding_var.get()); + + // Restore the value of `cur_pattern_`. + cur_pattern_ = OpPatternKind::kOpaque; } - void VisitExpr_(const TupleGetItemNode* op) final { - ICHECK(cur_binding_var_.defined()); - const Var& binding_var = cur_binding_var_.value(); - auto it = graph_.node_map.find(binding_var.get()); - ICHECK(it != graph_.node_map.end()); - IndexedForwardGraph::Node* node = it->second; + void VisitExpr_(const TupleGetItemNode* tuple_item) final { + // If the tuple-get-item node is not under a binding, there is no need to recurse into it. + if (cur_binding_var_node_ == nullptr) { + return; + } - node->pattern = OpPatternKind::kInjective; - this->UpdateEdge(op->tuple, node, OpPatternKind::kInjective); - this->AddNode(binding_var.get()); - } + SetNodePattern(cur_binding_var_node_, OpPatternKind::kInjective); + cur_pattern_ = OpPatternKind::kInjective; + VisitExpr(tuple_item->tuple); - void VisitExpr_(const IfNode* op) final { - LOG(FATAL) << "Dataflow block expects no Control flow inside."; + // Restore the value of `cur_pattern_`. + cur_pattern_ = OpPatternKind::kOpaque; } - private: - explicit IndexedForwardGraphCreator(support::Arena* arena, const IRModule& mod) - : mod_(mod), arena_(arena) {} + /********** Leaf Expression Nodes **********/ - // Helper functions to maintain IndexedForwardGraph - /*! - * \brief The - * \param node The Relax IR nodes - * \param parent The parent node in the IndexedForwardGraph. - * The source is external if the parent is nullptr. - * \param pattern The relation pattern between the node and its parent. - */ - void UpdateEdge(const Expr& node, IndexedForwardGraph::Node* parent, OpPatternKind pattern) { - const tvm::Object* key = node.get(); - auto it = graph_.node_map.find(key); - ICHECK(it != graph_.node_map.end()); - IndexedForwardGraph::Node* current = it->second; - if (parent != nullptr) { - auto* link = arena_->make>(); - link->value.node = parent; - link->value.pattern = pattern; - current->outputs.Push(link); + void VisitExpr_(const ConstantNode* constant) final { + // If the constant is not under a binding, there is no need to recurse into it. + if (cur_binding_var_node_ == nullptr) { + return; + } + // TODO(tvm-team): what about constant shape in match-shape? + + // If we're visiting the constant for the first time, we create a node for it. + // Otherwise, we fetch the node from the node map of the graph. + auto it_const = graph_.node_map.find(constant); + IndexedForwardGraph::Node* const_node = nullptr; + if (it_const == graph_.node_map.end()) { + const_node = CreateNode(constant); + LOG(INFO) << "[4]. Create node for constant. node is " << const_node; + // Since we never fuse constants, the pattern of the constant is set to `kOpaque`. + SetNodePattern(const_node, OpPatternKind::kOpaque); + AddToPostDFSOrder(const_node, constant); } else { - current->extern_ref = true; + const_node = it_const->second; } + AddEdge(const_node, cur_binding_var_node_, OpPatternKind::kOpaque); } - void AddNode(const tvm::Object* key) { - auto it = graph_.node_map.find(key); - ICHECK(it != graph_.node_map.end()) << "Cannot find node " << GetRef(key); - IndexedForwardGraph::Node* node = it->second; - if (node->ref == nullptr) { - node->ref = key; - node->index = graph_.post_dfs_order.size(); - graph_.post_dfs_order.push_back(node); - } else { - ICHECK(node->ref == key); + void VisitExpr_(const VarNode* var) final { + // If the variable is not under a binding, there is no need to recurse into it. + if (cur_binding_var_node_ == nullptr) { + return; } + auto it_var = graph_.node_map.find(var); + CHECK(it_var != graph_.node_map.end()) << "The variable is supposed to be defined before"; + + // - If the variable is a component of some other binding value (call or tuple-get-item), + // `cur_pattern_` is supposed to be properly set already. + // - Otherwise, `cur_pattern_` is supposed to be `kOpaque` by default. + AddEdge(it_var->second, cur_binding_var_node_, cur_pattern_); } - IndexedForwardGraph::Node* CreateNode(const tvm::Object* key) { - ICHECK(graph_.node_map.find(key) == graph_.node_map.end()); - IndexedForwardGraph::Node* node = arena_->make(); + void VisitExpr_(const DataflowVarNode* var) final { VisitExpr_(GetRef(var).get()); } + + /********** Helper Functions **********/ + + /*! + * \brief Check whether the expression is a leaf expression + * \param expr The expression to be checked + * \return Whether the expression is a leaf expression + * \note In order to avoid too much refactor, this method is a simple copy-paste of the is-leaf + * check in "block_builder.cc". And it should be refactored in the future. + * \sa src/relax/ir/block_builder.cc + */ + static bool IsLeaf(const Expr& expr) { + // NOTE: Tuples are treated as leaf nodes for ergonomics + return expr.as() || expr.as() || expr.as() || + expr.as() || expr.as() || expr.as() || + expr.as(); + } + + /*! + * \brief Create a graph node corresponding to the input key + * \param key The object which is used to create the graph node + * \return The created graph node + * \note The node corresponding to each key is supposed to be created for only once + */ + IndexedForwardGraph::Node* CreateNode(const Object* key) { + ICHECK(graph_.node_map.find(key) == graph_.node_map.end()) + << "The node corresponding to the input key is not supposed to be created before"; + auto* node = arena_->make(); graph_.node_map[key] = node; return node; } - private: - /*! \brief The whole IRModule */ - const IRModule& mod_; - /*! \brief Allocator of all the internal node object */ - support::Arena* arena_; - /*! \brief The output graph */ - IndexedForwardGraph graph_; - /*! \brief current binding var */ - Optional cur_binding_var_ = NullOpt; -}; - -class RelaxFuseMutator : public ExprMutator { - public: - // Run the transform - static IRModule Transform(const IRModule& mod, int fuse_opt_level, size_t max_fuse_depth) { - // setup the group map. - RelaxFuseMutator mutator(mod); - auto graph = IndexedForwardGraphCreator::Create(mod, &mutator.arena_); - auto groups = - GraphPartitioner(&mutator.arena_, fuse_opt_level, max_fuse_depth).Partition(graph); - for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) { - ICHECK(graph.post_dfs_order[nid]->ref != nullptr); - mutator.gmap_[graph.post_dfs_order[nid]->ref] = groups[nid]; - } + /*! + * \brief Append the input node to the post-dfs order of the graph + * \param node The node to be appended + * \param key The key corresponding to the node + * \note Each node is supposed to be appended to the post-dfs order for only once + */ + void AddToPostDFSOrder(IndexedForwardGraph::Node* node, const Object* key) { + auto it = graph_.node_map.find(key); + ICHECK(it != graph_.node_map.end() && it->second == node) + << "The node must have been created before adding to the post-dfs order"; + + // We only set the reference of the node when adding it to the post-dfs order. Thus, if the + // reference of a node is already set, it must have been appended to the post-dfs order. + ICHECK(node->ref == nullptr) + << "The node is not supposed to be added into the post-dfs order before"; + + LOG(INFO) << "[6]. Add node " << node << " to post-dfs order"; + node->ref = key; + node->index = graph_.post_dfs_order.size(); + graph_.post_dfs_order.push_back(node); + } - // The following line can be used for debug. - // GroupDebugDumper::Dump(mod, mutator.gmap_); + /*! + * \brief Add an edge from the input start to the input end in the graph, with specific pattern + * \param start The start of the edge + * \param end The end of the edge + * \param pattern The pattern of this edge + */ + void AddEdge(IndexedForwardGraph::Node* start, IndexedForwardGraph::Node* end, + OpPatternKind pattern) { + LOG(INFO) << "[8]. Edge: " << start << " ---> " << end << ", pattern: " << (int)pattern; + auto* link = arena_->make>(); + link->value.node = end; + link->value.pattern = pattern; + start->outputs.Push(link); + } - for (const auto& kv : mod->functions) { - Expr func = kv.second; - const GlobalVar& global_var = kv.first; - if (func->IsInstance()) { - func = mutator.VisitExpr(func); - mutator.builder_->AddFuncToContext(Downcast(func), global_var->name_hint); - } - } + /*! + * \brief Mark a given node as "external reference", which means the node cannot be fused as an + * intermediate node + * \param node The graph node to be marked + */ + void MarkAsExternRef(IndexedForwardGraph::Node* node) { node->extern_ref = true; } - return mutator.builder_->GetContextIRModule(); + /*! + * \brief Set the pattern of the input node + * \param node The graph node to be set + * \param pattern The pattern of the node + */ + void SetNodePattern(IndexedForwardGraph::Node* node, OpPatternKind pattern) { + ICHECK(initialized_nodes_.find(node) == initialized_nodes_.end()) + << "The input node is supposed to be set pattern for only once"; + LOG(INFO) << "[7]. Set pattern of node " << node << " to " << static_cast(pattern); + initialized_nodes_.insert(node); + node->pattern = pattern; } private: - explicit RelaxFuseMutator(const IRModule& mod) : mod_(mod) {} - - BindingBlock VisitBindingBlock_(const BindingBlockNode* block) final { - // Skip Binding Block since it's imprue (with side effect or control flow) - return GetRef(block); - } + /*! \brief The IRModule from which the indexed forward graph is created */ + IRModule mod_; + /*! \brief The allocator of all the internal node objects */ + support::Arena* arena_; + /*! \brief The created indexed forward graph */ + IndexedForwardGraph graph_; + /*! \brief The variable in the current VarBinding */ + IndexedForwardGraph::Node* cur_binding_var_node_ = nullptr; + /*! \brief The op pattern of the current binding */ + OpPatternKind cur_pattern_ = OpPatternKind::kOpaque; + /*! \brief The graph nodes whose patterns are set */ + std::unordered_set initialized_nodes_; + /*! \brief The structural equality checker */ + StructuralEqual structural_equal_; +}; - void VisitBinding_(const VarBindingNode* binding) final { - const Var& var = binding->var; - ICHECK(gmap_.count(var.get())); - cur_group_ = gmap_.at(var.get())->FindRoot(); - - auto it = ginfo_.find(cur_group_); - if (it == ginfo_.end()) { - // This is a new group - if (cur_group_->root_ref == var.get()) { - // Don't create new function if there is only one binding in the new group - ExprMutator::VisitBinding_(binding); - } else { - ginfo_[cur_group_] = GroupInfo(); - builder_->BeginDataflowBlock(); - ExprMutator::VisitBinding_(binding); - } - } else { - Expr new_value = this->VisitExpr(binding->value); - if (cur_group_->root_ref == var.get()) { - builder_->EmitOutput(new_value); - Var new_var = builder_->Emit(MakeNewFunction()); - this->var_remap_[var->vid] = new_var; +/*! + * \brief The ExprMutator used to create a new grouped function + * \details The workflow of this ExprMutator is: + * - The bindings in the function will be added by OperatorFusor via `AppendBinding(...)`. + * - When adding a new binding through `AppendBinding(...)`, we check whether the variables and + * constants used by the binding are defined by some previous added binding. And for the undefined + * variables and constants, we add them to the argument list and created new variables as the + * corresponding parameters. + * - When `CreateFunction()` is called, we go through each binding and update the binding with the + * new parameters. After that we wrap all bindings with a DataflowBlock and a Function. + */ +class FunctionCreator : public ExprMutator { + public: + /*! + * \brief Append a new binding to this function and possibly create new parameters for the + * function accordingly + * \param binding The binding to be appended + * \note Allowed bindings are: + * - VarBinding with value being a call node calling `relax.call_tir`. + * - VarBinding with value being a tuple-get-item node. + * // TODO(tvm-team): handle match shape + */ + void AppendBinding(const Binding& binding) { + ICHECK(!function_.defined()) + << "The `function_` is supposed to be uncreated when adding bindings"; + ICHECK(!has_output_var_) + << "It's not allowed to append more bindings once the function has an output variable"; + + if (const auto* var_binding = binding.as()) { + if (const auto* call = var_binding->value.as()) { + ICHECK(call->op == Op::Get("relax.call_tir")); + const GlobalVar& global_var = Downcast(call->args[0]); + // Update the name of the function. + name_ = name_ + "_" + Downcast(call->args[0])->name_hint; + + const Tuple& args = Downcast(call->args[1]); + for (const Expr& arg : args->fields) { + CheckDefAndUpdateParam(arg); + } + // TODO(tvm-team): handle shape expr } else { - builder_->Emit(new_value); + const auto* tuple_item = var_binding->value.as(); + ICHECK(tuple_item != nullptr); + CheckDefAndUpdateParam(tuple_item->tuple); } - } - } - - Expr VisitExpr_(const CallNode* call) final { - static const Op& call_tir_op_ = Op::Get("relax.call_tir"); - if (call->op.as()) { - if (call->op == call_tir_op_) { - return VisitCallTIR(call); - } else { - LOG(FATAL) << "Unsupported OpNode: " << call->op; - return Expr(); + // Mark the binding variable as defined. + defined_vars_.insert(var_binding->var.get()); + // Set `has_output_var_` to true if the binding variable is an output variable (a.k.a. is not + // a dataflow variable). + if (!var_binding->var->IsInstance()) { + has_output_var_ = true; } } else { - return ExprMutator::VisitExpr_(call); + // TODO(tvm-team): handle match_shape } + bindings_.push_back(binding); } - Expr VisitExpr_(const TupleGetItemNode* op) final { - if (ginfo_.find(cur_group_) != ginfo_.end()) { - auto t = ginfo_[cur_group_].GetOrAllocParam(op->tuple); - return TupleGetItem(t, op->index, op->span); + /*! + * \brief Create the grouped function according according to the collected bindings and parameters + * \note The created function won't be returned immediately. Tt's stored in the `function_` field. + */ + void CreateFunction() { + // Step 1. Start constructing a new dataflow block. + builder_->BeginDataflowBlock(); + // Step 2. Visit each binding, except the last one, one by one. + for (int i = 0; i < static_cast(bindings_.size()) - 1; ++i) { + VisitBinding(bindings_[i]); } - return ExprMutator::VisitExpr_(op); + + // Step 3. Since the binding var of the last binding should be an output variable, we deal with + // the last binding separately. + const auto* last_binding = bindings_.back().as(); + ICHECK(last_binding != nullptr) << "The last binding of a group is supposed to be a VarBinding"; + Expr binding_value = VisitExpr(last_binding->value); + Var output_var(last_binding->var->vid, NullOpt, last_binding->var->checked_type_); + output_var->shape_ = last_binding->var->shape_; + builder_->EmitOutput(VarBinding(output_var, binding_value)); + + // Step 4. Finish constructing the new block. + BindingBlock new_block = builder_->EndBlock(); + // Step 5. Create a new global variable and the function. + global_var_ = GlobalVar(name_); + function_ = Function(/*name=*/global_var_, // + /*params=*/params_, // + /*body=*/SeqExpr({new_block}, output_var), // + /*ret_type=*/output_var->checked_type_); + function_->shape_ = output_var->shape_; } - Call VisitCallTIR(const CallNode* call) { - // Update fused func name - GlobalVar gv = Downcast(call->args[0]); - BaseFunc prim_func = mod_->Lookup(gv); - GlobalVar new_gv = this->builder_->AddFuncToContext(prim_func, gv->name_hint); - - // Update fused func arguments - Tuple call_tir_args = Downcast(call->args[1]); - if (ginfo_.find(cur_group_) == ginfo_.end()) { - // No need to make new relax function, direct update call_tir - call_tir_args = Downcast(this->VisitExpr(call_tir_args)); - } else { - call_tir_args = GetNewArguments(call_tir_args); + /*! \brief The original bindings of the function */ + Array bindings_; + /*! \brief The parameters of the function */ + Array params_; + /*! \brief The arguments to call the function on the caller side */ + Array arguments_; + /*! \brief The name for the fused function */ + String name_ = "fused"; + /*! \brief The global variable corresponding to the constructed function */ + GlobalVar global_var_; + /*! \brief The constructed Relax function */ + Function function_{nullptr}; - // Do not move this line outside the branch - // since c++ set will automatically insert a new element when accessing - ginfo_[cur_group_].name_hint = ginfo_[cur_group_].name_hint + "_" + gv->name_hint; + private: + /*! + * \brief Check whether the input expression is defined within this function. If not, create a new + * parameter for the expression. + * \param expr The expression to be checked + */ + void CheckDefAndUpdateParam(const Expr& expr) { + // If the expression has already served as an argument, no need to create another one for it. + auto it = std::find(arguments_.begin(), arguments_.end(), expr); + if (it != arguments_.end()) { + return; } - // Create new call - Array args = {new_gv, call_tir_args, call->args[2]}; - return Call(call->op, args, {}, call->type_args, call->span); - } - - Tuple GetNewArguments(const Tuple& args) { - Array new_args; - for (Expr arg : args->fields) { - ICHECK(gmap_.count(arg.get())); - auto* arg_group = gmap_.at(arg.get())->FindRoot(); - arg = VisitExpr(arg); - if (cur_group_ != arg_group && arg->IsInstance()) { - Var param = ginfo_[cur_group_].GetOrAllocParam(arg); - new_args.push_back(param); + // If the expression is not a variable or is a undefined variable, it should be populated as a + // parameter of the relax function. + const auto* var = expr.as(); + if (var == nullptr || defined_vars_.count(var) == 0) { + String name{nullptr}; + if (var != nullptr) { + name = var->name_hint(); } else { - new_args.push_back(arg); + name = String("param_" + std::to_string(n_param_for_const_++)); } + + Var param(std::move(name), // + /*shape_annotation=*/NullOpt, // + /*type_annotation=*/expr->checked_type_); + param->shape_ = expr->shape_; + arguments_.push_back(expr); + params_.push_back(param); } - return Tuple(new_args); } - Expr MakeNewFunction() { - const GroupInfo& ginfo = ginfo_[cur_group_]; - DataflowBlock block = Downcast(builder_->EndBlock()); - Optional output_body; - - for (const relax::Binding& binding : block->bindings) { - Var var; - if (const relax::VarBindingNode* var_binding = binding.as()) { - var = var_binding->var; - } else if (const relax::MatchShapeNode* shape_binding = binding.as()) { - var = shape_binding->var; - } - if (var.defined() && !var.as()) { - ICHECK(!output_body) << "Only one output is allowed"; - output_body = var; - } + Expr VisitExpr(const Expr& expr) final { + // If the expression serves as an argument, return its correspondng parameter. + auto it = std::find(arguments_.begin(), arguments_.end(), expr); + if (it != arguments_.end()) { + return params_[it - arguments_.begin()]; } - ICHECK(output_body) << "There should be at least one output."; - const Expr& body = output_body.value(); - auto func = Function(NullOpt, ginfo.params, SeqExpr({block}, body), body->checked_type()); - func->shape_ = body->shape_; - GlobalVar gv = builder_->AddFuncToContext(func, ginfo.name_hint); - return Call(gv, ginfo.arguments); + // Otherwise, recurse into this expression. + return ExprMutator::VisitExpr(expr); } private: - // Debug function, dump the group assignment in text. - class GroupDebugDumper : public ExprVisitor { - public: - static void Dump(const IRModule& mod, - const std::unordered_map& gmap) { - GroupDebugDumper dumper(gmap); - for (const auto& kv : mod->functions) { - if (const auto* func = kv.second.as()) { - for (const Var& pram : func->params) { - // Skip function params since they are always a single group - dumper.skip_objects_.insert(pram.get()); - } - dumper(GetRef(func)); - } - } + /*! \brief The variables defined in this function */ + std::unordered_set defined_vars_; + /*! \brief The number of parameters reserved for constants */ + int n_param_for_const_ = 0; + /*! \brief The boolean indicating whether the input bindings have an output variable */ + bool has_output_var_ = false; +}; - LOG(INFO) << "Group partition results:\n" << dumper.os_.str(); +/*! + * \brief The ExprMutator used to fuse the operators in Relax functions + * \details Given the partition results on the indexed-forward graph, for each group whose size is + * larger than one, we create a new grouped function for it, containing all bindings in that group. + * And we substitute the bindings in a group with a single function call to the newly created + * grouped function. The workflow of this ExprMutator is: for each dataflow block, + * - we go through the bindings one by one. For each binding, if it is in a group whose size is + * larger than one, we add the binding to the function of the group it is in and update the + * parameters and arguments of that function; + * - then we finalize all the grouped functions by updating their bindings using BlockBuilder; + * - lastly, we go through the bindings again and substitute the bindings in a group with a single + * call to the corresponding grouped function. + * + * After transforming a Relax function, we update the function in the IRModule. Besides, we add all + * newly created grouped function to the IRModule. + */ +class OperatorFusor : public ExprMutator { + public: + /*! + * \brief Construct a new operator fusor. Given the indexed-forward graph and the graph partition + * result on that graph, the constructor creates a mapping from each leaf AST object + * (e.g. parameters, variables, constants) to the group of the node corresponding to the object + * in the graph. + * \param mod The IRModule to be transformed + * \param graph The indexed-forward graph of the input IRModule + * \param groups The grouped result of the group partition on the input indexed-forward graph. + */ + explicit OperatorFusor(IRModule mod, const IndexedForwardGraph& graph, + const std::vector& groups) + : mod_(std::move(mod)) { + for (int nid = 0; nid < static_cast(graph.post_dfs_order.size()); ++nid) { + GraphPartitioner::Group* group_root = groups[nid]->FindRoot(); + ICHECK(group_root != nullptr); + ICHECK(graph.post_dfs_order[nid]->ref != nullptr); + obj2group_[graph.post_dfs_order[nid]->ref] = group_root; } + } - private: - explicit GroupDebugDumper( - const std::unordered_map& gmap) - : gmap_(gmap) {} - - void TryPrintGroup(const ObjectRef& object) { - if (object->IsInstance()) return; - if (skip_objects_.count(object.get())) return; - auto it = gmap_.find(object.get()); - if (it == gmap_.end()) return; - const GraphPartitioner::Group* group = it->second->FindRoot(); - if (const auto* var = object.as()) { - os_ << var->name_hint(); - } else { - os_ << object; - } - os_ << "(" << object.get() << ")"; - auto g_it = group_id_.find(group); - os_ << ": Group #"; - if (g_it == group_id_.end()) { - os_ << (group_id_[group] = group_id_.size()); - } else { - os_ << g_it->second; - } - os_ << "\n"; - - // Prevent showing the same node multiple times - skip_objects_.insert(object.get()); + /*! + * \brief The main transformation on the IRModule + * \return The new IRModule after transformation + */ + IRModule Transform() { + // Step 1. Fetch the main function and apply transformation by recursing into the function. + // - Since cross-function call is not supported yet, FuseOps only serves the entry function, + // whose name is "main". + GlobalVar main_gv = mod_->GetGlobalVar("main"); + auto main_func = Downcast(mod_->Lookup("main")); + auto updated_main_func = Downcast(VisitExpr(main_func)); + + // Step 2. Update the main function in the IRModule. + IRModuleNode* p_mod = mod_.CopyOnWrite(); + p_mod->Update(main_gv, updated_main_func); + + // Step 3. Add the new functions into the IRModule. + for (const auto& kv : new_functions_) { + p_mod->Add(kv.second.first, kv.second.second); } - void VisitExpr(const Expr& expr) final { - TryPrintGroup(expr); - ExprVisitor::VisitExpr(expr); + return GetRef(p_mod); + } + + private: + BindingBlock VisitBindingBlock(const BindingBlock& block) final { + if (const auto* df_block = block.as()) { + return VisitBindingBlock_(df_block); } + // We skip ordinary binding blocks since they might be impure (with side effect or control flow) + return block; + } - private: - std::unordered_map group_id_; - std::ostringstream os_; - const std::unordered_map& gmap_; - std::unordered_set skip_objects_; - }; + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) final { + group2func_.clear(); - private: - struct GroupInfo { - // The parameters of the function. - Array params; - // The arguments to call the functions. - Array arguments; - // The name hint for the group - String name_hint = "fused"; - - Var GetOrAllocParam(const Expr& arg) { - // run linear scan as most fused groups contain only a few inputs. - for (size_t i = 0; i < arguments.size(); ++i) { - if (arg.same_as(arguments[i])) return params[i]; - } - // create a new parameter. - if (const auto* arg_var = arg.as()) { - params.push_back(Var(arg_var->name_hint(), arg_var->shape(), arg_var->checked_type_)); - } else { - // TODO(siyuan): need enhance it. - LOG(FATAL) << "ValueError: call args must be a var for now."; - } - arguments.push_back(arg); - return params.back(); + // Step 1. Collect the bindings for each grouped function. + CollectFuncBindings(block->bindings); + + // Step 2. Create the grouped function for each group. + for (auto& kv : group2func_) { + FunctionCreator& creator = kv.second; + creator.CreateFunction(); } - }; - /*! \brief Internal arena. */ - support::Arena arena_; - /*! \brief The group assignment map. */ - std::unordered_map gmap_; - /*! \brief Internal group information map. */ - std::unordered_map ginfo_; - /*! \brief The IRModule. */ - IRModule mod_; - /*! \brief The current group. */ - GraphPartitioner::Group* cur_group_; -}; -class TIRFuseMutator : public ExprMutator { - public: - static IRModule Transform(const IRModule& mod) { - TIRFuseMutator mutator(mod); + // Step 3. Start generating the new binding block. + // - For groups with single binding, we directly recurse into the binding and emit the new one. + // - For groups with multiple bindings, we emit the call to the grouped function only when + // visiting the last binding of the group, because only by doing this we don't break the + // dependencies among the bindings of different groups. And therefore, we will skip all but the + // last binding of the group. + builder_->BeginDataflowBlock(); + for (int i = 0; i < static_cast(block->bindings.size()); ++i) { + const Binding& binding = block->bindings[i]; + + // Case 1. If the binding is the only binding in its group, recurse into it and emit the + // transformed binding as usual. + GraphPartitioner::Group* group = GetGroupFromBinding(binding); + if (group->num_nodes == 1) { + VisitBinding(binding); + continue; + } - BaseFunc main_func = mod->Lookup("main"); - mutator.func_info_ = FuseFuncInfo("main", false); - mutator.builder_->AddFuncToContext(Downcast(mutator.VisitExpr(main_func)), "main"); + const auto& it_creator = group2func_.find(group); + ICHECK(it_creator != group2func_.end()); + const FunctionCreator& func_info = it_creator->second; - return mutator.builder_->GetContextIRModule(); - } + // Case 2. If the binding is not the last binding of the group, we skip it. + if (!func_info.bindings_.back().same_as(binding)) { + continue; + } - private: - explicit TIRFuseMutator(const IRModule& mod) : mod_(mod) {} + // Case 3. The binding is the last binding of the group. + const auto* var_binding = binding.as(); + ICHECK(var_binding != nullptr) << "The last binding of a group whose size is larger than 1 " + "is supposed to be a variable binding"; + + // Step a. Add the grouped function of this group to the field `new_functions_` of this fusor + // with deduplication. + std::pair gv_func_pair = + AddFuncWithDeduplication(func_info.global_var_, func_info.function_); + new_functions_[gv_func_pair.first->name_hint] = gv_func_pair; + + // Step b. Create the call to the deduplicated function, and then emit the call. + // - If this binding is the last binding of the current binding block, emit an output + // variable. + // - Otherwise, emit a dataflow variable. + Var new_var{nullptr}; + Call call_to_emit = Call(gv_func_pair.first, UpdateArgs(func_info.arguments_)); + if (i < static_cast(block->bindings.size()) - 1) { + new_var = builder_->Emit(call_to_emit); + } else { + new_var = builder_->EmitOutput(call_to_emit); + } - Expr VisitExpr_(const CallNode* call) final { - static const Op& call_tir_op_ = Op::Get("relax.call_tir"); - Expr e = ExprMutator::VisitExpr_(call); - call = e.as(); - ICHECK(call != nullptr); - - if (call->op->IsInstance()) { - // Emit primitive relax Function - GlobalVar gv = Downcast(call->op); - BaseFunc func = mod_->Lookup(gv); - FuseFuncInfo info = func_info_; - this->func_info_ = FuseFuncInfo(gv->name_hint, true); - GlobalVar new_gv = - this->builder_->AddFuncToContext(Downcast(VisitExpr(func)), gv->name_hint); - func_info_ = info; - return Call(new_gv, call->args, call->attrs, call->type_args, call->span); + // Step c. Update the mapping used for the remapping of the binding variables. + var_remap_[var_binding->var->vid] = new_var; } + // Step 4. Finish the binding block generation. + return builder_->EndBlock(); + } - if (call->op != call_tir_op_) { - return e; + /*! + * \brief Collect the bindings for each grouped function and update the information of the grouped + * function + * \param bindings The bindings to be collected + * \note The function update is done by `AppendBinding(...)` + */ + void CollectFuncBindings(const Array& bindings) { + for (const Binding& binding : bindings) { + // If the binding is the only binding in its group, there is no need to create a new function. + GraphPartitioner::Group* group = GetGroupFromBinding(binding); + if (group->num_nodes == 1) { + continue; + } + // Add the binding to the grouped function it's in, and update the function information + // accordingly. + FunctionCreator& func_info = group2func_[group]; + func_info.AppendBinding(binding); } + } - GlobalVar gv = Downcast(call->args[0]); - tir::PrimFunc func = Downcast(mod_->Lookup(gv)); - - if (func_info_.is_primitive) { - func_info_.prim_funcs.push_back(func); - // update func_info_.param_map - const Array call_tir_args = Downcast(call->args[1])->fields; - for (size_t i = 0; i < call_tir_args.size(); ++i) { - if (call_tir_args[i]->IsInstance()) { - func_info_.arguments.push_back(call_tir_args[i]); - } else if (call_tir_args[i]->IsInstance()) { - Var arg_var = Downcast(call_tir_args[i]); - auto it = func_info_.var2param.find(arg_var); - if (it == func_info_.var2param.end()) { - // add it to the arg list if the arg is not the result of previous call_tir - func_info_.arguments.push_back(arg_var); - } else { - const tir::Var& producer_param = Downcast((*it).second); - const tir::Var& consumer_param = func->params[i]; - func_info_.param_map.Set(consumer_param, producer_param); - } - } else { - ICHECK(false) << "Only var and constant are allowed"; - } - } - return e; + /*! + * \brief Get the group which the input binding is in + * \param binding The binding to be queried + * \return The pointer to the group which the input binding is in + */ + GraphPartitioner::Group* GetGroupFromBinding(const Binding& binding) { + Var var{nullptr}; + if (const auto* var_binding = binding.as()) { + var = var_binding->var; } else { - GlobalVar new_gv = this->builder_->AddFuncToContext(func, gv->name_hint); - return Call(call->op, {new_gv, call->args[1], call->args[2]}, call->attrs, call->type_args, - call->span); + const auto* match_shape = binding.as(); + ICHECK(match_shape != nullptr); + var = match_shape->var; } + + const auto& it_group = obj2group_.find(var.get()); + ICHECK(it_group != obj2group_.end()); + GraphPartitioner::Group* group = it_group->second; + ICHECK(group->FindRoot() == group); + return group; } - void VisitBinding_(const VarBindingNode* binding) final { - static const Op& call_tir_op_ = Op::Get("relax.call_tir"); - if (!func_info_.is_primitive) { - return ExprMutator::VisitBinding_(binding); + /*! + * \brief Update the pre-stored arguments according to the variable remapping of the fusor, by + * recursing into each argument + * \param args The arguments to be updated + * \return The updated arguments + */ + Array UpdateArgs(const Array& args) { + Array new_args; + new_args.reserve(args.size()); + for (const Expr& arg : args) { + new_args.push_back(VisitExpr(arg)); } - Expr value = this->VisitExpr(binding->value); - if (const auto* call = value.as()) { - if (!binding->var->IsInstance()) { - // Emit call_tir func if it's the output call - tir::PrimFunc func = tir::FusePrimFuncs(func_info_.prim_funcs, func_info_.param_map); - GlobalVar gv = this->builder_->AddFuncToContext(func, func_info_.name_hint); - Array call_args = {gv, Tuple(func_info_.arguments), call->args[2]}; - Call new_call_tir(call_tir_op_, call_args, call->attrs, call->type_args); - Var output = this->builder_->EmitOutput(new_call_tir); - this->var_remap_[binding->var->vid] = output; - } else { - // Update func_info_.var2param - GlobalVar gv = Downcast(call->args[0]); - tir::PrimFunc func = Downcast(mod_->Lookup(gv)); - const Expr& output_shapes = call->args[2]; - if (const auto* tuple_output_shapes = output_shapes.as()) { - // set var2param to a array if there is more than one output. - size_t output_size = tuple_output_shapes->fields.size(); - Array output_param(func->params.end() - output_size, func->params.end()); - func_info_.var2param.Set(binding->var, output_param); - } else { - func_info_.var2param.Set(binding->var, func->params.back()); + return new_args; + } + + /*! + * \brief Add the input global variable and function to the new function list of the fusor, and + * meanwhile resolve the name deduplication. We also discard the input function if some function + * with the same name as the input function structurally equals to it. + * \param gv The global variable corresponding to the new function to be added + * \param func The new function to be added + * \return The pair of the new global variable and the new function. Or a previously added pair if + * the input function structurally equals to the previously added function. + */ + std::pair AddFuncWithDeduplication(GlobalVar gv, Function func) { + std::string name = gv->name_hint; + int suffix = 0; + + while (true) { + auto it = new_functions_.find(name); + if (it == new_functions_.end()) { + if (gv->name_hint != name) { + gv = GlobalVar(name); + FunctionNode* p_func = func.CopyOnWrite(); + p_func->name = gv; } + return std::make_pair(std::move(gv), std::move(func)); } - } else if (const auto* tuple_get_item = value.as()) { - ICHECK(binding->var->IsInstance()) - << "Currently TupleGetItem outputs are not allowed"; - const Var& tuple_var = Downcast(tuple_get_item->tuple); - auto it = func_info_.var2param.find(tuple_var); - if (it == func_info_.var2param.end()) { - // Directly emit tuple if it's extern input - Var lv = this->builder_->Emit(value); - this->var_remap_[binding->var->vid] = lv; - } else { - // Update var2param if the input is local var - Array params = Downcast>((*it).second); - func_info_.var2param.Set(binding->var, params[tuple_get_item->index]); + + Function existing_func = (*it).second.second; + if (structural_equal_(func, existing_func)) { + LOG(INFO) << "[2]. deduplicate function " << gv->name_hint; + return std::make_pair((*it).second.first, (*it).second.second); } - } else { - LOG(FATAL) << "Unsupported binding value: " << value; + + std::ostringstream os; + os << gv->name_hint << "_" << ++suffix; + name = os.str(); } } private: - struct FuseFuncInfo { - FuseFuncInfo() = default; - FuseFuncInfo(const String& name_hint, bool is_primitive) - : name_hint(name_hint), is_primitive(is_primitive) {} - - /*! \brief The relax function name hint */ - String name_hint = ""; - /*! \brief An boolean indicate if the function if to be fused */ - bool is_primitive = false; - /*! \brief The prim_funcs to be fused. */ - Array prim_funcs; - /*! - * \brief The mapping from relax to prim_func param - * \note The rhs can be a tir::Var or Array (for tuple return) - */ - Map var2param; - /*! - * \brief A map indicate how data exchange between functions. - * The map is from consumer params to the producer params. - */ - Map param_map; - /*! \brief The arguments for calling prim_func */ - Array arguments; - }; - - /*! \brief The IRModule */ - const IRModule& mod_; - /*! \brief The IRModule */ - FuseFuncInfo func_info_; + /*! \brief The IRModule. */ + IRModule mod_; + /*! \brief Internal arena. */ + support::Arena arena_; + /*! \brief The group assignment map. */ + std::unordered_map obj2group_; + /*! \brief Internal function information map. */ + std::unordered_map group2func_; + /*! \brief The new global variables and functions to be added to the module */ + std::unordered_map> new_functions_; + /*! \brief The structural equality checker */ + StructuralEqual structural_equal_; }; -class Inliner : public ExprMutator { - public: - static IRModule Transform(const IRModule& mod) { - Inliner inliner(mod); - for (const auto& kv : mod->functions) { - BaseFunc func = kv.second; - const GlobalVar& global_var = kv.first; - // We won't add current PrimFunc to the context - if (global_var->name_hint == "main") { - Expr new_func = inliner.VisitExpr(func); - inliner.builder_->AddFuncToContext(Downcast(new_func), global_var->name_hint); - } else if (func->IsInstance()) { - inliner.builder_->AddFuncToContext(func, global_var->name_hint); - } - } +IRModule FuseOps(IRModule mod, int opt_level, size_t max_fuse_depth) { + support::Arena arena; - return inliner.builder_->GetContextIRModule(); - } + // Step 1. Create the indexed-forward graph according to the input IRModule. + IndexedForwardGraph graph = GraphCreator::Create(mod, &arena); - private: - explicit Inliner(const IRModule& mod) : mod_(mod) {} + // Step 2. Partition the graph by applying the fusion algorithm. + std::vector groups = + GraphPartitioner(&arena, opt_level, max_fuse_depth).Partition(graph); - void VisitBinding_(const VarBindingNode* binding) final { - Optional _relax_func = get_relax_call(binding); - if (!_relax_func.defined()) { - return ExprMutator::VisitBinding_(binding); + LOG(INFO) << "number of groups: " << groups.size(); + for (int i = 0; i < static_cast(groups.size()); ++i) { + if (groups[i]->FindRoot() == groups[i]) { + LOG(INFO) << "group[" << i << "] has " << groups[i]->num_nodes; } - Function relax_func = _relax_func.value(); - Call call = Downcast(binding->value); - VisitPrimitiveFunc(relax_func, call->args, binding->var); } - private: - void VisitPrimitiveFunc(const Function& func, const Array& args, const Var& binding_var) { - // update var_remap_ via function params - ICHECK_EQ(func->params.size(), args.size()); - for (size_t i = 0; i < func->params.size(); ++i) { - const Var& param = func->params[i]; - const Expr& arg = args[i]; - this->var_remap_[param->vid] = Downcast(arg); - } + // Step 3. Transform the IRModule by fusing the operators in accordance with the graph partition + // results. + mod = OperatorFusor(mod, graph, groups).Transform(); - const auto* seq = func->body.as(); - ICHECK(seq != nullptr); - ICHECK_EQ(seq->blocks.size(), 1); - const BindingBlock& block = seq->blocks[0]; - ICHECK(block->IsInstance()); - for (const Binding& binding : block->bindings) { - if (const auto* var_binding = binding.as()) { - Var var = this->builder_->Emit(VisitExpr(var_binding->value)); - this->var_remap_[var_binding->var->vid] = var; - } else { - ICHECK(false) << "Unsupported binding"; - } - } - // Update return value remap - Var return_var = Downcast(seq->body); - this->var_remap_[binding_var->vid] = this->var_remap_[return_var->vid]; - } + const auto* f = runtime::Registry::Get("script.AsRelaxScript"); + String s = (*f)(mod, false); + LOG(INFO) << "After FuseOps:\n" << s; - Optional get_relax_call(const VarBindingNode* binding) { - const auto* call = binding->value.as(); - // Cond 1. binding value is a call node. - if (call == nullptr) return NullOpt; - // Cond 2. Call node op must be GlobalVar - if (!call->op->IsInstance()) return NullOpt; - GlobalVar gv = Downcast(call->op); - // Cond 3. The GlobalVar must be in the IRModule - auto it = mod_->functions.find(gv); - if (it == mod_->functions.end()) return NullOpt; - // Cond 4. The function must be a relax function - const BaseFunc& func = (*it).second; - if (!func->IsInstance()) return NullOpt; - return Downcast(func); - } + // TODO(ruihang): unit tests: 1. name duplication - private: - const IRModule& mod_; -}; - -IRModule FuseOps(IRModule mod, int opt_level, size_t max_fuse_depth) { - mod = RelaxFuseMutator::Transform(mod, opt_level, max_fuse_depth); - // const auto* f = runtime::Registry::Get("script.AsRelaxScript"); - // String s = (*f)(mod, false); - // std::cout << s << std::endl; - mod = TIRFuseMutator::Transform(mod); - mod = Inliner::Transform(mod); return mod; }