diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 8d7ed163a197..df896cb690eb 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -36,6 +36,7 @@ namespace relay { // Pattern Matcher bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) { + VLOG(1) << "Match " << PrettyPrint(pattern) << " in:" << std::endl << PrettyPrint(expr); memo_.clear(); matched_nodes_.clear(); return VisitDFPattern(pattern, expr); @@ -58,6 +59,7 @@ bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr if (out) { memo_[pattern].push_back(expr); matched_nodes_.push_back(pattern); + VLOG(1) << "Matched " << PrettyPrint(pattern) << " at:" << std::endl << PrettyPrint(expr); } else { ClearMap(watermark); } @@ -124,7 +126,6 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons if (!matches) { return matches; } - VLOG(1) << "considering AttrPatternNode at:\n" << PrettyPrint(expr); auto attributes = attr_pattern->attrs.as()->dict; if (const auto* op_node = expr.as()) { Op op = GetRef(op_node); @@ -299,14 +300,18 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex // Recursively find the Dominator parent along all inputs paths. bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) { auto call_node = expr.as(); - for (auto node : expr_graph_.node_map_.at(expr)->inputs_) { - if (!(call_node && node->ref_ == call_node->op)) { + auto index_node = expr_to_node(expr); + for (auto node : index_node->inputs_) { + if (!(call_node && node->ref() == call_node->op)) { memoize_ = true; - if (VisitDFPattern(op->parent, node->ref_)) { + if (VisitDFPattern(op->parent, node->ref())) { return true; } else { memoize_ = false; - if (!VisitDFPattern(op->path, node->ref_) || !MatchesPath(op, node->ref_)) { + if (!VisitDFPattern(op->path, node->ref())) { + return false; + } + if (!MatchesPath(op, node->ref())) { return false; } } @@ -318,19 +323,19 @@ bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& e // Iteratively ensure that the parent is dominated somewhere by the child or the path bool DFPatternMatcher::DominatesParent(const DominatorPatternNode* op, const Expr& expr) { std::stack stack; - std::unordered_set visited; + std::unordered_set visited; stack.push(expr); while (!stack.empty()) { Expr current = stack.top(); stack.pop(); - for (auto node : expr_graph_.node_map_.at(current)->dominator_children_) { - if (visited.count(node->ref_) == 0) { - if (VisitDFPattern(op->parent, node->ref_)) { + for (auto node : expr_to_node(current)->dominator_children_) { + if (visited.count(node->node_ref_) == 0) { + if (VisitDFPattern(op->parent, node->ref())) { return true; } else { - stack.push(node->ref_); + stack.push(node->ref()); } - visited.insert(node->ref_); + visited.insert(node->node_ref_); } } } @@ -500,7 +505,8 @@ bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr } bool MatchPattern(DFPattern pattern, Expr expr) { - return DFPatternMatcher(expr).Match(pattern, expr); + std::unique_ptr> expr_graph = CreateIndexedGraph(expr); + return DFPatternMatcher(expr_graph.get()).Match(pattern, expr); } TVM_REGISTER_GLOBAL("relay.dataflow_pattern.match").set_body_typed(MatchPattern); @@ -575,7 +581,8 @@ const std::unordered_map& PatternGrouper::GroupMatch pattern_ = pattern; pattern_graph_ = CreateIndexedGraph(pattern_); - auto matcher = DFPatternMatcher(pre); + std::unique_ptr> expr_graph = CreateIndexedGraph(pre); + DFPatternMatcher matcher(expr_graph.get()); matcher_ = &matcher; this->VisitExprs(); return this->groups_; @@ -583,9 +590,9 @@ const std::unordered_map& PatternGrouper::GroupMatch void PatternGrouper::VisitExprs() { std::unordered_set pre_partitioned; - for (size_t i = matcher_->expr_graph_.topological_order_.size(); i != 0; --i) { - size_t index = i - 1; - Expr current = matcher_->expr_graph_.topological_order_.at(index)->ref_; + for (PostDfsIndex i = matcher_->size(); i != 0; --i) { + PostDfsIndex index = i - 1; + const auto current = matcher_->index_to_node(index)->ref(); if (gid_assignments_.count(current) == 0) { // Don't visit nodes we've already grouped if (auto op = current.as()) { if (op->attrs.defined() && op->attrs->dict.count(attr::kPartitionedFromPattern) != 0) { @@ -607,9 +614,10 @@ void PatternGrouper::CreateGroup(const Expr& expr) { auto node_map = matcher_->GetMemo(); // Get fuzzy patterns std::unordered_set fuzzy_matches; - for (auto node : pattern_graph_.topological_order_) { + for (PostDfsIndex index = 0; index < pattern_graph_->size(); ++index) { + auto node = pattern_graph_->index_to_node(index); // Don't treat fuzzy Dominator patterns input variables for partition - if (auto op = node->ref_.as()) { + if (auto op = node->ref().as()) { for (auto fuzzy_op : {op->parent, op->path}) { for (auto match : node_map[fuzzy_op]) { fuzzy_matches.insert(match); @@ -617,12 +625,13 @@ void PatternGrouper::CreateGroup(const Expr& expr) { } } // Don't treat Function params or body as input variables for partition - if (node->ref_.as()) { - auto matches = node_map[node->ref_]; + if (node->ref().as()) { + auto matches = node_map[node->ref()]; for (auto match : matches) { - auto graph = CreateIndexedGraph(match.as()->body); - for (auto node : graph.topological_order_) { - fuzzy_matches.insert(node->ref_); + auto sub_graph = CreateIndexedGraph(match.as()->body); + for (PostDfsIndex sub_index = 0; sub_index < sub_graph->size(); ++sub_index) { + auto sub_node = sub_graph->index_to_node(sub_index); + fuzzy_matches.insert(sub_node->ref()); } } } @@ -636,10 +645,11 @@ void PatternGrouper::CreateGroup(const Expr& expr) { std::unordered_map inputs; Array params; - for (auto node : pattern_graph_.topological_order_) { + for (PostDfsIndex index = 0; index < pattern_graph_->size(); ++index) { + auto node = pattern_graph_->index_to_node(index); auto make_input = [&](const Expr& input) { if (fuzzy_matches.count(input) == 0 && input.as() == nullptr && - input.as() == nullptr && !EmbedConst(input, node->ref_)) { + input.as() == nullptr && !EmbedConst(input, node->ref())) { inputs[input] = Var("FunctionVar_" + std::to_string(graph_number_) + "_" + std::to_string(var_number), NullValue()); @@ -648,11 +658,11 @@ void PatternGrouper::CreateGroup(const Expr& expr) { var_number++; } }; - auto tuple = node->ref_.as(); - auto call = node->ref_.as(); + auto tuple = node->ref().as(); + auto call = node->ref().as(); if (tuple && !tuple->fields.defined()) { - if (node_map.count(node->ref_)) { - auto matches = node_map[node->ref_]; + if (node_map.count(node->ref())) { + auto matches = node_map[node->ref()]; for (auto match : matches) { for (auto input : match.as()->fields) { make_input(input); @@ -660,8 +670,8 @@ void PatternGrouper::CreateGroup(const Expr& expr) { } } } else if (call && !call->args.defined()) { - if (node_map.count(node->ref_)) { - auto matches = node_map[node->ref_]; + if (node_map.count(node->ref())) { + auto matches = node_map[node->ref()]; for (auto match : matches) { for (auto input : match.as()->args) { make_input(input); @@ -669,8 +679,8 @@ void PatternGrouper::CreateGroup(const Expr& expr) { } } } else if (node->inputs_.size() == 0) { - if (node_map.count(node->ref_)) { - auto matches = node_map[node->ref_]; + if (node_map.count(node->ref())) { + auto matches = node_map[node->ref()]; for (auto match : matches) { make_input(match); } @@ -708,13 +718,17 @@ void PatternGrouper::CreateGroup(const Expr& expr) { return; } else if (kv.second != body) { // if the node isn't the output of the group - auto node = matcher_->expr_graph_.node_map_.at(kv.first); + auto node = matcher_->expr_to_node(kv.first); for (auto* output : node->outputs_) { // and the node is used by nodes outside of the group - if (memo.count(output->ref_) == 0 && - !matcher_->expr_graph_.node_map_.at(expr)->Dominates(output)) { - // Exit because nodes in this pattern's body are used outside the pattern - // fusing it would be invalid + if (memo.count(output->ref()) == 0) { + // TODO(mbs): This condition used to also include the following test, which since + // the dominators relation is used back-to-front was always vacuously true. So the + // code is just rejecting the match if a strictly internal node happened to connect + // to an outside node. + ICHECK(!matcher_->expr_to_node(expr)->Dominates(output)); + // Exit because nodes in this pattern's body are used outside the pattern, fusing it + // would be invalid return; } } diff --git a/src/relay/ir/dataflow_matcher_impl.h b/src/relay/ir/dataflow_matcher_impl.h index d993d4720e4e..f04190f72e40 100644 --- a/src/relay/ir/dataflow_matcher_impl.h +++ b/src/relay/ir/dataflow_matcher_impl.h @@ -27,7 +27,9 @@ #include #include #include +#include +#include #include #include #include @@ -39,10 +41,20 @@ namespace relay { class DFPatternMatcher : public DFPatternFunctor { public: - explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {} + explicit DFPatternMatcher(const IndexedGraph* expr_graph) : expr_graph_(expr_graph) {} bool Match(const DFPattern& pattern, const Expr& expr); Map> GetMemo() { return Map>(memo_); } - const IndexedGraph expr_graph_; + + const IndexedGraph::Node* expr_to_node(const Expr& expr) const { + return expr_graph_->item_to_node(expr); + } + const IndexedGraph::Node* index_to_node(size_t index) const { + return expr_graph_->index_to_node(index); + } + size_t size() const { return expr_graph_->size(); } + const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& memo() const { + return memo_; + } protected: bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override; @@ -67,6 +79,7 @@ class DFPatternMatcher : public DFPatternFunctor* expr_graph_; std::unordered_map, ObjectPtrHash, ObjectPtrEqual> memo_; std::vector matched_nodes_; bool memoize_ = true; @@ -131,7 +144,7 @@ class PatternGrouper { std::unordered_map groups_; std::unordered_map gid_assignments_; DFPatternMatcher* matcher_ = nullptr; - IndexedGraph pattern_graph_; + std::unique_ptr> pattern_graph_; int gid_ = 0; int graph_number_ = 0; }; diff --git a/src/relay/ir/indexed_graph.cc b/src/relay/ir/indexed_graph.cc index 4efe57b491db..f39ff4850eae 100644 --- a/src/relay/ir/indexed_graph.cc +++ b/src/relay/ir/indexed_graph.cc @@ -19,195 +19,393 @@ /*! * \file src/relay/ir/indexed_graph.cc - * \brief Utilties for Creating Indexed Graphs. + * \brief A graph representation of the dataflow in a Relay expression or Relay (dataflow) + * pattern. */ #include "indexed_graph.h" #include #include #include -#include +#include + +#include namespace tvm { namespace relay { -// IndexedGraph +std::string RefToSummary(const Expr& expr) { + class Visitor : public ExprFunctor { + std::string VisitExpr_(const VarNode* op) final { return "%" + op->name_hint(); } + std::string VisitExpr_(const GlobalVarNode* op) final { return "@" + op->name_hint; } + std::string VisitExpr_(const ConstantNode* op) final { return "const"; } + std::string VisitExpr_(const TupleNode* op) final { + return "tuple(" + std::to_string(op->fields.size()) + ")"; + } + std::string VisitExpr_(const FunctionNode* op) final { return "fn"; } + std::string VisitExpr_(const CallNode* op) final { + return VisitExpr(op->op) + "(" + std::to_string(op->args.size()) + ")"; + } + std::string VisitExpr_(const LetNode* op) final { return "let"; } + std::string VisitExpr_(const IfNode* op) final { return "if"; } + std::string VisitExpr_(const OpNode* op) final { return op->name; } + std::string VisitExpr_(const TupleGetItemNode* op) final { + return "." + std::to_string(op->index); + } + std::string VisitExpr_(const RefCreateNode* op) final { return "ref_create"; } + std::string VisitExpr_(const RefReadNode* op) final { return "ref_read"; } + std::string VisitExpr_(const RefWriteNode* op) final { return "ref_write"; } + std::string VisitExpr_(const ConstructorNode* op) final { return "ctor"; } + std::string VisitExpr_(const MatchNode* op) final { return "match"; } + }; + return Visitor().VisitExpr(expr); +} + +std::string RefToSummary(const DFPattern& pattern) { + // TODO(mbs): Implement as debugging requires. + return ""; +} -IndexedGraph CreateIndexedGraph(const Expr& expr) { - using NodePtr = std::shared_ptr::Node>; - /*! \brief Creator Creates an IndexedGraph and determintes Topological order */ +std::unique_ptr> CreateIndexedGraph(const Expr& expr) { + /*! + * \brief Adds indexed graph nodes in post-dfs order, and discovers which let-bound vars are to + * recursive functions. + */ class Creator : public MixedModeVisitor { public: - IndexedGraph CreateGraph(const Expr& expr) { + std::pair>, + std::unique_ptr>> + CreateGraph(const Expr& expr) { VisitExpr(expr); - graph_.node_map_[expr]->is_external_ = true; - return std::move(graph_); + // Last visited node is implicitly used 'externally'. + graph_->item_to_node(expr)->is_external_ = true; + return {std::move(graph_), std::move(rec_calls_)}; } protected: using MixedModeVisitor::VisitExpr_; + // By the default the MixedModeVisitor will place + // - callee and arguments before a call + // - tuple fields before a tuple + // - tuple before a tuple projection void VisitLeaf(const Expr& expr) override { + if (const auto* var_node = expr.as()) { + if (var_node == current_let_bound_var_) { + // Don't visit occurrences of let-rec bound vars in the recursive function body. + // Instead, wait for them to be visited at call sites outside of the function. + VLOG(1) << "Ignore let-rec var '" << var_node->name_hint() << "'"; + return; + } + } + MixedModeVisitor::VisitLeaf(expr); - auto node = std::make_shared::Node>(expr, index_++); - graph_.node_map_[expr] = node; - graph_.topological_order_.push_back(node); + graph_->AddNode(expr); + + if (const auto* call_node = expr.as()) { + if (const auto* var_node = call_node->op.as()) { + if (var_node == current_let_bound_var_) { + // Remember this is a recursive call to the let-rec bound function. + // The Annotator functor below will not record any dependency from the let-rec bound + // var to the expression so that the indexed graph is always a DAG. + VLOG(1) << "Remembering recursive call to '" << var_node->name_hint() << "'"; + rec_calls_->emplace(call_node); + } + } + } } - void VisitExpr_(const LetNode* let) override { + void VisitExpr_(const LetNode* let_node) override { auto pre_visit = [&](const LetNode* op) { - this->VisitSpan(op->span); - this->VisitExpr(op->value); - this->VisitExpr(op->var); + // Let-bound values come before their let-bound variable. + const VarNode* prev_let_bound_var = current_let_bound_var_; + current_let_bound_var_ = op->var.get(); + VisitExpr(op->value); + current_let_bound_var_ = prev_let_bound_var; + VisitExpr(op->var); }; auto post_visit = [&](const LetNode* op) { - this->VisitExpr(op->body); - if (let != op) { - Expr expr = GetRef(op); + VisitExpr(op->body); + if (let_node != op) { + // Replicate VisitLeaf, which we are effectively bypassing. visit_counter_[op]++; - auto node = std::make_shared::Node>(expr, index_++); - graph_.node_map_[expr] = node; - graph_.topological_order_.push_back(node); + graph_->AddNode(GetRef(op)); } }; - ExpandANormalForm(let, pre_visit, post_visit); + ExpandANormalForm(let_node, pre_visit, post_visit); } - IndexedGraph graph_; - size_t index_ = 0; + class PatternCreator : public PatternVisitor { + public: + explicit PatternCreator(Creator* creator) : creator_(creator) {} + + private: + void VisitPattern_(const PatternVarNode* pattern_var_node) final { + creator_->VisitLeaf(pattern_var_node->var); + } + + Creator* creator_; + }; + + void VisitExpr_(const MatchNode* match_node) override { + // Matched data comes before match-bound vars then match rhs, in match order. + VisitExpr(match_node->data); + for (const Clause& c : match_node->clauses) { + PatternCreator pattern_creator(this); + pattern_creator.VisitPattern(c->lhs); + VisitExpr(c->rhs); + } + } + + /*! \brief Graph we are accumulated nodes into. */ + std::unique_ptr> graph_ = std::make_unique>(); + /*! \brief Variable the currently visited expression is to be let-bound to, if any. */ + const VarNode* current_let_bound_var_ = nullptr; + /*! \brief Accumulated calls to recursive functions. */ + std::unique_ptr> rec_calls_ = + std::make_unique>(); }; - /*! \brief Annotator takes an IndexedGraph, fills it's forward outputs, and does dominator tree - * analysis. + + /*! + * \brief Fills in the inputs and outputs for all nodes, then does dominator analysis. * - * Annotator use ExprFunctor to visit nodes, but iterates over them in pre-determined - * topological order instead of recursing. + * Thought we use the ExprFunctor to visit nodes, we never recurse and instead just inspect + * each sub-expression's immediate sub-sub-expressions to accumulate inputs and outputs. */ - class Annotator : public ExprFunctor { + class Annotator : public ExprFunctor { public: - Annotator(const IndexedGraph& graph) : graph_(graph) {} - IndexedGraph Annotate() { + explicit Annotator(std::pair>, + std::unique_ptr>> + args) + : graph_(std::move(args.first)), rec_calls_(std::move(args.second)) {} + + std::unique_ptr> Annotate() { // Visit all of the nodes in topological order to get forward outputs - for (const auto& node : graph_.topological_order_) { - ExprFunctor::VisitExpr(node->ref_, nullptr); + for (PostDfsIndex index = 0; index < graph_->size(); ++index) { + VisitExpr(graph_->index_to_node(index)->ref()); } // do the dominator analysis - graph_.PostDom(); + graph_->PostDom(); return std::move(graph_); } - /*! Default visitation pushes the parent to the child's outputs and the child to the parent's - * inputs*/ - void VisitExpr(const Expr& expr, NodePtr parent) override { - auto current = graph_.node_map_[expr]; - if (parent) { - current->outputs_.push_back(parent.get()); - parent->inputs_.push_back(current.get()); - } + /*! + * \brief Add \p parent as a possible output of the node corresponding to \p expr. + */ + void AddOutput(const Expr& expr, IndexedGraph::Node* parent) { + auto current = graph_->item_to_node(expr); + current->outputs_.push_back(parent); + parent->inputs_.push_back(current); } protected: - IndexedGraph graph_; - void VisitExpr_(const VarNode* op, NodePtr parent) override { - if (op->type_annotation.defined()) { - this->VisitType(op->type_annotation); - } - } + void VisitExpr_(const VarNode* var_node) override {} - void VisitExpr_(const GlobalVarNode* op, NodePtr parent) override {} + void VisitExpr_(const GlobalVarNode* global_var_node) override {} - void VisitExpr_(const ConstantNode* op, NodePtr parent) override {} + void VisitExpr_(const ConstantNode* constant_node) override {} - void VisitExpr_(const TupleNode* op, NodePtr parent) override { - for (auto field : op->fields) { - this->VisitExpr(field, graph_.node_map_[GetRef(op)]); + void VisitExpr_(const TupleNode* tuple_node) override { + auto node = graph_->item_to_node(GetRef(tuple_node)); + for (auto field : tuple_node->fields) { + AddOutput(field, node); } } - void VisitExpr_(const FunctionNode* op, NodePtr parent) override { - for (auto param : op->params) { - this->VisitExpr(param, graph_.node_map_[GetRef(op)]); + void VisitExpr_(const FunctionNode* function_node) override { + auto node = graph_->item_to_node(GetRef(function_node)); + // Nothing to do for parameters -- each use of a parameter will contribute to its outputs. + AddOutput(function_node->body, node); + } + + void VisitExpr_(const CallNode* call_node) override { + auto node = graph_->item_to_node(GetRef(call_node)); + if (rec_calls_->count(call_node)) { + // We want the indexed graph to be a DAG, so don't consider a call to a let-rec bound + // function from inside the function to depend on the let-rec bound var. + VLOG(1) << "Ignoring op in call " << RefToSummary(GetRef(call_node)); + } else { + AddOutput(call_node->op, node); + } + for (auto arg : call_node->args) { + AddOutput(arg, node); } + } + + void VisitExpr_(const LetNode* let_node) override { + auto node = graph_->item_to_node(GetRef(let_node)); + auto let_var_node = graph_->item_to_node(let_node->var); + AddOutput(let_node->value, let_var_node); + // Nothing to do for the let-bound variable -- each use of that variable in the let-body + // will contribute to its outputs. + AddOutput(let_node->body, node); + } - this->VisitExpr(op->body, graph_.node_map_[GetRef(op)]); + void VisitExpr_(const IfNode* if_node) override { + auto node = graph_->item_to_node(GetRef(if_node)); + AddOutput(if_node->cond, node); + AddOutput(if_node->true_branch, node); + AddOutput(if_node->false_branch, node); } - void VisitExpr_(const CallNode* op, NodePtr parent) override { - this->VisitExpr(op->op, graph_.node_map_[GetRef(op)]); + void VisitExpr_(const OpNode* op_node) override {} - for (auto ty_arg : op->type_args) { - this->VisitType(ty_arg); + void VisitExpr_(const TupleGetItemNode* tuple_get_item_node) override { + auto node = graph_->item_to_node(GetRef(tuple_get_item_node)); + AddOutput(tuple_get_item_node->tuple, node); + } + + void VisitExpr_(const RefCreateNode* ref_create_node) override { + auto node = graph_->item_to_node(GetRef(ref_create_node)); + AddOutput(ref_create_node->value, node); + } + + void VisitExpr_(const RefReadNode* ref_read_node) override { + auto node = graph_->item_to_node(GetRef(ref_read_node)); + AddOutput(ref_read_node->ref, node); + } + + void VisitExpr_(const RefWriteNode* ref_write_node) override { + auto node = graph_->item_to_node(GetRef(ref_write_node)); + AddOutput(ref_write_node->ref, node); + AddOutput(ref_write_node->value, node); + } + + void VisitExpr_(const ConstructorNode* constructor_node) override {} + + class PatternAnnotator : public PatternVisitor { + public: + PatternAnnotator(Annotator* annotator, const ExprNode* adt_node) + : annotator_(annotator), adt_node_(adt_node) {} + + private: + void VisitPattern_(const PatternVarNode* pattern_var_node) final { + auto node = annotator_->graph_->item_to_node(pattern_var_node->var); + annotator_->AddOutput(GetRef(adt_node_), node); } - for (auto arg : op->args) { - this->VisitExpr(arg, graph_.node_map_[GetRef(op)]); + Annotator* annotator_; + const ExprNode* adt_node_; + }; + + void VisitExpr_(const MatchNode* match_node) override { + // Data flows from the match data to pattern vars into match arms and out into overall + // match. + auto node = graph_->item_to_node(GetRef(match_node)); + for (const Clause& c : match_node->clauses) { + PatternAnnotator pattern_annotator(this, match_node->data.get()); + pattern_annotator.VisitPattern(c->lhs); + AddOutput(c->rhs, node); } } - void VisitExpr_(const LetNode* op, NodePtr parent) override { - this->VisitExpr(op->value, graph_.node_map_[GetRef(op)]); - this->VisitExpr(op->var, graph_.node_map_[GetRef(op)]); - this->VisitExpr(op->body, graph_.node_map_[GetRef(op)]); - } + std::unique_ptr> graph_; + /*! \brief Accumulated calls to recursive functions. */ + std::unique_ptr> rec_calls_; + }; + + /*! \brief Fills in the basic blocks for all nodes. */ + class Blocker : public MixedModeVisitor { + public: + explicit Blocker(std::unique_ptr> graph) : graph_(std::move(graph)) {} - void VisitExpr_(const IfNode* op, NodePtr parent) override { - this->VisitExpr(op->cond, graph_.node_map_[GetRef(op)]); - this->VisitExpr(op->true_branch, graph_.node_map_[GetRef(op)]); - this->VisitExpr(op->false_branch, graph_.node_map_[GetRef(op)]); + std::unique_ptr> Scope(const Expr& expr) { + VisitExpr(expr); + return std::move(graph_); } - void VisitExpr_(const OpNode* op, NodePtr parent) override { return; } + private: + using MixedModeVisitor::VisitExpr_; - void VisitExpr_(const TupleGetItemNode* op, NodePtr parent) override { - this->VisitExpr(op->tuple, graph_.node_map_[GetRef(op)]); + void VisitLeaf(const Expr& expr) override { + MixedModeVisitor::VisitLeaf(expr); + SetScope(expr); } - void VisitExpr_(const RefCreateNode* op, NodePtr parent) override { - this->VisitExpr(op->value, graph_.node_map_[GetRef(op)]); + void VisitExpr_(const FunctionNode* function_node) override { + auto node = graph_->item_to_node(GetRef(function_node)); + basic_block_stack_.push_back(node); + ExprVisitor::VisitExpr_(function_node); + basic_block_stack_.pop_back(); } - void VisitExpr_(const RefReadNode* op, NodePtr parent) override { - this->VisitExpr(op->ref, graph_.node_map_[GetRef(op)]); + void VisitExpr_(const IfNode* if_node) override { + VisitExpr(if_node->cond); + auto node = graph_->item_to_node(GetRef(if_node)); + basic_block_stack_.push_back(node); + VisitExpr(if_node->true_branch); + VisitExpr(if_node->false_branch); + basic_block_stack_.pop_back(); } - void VisitExpr_(const RefWriteNode* op, NodePtr parent) override { - this->VisitExpr(op->ref, graph_.node_map_[GetRef(op)]); - this->VisitExpr(op->value, graph_.node_map_[GetRef(op)]); + void VisitExpr_(const LetNode* let_node) override { + auto pre_visit = [&](const LetNode* op) { + VisitExpr(op->value); + VisitExpr(op->var); + }; + auto post_visit = [&](const LetNode* op) { + VisitExpr(op->body); + if (let_node != op) { + visit_counter_[op]++; + SetScope(GetRef(op)); + } + }; + ExpandANormalForm(let_node, pre_visit, post_visit); } - void VisitExpr_(const ConstructorNode* op, NodePtr parent) override { - for (const Type& t : op->inputs) { - this->VisitType(t); + class PatternBlocker : public PatternVisitor { + public: + explicit PatternBlocker(Blocker* scoper) : scoper_(scoper) {} + + private: + void VisitPattern_(const PatternVarNode* pattern_var_node) final { + scoper_->SetScope(pattern_var_node->var); } - this->VisitType(op->belong_to); - } - void VisitExpr_(const MatchNode* op, NodePtr parent) override { - this->VisitExpr(op->data, graph_.node_map_[GetRef(op)]); - for (const Clause& c : op->clauses) { - this->VisitClause(c, graph_.node_map_[GetRef(op)]); + Blocker* scoper_; + }; + + void VisitExpr_(const MatchNode* match_node) override { + VisitExpr(match_node->data); + auto node = graph_->item_to_node(GetRef(match_node)); + basic_block_stack_.push_back(node); + for (const Clause& c : match_node->clauses) { + PatternBlocker pattern_scoper(this); + pattern_scoper.VisitPattern(c->lhs); + VisitExpr(c->rhs); } + basic_block_stack_.pop_back(); } - void VisitClause(const Clause& op, NodePtr parent) { - this->VisitPattern(op->lhs); - this->VisitExpr(op->rhs, parent); + void SetScope(const Expr& expr) { + auto node = graph_->item_to_node(expr); + if (!basic_block_stack_.empty()) { + node->basic_block_ = basic_block_stack_.back(); + } } - void VisitPattern(const Pattern& p) { return; } - - void VisitType(const Type& t) { return; } + std::unique_ptr> graph_; + std::vector::Node*> basic_block_stack_; }; - return Annotator(Creator().CreateGraph(expr)).Annotate(); + + VLOG(1) << "CreateIndexedGraph:" << std::endl << PrettyPrint(expr); + std::unique_ptr> graph = + Blocker(Annotator(Creator().CreateGraph(expr)).Annotate()).Scope(expr); + VLOG(1) << "graph:" << std::endl << graph->ToString(); +#if TVM_LOG_DEBUG + graph->CheckValid(); +#endif + return graph; } -IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { - using NodePtr = std::shared_ptr::Node>; - /*! \brief Creator Creates an IndexedGraph and determintes Toplogical order */ +std::unique_ptr> CreateIndexedGraph(const DFPattern& pattern) { + /*! \brief Creates an IndexedGraph and determines topological order */ class Creator : public DFPatternVisitor { public: - IndexedGraph CreateGraph(const DFPattern& pattern) { + std::unique_ptr> CreateGraph(const DFPattern& pattern) { + graph_ = std::make_unique>(); VisitDFPattern(pattern); - graph_.node_map_[pattern]->is_external_ = true; + graph_->item_to_node(pattern)->is_external_ = true; return std::move(graph_); } @@ -215,121 +413,135 @@ IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { void VisitDFPattern(const DFPattern& pattern) override { if (this->visited_.count(pattern.get()) == 0) { DFPatternVisitor::VisitDFPattern(pattern); - auto node = std::make_shared::Node>(pattern, index_++); - graph_.node_map_[pattern] = node; - graph_.topological_order_.push_back(node); + graph_->AddNode(pattern); } } - IndexedGraph graph_; - size_t index_ = 0; + + std::unique_ptr> graph_; }; + /*! \brief Annotator takes an IndexedGraph, fills it's forward outputs, and does domiantor tree * analysis. * * Annotator use ExprFunctor to visit nodes, but iterates over them in pre-determined * topological order instead of recursing. */ - class Annotator : public DFPatternFunctor { + class Annotator : public DFPatternFunctor { public: - Annotator(const IndexedGraph& graph) : graph_(graph) {} - IndexedGraph Annotate() { + Annotator(std::unique_ptr> graph) : graph_(std::move(graph)) {} + + std::unique_ptr> Annotate() { // Visit all of the nodes in topological order to get forward outputs - for (const auto& node : graph_.topological_order_) { - DFPatternFunctor::VisitDFPattern(node->ref_, nullptr); + for (PostDfsIndex index = 0; index < graph_->size(); ++index) { + VisitDFPattern(graph_->index_to_node(index)->ref()); } - graph_.PostDom(); // do the dominator analysis + graph_->PostDom(); return std::move(graph_); } /*! Default visitation pushes the parent to the child's outputs */ - void VisitDFPattern(const DFPattern& pattern, NodePtr parent) override { - auto current = graph_.node_map_[pattern]; + void AddOutput(const DFPattern& pattern, IndexedGraph::Node* parent) { + auto current = graph_->item_to_node(pattern); if (parent) { - current->outputs_.push_back(parent.get()); - parent->inputs_.push_back(current.get()); + current->outputs_.push_back(parent); + parent->inputs_.push_back(current); } } protected: - IndexedGraph graph_; - void VisitDFPattern_(const AltPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->left, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->right, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const AltPatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->left, node); + AddOutput(op->right, node); } - void VisitDFPattern_(const AttrPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const AttrPatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->pattern, node); } - void VisitDFPattern_(const CallPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->op, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const CallPatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->op, node); if (op->args.defined()) { for (auto arg : op->args) { - VisitDFPattern(arg, graph_.node_map_[GetRef(op)]); + AddOutput(arg, node); } } } - void VisitDFPattern_(const ConstantPatternNode* op, NodePtr parent) override {} + void VisitDFPattern_(const ConstantPatternNode* op) override {} - void VisitDFPattern_(const DataTypePatternNode* op, NodePtr parent) override { - VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const DataTypePatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->pattern, node); } - void VisitDFPattern_(const DominatorPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->parent, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->path, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->child, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const DominatorPatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->parent, node); + AddOutput(op->path, node); + AddOutput(op->child, node); } - void VisitDFPattern_(const ExprPatternNode* op, NodePtr parent) override {} + void VisitDFPattern_(const ExprPatternNode* op) override {} - void VisitDFPattern_(const FunctionPatternNode* op, NodePtr parent) override { + void VisitDFPattern_(const FunctionPatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); if (op->params.defined()) { for (auto param : op->params) { - VisitDFPattern(param, graph_.node_map_[GetRef(op)]); + AddOutput(param, node); } } - VisitDFPattern(op->body, graph_.node_map_[GetRef(op)]); + AddOutput(op->body, node); } - void VisitDFPattern_(const ShapePatternNode* op, NodePtr parent) override { - VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const ShapePatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->pattern, node); } - void VisitDFPattern_(const TupleGetItemPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->tuple, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const TupleGetItemPatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->tuple, node); } - void VisitDFPattern_(const TuplePatternNode* op, NodePtr parent) override { + void VisitDFPattern_(const TuplePatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); if (op->fields.defined()) { for (auto field : op->fields) { - VisitDFPattern(field, graph_.node_map_[GetRef(op)]); + AddOutput(field, node); } } } - void VisitDFPattern_(const IfPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->cond, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->true_branch, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->false_branch, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const IfPatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->cond, node); + AddOutput(op->true_branch, node); + AddOutput(op->false_branch, node); } - void VisitDFPattern_(const LetPatternNode* op, NodePtr parent) override { - VisitDFPattern(op->var, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->value, graph_.node_map_[GetRef(op)]); - VisitDFPattern(op->body, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const LetPatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->var, node); + AddOutput(op->value, node); + AddOutput(op->body, node); } - void VisitDFPattern_(const TypePatternNode* op, NodePtr parent) override { - VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); + void VisitDFPattern_(const TypePatternNode* op) override { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->pattern, node); } - void VisitDFPattern_(const VarPatternNode* op, NodePtr parent) override {} + void VisitDFPattern_(const VarPatternNode* op) override {} - void VisitDFPattern_(const WildcardPatternNode* op, NodePtr parent) override {} + void VisitDFPattern_(const WildcardPatternNode* op) override {} + + std::unique_ptr> graph_; }; + return Annotator(Creator().CreateGraph(pattern)).Annotate(); } diff --git a/src/relay/ir/indexed_graph.h b/src/relay/ir/indexed_graph.h index d073bcaeea5c..c1ce53f40da3 100644 --- a/src/relay/ir/indexed_graph.h +++ b/src/relay/ir/indexed_graph.h @@ -19,7 +19,12 @@ /*! * \file src/relay/ir/indexed_graph.h - * \brief A pattern matcher for matching dataflow properties. + * \brief A graph representation of the dataflow in a Relay expression or Relay (dataflow) + * pattern. Each 'indexed graph' node is 1:1 with an expression/pattern 'node', hence the + * term 'IndexedGraph'. Dataflow is captured in a generic representation which is convenient + * for analysis, particularly pattern matching and partitioning. + * + * TODO(mbs): Copied from fuse_ops.cc, consider refactoring to share implementation. */ #ifndef TVM_RELAY_IR_INDEXED_GRAPH_H_ #define TVM_RELAY_IR_INDEXED_GRAPH_H_ @@ -28,6 +33,7 @@ #include #include +#include #include #include #include @@ -36,47 +42,108 @@ namespace tvm { namespace relay { +/*! \brief The index of a node in the post-dfs traversal of overall expression. */ +using PostDfsIndex = size_t; + +/*! + * \brief Returns a brief summary of the 'reference' expression or pattern. Only used by + * IndexedGraph::ToString() for debugging. + */ +std::string RefToSummary(const Expr& expr); +std::string RefToSummary(const DFPattern& pattern); + /*! - * \brief A Wrapper around a templated graph type - * Holds a forward-backward indexed representation of the graph and a dominator tree representation - * of the graph + * \brief Represents the implied dataflow of an expression or (dataflow) pattern as a DAG who's + * nodes are 1:1 with those in the underlying expression/pattern. + * + * Each indexed graph node captures: + * - Dataflow inputs. + * - Dataflow outputs (or a flag indicating the node is an implied output). + * - Dominator parent (ie closest node at which all outputs of the current node re-combine). + * - Dominator children (inverse of above). + * - Basic block (ie node representing the body of a function, arm of an if, etc). * - * This class is templated and the implementaiton is in the header file so we can analyze both - * DFPattern and Expr with the same infrastructure. + * This class is templated so we can analyze both DFPatterns and Exprs with the same infrastructure. * - * IndexedGraph should be instantiated through the CreateIndexedGraph utilities. + * IndexedGraph should be instantiated through the CreateIndexedGraph utilities below. */ template class IndexedGraph { public: - /*! \brief A Node that wraps the input type and represents the indexed graph and dominator tree */ + using TNode = typename T::ContainerType; + + /*! \brief A Node in the graph. */ struct Node { /*! \brief Node Constructor - * \param ref The input graph node - * \param index The index of the node in toplogical order + * \param ref The expression or dataflow pattern node this indexed graph node is augmenting. + * \param index The index of this node in the topological order */ - Node(const T& ref, const size_t index) : ref_(ref), index_(index) {} + Node(const TNode* ref, PostDfsIndex index) : node_ref_(ref), index_(index) {} + + /*! \brief The underlying expression or pattern node. */ + const TNode* node_ref_; - /*! \brief The input node */ - const T ref_; - /*! \brief The topological order index */ - const size_t index_; + T ref() const { + ICHECK(node_ref_ != nullptr); + return GetRef(node_ref_); + } + + /*! + * \brief The index of this node in post-dfs order. If left.index_ > right.index_ then + * left does not flow into right. If left.index_ = right.index_ then left and right are + * the same node. + */ + const PostDfsIndex index_; - /*! \brief A boolean to determine if this node is external to the graph */ + /*! \brief If true this node has implicit outputs, for example as the result of a function. */ bool is_external_ = false; - /*! \brief The forward inputs of the node */ + /*! \brief Immediate dataflow inputs to this node. */ std::vector inputs_; - /*! \brief The forward outputs/users of the node */ + /*! \brief Immediate dataflow outputs of this node -- may be empty if is_external_ is true. */ std::vector outputs_; - /*! \brief The depth of the node in the dominator tree */ + /*! + * \brief The node representing the 'basic block' containing this node: + * - Function bodies start a new basic block for their bodies. + * - The true and false branches of an if start their own blocks. + * - The arms of a match each have their own blocks. + */ + Node* basic_block_ = nullptr; + + /*! \brief The depth of this node in the dominator tree */ size_t depth_ = 0; - /*! \brief The dominator parent/final user of the outputs of this node */ - Node* dominator_parent_; - /*! \brief The nodes this node dominates */ + /*! + * \brief The dominator parent of this node. This is the node N with least index such that + * all possible dataflows from this node pass through N. + */ + Node* dominator_parent_ = nullptr; + /*! \brief The nodes this node dominates. */ std::vector dominator_children_; - bool Dominates(const Node* other) { + /*! + * Add to \p nodes all the nodes which are strictly downstream of \p this, ie can be + * reached by following output paths. + */ + void AccumulateDownstreamNodes(std::unordered_set* nodes) const { + std::stack stack; + stack.push(this); + while (!stack.empty()) { + const Node* current = stack.top(); + stack.pop(); + for (auto node : current->outputs_) { + if (nodes->count(node) == 0) { + stack.push(node); + nodes->insert(node); + } + } + } + } + + /*! + * \brief Returns true if \p this is a dominator of \p other. Ie all dataflow paths from \p + * other pass through \p this. + */ + bool Dominates(const Node* other) const { std::stack stack; std::unordered_set visited; stack.push(this); @@ -97,10 +164,125 @@ class IndexedGraph { return false; } }; + + PostDfsIndex size() const { return topological_order_.size(); } + + Node* item_to_node(const T& item) { return item_to_node(item.get()); } + const Node* item_to_node(const T& item) const { return item_to_node(item.get()); } + + Node* item_to_node(const TNode* item) { + auto itr = node_map_.find(item); + ICHECK(itr != node_map_.end()) << PrettyPrint(GetRef(item)); + return itr->second; + } + + const Node* item_to_node(const TNode* item) const { + auto itr = node_map_.find(item); + ICHECK(itr != node_map_.end()) << PrettyPrint(GetRef(item)); + return itr->second; + } + + Node* index_to_node(PostDfsIndex index) { + ICHECK_LT(index, topological_order_.size()) << index; + return topological_order_[index].get(); + } + + const Node* index_to_node(PostDfsIndex index) const { + ICHECK_LT(index, topological_order_.size()) << index; + return topological_order_[index].get(); + } + + /*! + * \brief (For debugging only) Returns description of indexed graph with hints as to the + * sub-expressions or sub-patterns corresponding to each indexed graph node. + */ + std::string ToString() const { + std::ostringstream os; + os << "IndexedGraph(size = " << topological_order_.size() << ") {" << std::endl; + for (PostDfsIndex index = 0; index < topological_order_.size(); ++index) { + const Node* node = topological_order_[index].get(); + ICHECK_EQ(index, node->index_); + os << " " << index << " (" << RefToSummary(node->ref()) << "): inputs=["; + for (const auto* sub_node : node->inputs_) { + os << sub_node->index_ << ","; + } + os << "], outputs=["; + for (const auto* sub_node : node->outputs_) { + os << sub_node->index_ << ","; + } + os << "]"; + if (node->is_external_) { + os << ", external"; + } + if (node->basic_block_) { + os << ", basic_block=" << node->basic_block_->index_; + } + if (node->depth_ > 0) { + os << ", depth=" << node->depth_; + } + if (node->dominator_parent_) { + os << ", dom_parent=" << node->dominator_parent_->index_; + } + os << ", dom_children=["; + for (const auto* sub_node : node->dominator_children_) { + os << sub_node->index_ << ","; + } + os << "]" << std::endl; + } + os << "}"; + return os.str(); + } + + /*! + * Check-fails if the graph is ill-formed. For debugging only. + */ + void CheckValid() const { + ICHECK_GT(topological_order_.size(), 0); + for (PostDfsIndex index = 0; index < topological_order_.size(); ++index) { + const Node* node = topological_order_[index].get(); + // We have a node. + ICHECK(node); + // Bijections with post-dfs indexes and expressions/patterns are correct. + ICHECK_EQ(node->index_, index); + ICHECK(node->node_ref_); + auto itr = node_map_.find(node->node_ref_); + ICHECK(itr != node_map_.end()); + ICHECK_EQ(itr->second, node) << "at index " << index << " in:" << std::endl << ToString(); + // Inputs come before. + for (size_t i = 0; i < node->inputs_.size(); ++i) { + const Node* input = node->inputs_[i]; + ICHECK(input); + ICHECK_LT(input->index_, index); + ICHECK(std::find(input->outputs_.begin(), input->outputs_.end(), node) != + input->outputs_.end()); + } + // Outputs come after. + for (size_t i = 0; i < node->outputs_.size(); ++i) { + const Node* output = node->outputs_[i]; + ICHECK(output); + ICHECK_GT(output->index_, index); + ICHECK(std::find(output->inputs_.begin(), output->inputs_.end(), node) != + output->inputs_.end()); + } + ICHECK_GT(node->depth_, 0); + // Dominator children come before. + for (size_t i = 0; i < node->dominator_children_.size(); ++i) { + const Node* child = node->dominator_children_[i]; + ICHECK(child); + ICHECK_LT(child->index_, index); + } + if (node->dominator_parent_) { + // Dominator comes after. + ICHECK_GT(node->dominator_parent_->index_, index); + } + } + } + + private: /*! \brief Construct the domination tree inside IndexedGraph */ void PostDom() { - for (size_t i = topological_order_.size(); i != 0; --i) { - size_t index = i - 1; + for (PostDfsIndex i = topological_order_.size(); i != 0; --i) { + PostDfsIndex index = i - 1; auto* current = topological_order_[index].get(); if (current->is_external_) { current->depth_ = 1; @@ -109,16 +291,13 @@ class IndexedGraph { auto parent = LeastCommonAncestor(current->outputs_); current->depth_ = parent ? parent->depth_ + 1 : 1; current->dominator_parent_ = parent; - parent->dominator_children_.push_back(current); + if (parent) { + parent->dominator_children_.push_back(current); + } } } } - /*! \brief Map of input nodes to IndexedGraph Nodes */ - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> node_map_; - /*! \brief Topological IndexedGraph Nodes */ - std::vector> topological_order_; - protected: /*! \brief Find the least common ancestor of all outputs of a node */ Node* LeastCommonAncestor(const std::vector& outputs) { if (outputs.size() == 0) { @@ -136,9 +315,11 @@ class IndexedGraph { if (lhs == nullptr || rhs == nullptr) { return nullptr; } + PostDfsIndex lhs_index = lhs->index_; + PostDfsIndex rhs_index = rhs->index_; while (lhs != rhs) { - ICHECK(lhs); - ICHECK(rhs); + ICHECK(lhs && rhs) << "LCA(" << lhs_index << ", " << rhs_index << ") on graph:" << std::endl + << ToString(); if (lhs->depth_ < rhs->depth_) { rhs = rhs->dominator_parent_; } else if (lhs->depth_ > rhs->depth_) { @@ -150,13 +331,41 @@ class IndexedGraph { } return lhs; } + + /*! + * \brief Appends a node corresponding to \p ref, and maintains the sub-expression/sub-pattern to + * node bijection. The insertion index will be the node's PostDfsIndex. All other node properties + * are accumulated in-place. + */ + void AddNode(const T& ref) { + PostDfsIndex index = topological_order_.size(); + auto node = std::make_unique(ref.get(), index); + node_map_[ref.get()] = node.get(); + topological_order_.emplace_back(std::move(node)); + } + + /*! + * \brief Map from underlying sub-expression or sub-pattern nodes to their indexed graph nodes. + */ + std::unordered_map node_map_; + /*! \brief All nodes in increasing post-dfs index order. This vector owns all the nodes. */ + std::vector> topological_order_; + + friend std::unique_ptr> CreateIndexedGraph(const Expr& expr); + friend std::unique_ptr> CreateIndexedGraph(const DFPattern& pattern); }; -/*! \brief Create an Indexed Graph based on an Expr */ -IndexedGraph CreateIndexedGraph(const Expr& expr); -/*! \brief Create an Indexed Graph based on an DFPattern */ -IndexedGraph CreateIndexedGraph(const DFPattern& pattern); +/*! \brief Returns an Indexed Graph for \p expr, which much outlive the result. */ +std::unique_ptr> CreateIndexedGraph(const Expr& expr); + +/*! + * \brief Returns an Indexed Graph for \p pattern, which must outlive the result. + * The dataflow for a pattern mimics the dataflow for the expression which would match + * that pattern. + */ +std::unique_ptr> CreateIndexedGraph(const DFPattern& pattern); } // namespace relay } // namespace tvm + #endif // TVM_RELAY_IR_INDEXED_GRAPH_H_ diff --git a/src/relay/op/dyn/tensor/transform.cc b/src/relay/op/dyn/tensor/transform.cc index f7045305e90d..d5cc6608662b 100644 --- a/src/relay/op/dyn/tensor/transform.cc +++ b/src/relay/op/dyn/tensor/transform.cc @@ -258,6 +258,7 @@ RELAY_REGISTER_OP("dyn.broadcast_to") .describe(R"code(Broadcast the first input to match the shape argument. )code" TVM_ADD_FILELINE) .set_num_inputs(2) + .set_attrs_type() .add_argument("data", "Tensor", "The input tensor.") .add_argument("shape", "Tensor", "Target shape.") .set_support_level(4) diff --git a/tests/cpp/relay/ir/indexed_graph_test.cc b/tests/cpp/relay/ir/indexed_graph_test.cc new file mode 100644 index 000000000000..17ec68261684 --- /dev/null +++ b/tests/cpp/relay/ir/indexed_graph_test.cc @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "../../../src/relay/ir/indexed_graph.h" + +#include +#include +#include +#include + +namespace tvm { +namespace relay { +namespace { + +// A module stolen from onnx/test_forward.py::test_loop which combines functions, recursion, +// control flow, tuples as well as the usual operator calls. +// We include the known post-dfs indexes in comments to help write the tests. +IRModule TestRecursiveIRModule() { + Device device = {kDLCPU, 0}; + Constant const0(runtime::NDArray::Empty(ShapeTuple({1}), DataType::Int(64), device)); + Constant const1(runtime::NDArray::Empty(ShapeTuple({0, 1}), DataType::Float(32), device)); + Map> metadata; + metadata.Set("relay.Constant", Array({const0, const1})); + constexpr const char* kModel = R"( + #[version = "0.0.5"] + def @main(%trip_count: int64, // 0 + %cond: bool, // 1 + %y: Tensor[(1), float32]) // 2 + -> (Tensor[(1), float32], Tensor[(?, ?), float32]) { + %17 = ( + let %while_loop = fn (%iter_count: int64, // 3 + %max_count: int64, // 4 + %cond_in: bool, // 5 + %y_in: Tensor[(1), float32], // 6 + %scan_out: Tensor[(?, ?), float32]) // 7 + -> (int64, int64, bool, Tensor[(1), float32], Tensor[(?, ?), float32]) { + %0 = equal(%cond_in, True); // 11 + %1 = less(%iter_count, %max_count); // 13 + %2 = logical_and(%0, %1); // 14 + if (%2) { + %3 = cast(%iter_count, dtype="float32"); // 20 + %4 = add(%y_in, %3); // 21 + %5 = less(%4, 5f); // 23 + %6 = squeeze(%5); // 24 + %7 = reshape(%iter_count, newshape=[1]); // 29 + %8 = (%7, meta[relay.Constant][0]); // 31 + %9 = concatenate(%8); // 32 + %10 = copy(%4); // 36 + %11 = dyn.broadcast_to(%scan_out, %9, shape=None); // 33 + %12 = expand_dims(%10, axis=0); // 37 + %13 = (%11, %12); // 38 + %14 = add(%iter_count, 1i64); // 17 + %15 = cast(%6, dtype="bool"); // 25 + %16 = concatenate(%13); // 39 + %while_loop(%14, %max_count, %15, %4, %16) // 40 + } else { + (%iter_count, %max_count, %cond_in, %y_in, %scan_out) // 41 + } // 42 + }; // 43 + %while_loop // 44 + ); // 45 + %18 = %17(0i64, %trip_count, %cond, %y, meta[relay.Constant][1]); // 48 + %19 = %18.3; // 49 + %20 = %18.4; // 50 + (%19, %20) // 51 + } // 52 + )"; + return parser::ParseModule("string", kModel, /*init_module=*/{}, metadata); +} + +TEST(IndexedGraph, RecursiveExprRegression) { + IRModule ir_mod = TestRecursiveIRModule(); + auto main = Downcast(ir_mod->Lookup("main")); + auto graph = CreateIndexedGraph(main); + graph->CheckValid(); + + { + // Dataflow node properties for %4 + auto node = graph->index_to_node(21); + const auto* call_node = node->ref().as(); + ASSERT_NE(call_node, nullptr); + const auto* op_node = call_node->op.as(); + ASSERT_NE(op_node, nullptr); + ASSERT_EQ(op_node->name, "add"); + + // 3 inputs (the op itself is an input) + ASSERT_EQ(node->inputs_.size(), 3); + ASSERT_EQ(node->inputs_[0]->index_, 15); // the add op + ASSERT_EQ(node->inputs_[1]->index_, 6); // %y_in + ASSERT_EQ(node->inputs_[2]->index_, 20); // %3 + + // 3 outputs + ASSERT_EQ(node->outputs_.size(), 3); + ASSERT_EQ(node->outputs_[0]->index_, 23); // %5 + ASSERT_EQ(node->outputs_[1]->index_, 36); // %10 + ASSERT_EQ(node->outputs_[2]->index_, 40); // recursive %while_loop call + + // In the 'if' basic block + ASSERT_EQ(node->basic_block_->index_, 42); + + // Dominator 'parent' is recursive call + ASSERT_EQ(node->dominator_parent_->index_, 40); + + // One dominator child from %3 + ASSERT_EQ(node->dominator_children_.size(), 1); + ASSERT_EQ(node->dominator_children_[0]->index_, 20); + } + + { + // The recursive call to %while_loop does not depend on %while_loop + auto node = graph->index_to_node(40); + const auto* call_node = node->ref().as(); + ASSERT_NE(call_node, nullptr); + const auto* var_node = call_node->op.as(); + ASSERT_NE(var_node, nullptr); + ASSERT_EQ(var_node->name_hint(), "while_loop"); + + ASSERT_EQ(node->inputs_.size(), 5); + ASSERT_EQ(node->inputs_[0]->index_, 17); // %14 + ASSERT_EQ(node->inputs_[1]->index_, 4); // %max_count + ASSERT_EQ(node->inputs_[2]->index_, 25); // %15 + ASSERT_EQ(node->inputs_[3]->index_, 21); // %4 + ASSERT_EQ(node->inputs_[4]->index_, 39); // %16 + } + + { + // Downstream nodes of %18 + auto node = graph->index_to_node(48); + std::unordered_set::Node*> downstreams; + node->AccumulateDownstreamNodes(&downstreams); + ASSERT_EQ(downstreams.size(), 4); + for (const auto* downstream : downstreams) { + ASSERT_TRUE(downstream->index_ >= 49 && downstream->index_ <= 52); + } + } + + { + // Dominates relation for %4 + auto upstream = graph->index_to_node(21); + // Path 1: 21->23->24->25->40 + // Path 2: 21->36->37->38->39->40 + // Then 40->43 + auto downstream = graph->index_to_node(43); + ASSERT_TRUE(downstream->Dominates(upstream)); + } +} + +// A module with unused let-bound function. The 'add' operator should have no dominator +// since it is used both in the unused function and in the main body. +IRModule TestUnusedLetBoundIRModule() { + constexpr const char* kModel = R"( + #[version = "0.0.5"] + def @main(%x: int64) -> int64 { // 0 + let %f = fn ( // 5 + %y: int64 // 1 + ) { + add(%x, %y) // 3 + }; + if (less(%x, 5i64)) { + add(%x, 3i64) // 10 + } else { + %x + } + } + )"; + return parser::ParseModule("string", kModel); +} + +TEST(IndexedGraph, UnusedLetVars) { + IRModule ir_mod = TestUnusedLetBoundIRModule(); + auto main = Downcast(ir_mod->Lookup("main")); + auto graph = CreateIndexedGraph(main); + graph->CheckValid(); + + { + auto node = graph->index_to_node(2); + const auto* op_node = node->ref().as(); + ICHECK(op_node); + ICHECK_EQ(op_node->name, "add"); + ICHECK_EQ(node->outputs_.size(), 2); + ICHECK_EQ(node->outputs_[0]->index_, 3); + ICHECK_EQ(node->outputs_[1]->index_, 10); + ICHECK(node->dominator_parent_ == nullptr); + } +} + +} // namespace +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index 74e03f6a9755..f0474c911273 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -16,7 +16,6 @@ # under the License. # pylint: disable=unused-wildcard-import import numpy as np -import pytest import tvm from tvm import relay @@ -601,6 +600,38 @@ def test_match_fake_diamond(): assert not diamond.match(out) +def test_at_most_one_parent(): + # Pattern + P = is_op("nn.conv2d")(wildcard(), wildcard()) # 'parent' + I = is_op("nn.relu")(wildcard()) # 'intermediate' ('path' in the code) + C = is_op("add")(wildcard(), wildcard()) # 'child' + pattern = dominates(P, I, C) + + # n6(P) + # / \ + # n7 \ + # / \ + # n8(P) n10(I) + # \ / + # n9(I) / + # \ / + # n11(C) + + x = relay.var("x") + w = relay.var("w") + n6 = relay.op.nn.conv2d(x, w) # matches P + n7 = relay.op.tanh(n6) # does not match I + n8 = relay.op.nn.conv2d(n7, w) # matches P + n9 = relay.op.nn.relu(n8) # matches I + n10 = relay.op.nn.relu(n6) # matches I + n11 = relay.add(n9, n10) # matches C + + # Does not match: Can't match the parent pattern P at both 8 and 6. + # Note that if we did allow P to be used twice the implementation would + # need to be changed to not 'jump over' n7. + assert not pattern.match(n11) + + def test_match_dominator(): # Pattern is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()) @@ -1760,4 +1791,4 @@ def callback(self, pre, post, node_map): if __name__ == "__main__": - pytest.main([__file__]) + tvm.testing.main()