Skip to content

Commit

Permalink
[Relay] IndexedGraph improvements in preparation for Collage (#11481)
Browse files Browse the repository at this point in the history
* [Relay] Odd's 'n ends changes to help Collage.
 - Complete the implementation of WithFields.
   (Unfortunately they appear to be without unit tests and I continue this tradition...)
 - InferTypeExpr for InferTypeLocal but return the expression rather than the type.
 - Remove python binding of InlineComposites since C++ impl was removed some time ago.
 - Make IndexedGraph<Expr/DFPattern> more robust as stand-alone datastructure, and avoid unnecessary copies.
   This will become a fundamental datastructure in Collage rather than just a helper for DFPatternMatcher.
 - Extend IndexedGraph with a notion of 'basic block' on every dataflow node. Needed by Collage to
   avoid impossible partitions.

* - Revert non IndexedGraph changes.

* - Stick to 'Indexed graph' terminology
- More tests

* - Stick to 'Indexed graph' terminology
- More tests

* - Remove silly unit test
  • Loading branch information
mbs-octoml authored Jun 7, 2022
1 parent 8170219 commit d8f57ed
Show file tree
Hide file tree
Showing 7 changed files with 922 additions and 237 deletions.
90 changes: 52 additions & 38 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ namespace relay {

// Pattern Matcher
bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
VLOG(1) << "Match " << PrettyPrint(pattern) << " in:" << std::endl << PrettyPrint(expr);
memo_.clear();
matched_nodes_.clear();
return VisitDFPattern(pattern, expr);
Expand All @@ -58,6 +59,7 @@ bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr
if (out) {
memo_[pattern].push_back(expr);
matched_nodes_.push_back(pattern);
VLOG(1) << "Matched " << PrettyPrint(pattern) << " at:" << std::endl << PrettyPrint(expr);
} else {
ClearMap(watermark);
}
Expand Down Expand Up @@ -124,7 +126,6 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons
if (!matches) {
return matches;
}
VLOG(1) << "considering AttrPatternNode at:\n" << PrettyPrint(expr);
auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
if (const auto* op_node = expr.as<OpNode>()) {
Op op = GetRef<Op>(op_node);
Expand Down Expand Up @@ -299,14 +300,18 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex
// Recursively find the Dominator parent along all inputs paths.
bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) {
auto call_node = expr.as<CallNode>();
for (auto node : expr_graph_.node_map_.at(expr)->inputs_) {
if (!(call_node && node->ref_ == call_node->op)) {
auto index_node = expr_to_node(expr);
for (auto node : index_node->inputs_) {
if (!(call_node && node->ref() == call_node->op)) {
memoize_ = true;
if (VisitDFPattern(op->parent, node->ref_)) {
if (VisitDFPattern(op->parent, node->ref())) {
return true;
} else {
memoize_ = false;
if (!VisitDFPattern(op->path, node->ref_) || !MatchesPath(op, node->ref_)) {
if (!VisitDFPattern(op->path, node->ref())) {
return false;
}
if (!MatchesPath(op, node->ref())) {
return false;
}
}
Expand All @@ -318,19 +323,19 @@ bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& e
// Iteratively ensure that the parent is dominated somewhere by the child or the path
bool DFPatternMatcher::DominatesParent(const DominatorPatternNode* op, const Expr& expr) {
std::stack<Expr> stack;
std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> visited;
std::unordered_set<const ExprNode*> visited;
stack.push(expr);
while (!stack.empty()) {
Expr current = stack.top();
stack.pop();
for (auto node : expr_graph_.node_map_.at(current)->dominator_children_) {
if (visited.count(node->ref_) == 0) {
if (VisitDFPattern(op->parent, node->ref_)) {
for (auto node : expr_to_node(current)->dominator_children_) {
if (visited.count(node->node_ref_) == 0) {
if (VisitDFPattern(op->parent, node->ref())) {
return true;
} else {
stack.push(node->ref_);
stack.push(node->ref());
}
visited.insert(node->ref_);
visited.insert(node->node_ref_);
}
}
}
Expand Down Expand Up @@ -500,7 +505,8 @@ bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr
}

bool MatchPattern(DFPattern pattern, Expr expr) {
return DFPatternMatcher(expr).Match(pattern, expr);
std::unique_ptr<IndexedGraph<Expr>> expr_graph = CreateIndexedGraph(expr);
return DFPatternMatcher(expr_graph.get()).Match(pattern, expr);
}

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.match").set_body_typed(MatchPattern);
Expand Down Expand Up @@ -575,17 +581,18 @@ const std::unordered_map<int, PatternGrouper::Group>& PatternGrouper::GroupMatch

pattern_ = pattern;
pattern_graph_ = CreateIndexedGraph(pattern_);
auto matcher = DFPatternMatcher(pre);
std::unique_ptr<IndexedGraph<Expr>> expr_graph = CreateIndexedGraph(pre);
DFPatternMatcher matcher(expr_graph.get());
matcher_ = &matcher;
this->VisitExprs();
return this->groups_;
}

void PatternGrouper::VisitExprs() {
std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> pre_partitioned;
for (size_t i = matcher_->expr_graph_.topological_order_.size(); i != 0; --i) {
size_t index = i - 1;
Expr current = matcher_->expr_graph_.topological_order_.at(index)->ref_;
for (PostDfsIndex i = matcher_->size(); i != 0; --i) {
PostDfsIndex index = i - 1;
const auto current = matcher_->index_to_node(index)->ref();
if (gid_assignments_.count(current) == 0) { // Don't visit nodes we've already grouped
if (auto op = current.as<FunctionNode>()) {
if (op->attrs.defined() && op->attrs->dict.count(attr::kPartitionedFromPattern) != 0) {
Expand All @@ -607,22 +614,24 @@ void PatternGrouper::CreateGroup(const Expr& expr) {
auto node_map = matcher_->GetMemo();
// Get fuzzy patterns
std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> fuzzy_matches;
for (auto node : pattern_graph_.topological_order_) {
for (PostDfsIndex index = 0; index < pattern_graph_->size(); ++index) {
auto node = pattern_graph_->index_to_node(index);
// Don't treat fuzzy Dominator patterns input variables for partition
if (auto op = node->ref_.as<DominatorPatternNode>()) {
if (auto op = node->ref().as<DominatorPatternNode>()) {
for (auto fuzzy_op : {op->parent, op->path}) {
for (auto match : node_map[fuzzy_op]) {
fuzzy_matches.insert(match);
}
}
}
// Don't treat Function params or body as input variables for partition
if (node->ref_.as<FunctionPatternNode>()) {
auto matches = node_map[node->ref_];
if (node->ref().as<FunctionPatternNode>()) {
auto matches = node_map[node->ref()];
for (auto match : matches) {
auto graph = CreateIndexedGraph(match.as<FunctionNode>()->body);
for (auto node : graph.topological_order_) {
fuzzy_matches.insert(node->ref_);
auto sub_graph = CreateIndexedGraph(match.as<FunctionNode>()->body);
for (PostDfsIndex sub_index = 0; sub_index < sub_graph->size(); ++sub_index) {
auto sub_node = sub_graph->index_to_node(sub_index);
fuzzy_matches.insert(sub_node->ref());
}
}
}
Expand All @@ -636,10 +645,11 @@ void PatternGrouper::CreateGroup(const Expr& expr) {
std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual> inputs;
Array<Var> params;

for (auto node : pattern_graph_.topological_order_) {
for (PostDfsIndex index = 0; index < pattern_graph_->size(); ++index) {
auto node = pattern_graph_->index_to_node(index);
auto make_input = [&](const Expr& input) {
if (fuzzy_matches.count(input) == 0 && input.as<OpNode>() == nullptr &&
input.as<FunctionNode>() == nullptr && !EmbedConst(input, node->ref_)) {
input.as<FunctionNode>() == nullptr && !EmbedConst(input, node->ref())) {
inputs[input] =
Var("FunctionVar_" + std::to_string(graph_number_) + "_" + std::to_string(var_number),
NullValue<Type>());
Expand All @@ -648,29 +658,29 @@ void PatternGrouper::CreateGroup(const Expr& expr) {
var_number++;
}
};
auto tuple = node->ref_.as<TuplePatternNode>();
auto call = node->ref_.as<CallPatternNode>();
auto tuple = node->ref().as<TuplePatternNode>();
auto call = node->ref().as<CallPatternNode>();
if (tuple && !tuple->fields.defined()) {
if (node_map.count(node->ref_)) {
auto matches = node_map[node->ref_];
if (node_map.count(node->ref())) {
auto matches = node_map[node->ref()];
for (auto match : matches) {
for (auto input : match.as<TupleNode>()->fields) {
make_input(input);
}
}
}
} else if (call && !call->args.defined()) {
if (node_map.count(node->ref_)) {
auto matches = node_map[node->ref_];
if (node_map.count(node->ref())) {
auto matches = node_map[node->ref()];
for (auto match : matches) {
for (auto input : match.as<CallNode>()->args) {
make_input(input);
}
}
}
} else if (node->inputs_.size() == 0) {
if (node_map.count(node->ref_)) {
auto matches = node_map[node->ref_];
if (node_map.count(node->ref())) {
auto matches = node_map[node->ref()];
for (auto match : matches) {
make_input(match);
}
Expand Down Expand Up @@ -708,13 +718,17 @@ void PatternGrouper::CreateGroup(const Expr& expr) {
return;
} else if (kv.second != body) {
// if the node isn't the output of the group
auto node = matcher_->expr_graph_.node_map_.at(kv.first);
auto node = matcher_->expr_to_node(kv.first);
for (auto* output : node->outputs_) {
// and the node is used by nodes outside of the group
if (memo.count(output->ref_) == 0 &&
!matcher_->expr_graph_.node_map_.at(expr)->Dominates(output)) {
// Exit because nodes in this pattern's body are used outside the pattern
// fusing it would be invalid
if (memo.count(output->ref()) == 0) {
// TODO(mbs): This condition used to also include the following test, which since
// the dominators relation is used back-to-front was always vacuously true. So the
// code is just rejecting the match if a strictly internal node happened to connect
// to an outside node.
ICHECK(!matcher_->expr_to_node(expr)->Dominates(output));
// Exit because nodes in this pattern's body are used outside the pattern, fusing it
// would be invalid
return;
}
}
Expand Down
19 changes: 16 additions & 3 deletions src/relay/ir/dataflow_matcher_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
#include <tvm/relay/dataflow_matcher.h>
#include <tvm/relay/dataflow_pattern.h>
#include <tvm/relay/dataflow_pattern_functor.h>
#include <tvm/relay/expr_functor.h>

#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
Expand All @@ -39,10 +41,20 @@ namespace relay {

class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
public:
explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {}
explicit DFPatternMatcher(const IndexedGraph<Expr>* expr_graph) : expr_graph_(expr_graph) {}
bool Match(const DFPattern& pattern, const Expr& expr);
Map<DFPattern, Array<Expr>> GetMemo() { return Map<DFPattern, Array<Expr>>(memo_); }
const IndexedGraph<Expr> expr_graph_;

const IndexedGraph<Expr>::Node* expr_to_node(const Expr& expr) const {
return expr_graph_->item_to_node(expr);
}
const IndexedGraph<Expr>::Node* index_to_node(size_t index) const {
return expr_graph_->index_to_node(index);
}
size_t size() const { return expr_graph_->size(); }
const std::unordered_map<DFPattern, Array<Expr>, ObjectPtrHash, ObjectPtrEqual>& memo() const {
return memo_;
}

protected:
bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
Expand All @@ -67,6 +79,7 @@ class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Ex
bool MatchesPath(const DominatorPatternNode* op, const Expr& expr);
bool DominatesParent(const DominatorPatternNode* op, const Expr& expr);

const IndexedGraph<Expr>* expr_graph_;
std::unordered_map<DFPattern, Array<Expr>, ObjectPtrHash, ObjectPtrEqual> memo_;
std::vector<DFPattern> matched_nodes_;
bool memoize_ = true;
Expand Down Expand Up @@ -131,7 +144,7 @@ class PatternGrouper {
std::unordered_map<int, Group> groups_;
std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual> gid_assignments_;
DFPatternMatcher* matcher_ = nullptr;
IndexedGraph<DFPattern> pattern_graph_;
std::unique_ptr<IndexedGraph<DFPattern>> pattern_graph_;
int gid_ = 0;
int graph_number_ = 0;
};
Expand Down
Loading

0 comments on commit d8f57ed

Please sign in to comment.