diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index df896cb690eb..b2776a41c50c 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -609,6 +609,8 @@ void PatternGrouper::VisitExprs() { } void PatternGrouper::CreateGroup(const Expr& expr) { + VLOG(1) << "Creating group for:" << std::endl << PrettyPrint(expr); + int var_number = 0; auto node_map = matcher_->GetMemo(); @@ -696,6 +698,7 @@ void PatternGrouper::CreateGroup(const Expr& expr) { auto body = extractor.Mutate(expr); group.function = Function(params, body, NullValue(), Array()); + VLOG(1) << "Candidate extracted function:" << std::endl << PrettyPrint(group.function); group.name = extractor.GetName(); // Check to make sure we aren't overlapping with another group or creating an invalid fusion // The MatchExtractor will create a new graph by replacing nodes that match the inputs of the @@ -708,6 +711,10 @@ void PatternGrouper::CreateGroup(const Expr& expr) { // Similiarly, if interior nodes in a group are used outside of the group fusing to a single // output would create an invalid graph tranformation, so we block the creation of such groups. auto memo = extractor.GetMemo(); + for (auto kv : memo) { + VLOG(1) << "matched index " << matcher_->expr_to_node(kv.first)->index_; + } + for (auto kv : memo) { // Check to ensure that this node isn't an input or a global if (inputs.count(kv.first) == 0 && kv.first.as() == nullptr && @@ -720,16 +727,19 @@ void PatternGrouper::CreateGroup(const Expr& expr) { // if the node isn't the output of the group auto node = matcher_->expr_to_node(kv.first); for (auto* output : node->outputs_) { - // and the node is used by nodes outside of the group if (memo.count(output->ref()) == 0) { - // TODO(mbs): This condition used to also include the following test, which since - // the dominators relation is used back-to-front was always vacuously true. So the - // code is just rejecting the match if a strictly internal node happened to connect - // to an outside node. - ICHECK(!matcher_->expr_to_node(expr)->Dominates(output)); - // Exit because nodes in this pattern's body are used outside the pattern, fusing it - // would be invalid - return; + // A node inside the matched group contributes an output to nodes outside of the matched + // group... + auto root = matcher_->expr_to_node(expr); + if (!root->Dominates(output)) { + // ...and the outside dataflow does not come back to the root of the matched group. + // So reject the match since it would create a cycle. + VLOG(1) << "Rejecting group since would create a cycle with output " << output->index_ + << " for root " << root->index_ << " in graph:" << std::endl + << matcher_->expr_graph().ToString(); + return; + } + // else: We'll allow the output to be included in the matched group. } } } diff --git a/src/relay/ir/dataflow_matcher_impl.h b/src/relay/ir/dataflow_matcher_impl.h index f04190f72e40..a174d8e34eb7 100644 --- a/src/relay/ir/dataflow_matcher_impl.h +++ b/src/relay/ir/dataflow_matcher_impl.h @@ -55,6 +55,7 @@ class DFPatternMatcher : public DFPatternFunctor, ObjectPtrHash, ObjectPtrEqual>& memo() const { return memo_; } + const IndexedGraph& expr_graph() const { return *expr_graph_; } protected: bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;