From bd0240689e7fd18d71d59229f15a63e61881f17c Mon Sep 17 00:00:00 2001 From: Andrey Malyshev Date: Wed, 8 Jan 2025 15:19:57 +0200 Subject: [PATCH] Disable InferType if it was done and no changes after previous pass This optimizatin allows to speedup PatternRewriter transformations by reusing of preious type inferred expression instead of perform InferType multiple times --- src/relay/ir/dataflow_matcher.cc | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 3e86e1c8eaf9..9d117adbbcaf 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -851,24 +851,32 @@ Expr PatternRewriter::Rewrite(const Array& callbacks, const E std::unordered_map done; do { last = post; + // We don't have to call InferType if previous pass has not modified anything + // We can just take previous typed state of the expression + bool types_invalidated = true; for (auto callback : callbacks) { if (!done[callback]) { auto before = post; + auto post_typed = post; callback_ = callback; - if (callback_->require_type) { - post = InferTypeWithModule(post, mod_); + if (callback_->require_type && types_invalidated) { + post_typed = InferTypeWithModule(post, mod_); } auto grouper = PatternGrouper(); - groups_ = grouper.GroupMatches(callback_->pattern, post); + groups_ = grouper.GroupMatches(callback_->pattern, post_typed); gid_assignments_ = grouper.GetGIDAssignments(); memo_.clear(); VLOG(1) << "pre rewritten:" << std::endl << PrettyPrint(pre); - post = this->VisitExpr(post); + post = this->VisitExpr(post_typed); VLOG(1) << "post rewritten:" << std::endl << PrettyPrint(post); count++; - if (callback_->rewrite_once) { - bool current_equal = (*structural_equal)(before, post, false, true); - if (!current_equal) { + bool current_equal = (*structural_equal)(before, post, false, true); + if (callback_->require_type && current_equal) { + types_invalidated = false; + post = post_typed; + } else { + types_invalidated = true; + if (callback_->rewrite_once) { done[callback] = true; } }