From a49a7fee499b9177701cbda78dfbc0bbede1c3af Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Sun, 12 Feb 2023 22:13:30 +0800 Subject: [PATCH] [Relay][Pass] Separate out the graph partitioning code from fuse_ops.cc (#13964) * [Relay][Pass] Separate out the graph partitioning code from fuse_ops.cc The current `fuse_ops.cc` contains the following parts: 1. `IndexedForwardGraph` and `DominatorTree` which are used for graph partitioning 2. A Relay Expr visitor to create the `DominatorTree` 3. A Relay Expr mutator to fuse the ops This PR separates the graph partitioning code from `fuse_ops.cc` and moves it to the analysis folder, for: 1. Better code organization and readability as the graph partitioning code is quite long and not directly related to the fusion mutator 2. Possible reuse opportunities for other fusion passes in Relax NOTE: we won't bring relax fusion in `main` branch for now, but this pr is still reasonable for `main`. * lint --- src/relay/analysis/graph_partitioner.cc | 334 +++++++++++++++ src/relay/analysis/graph_partitioner.h | 269 ++++++++++++ src/relay/transforms/fuse_ops.cc | 516 +----------------------- 3 files changed, 615 insertions(+), 504 deletions(-) create mode 100644 src/relay/analysis/graph_partitioner.cc create mode 100644 src/relay/analysis/graph_partitioner.h diff --git a/src/relay/analysis/graph_partitioner.cc b/src/relay/analysis/graph_partitioner.cc new file mode 100644 index 000000000000..861fd58d9e5c --- /dev/null +++ b/src/relay/analysis/graph_partitioner.cc @@ -0,0 +1,334 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "./graph_partitioner.h" + +#include + +namespace tvm { +namespace relay { + +DominatorTree DominatorTree::PostDom(support::Arena* arena, const IndexedForwardGraph& graph) { + DominatorTree tree; + tree.nodes.resize(graph.post_dfs_order.size(), nullptr); + // reverse topo order + for (size_t i = graph.post_dfs_order.size(); i != 0; --i) { + size_t index = i - 1; + tree.nodes[index] = tree.GetNode(arena, graph.post_dfs_order[index]); + } + return tree; +} + +DominatorTree::Node* DominatorTree::LeastCommonAncestor(Node* lhs, Node* rhs, + OpPatternKind* edge_pattern) { + while (lhs != rhs) { + if (lhs == nullptr) return nullptr; + if (rhs == nullptr) return nullptr; + if (lhs->depth < rhs->depth) { + edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern); + rhs = rhs->parent; + } else if (rhs->depth < lhs->depth) { + edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern); + lhs = lhs->parent; + } else { + edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern); + edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern); + lhs = lhs->parent; + rhs = rhs->parent; + } + } + return lhs; +} + +DominatorTree::Node* DominatorTree::LeastCommonAncestor( + const LinkedList& input_nodes, OpPatternKind* edge_pattern) { + auto link = input_nodes.head; + if (link == nullptr) { + return nullptr; + } + auto get_node = [&](const IndexedForwardGraph::Edge& edge) { + size_t oindex = edge.node->index; + ICHECK_LT(oindex, nodes.size()); + Node* onode = nodes[oindex]; + ICHECK(onode != nullptr); + return onode; + }; + Node* parent = get_node(link->value); + *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern); + link = link->next; + for (; link != nullptr; link = link->next) { + parent = LeastCommonAncestor(parent, get_node(link->value), edge_pattern); + *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern); + } + return parent; +} + +DominatorTree::Node* DominatorTree::GetNode(support::Arena* arena, + IndexedForwardGraph::Node* gnode) { + Node* tnode = arena->make(); + tnode->gnode = gnode; + if (gnode->extern_ref) { + tnode->depth = 1; + tnode->parent = nullptr; + tnode->pattern = kOpaque; + } else { + // find the LCAs of all outputs. + OpPatternKind pattern = kElemWise; + Node* parent = LeastCommonAncestor(gnode->outputs, &pattern); + tnode->depth = parent ? parent->depth + 1 : 1; + tnode->parent = parent; + tnode->pattern = pattern; + } + return tnode; +} + +std::vector GraphPartitioner::Partition( + const IndexedForwardGraph& graph) { + this->InitGroups(graph); + if (opt_level_ == 0) return std::move(groups_); + // get post dominator tree + auto post_dom_tree = DominatorTree::PostDom(arena_, graph); + // run fusion algorithm. + for (int phase = 0; phase < 3; ++phase) { + this->RunFuse(graph, post_dom_tree, phase); + } + return std::move(groups_); +} + +GraphPartitioner::Group* GraphPartitioner::Group::FindRoot() { + // fast path + if (this->parent == nullptr) return this; + // slow path with path compression. + Group* root = this; + while (root->parent != nullptr) { + root = root->parent; + } + for (Group* p = this; p != root;) { + Group* parent = p->parent; + p->parent = root; + p = parent; + } + return root; +} + +template +bool GraphPartitioner::CheckPath_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, + F fcond) { + if (visited_.count(src)) return true; + visited_.insert(src); + Group* gnode = groups_[src->index]; + ICHECK(gnode != nullptr); + gnode = gnode->FindRoot(); + if (!fcond(gnode->pattern, src == sink)) return false; + if (src == sink) return true; + for (auto link = src->outputs.head; link != nullptr; link = link->next) { + if (!CheckPath_(link->value.node, sink, fcond)) return false; + } + return true; +} + +template +bool GraphPartitioner::CheckPath(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, + F fcond) { + ICHECK(!src->extern_ref); + visited_.clear(); + ICHECK(src != sink); + for (auto link = src->outputs.head; link != nullptr; link = link->next) { + if (!CheckPath_(link->value.node, sink, fcond)) return false; + } + return true; +} + +OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) { + if (lhs > relay::kBroadcast && rhs > relay::kBroadcast) { + LOG(FATAL) << "Cannot merge two complex group together"; + } + if (lhs > rhs) return lhs; + return rhs; +} + +void GraphPartitioner::MergeFromTo(Group* child, Group* parent) { + child = child->FindRoot(); + parent = parent->FindRoot(); + if (child == parent) return; + // update the number of nodes of the parent group + parent->num_nodes += child->num_nodes; + child->parent = parent; + // update anchor ref and pattern + if (child->anchor_ref != nullptr) { + ICHECK(parent->anchor_ref == nullptr); + parent->anchor_ref = child->anchor_ref; + parent->pattern = CombinePattern(child->pattern, parent->pattern); + } +} + +void GraphPartitioner::CommitFuse_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, + Group* target) { + if (src == sink) return; + if (visited_.count(src)) return; + visited_.insert(src); + Group* gnode = groups_[src->index]; + ICHECK(gnode != nullptr); + // merge the current group to the parent if possible. + MergeFromTo(gnode, target); + for (auto link = src->outputs.head; link != nullptr; link = link->next) { + CommitFuse_(link->value.node, sink, target); + } +} + +void GraphPartitioner::CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) { + Group* target = groups_[sink->index]; + visited_.clear(); + ICHECK(src != sink); + CommitFuse_(src, sink, target); +} + +size_t GraphPartitioner::CountNodesUptoSink_(IndexedForwardGraph::Node* src, + IndexedForwardGraph::Node* sink) { + if (src == sink || visited_.count(src)) return 0; + visited_.insert(src); + Group* gnode = groups_[src->index]; + ICHECK(gnode != nullptr); + auto sum = gnode->num_nodes; + for (auto link = src->outputs.head; link != nullptr; link = link->next) { + sum += CountNodesUptoSink_(link->value.node, sink); + } + return sum; +} + +size_t GraphPartitioner::CountFusedNodesWithNewChild(IndexedForwardGraph::Node* child, + IndexedForwardGraph::Node* dom_parent) { + Group* target = groups_[dom_parent->index]; + visited_.clear(); + ICHECK(child != dom_parent); + return target->FindRoot()->num_nodes + CountNodesUptoSink_(child, dom_parent); +} + +void GraphPartitioner::InitGroups(const IndexedForwardGraph& graph) { + groups_.resize(graph.post_dfs_order.size()); + for (size_t nid = 0; nid < groups_.size(); ++nid) { + const auto* graph_node = graph.post_dfs_order[nid]; + auto* group_node = arena_->make(); + group_node->pattern = graph_node->pattern; + group_node->root_ref = graph_node->ref; + // set anchor ref if necessary. + if (group_node->pattern == relay::kOutEWiseFusable) { + group_node->anchor_ref = graph_node->ref; + } + groups_[nid] = group_node; + } +} + +void GraphPartitioner::RunFuse(const IndexedForwardGraph& graph, // + const DominatorTree& post_dom_tree, // + int phase) { + for (size_t nid = 0; nid < groups_.size(); ++nid) { + // the group of current node has been specified already. + auto* graph_node = graph.post_dfs_order[nid]; + auto* dom_node = post_dom_tree.nodes[nid]; + Group* group_node = groups_[nid]; + ICHECK(group_node != nullptr); + // no actions for opaque nodes + if (group_node->pattern == kOpaque) continue; + // no actions needed if the current node have no dominator + if (dom_node->parent == nullptr) continue; + ICHECK(!graph_node->extern_ref); + size_t dom_parent_gindex = dom_node->parent->gnode->index; + + // refuse the fusion if too many ops are going to be fused together + if (CountFusedNodesWithNewChild(graph_node, dom_node->parent->gnode) > max_fuse_depth_) + continue; + + if (phase == 2) { + // Fuse injective ops into intermediate tuples, if any + if (group_node->pattern > relay::kInjective) continue; + Group* dom_parent_group = groups_[dom_parent_gindex]; + Group* dom_root_group = dom_parent_group->FindRoot(); + // If dom node group has a tuple as its root, we do not fuse tuple fields into it + if (dom_root_group->pattern == relay::kTuple) continue; + if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= relay::kInjective) { + // Now we know the tuple has been fused into subsequent injective ops + auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; }; + // dom_root_group can also be tuple, as in inception layers + // CheckPath is needed to avoid fusing two intermediate tuples + if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { + CommitFuse(graph_node, dom_node->parent->gnode); + } + } + continue; + } + + // Skip if current node is already fused to the parent. + if (groups_[dom_parent_gindex] != nullptr && + group_node->FindRoot() == groups_[dom_parent_gindex]->FindRoot()) { + continue; + } + // Do not fuse into tuple for now + if (groups_[dom_parent_gindex]->pattern == kTuple) continue; + // Try to fuse current node to its post-dominator. + if (group_node->pattern == kOutEWiseFusable) { + if (phase != 0) continue; + // Path for OutEWiseFusable: conv2d + // Check if the dominator relation is elemwise. + if (dom_node->parent != nullptr && dom_node->pattern == kElemWise) { + ICHECK(dom_node->parent->gnode != nullptr); + // The fuse can be executed if all the intermediate ops are still broadcast. + auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kBroadcast; }; + if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { + CommitFuse(graph_node, dom_node->parent->gnode); + } + } + } else if (group_node->pattern <= kBroadcast) { + // Pre-condition: can only be fused to parent which is injective or reduction. + if (dom_node->parent != nullptr && + (dom_node->pattern <= kInjective || dom_node->pattern == kCommReduce)) { + // Check if all the intermediate ops are still broadcast. + // The final terminal node can already be fused to a OutEWiseFusable group. + auto fcond = [](OpPatternKind kind, bool is_sink) { + if (!is_sink) { + // Elemwise, broadcast, and injective ops on the parallel branches + // are allowed be fused to the elemwise/broadcast anchor. + return kind <= kInjective; + } else { + return (kind <= kBroadcast || kind == kCommReduce || kind == kInjective || + kind == kOutEWiseFusable); + } + }; + if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { + CommitFuse(graph_node, dom_node->parent->gnode); + } + } + } else if (group_node->pattern == kInjective || group_node->pattern == kTuple) { + // defer injective fusion to second phase. + // so conv2d always finishes fusing. + if (phase != 1) continue; + // Check if all path are injective. + auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; }; + if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { + CommitFuse(graph_node, dom_node->parent->gnode); + } + } else { + // do nothing. + ICHECK(group_node->pattern == kCommReduce); + } + } +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/analysis/graph_partitioner.h b/src/relay/analysis/graph_partitioner.h new file mode 100644 index 000000000000..9433aafa119d --- /dev/null +++ b/src/relay/analysis/graph_partitioner.h @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/analysis/graph_partitioner.h + * \brief The helper function for op fusion. + */ + +#ifndef TVM_RELAY_ANALYSIS_GRAPH_PARTITIONER_H_ +#define TVM_RELAY_ANALYSIS_GRAPH_PARTITIONER_H_ + +#include + +#include +#include +#include + +#include "../../support/arena.h" + +namespace tvm { +namespace relay { + +using support::LinkedList; +using support::LinkNode; + +/*! + * \brief Indexed data flow graph in forward direction. + * This is a temporary data structure used for operator fusion analysis. + * + * This data structure only captures the dataflow fragment and + * could ignore blocks like let by simply ordering each dataflow block + * and mark the output node as extern_ref; + */ +class IndexedForwardGraph { + public: + struct Node; + /*! + * The forward edge in the dataflow graph. + */ + struct Edge { + /*! \brief The corresponding node */ + Node* node{nullptr}; + /*! \brief The respective pattern of this op */ + OpPatternKind pattern{kOpaque}; + }; + /*! \brief A node in the graph. */ + struct Node { + /*! \brief weak reference to the corresponding edge. */ + const tvm::Object* ref{nullptr}; + /*! \brief The index of the node in topological order. */ + size_t index{0}; + /*! \brief Whether this node is referenced by external source */ + bool extern_ref{false}; + /*! \brief The general pattern in the node */ + OpPatternKind pattern{kOpaque}; + /*! \brief The outputs of the node. */ + LinkedList outputs; + }; + /*! \brief The node map that maps node to graph */ + std::unordered_map node_map; + /*! \brief All the nodes in post DFS order */ + std::vector post_dfs_order; + + /*! \brief Dump the graph into string. */ + void DebugDump() { + std::ostringstream os; + for (size_t i = 0; i < post_dfs_order.size(); ++i) { + Node* node = post_dfs_order[i]; + os << "node[" << i << "], " << GetRef(node->ref) << " outputs=["; + for (auto* link = node->outputs.head; link != nullptr; link = link->next) { + os << link->value.node->index << ", "; + } + os << "]\n"; + } + LOG(INFO) << os.str(); + } +}; + +/*! + * \brief Dominator tree that represent domination or + * post domination relation of the node. + */ +class DominatorTree { + public: + /*! + * \brief A node in the dominator tree. + */ + struct Node { + /*! \brief The node in the tree */ + IndexedForwardGraph::Node* gnode{nullptr}; + /*! \brief parent of the tree */ + Node* parent{nullptr}; + /*! \brief current depth*/ + int depth{0}; + /*! \brief aggregated pattern to parent */ + OpPatternKind pattern{kOpaque}; + }; + // index -> node. + std::vector nodes; + /*! + * \brief compute a post dominator relation for a given dataflow graph. + * \param arena The arena used for node allocation. + * \param graph The graph to be analyzed. + * \return The dominator tree of the graph. + * \note This algorithm makes use of the fact that graph is DAG, + * and runs a single pass algorithm via LCA (Least Common Ancestor) + */ + static DominatorTree PostDom(support::Arena* arena, const IndexedForwardGraph& graph); + + private: + // Combine pattern together. + inline static OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) { + if (lhs > rhs) return lhs; + return rhs; + } + /*! + * \brief Find the least common ancestor of the two nodes. + * \param lhs The left node. + * \param rhs The right node. + * \param edge_pattern + * The combined edge pattern across all the parents. + * \return The least common ancestor of the two. + */ + static Node* LeastCommonAncestor(Node* lhs, Node* rhs, OpPatternKind* edge_pattern); + /*! + * \brief Find the least common ancestor of a list of nodes. + * \param nodes the nodes. + * \param edge_pattern + * The combined edge pattern across all the parents. + * \return The least common ancestor of all nodes. + */ + Node* LeastCommonAncestor(const LinkedList& input_nodes, + OpPatternKind* edge_pattern); + + /*! + * \brief Convert the Node from an IndexedForwardGraph Node into DomaintorTree Node. + * \param arena The Arena. + * \param gnode An IndexedForwardGraph Node. + * \return The DominatorTree Node. + */ + Node* GetNode(support::Arena* arena, IndexedForwardGraph::Node* gnode); +}; + +/*! + * \brief A partition of the graph marked by union find data structure. + */ +class GraphPartitioner { + public: + explicit GraphPartitioner(support::Arena* arena, int opt_level, size_t max_fuse_depth) + : arena_(arena), opt_level_(opt_level), max_fuse_depth_(max_fuse_depth) {} + /*! + * \brief Group as a union find data structure. + */ + struct Group { + /*! \brief The parent in the union find data structure. */ + Group* parent{nullptr}; + /*! \brief The pattern of the group */ + OpPatternKind pattern; + /*! \brief reference to the root node. */ + const tvm::Object* root_ref{nullptr}; + /*! + * \brief Reference to the anchor node, + * this field is not nullptr only if pattern is kOutEWiseFusable. + */ + const tvm::Object* anchor_ref{nullptr}; + /*! + * \brief The number of nodes belonging to this group + */ + uint32_t num_nodes{1}; + + /*! \brief Optional attributes to annotate the grouped function. */ + runtime::Map attrs; + /*! + * \brief Find the group root, perform path compression + * \return The root type node. + */ + Group* FindRoot(); + }; + /*! + * \brief Partition a graph. + * \return group assignments of each node. + */ + std::vector Partition(const IndexedForwardGraph& graph); + + private: + /*! \brief The internal arena for temporary space. */ + support::Arena* arena_; + /*! \brief optimization level for fuse operation. */ + int opt_level_; + /*! \brief The maximum number of operations in one fused function */ + size_t max_fuse_depth_; + /*! \brief The internal groups. */ + std::vector groups_; + /*! \brief internal field used for deduplication */ + std::unordered_set visited_; + // Internal implementation of CheckPath + template + bool CheckPath_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond); + + /*! + * \brief Check all the node and edge pattern + * between src and sink satisfies fcond. + * + * src is not checked. + * + * \param src The source node. + * \param sink The termination node. + * \param fcond The condition to be checked. + * \tparam F the condition function, with signature + * \note sink must be a post-dominator of src. + */ + template + bool CheckPath(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond); + + /*! + * \brief Merge the child group to the parent. + * \param child The child group. + * \param parent The parent group. + */ + void MergeFromTo(Group* child, Group* parent); + + // Internal implementation of CommitFuse + void CommitFuse_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, Group* target); + + /*! + * \brief Commit fusion operation. + * \param src The source node. + * \param sink The termination node. + * \note sink must be a post-dominator of src. + */ + void CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink); + + size_t CountNodesUptoSink_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink); + + // Count the number of nodes in a fused subgraph if child is additionally fused. + // dom_parent is already known to be a part of the subgraph. + // For a diamond structure, there can be multiple paths connecting child and dom_parent. + // All intermediate nodes between child and dom_parent are taken into account. + // Since dom_parent can itself be an intermediate node in the subgraph, calling FindRoot() + // is important for correct calculation. + size_t CountFusedNodesWithNewChild(IndexedForwardGraph::Node* child, + IndexedForwardGraph::Node* dom_parent); + + // Initialize the groups. + void InitGroups(const IndexedForwardGraph& graph); + + // execute the fusion algorithm. + void RunFuse(const IndexedForwardGraph& graph, const DominatorTree& post_dom_tree, int phase); +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_ANALYSIS_GRAPH_PARTITIONER_H_ diff --git a/src/relay/transforms/fuse_ops.cc b/src/relay/transforms/fuse_ops.cc index afa60f1bb4e5..1fb857cb1cb3 100644 --- a/src/relay/transforms/fuse_ops.cc +++ b/src/relay/transforms/fuse_ops.cc @@ -32,6 +32,7 @@ #include #include "../../support/arena.h" +#include "../analysis/graph_partitioner.h" #include "../op/annotation/annotation.h" #include "./pass_utils.h" #include "./pattern_utils.h" @@ -88,72 +89,16 @@ static const Op& stop_fusion_op = Op::Get("annotation.stop_fusion"); TVM_REGISTER_PASS_CONFIG_OPTION("relay.FuseOps.max_depth", Integer); TVM_REGISTER_PASS_CONFIG_OPTION("relay.FuseOps.link_params", Bool); -/*! - * \brief Indexed data flow graph in forward direction. - * This is a temporary data structure used for operator fusion analysis. - * - * This data structure only captures the dataflow fragment and - * could ignore blocks like let by simply ordering each dataflow block - * and mark the output node as extern_ref; - */ -class IndexedForwardGraph { +// Creator of post dominator tree of the dataflow +class IndexedForwardGraphCreator : private ExprVisitor { public: - struct Node; - /*! - * The forward edge in the dataflow graph. - */ - struct Edge { - /*! \brief The corresponding node */ - Node* node{nullptr}; - /*! \brief The respective pattern of this op */ - OpPatternKind pattern{kOpaque}; - }; - /*! \brief A node in the graph. */ - struct Node { - /*! \brief weak reference to the corresponding edge. */ - const tvm::Object* ref{nullptr}; - /*! \brief The index of the node in topological order. */ - size_t index{0}; - /*! \brief Whether this node is referenced by external source */ - bool extern_ref{false}; - /*! \brief The general pattern in the node */ - OpPatternKind pattern{kOpaque}; - /*! \brief The outputs of the node. */ - LinkedList outputs; - }; - /*! \brief The node map that maps node to graph */ - std::unordered_map node_map; - /*! \brief All the nodes in post DFS order */ - std::vector post_dfs_order; - - /*! \brief Dump the graph into string. */ - void DebugDump() { - std::ostringstream os; - for (size_t i = 0; i < post_dfs_order.size(); ++i) { - Node* node = post_dfs_order[i]; - os << "node[" << i << "], " << GetRef(node->ref) << " outputs=["; - for (auto* link = node->outputs.head; link != nullptr; link = link->next) { - os << link->value.node->index << ", "; - } - os << "]\n"; - } - LOG(INFO) << os.str(); + static IndexedForwardGraph Create(support::Arena* arena, const Expr& body) { + IndexedForwardGraphCreator creator(arena); + return creator.Prepare(body); } - /*! - * \brief create a indexed forward graph. - * \param arena The arena used for data allocation. - * \param body The body of the expression to create a graph. - */ - static IndexedForwardGraph Create(support::Arena* arena, const Expr& body); private: - class Creator; -}; - -// Creator of post dominator tree of the dataflow -class IndexedForwardGraph::Creator : private ExprVisitor { - public: - explicit Creator(support::Arena* arena) : arena_(arena) {} + explicit IndexedForwardGraphCreator(support::Arena* arena) : arena_(arena) {} IndexedForwardGraph Prepare(const Expr& body) { this->Update(body, nullptr, kOpaque); @@ -213,7 +158,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { void VisitExpr_(const ConstantNode* op) final { this->AddNode(op); - Node* node = graph_.node_map.at(op); + IndexedForwardGraph::Node* node = graph_.node_map.at(op); DataType dtype = DataType(op->data->dtype); // This rule must be consistent with code generator. bool is_simple_const = @@ -230,7 +175,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { void VisitExpr_(const CallNode* call) final { ICHECK(graph_.node_map.count(call)); - Node* node = graph_.node_map.at(call); + IndexedForwardGraph::Node* node = graph_.node_map.at(call); static auto fpattern = Op::GetAttrMap("TOpPattern"); // Now we set the pattern of this call. // @@ -274,7 +219,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { void VisitExpr_(const TupleNode* op) final { ICHECK(graph_.node_map.count(op)); - Node* tuple_node = graph_.node_map.at(op); + IndexedForwardGraph::Node* tuple_node = graph_.node_map.at(op); tuple_node->pattern = kTuple; for (const Expr& field : op->fields) { if (field->checked_type().as()) { @@ -306,7 +251,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { this->Update(op->tuple, nullptr, kOpaque); } else { ICHECK(graph_.node_map.count(op)); - Node* node = graph_.node_map.at(op); + IndexedForwardGraph::Node* node = graph_.node_map.at(op); node->pattern = kInjective; this->Update(op->tuple, node, kInjective); } @@ -372,443 +317,6 @@ class IndexedForwardGraph::Creator : private ExprVisitor { } }; -IndexedForwardGraph IndexedForwardGraph::Create(support::Arena* arena, const Expr& body) { - return Creator(arena).Prepare(body); -} - -/*! - * \brief Dominator tree that represent domination or - * post domination relation of the node. - */ -class DominatorTree { - public: - /*! - * \brief A node in the dominator tree. - */ - struct Node { - /*! \brief The node in the tree */ - IndexedForwardGraph::Node* gnode{nullptr}; - /*! \brief parent of the tree */ - Node* parent{nullptr}; - /*! \brief current depth*/ - int depth{0}; - /*! \brief aggregated pattern to parent */ - OpPatternKind pattern{kOpaque}; - }; - // index -> node. - std::vector nodes; - /*! - * \brief compute a post dominator relation for a given dataflow graph. - * \param arena The arena used for node allocation. - * \param graph The graph to be analyzed. - * \return The dominator tree of the graph. - * \note This algorithm makes use of the fact that graph is DAG, - * and runs a single pass algorithm via LCA (Least Common Ancestor) - */ - static DominatorTree PostDom(support::Arena* arena, const IndexedForwardGraph& graph); - - private: - // Combine pattern together. - static OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) { - if (lhs > rhs) return lhs; - return rhs; - } - /*! - * \brief Find the least common ancestor of the two nodes. - * \param lhs The left node. - * \param rhs The right node. - * \param edge_pattern - * The combined edge pattern across all the parents. - * \return The least common ancestor of the two. - */ - static Node* LeastCommonAncestor(Node* lhs, Node* rhs, OpPatternKind* edge_pattern) { - while (lhs != rhs) { - if (lhs == nullptr) return nullptr; - if (rhs == nullptr) return nullptr; - if (lhs->depth < rhs->depth) { - edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern); - rhs = rhs->parent; - } else if (rhs->depth < lhs->depth) { - edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern); - lhs = lhs->parent; - } else { - edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern); - edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern); - lhs = lhs->parent; - rhs = rhs->parent; - } - } - return lhs; - } - /*! - * \brief Find the least common ancestor of a list of nodes. - * \param nodes the nodes. - * \param edge_pattern - * The combined edge pattern across all the parents. - * \return The least common ancestor of all nodes. - */ - Node* LeastCommonAncestor(const LinkedList& input_nodes, - OpPatternKind* edge_pattern) { - auto link = input_nodes.head; - if (link == nullptr) { - return nullptr; - } - auto get_node = [&](const IndexedForwardGraph::Edge& edge) { - size_t oindex = edge.node->index; - ICHECK_LT(oindex, nodes.size()); - Node* onode = nodes[oindex]; - ICHECK(onode != nullptr); - return onode; - }; - Node* parent = get_node(link->value); - *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern); - link = link->next; - for (; link != nullptr; link = link->next) { - parent = LeastCommonAncestor(parent, get_node(link->value), edge_pattern); - *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern); - } - return parent; - } - /*! - * \brief Convert the Node from an IndexedForwardGraph Node into DomaintorTree Node. - * \param arena The Arena. - * \param gnode An IndexedForwardGraph Node. - * \return The DominatorTree Node. - */ - Node* GetNode(support::Arena* arena, IndexedForwardGraph::Node* gnode) { - Node* tnode = arena->make(); - tnode->gnode = gnode; - if (gnode->extern_ref) { - tnode->depth = 1; - tnode->parent = nullptr; - tnode->pattern = kOpaque; - } else { - // find the LCAs of all outputs. - OpPatternKind pattern = kElemWise; - Node* parent = LeastCommonAncestor(gnode->outputs, &pattern); - tnode->depth = parent ? parent->depth + 1 : 1; - tnode->parent = parent; - tnode->pattern = pattern; - } - return tnode; - } -}; - -DominatorTree DominatorTree::PostDom(support::Arena* arena, const IndexedForwardGraph& graph) { - DominatorTree tree; - tree.nodes.resize(graph.post_dfs_order.size(), nullptr); - // reverse topo order - for (size_t i = graph.post_dfs_order.size(); i != 0; --i) { - size_t index = i - 1; - tree.nodes[index] = tree.GetNode(arena, graph.post_dfs_order[index]); - } - return tree; -} - -/*! - * \brief A partition of the graph marked by union find data structure. - */ -class GraphPartitioner { - public: - explicit GraphPartitioner(support::Arena* arena, int opt_level, size_t max_fuse_depth) - : arena_(arena), opt_level_(opt_level), max_fuse_depth_(max_fuse_depth) {} - /*! - * \brief Group as a union find data structure. - */ - struct Group { - /*! \brief The parent in the union find data structure. */ - Group* parent{nullptr}; - /*! \brief The pattern of the group */ - OpPatternKind pattern; - /*! \brief reference to the root node. */ - const tvm::Object* root_ref{nullptr}; - /*! - * \brief Reference to the anchor node, - * this field is not nullptr only if pattern is kOutEWiseFusable. - */ - const tvm::Object* anchor_ref{nullptr}; - /*! - * \brief Find the group root, perform path compression - * \return The root type node. - */ - Group* FindRoot() { - // fast path - if (this->parent == nullptr) return this; - // slow path with path compression. - Group* root = this; - while (root->parent != nullptr) { - root = root->parent; - } - for (Group* p = this; p != root;) { - Group* parent = p->parent; - p->parent = root; - p = parent; - } - return root; - } - - /*! - * \brief The number of nodes belonging to this group - */ - uint32_t num_nodes{1}; - }; - /*! - * \brief Partition a graph. - * \return group assignments of each node. - */ - std::vector Partition(const IndexedForwardGraph& graph); - - private: - /*! \brief The internal arena for temporary space. */ - support::Arena* arena_; - /*! \brief optimization level for fuse operation. */ - int opt_level_; - /*! \brief The maximum number of operations in one fused function */ - size_t max_fuse_depth_; - /*! \brief The internal groups. */ - std::vector groups_; - /*! \brief internal field used for deduplication */ - std::unordered_set visited_; - // Internal implelementation of CheckPath - template - bool CheckPath_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond) { - if (visited_.count(src)) return true; - visited_.insert(src); - Group* gnode = groups_[src->index]; - ICHECK(gnode != nullptr); - gnode = gnode->FindRoot(); - if (!fcond(gnode->pattern, src == sink)) return false; - if (src == sink) return true; - for (auto link = src->outputs.head; link != nullptr; link = link->next) { - if (!CheckPath_(link->value.node, sink, fcond)) return false; - } - return true; - } - /*! - * \brief Check all the node and edge pattern - * between src and sink satisfies fcond. - * - * src is not checked. - * - * \param src The source node. - * \param sink The termination node. - * \param fcond The condition to be checked. - * \tparam F the condition function, with signature - * \note sink must be a post-dominator of src. - */ - template - bool CheckPath(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, F fcond) { - ICHECK(!src->extern_ref); - visited_.clear(); - ICHECK(src != sink); - for (auto link = src->outputs.head; link != nullptr; link = link->next) { - if (!CheckPath_(link->value.node, sink, fcond)) return false; - } - return true; - } - // Combine two patterns together. - static OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) { - if (lhs > kBroadcast && rhs > kBroadcast) { - LOG(FATAL) << "Cannot merge two complex group together"; - } - if (lhs > rhs) return lhs; - return rhs; - } - /*! - * \brief Merge the child group to the parent. - * \param child The child group. - * \param parent The parent group. - */ - void MergeFromTo(Group* child, Group* parent) { - child = child->FindRoot(); - parent = parent->FindRoot(); - if (child == parent) return; - // update the number of nodes of the parent group - parent->num_nodes += child->num_nodes; - child->parent = parent; - // update anchor ref and pattern - if (child->anchor_ref != nullptr) { - ICHECK(parent->anchor_ref == nullptr); - parent->anchor_ref = child->anchor_ref; - parent->pattern = CombinePattern(child->pattern, parent->pattern); - } - } - // Internal implelementation of CommitFuse - void CommitFuse_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink, Group* target) { - if (src == sink) return; - if (visited_.count(src)) return; - visited_.insert(src); - Group* gnode = groups_[src->index]; - ICHECK(gnode != nullptr); - // merge the current group to the parent if possible. - MergeFromTo(gnode, target); - for (auto link = src->outputs.head; link != nullptr; link = link->next) { - CommitFuse_(link->value.node, sink, target); - } - } - /*! - * \brief Commit fusion operation. - * \param src The source node. - * \param sink The termination node. - * \note sink must be a post-dominator of src. - */ - void CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) { - Group* target = groups_[sink->index]; - visited_.clear(); - ICHECK(src != sink); - CommitFuse_(src, sink, target); - } - - size_t CountNodesUptoSink_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) { - if (src == sink || visited_.count(src)) return 0; - visited_.insert(src); - Group* gnode = groups_[src->index]; - ICHECK(gnode != nullptr); - auto sum = gnode->num_nodes; - for (auto link = src->outputs.head; link != nullptr; link = link->next) { - sum += CountNodesUptoSink_(link->value.node, sink); - } - return sum; - } - - // Count the number of nodes in a fused subgraph if child is additionaly fused. - // dom_parent is already known to be a part of the subgraph. - // For a diamond structure, there can be multiple paths connecting child and dom_parent. - // All intermediate nodes between child and dom_parent are taken into account. - // Since dom_parent can itself be an intermediate node in the subgraph, calling FindRoot() - // is important for correct calculation. - size_t CountFusedNodesWithNewChild(IndexedForwardGraph::Node* child, - IndexedForwardGraph::Node* dom_parent) { - Group* target = groups_[dom_parent->index]; - visited_.clear(); - ICHECK(child != dom_parent); - return target->FindRoot()->num_nodes + CountNodesUptoSink_(child, dom_parent); - } - - // Initialize the groups. - void InitGroups(const IndexedForwardGraph& graph) { - groups_.resize(graph.post_dfs_order.size()); - for (size_t nid = 0; nid < groups_.size(); ++nid) { - const auto* graph_node = graph.post_dfs_order[nid]; - auto* group_node = arena_->make(); - group_node->pattern = graph_node->pattern; - group_node->root_ref = graph_node->ref; - // set anchor ref if necessary. - if (group_node->pattern == kOutEWiseFusable) { - group_node->anchor_ref = graph_node->ref; - } - groups_[nid] = group_node; - } - } - - // execute the fusion algorithm. - void RunFuse(const IndexedForwardGraph& graph, const DominatorTree& post_dom_tree, int phase) { - for (size_t nid = 0; nid < groups_.size(); ++nid) { - // the group of current node has been specified already. - auto* graph_node = graph.post_dfs_order[nid]; - auto* dom_node = post_dom_tree.nodes[nid]; - Group* group_node = groups_[nid]; - ICHECK(group_node != nullptr); - // no actions for opaque nodes - if (group_node->pattern == kOpaque) continue; - // no actions needed if the current node have no dominator - if (dom_node->parent == nullptr) continue; - ICHECK(!graph_node->extern_ref); - size_t dom_parent_gindex = dom_node->parent->gnode->index; - - // refuse the fusion if too many ops are going to be fused together - if (CountFusedNodesWithNewChild(graph_node, dom_node->parent->gnode) > max_fuse_depth_) - continue; - - if (phase == 2) { - // Fuse injective ops into intermediate tuples, if any - if (group_node->pattern > kInjective) continue; - Group* dom_parent_group = groups_[dom_parent_gindex]; - Group* dom_root_group = dom_parent_group->FindRoot(); - // If dom node group has a tuple as its root, we do not fuse tuple fields into it - if (dom_root_group->pattern == kTuple) continue; - if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= kInjective) { - // Now we know the tuple has been fused into subsequent injective ops - auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; }; - // dom_root_group can also be tuple, as in inception layers - // CheckPath is needed to avoid fusing two intermediate tuples - if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { - CommitFuse(graph_node, dom_node->parent->gnode); - } - } - continue; - } - - // Skip if current node is already fused to the parent. - if (groups_[dom_parent_gindex] != nullptr && - group_node->FindRoot() == groups_[dom_parent_gindex]->FindRoot()) { - continue; - } - // Do not fuse into tuple for now - if (groups_[dom_parent_gindex]->pattern == kTuple) continue; - // Try to fuse current node to its post-dominator. - if (group_node->pattern == kOutEWiseFusable) { - if (phase != 0) continue; - // Path for OutEWiseFusable: conv2d - // Check if the dominator relation is elemwise. - if (dom_node->parent != nullptr && dom_node->pattern == kElemWise) { - ICHECK(dom_node->parent->gnode != nullptr); - // The fuse can be executed if all the intermediate ops are still broadcast. - auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kBroadcast; }; - if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { - CommitFuse(graph_node, dom_node->parent->gnode); - } - } - } else if (group_node->pattern <= kBroadcast) { - // Pre-condition: can only be fused to parent which is injective or reduction. - if (dom_node->parent != nullptr && - (dom_node->pattern <= kInjective || dom_node->pattern == kCommReduce)) { - // Check if all the intermediate ops are still broadcast. - // The final terminal node can already be fused to a OutEWiseFusable group. - auto fcond = [](OpPatternKind kind, bool is_sink) { - if (!is_sink) { - // Elemwise, broadcast, and injective ops on the parallel branches - // are allowed be fused to the elemwise/broadcast anchor. - return kind <= kInjective; - } else { - return (kind <= kBroadcast || kind == kCommReduce || kind == kInjective || - kind == kOutEWiseFusable); - } - }; - if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { - CommitFuse(graph_node, dom_node->parent->gnode); - } - } - } else if (group_node->pattern == kInjective || group_node->pattern == kTuple) { - // defer injective fusion to second phase. - // so conv2d always finishes fusing. - if (phase != 1) continue; - // Check if all path are injective. - auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; }; - if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { - CommitFuse(graph_node, dom_node->parent->gnode); - } - } else { - // do nothing. - ICHECK(group_node->pattern == kCommReduce); - } - } - } -}; - -std::vector GraphPartitioner::Partition( - const IndexedForwardGraph& graph) { - this->InitGroups(graph); - if (opt_level_ == 0) return std::move(groups_); - // get post dominator tree - auto post_dom_tree = DominatorTree::PostDom(arena_, graph); - // run fusion algorithm. - for (int phase = 0; phase < 3; ++phase) { - this->RunFuse(graph, post_dom_tree, phase); - } - return std::move(groups_); -} - class FuseMutator : private MixedModeMutator { public: FuseMutator(int fuse_opt_level, size_t max_fuse_depth, bool link_params) @@ -825,7 +333,7 @@ class FuseMutator : private MixedModeMutator { // Run the transform Expr Transform(const Expr& body, int fuse_opt_level, size_t max_fuse_depth, bool link_params) { // setup the group map. - auto graph = IndexedForwardGraph::Create(&arena_, body); + auto graph = IndexedForwardGraphCreator::Create(&arena_, body); auto groups = GraphPartitioner(&arena_, fuse_opt_level, max_fuse_depth).Partition(graph); for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) { ICHECK(graph.post_dfs_order[nid]->ref != nullptr);