From 9f668e6f3602de36d57772ecfe7ceaf0ad716b44 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 28 Jun 2024 14:59:47 -0500 Subject: [PATCH] [Relax][Refactor] Reorganize pattern-matching A follow-up to https://github.com/apache/tvm/pull/16730. Now that the implementations for `rewrite_call` and `rewrite_bindings` are in separate classes, they can be further split out into separate files. --- src/relax/ir/dataflow_block_rewriter.cc | 489 +++++++++++++ src/relax/ir/dataflow_expr_rewriter.cc | 222 ++++++ src/relax/ir/dataflow_matcher.cc | 663 +----------------- ...flow_matcher_impl.h => dataflow_matcher.h} | 3 + 4 files changed, 733 insertions(+), 644 deletions(-) create mode 100644 src/relax/ir/dataflow_block_rewriter.cc create mode 100644 src/relax/ir/dataflow_expr_rewriter.cc rename src/relax/ir/{dataflow_matcher_impl.h => dataflow_matcher.h} (97%) diff --git a/src/relax/ir/dataflow_block_rewriter.cc b/src/relax/ir/dataflow_block_rewriter.cc new file mode 100644 index 0000000000000..c5af0493a54c6 --- /dev/null +++ b/src/relax/ir/dataflow_block_rewriter.cc @@ -0,0 +1,489 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/ir/dataflow_block_rewriter.cc + * \brief A transform to match a Relax DataflowBlock and rewrite + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "dataflow_matcher.h" + +namespace tvm { +namespace relax { + +class MatcherUseDefAnalysis : public relax::ExprVisitor { + public: + std::vector vars; + std::map> def2use; + // caller -> callee table. + std::map> caller2callees; + + const VarNode* cur_user_; + + void VisitBinding_(const VarBindingNode* binding) override { + // init + cur_user_ = binding->var.get(); + this->VisitVarDef(binding->var); + this->VisitExpr(binding->value); + cur_user_ = nullptr; + } + + void VisitExpr_(const VarNode* op) override { + if (nullptr == cur_user_) return; + + auto check_and_push = [](std::vector& vec, const VarNode* var) { + if (std::find(vec.begin(), vec.end(), var) == vec.end()) { + vec.push_back(var); + } + }; + + check_and_push(def2use[op], cur_user_); + check_and_push(vars, op); + + caller2callees[cur_user_].push_back(op); + } +}; + +struct PNode { + const DFPatternNode* ptr; + std::vector&>> children; + std::vector&>> parents; +}; + +struct RNode { + const VarNode* ptr; + std::vector children; + std::vector parents; +}; + +struct MatchState { + void add(const PNode* p, const RNode* r) { + match_p_r[p] = r; + match_r_p[r] = p; + } + + void add(const DFConstraintNode* constraint) { validated_constraints_.insert(constraint); } + + void add(MatchState&& other) { + match_p_r.merge(std::move(other.match_p_r)); + match_r_p.merge(std::move(other.match_r_p)); + validated_constraints_.merge(other.validated_constraints_); + } + + const VarNode* matched(const PNode* p) const { + if (auto it = match_p_r.find(p); it != match_p_r.end()) { + return it->second->ptr; + } + return nullptr; + } + + const DFPatternNode* matched(const RNode* r) const { + if (auto it = match_r_p.find(r); it != match_r_p.end()) { + return it->second->ptr; + } + return nullptr; + } + + const VarNode* matched(const PNode& p) const { return matched(&p); } + const DFPatternNode* matched(const RNode& r) const { return matched(&r); } + + bool is_validated(const DFConstraintNode* constraint) const { + return validated_constraints_.count(constraint); + } + + private: + std::unordered_map match_p_r; + std::unordered_map match_r_p; + std::unordered_set validated_constraints_; +}; + +/** + * \brief This method try to match a real node and a pattern node along with its neighbors. + */ +static std::optional TryMatch(const PNode& p, const RNode& r, + const MatchState& current_match, DFPatternMatcher* m, + const MatcherUseDefAnalysis& ud_analysis) { + if (!m->Match(GetRef(p.ptr), GetRef(r.ptr))) return std::nullopt; + + MatchState new_match; + + new_match.add(&p, &r); + + // forward matching; + for (const auto& [pchild, constraints] : p.children) { + bool any_cons_sat = false; + for (const auto& rchild : r.children) { + if (new_match.matched(rchild)) { + // The child variable is already matched to other child pattern in a previous iteration. + continue; + } + if (auto v = current_match.matched(pchild); v && v != rchild->ptr) { + // The child pattern is already matched to other variable in a earlier call to TryMatch. + continue; + } + + const auto& uses = ud_analysis.def2use.at(r.ptr); + + // check edge constraints. + bool all_cons_pass = true; + for (const auto& cons : constraints) { + if (cons.type == PairCons::kOnlyUsedBy && uses.size() != 1) { + all_cons_pass = false; + break; + } + + if (cons.index != -1) { + const auto& callees = ud_analysis.caller2callees.at(rchild->ptr); + if (callees.size() <= static_cast(cons.index) || callees[cons.index] != r.ptr) { + all_cons_pass = false; + break; + } + } + } + if (!all_cons_pass || new_match.matched(pchild)) continue; + any_cons_sat = true; + + if (auto match_rec = TryMatch(*pchild, *rchild, current_match, m, ud_analysis)) { + new_match.add(pchild, rchild); + new_match.add(std::move(*match_rec)); + } + } + if (!new_match.matched(pchild) || !any_cons_sat) return std::nullopt; + } + + return new_match; +} + +static std::optional TryValidate( + const MatchState& current_match, + const std::unordered_map& pattern2node, + const std::vector& validation_constraints, arith::Analyzer* analyzer) { + MatchState new_match; + + std::function(const DFPatternNode*)> query_match_state = + [&pattern2node, ¤t_match](const DFPatternNode* pattern) -> Optional { + auto it = pattern2node.find(pattern); + ICHECK(it != pattern2node.end()) + << "DFConstraint attempted to access DFPattern " << GetRef(pattern) + << ", which does not appear in the PatternContext"; + const auto& p_node = it->second; + if (auto ptr = current_match.matched(p_node)) { + return GetRef(ptr); + } else { + return NullOpt; + } + }; + + for (const auto& constraint : validation_constraints) { + if (!current_match.is_validated(constraint.get())) { + auto [necessary_condition, is_sufficient] = constraint->AsPrimExpr(query_match_state); + + necessary_condition = analyzer->Simplify(necessary_condition); + const auto* known = tir::as_const_int(necessary_condition); + + if (known && *known && is_sufficient) { + // The condition passes, and the expression provided is both + // necessary and sufficient for the constraint to pass. Mark + // the constraint as passing, to avoid re-checking it unless + // we backtrack. + new_match.add(constraint.get()); + } else if (known && !*known) { + // The condition fails. Even if additional information would + // be required to pass a constraint, it may bail out early as + // a failure (e.g. shape mismatch in the first two items out + // of N shapes that must all match). + return std::nullopt; + } else if (is_sufficient) { + // The condition depends on dynamic parameters. In the + // future, this may be exposed to the user as a condition for + // optimization, or can be combined with the conditions + // provided from other constraints. + return std::nullopt; + } + } + } + + return new_match; +} + +static std::optional MatchTree( + const MatchState& current_match, size_t current_root_idx, + const std::unordered_map& pattern2node, + const std::unordered_map& var2node, DFPatternMatcher* matcher, + const std::vector& roots, const std::vector& validation_constraints, + const MatcherUseDefAnalysis& ud_analysis, arith::Analyzer* analyzer) { + auto get_next_root = [&](size_t root_idx) -> const PNode* { + // Look for the next unmatched root node. + for (; root_idx < roots.size(); ++root_idx) { + const auto& root = pattern2node.at(roots[root_idx].get()); + if (!current_match.matched(root)) { + return &root; + } + } + return nullptr; + }; + + const auto root = get_next_root(current_root_idx); + + if (!root) { + // All root nodes have been matched + return current_match; + } + + MatchState new_match = current_match; + + for (const auto& var : ud_analysis.vars) { + const RNode& r_node = var2node.at(var); + if (new_match.matched(r_node)) continue; + if (auto match = TryMatch(*root, r_node, new_match, matcher, ud_analysis)) { + // Recursively try to match the next subtree. + new_match.add(std::move(*match)); + if (auto validation = + TryValidate(new_match, pattern2node, validation_constraints, analyzer)) { + new_match.add(std::move(*validation)); + if (auto match_rec = + MatchTree(new_match, current_root_idx + 1, pattern2node, var2node, matcher, roots, + validation_constraints, ud_analysis, analyzer)) { + new_match.add(std::move(*match_rec)); + return new_match; + } + } + // Recursive matching has failed, backtrack. + new_match = current_match; + continue; + } + } + + return std::nullopt; +} + +Optional> MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb, + const Map& bindings) { + // TODO(@ganler): Handle non-may external use. + ICHECK(ctx->allow_extern_use == PatternContextNode::kMay) << "Only kMay is supported yet."; + DFPatternMatcher matcher(bindings); + + MatcherUseDefAnalysis ud_analysis; + ud_analysis.VisitBindingBlock_(dfb.get()); + + // First construct a graph of PNode and RNode. + std::unordered_map var2node; + var2node.reserve(dfb->bindings.size()); + + for (const VarNode* cur_var : ud_analysis.vars) { + const auto& uses = ud_analysis.def2use.at(cur_var); + RNode& cur_node = var2node[cur_var]; + cur_node.ptr = cur_var; + for (const VarNode* use : uses) { + auto& use_node = var2node[use]; + use_node.ptr = use; + cur_node.children.push_back(&use_node); + use_node.parents.push_back(&cur_node); + } + } + + std::unordered_map pattern2node; + pattern2node.reserve(ctx->edge_constraints.size()); + + for (const auto& def_pattern : ctx->src_ordered) { + PNode& def_node = pattern2node[def_pattern.get()]; + const auto& uses = ctx->edge_constraints.at(def_pattern); + def_node.ptr = def_pattern.get(); + def_node.children.reserve(uses.size()); + for (const auto& [use_pattern, cons] : uses) { + PNode& use_node = pattern2node[use_pattern.get()]; + use_node.ptr = use_pattern.get(); + use_node.parents.emplace_back(&def_node, std::ref(cons)); + def_node.children.emplace_back(&use_node, std::ref(cons)); + } + } + + std::vector roots; + for (const auto& pat : ctx->src_ordered) { + if (pattern2node[pat.get()].parents.empty()) { + roots.push_back(pat); + } + } + + if (roots.empty()) { + return NullOpt; + } + + arith::Analyzer analyzer; + auto match = MatchTree({}, 0, pattern2node, var2node, &matcher, roots, + ctx->validation_constraints, ud_analysis, &analyzer); + if (!match) { + return NullOpt; + } + + Map ret; + for (const auto& [pat, p_node] : pattern2node) { + ICHECK(match->matched(p_node)); + ret.Set(GetRef(pat), GetRef(match->matched(p_node))); + } + return ret; +} + +Optional> MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb) { + return MatchGraph(ctx, dfb, AnalyzeVar2Value(dfb)); +} + +TVM_REGISTER_GLOBAL("relax.dpl.match_dfb") + .set_body_typed([](const PatternContext& ctx, const DataflowBlock& dfb) { + return MatchGraph(ctx, dfb); + }); + +/*! + * \brief Apply pattern matching to each dataflow block, replacing matches + * with the output of a user-provided rewriter function. + */ +class BlockPatternRewriter : ExprMutator { + public: + using ExprMutator::VisitBindingBlock_; + using ExprMutator::VisitExpr_; + + BlockPatternRewriter( + const PatternContext& ctx, + TypedPackedFunc(Map, Map)> rewriter_func) + : ctx_(ctx), rewriter_func_(rewriter_func) {} + + template + static Function Run( + PatternType pat, + TypedPackedFunc(Map, Map)> rewriter_func, + Function func) { + BlockPatternRewriter rewriter(pat, rewriter_func); + + func = Downcast(rewriter(func)); + func = Downcast(RemoveAllUnused(func)); + return func; + } + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block_node) override { + return RewriteDataflowBlockFixedPoint(GetRef(block_node)); + } + + private: + void EmitUsedVars(Expr val, const Array& pending_bindings, + std::unordered_set* emitted_vars) { + std::unordered_set unemitted_vars; + PostOrderVisit(val, [=, &unemitted_vars](Expr e) { + if (auto v = e.as(); v && !emitted_vars->count(v)) { + unemitted_vars.insert(v); + } + }); + + if (unemitted_vars.empty()) { + return; + } + + size_t num_unemitted = unemitted_vars.size(); + for (size_t i = 0; i < pending_bindings.size(); ++i) { + const auto& binding = pending_bindings[i]; + if (auto var_bind = binding.as(); + var_bind && unemitted_vars.count(var_bind->var.get())) { + // var_bind->value may also depend on other unemitted vars in this range + Array prev_bindings(pending_bindings.begin(), pending_bindings.begin() + i); + EmitUsedVars(var_bind->value, prev_bindings, emitted_vars); + this->VisitBinding(binding); + emitted_vars->insert(var_bind->var.get()); + if (--num_unemitted == 0) { + return; + } + } + } + } + + // Repeat until all matchable subsets of bindings are rewritten. + BindingBlock RewriteDataflowBlockFixedPoint(BindingBlock block) { + auto df_block = Downcast(block); + Map bindings = AnalyzeVar2Value(df_block); + if (auto matches = MatchGraph(ctx_, df_block, bindings)) { + builder_->BeginDataflowBlock(); + Map replacements = rewriter_func_(matches.value(), bindings); + + std::unordered_set emitted_vars; + + bool changed = false; + for (size_t i = 0; i < block->bindings.size(); ++i) { + const auto& binding = block->bindings[i]; + if (auto var_bind = binding.as()) { + if (auto new_val = replacements.Get(var_bind->var).value_or(var_bind->value); + !StructuralEqual()(var_bind->value, new_val)) { + Array pending_bindings(block->bindings.begin() + i + 1, block->bindings.end()); + // Make sure there is no unbound variable used in the new value before it is emitted + EmitUsedVars(new_val, pending_bindings, &emitted_vars); + this->ReEmitBinding(var_bind, builder_->Normalize(new_val)); + changed = true; + } else if (!emitted_vars.count(var_bind->var.get())) { + this->VisitBinding(binding); + emitted_vars.insert(var_bind->var.get()); + } + } else { + this->VisitBinding(binding); + } + } + + auto new_block = builder_->EndBlock(); + + if (!changed) return new_block; + return RewriteDataflowBlockFixedPoint(new_block); + } + return block; + } + + /*! \brief The pattern constraint contexts for rewriting dataflow blocks */ + PatternContext ctx_; + /*! + * \brief The user-provided rewriter function. Its signature and semantics are: + * + * - (Map, Map) -> Map + * + * Given the map of patterns and corresponding variables (bound + * variables or parameters), it should return a map that + * specifies new values for matched bound variables. It can refer + * to the passed bindings to create the replacement expressions. + */ + TypedPackedFunc(Map, Map)> rewriter_func_; +}; + +Function RewriteBindings( + const PatternContext& ctx, + TypedPackedFunc(Map, Map)> rewriter, Function func) { + return BlockPatternRewriter::Run(ctx, rewriter, func); +} + +TVM_REGISTER_GLOBAL("relax.dpl.rewrite_bindings").set_body_typed(RewriteBindings); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/dataflow_expr_rewriter.cc b/src/relax/ir/dataflow_expr_rewriter.cc new file mode 100644 index 0000000000000..4793d1d75a300 --- /dev/null +++ b/src/relax/ir/dataflow_expr_rewriter.cc @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/ir/dataflow_expr_rewriter.cc + * \brief A transform to match a Relax Expr and rewrite + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../transform/utils.h" +#include "dataflow_matcher.h" + +namespace tvm { +namespace relax { + +Optional> ExtractMatchedExpr(DFPattern pattern, Expr expr, + Optional> bindings_opt) { + auto bindings = bindings_opt.value_or({}); + DFPatternMatcher matcher(bindings); + + if (!matcher.Match(pattern, expr)) { + return NullOpt; + } + + Map matching; + for (const auto& [pat, matches] : matcher.GetMemo()) { + ICHECK_EQ(matches.size(), 1) << "More than one match for the pattern " << pat; + matching.Set(pat, matches[0]); + } + return matching; +} + +TVM_REGISTER_GLOBAL("relax.dpl.extract_matched_expr").set_body_typed(ExtractMatchedExpr); + +bool MatchExpr(DFPattern pattern, Expr expr, Optional> bindings_opt) { + return static_cast(ExtractMatchedExpr(pattern, expr, bindings_opt)); +} + +TVM_REGISTER_GLOBAL("relax.dpl.match_expr").set_body_typed(MatchExpr); + +/*! + * \brief Apply pattern matching to each expression, replacing + * matches with the output of a user-provided rewriter function. + */ +class ExprPatternRewriter : ExprMutator { + public: + using ExprMutator::VisitExpr_; + + ExprPatternRewriter(DFPattern pat, + TypedPackedFunc)> rewriter_func) + : pattern_(pat), rewriter_func_(rewriter_func) {} + + template + static Function Run(PatternType pat, + TypedPackedFunc)> rewriter_func, + Function func) { + ExprPatternRewriter rewriter(pat, rewriter_func); + func = Downcast(rewriter(func)); + func = Downcast(RemoveAllUnused(func)); + return func; + } + + Expr VisitExpr_(const SeqExprNode* seq) override { + auto cache = bindings_; + SeqExpr prev = GetRef(seq); + + StructuralEqual struct_equal; + + while (true) { + SeqExpr next = Downcast(builder_->Normalize(ExprMutator::VisitExpr_(prev.get()))); + if (struct_equal(prev, next)) { + return std::move(next); + } + + // Canonicalization may result in two previously-different + // expressions being recognized as identical. Elimination of + // common subexpressions may result in trival var-to-var + // bindings that can be canonicalized. Therefore, iterate the + // simplification steps until converged. + while (true) { + auto start_of_loop = next; + next = Downcast(CanonicalizeBindings(next)); + next = Downcast(EliminateCommonSubexpr(next)); + next = Downcast(RemoveAllUnused(next)); + if (struct_equal(start_of_loop, next)) { + break; + } + } + + if (struct_equal(prev, next)) { + return std::move(next); + } + + // Reset all knowledge of bindings that were collected from + // this SeqExpr. The collected bindings are only after + // the point where they were collected, and we are repeating + // the mutation of this SeqExpr. + bindings_ = cache; + prev = next; + } + } + + void VisitBinding_(const VarBindingNode* binding) override { + auto expr = VisitExpr(binding->value); + bindings_.Set(binding->var, expr); + ReEmitBinding(binding, expr); + } + + Expr VisitExpr(const Expr& expr) override { + auto node = ExprMutator::VisitExpr(expr); + + std::vector matches_top_level; + if (auto rewritten = TryRewrite(node, pattern_, &matches_top_level)) { + return builder_->Normalize(rewritten.value()); + } + + return node; + } + + private: + Optional TryRewrite(const Expr& expr, const DFPattern& pattern, + std::vector* matches_top_level) { + ICHECK(matches_top_level); + + // Special handling if the user-supplied pattern is a `OrPattern`. + // While the `ExtractMatchedExpr` can handle matching the + // `OrPattern`, it will return on the first match, even if the + // `rewriter_func_` doesn't apply a replacement. Unpacking the + // `OrPattern` here allows the match to be resumed if + // `rewriter_func_` returns the original function unmodified. + // This is only valid for a top-level match. + if (auto or_pattern = pattern.as()) { + matches_top_level->push_back(pattern); + Optional output = TryRewrite(expr, or_pattern->left, matches_top_level); + if (!output.defined()) { + output = TryRewrite(expr, or_pattern->right, matches_top_level); + } + matches_top_level->pop_back(); + return output; + } + + if (auto opt_matches = ExtractMatchedExpr(pattern, expr, bindings_)) { + auto matches = opt_matches.value(); + + // Append any additional matches that from the unwrapped + // `OrPattern`. When matching against `pat = pat_lhs | + // pat_rhs`, we call `ExtractMatchedExpr` on `pat_lhs` and + // `pat_rhs` separately. The top-level `pat` is never seen by + // `ExtractMatchedExpr`, and must be re-added afterward. + if (matches_top_level->size()) { + auto matched_expr = DFPatternMatcher::UnwrapBindings(expr, bindings_); + for (const auto& pat : *matches_top_level) { + matches.Set(pat, matched_expr); + } + } + + Expr rewritten_expr = rewriter_func_(expr, matches); + if (!rewritten_expr.same_as(expr)) { + return builder_->Normalize(rewritten_expr); + } + } + + return NullOpt; + } + + /*! \brief The pattern for rewriting call nodes */ + DFPattern pattern_; + /*! + * \brief The user-provided rewriter function. Its signature and semantics are: + * + * - (Call, Map) -> Call + * + * Given the matched call node and the map of patterns and + * matched expressions, it should return a new call node to + * replace the original one or the original matched call node as + * is. + */ + TypedPackedFunc)> rewriter_func_; + + /*! \brief The known variable bindings + * + * The variable bindings whose value is known. This must be tracked + * separately from the block builder, so that it can be reset after + * each iteration of the mutate-until-converged loop applied to + * `SeqExpr`. + */ + Map bindings_; +}; + +Function RewriteCall(const DFPattern& pat, + TypedPackedFunc)> rewriter, Function func) { + return ExprPatternRewriter::Run(pat, rewriter, func); +} + +TVM_REGISTER_GLOBAL("relax.dpl.rewrite_call").set_body_typed(RewriteCall); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index c0b8d1e1df08b..989c1174f41da 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -22,6 +22,8 @@ * \brief The dataflow pattern matcher for Relax. */ +#include "dataflow_matcher.h" + #include #include #include @@ -45,7 +47,6 @@ #include "../../arith/constraint_extract.h" #include "../transform/utils.h" -#include "dataflow_matcher_impl.h" namespace tvm { namespace relax { @@ -59,7 +60,7 @@ bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) { return VisitDFPattern(pattern, expr); } -static Expr TryGetValOfVar(Expr expr, const Map& var2val) { +Expr DFPatternMatcher::UnwrapBindings(Expr expr, const Map& var2val) { auto unwrap = [&](Expr expr) -> Optional { // Unwrap variables into the value to which they are bound. if (var2val.size()) { @@ -98,7 +99,7 @@ void DFPatternMatcher::ClearMap(size_t watermark) { bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr0) { CHECK(pattern.defined()) << "Null pattern found when matching against " << expr0; - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); if (memoize_ && memo_.count(pattern)) { ICHECK_EQ(memo_[pattern].size(), 1); return expr.same_as(memo_[pattern][0]); @@ -118,17 +119,17 @@ bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr } bool DFPatternMatcher::VisitDFPattern_(const OrPatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr); } bool DFPatternMatcher::VisitDFPattern_(const AndPatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); return VisitDFPattern(op->left, expr) && VisitDFPattern(op->right, expr); } bool DFPatternMatcher::VisitDFPattern_(const NotPatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); return !VisitDFPattern(op->reject, expr); } @@ -183,7 +184,7 @@ bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) { } bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); bool matches = VisitDFPattern(attr_pattern->pattern, expr); if (!matches) return matches; VLOG(1) << "considering AttrPatternNode at:\n" << expr; @@ -241,7 +242,7 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons } bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); // utilities auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* { if (op) { @@ -351,12 +352,12 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex } bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); return StructuralEqual()(op->expr, expr); } bool DFPatternMatcher::VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); bool matches = false; if (const auto* func = expr.as()) { matches = true; @@ -379,7 +380,7 @@ bool DFPatternMatcher::VisitDFPattern_(const FunctionPatternNode* op, const Expr } bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); if (const auto* tuple_get_item_node = expr.as()) { return (op->index == -1 || op->index == tuple_get_item_node->index) && VisitDFPattern(op->tuple, tuple_get_item_node->tuple); @@ -388,7 +389,7 @@ bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const } bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); bool matches = false; if (const auto* tuple_node = expr.as()) { matches = true; @@ -429,7 +430,7 @@ bool DFPatternMatcher::TryUnorderedMatch(size_t idx, const tvm::Array } bool DFPatternMatcher::VisitDFPattern_(const UnorderedTuplePatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); if (const auto* tuple_node = expr.as()) { if (op->fields.size() == tuple_node->fields.size()) { @@ -449,7 +450,7 @@ bool DFPatternMatcher::VisitDFPattern_(const StructInfoPatternNode* op, const Ex return false; } - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); auto expr_struct_info = GetStructInfo(expr); PrimExpr new_constraint = StructInfoBaseCheckPrecondition(op->struct_info, expr_struct_info); @@ -497,7 +498,7 @@ PrimExpr DFPatternMatcher::SimplifyCondition(PrimExpr condition) { } bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); auto expr_type = expr.as()->checked_type(); return (StructuralEqual()(op->type, expr_type)) && VisitDFPattern(op->pattern, expr); } @@ -584,7 +585,7 @@ std::tuple SameShapeConstraintNode::AsPrimExpr( } bool DFPatternMatcher::VisitDFPattern_(const PrimArrPatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); if (const ShapeExprNode* shape_expr = expr.as()) return ShapeEqual(&analyzer_, op->fields, shape_expr->values); return false; @@ -609,7 +610,7 @@ bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& exp } bool DFPatternMatcher::VisitDFPattern_(const ExternFuncPatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); if (const auto* extern_fn = expr.as()) { return "" == op->global_symbol() || op->global_symbol() == extern_fn->global_symbol; } @@ -618,7 +619,7 @@ bool DFPatternMatcher::VisitDFPattern_(const ExternFuncPatternNode* op, const Ex bool DFPatternMatcher::VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr0) { // constants can be binded to relax.Var as well. - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); return expr.as() != nullptr; } @@ -642,631 +643,5 @@ bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr return true; } -Optional> ExtractMatchedExpr(DFPattern pattern, Expr expr, - Optional> bindings_opt) { - auto bindings = bindings_opt.value_or({}); - DFPatternMatcher matcher(bindings); - - if (!matcher.Match(pattern, expr)) { - return NullOpt; - } - - Map matching; - for (const auto& [pat, matches] : matcher.GetMemo()) { - ICHECK_EQ(matches.size(), 1) << "More than one match for the pattern " << pat; - matching.Set(pat, matches[0]); - } - return matching; -} - -TVM_REGISTER_GLOBAL("relax.dpl.extract_matched_expr").set_body_typed(ExtractMatchedExpr); - -bool MatchExpr(DFPattern pattern, Expr expr, Optional> bindings_opt) { - return static_cast(ExtractMatchedExpr(pattern, expr, bindings_opt)); -} - -TVM_REGISTER_GLOBAL("relax.dpl.match_expr").set_body_typed(MatchExpr); - -class MatcherUseDefAnalysis : public relax::ExprVisitor { - public: - std::vector vars; - std::map> def2use; - // caller -> callee table. - std::map> caller2callees; - - const VarNode* cur_user_; - - void VisitBinding_(const VarBindingNode* binding) override { - // init - cur_user_ = binding->var.get(); - this->VisitVarDef(binding->var); - this->VisitExpr(binding->value); - cur_user_ = nullptr; - } - - void VisitExpr_(const VarNode* op) override { - if (nullptr == cur_user_) return; - - auto check_and_push = [](std::vector& vec, const VarNode* var) { - if (std::find(vec.begin(), vec.end(), var) == vec.end()) { - vec.push_back(var); - } - }; - - check_and_push(def2use[op], cur_user_); - check_and_push(vars, op); - - caller2callees[cur_user_].push_back(op); - } -}; - -struct PNode { - const DFPatternNode* ptr; - std::vector&>> children; - std::vector&>> parents; -}; - -struct RNode { - const VarNode* ptr; - std::vector children; - std::vector parents; -}; - -struct MatchState { - void add(const PNode* p, const RNode* r) { - match_p_r[p] = r; - match_r_p[r] = p; - } - - void add(const DFConstraintNode* constraint) { validated_constraints_.insert(constraint); } - - void add(MatchState&& other) { - match_p_r.merge(std::move(other.match_p_r)); - match_r_p.merge(std::move(other.match_r_p)); - validated_constraints_.merge(other.validated_constraints_); - } - - const VarNode* matched(const PNode* p) const { - if (auto it = match_p_r.find(p); it != match_p_r.end()) { - return it->second->ptr; - } - return nullptr; - } - - const DFPatternNode* matched(const RNode* r) const { - if (auto it = match_r_p.find(r); it != match_r_p.end()) { - return it->second->ptr; - } - return nullptr; - } - - const VarNode* matched(const PNode& p) const { return matched(&p); } - const DFPatternNode* matched(const RNode& r) const { return matched(&r); } - - bool is_validated(const DFConstraintNode* constraint) const { - return validated_constraints_.count(constraint); - } - - private: - std::unordered_map match_p_r; - std::unordered_map match_r_p; - std::unordered_set validated_constraints_; -}; - -/** - * \brief This method try to match a real node and a pattern node along with its neighbors. - */ -static std::optional TryMatch(const PNode& p, const RNode& r, - const MatchState& current_match, DFPatternMatcher* m, - const MatcherUseDefAnalysis& ud_analysis) { - if (!m->Match(GetRef(p.ptr), GetRef(r.ptr))) return std::nullopt; - - MatchState new_match; - - new_match.add(&p, &r); - - // forward matching; - for (const auto& [pchild, constraints] : p.children) { - bool any_cons_sat = false; - for (const auto& rchild : r.children) { - if (new_match.matched(rchild)) { - // The child variable is already matched to other child pattern in a previous iteration. - continue; - } - if (auto v = current_match.matched(pchild); v && v != rchild->ptr) { - // The child pattern is already matched to other variable in a earlier call to TryMatch. - continue; - } - - const auto& uses = ud_analysis.def2use.at(r.ptr); - - // check edge constraints. - bool all_cons_pass = true; - for (const auto& cons : constraints) { - if (cons.type == PairCons::kOnlyUsedBy && uses.size() != 1) { - all_cons_pass = false; - break; - } - - if (cons.index != -1) { - const auto& callees = ud_analysis.caller2callees.at(rchild->ptr); - if (callees.size() <= static_cast(cons.index) || callees[cons.index] != r.ptr) { - all_cons_pass = false; - break; - } - } - } - if (!all_cons_pass || new_match.matched(pchild)) continue; - any_cons_sat = true; - - if (auto match_rec = TryMatch(*pchild, *rchild, current_match, m, ud_analysis)) { - new_match.add(pchild, rchild); - new_match.add(std::move(*match_rec)); - } - } - if (!new_match.matched(pchild) || !any_cons_sat) return std::nullopt; - } - - return new_match; -} - -static std::optional TryValidate( - const MatchState& current_match, - const std::unordered_map& pattern2node, - const std::vector& validation_constraints, arith::Analyzer* analyzer) { - MatchState new_match; - - std::function(const DFPatternNode*)> query_match_state = - [&pattern2node, ¤t_match](const DFPatternNode* pattern) -> Optional { - auto it = pattern2node.find(pattern); - ICHECK(it != pattern2node.end()) - << "DFConstraint attempted to access DFPattern " << GetRef(pattern) - << ", which does not appear in the PatternContext"; - const auto& p_node = it->second; - if (auto ptr = current_match.matched(p_node)) { - return GetRef(ptr); - } else { - return NullOpt; - } - }; - - for (const auto& constraint : validation_constraints) { - if (!current_match.is_validated(constraint.get())) { - auto [necessary_condition, is_sufficient] = constraint->AsPrimExpr(query_match_state); - - necessary_condition = analyzer->Simplify(necessary_condition); - const auto* known = tir::as_const_int(necessary_condition); - - if (known && *known && is_sufficient) { - // The condition passes, and the expression provided is both - // necessary and sufficient for the constraint to pass. Mark - // the constraint as passing, to avoid re-checking it unless - // we backtrack. - new_match.add(constraint.get()); - } else if (known && !*known) { - // The condition fails. Even if additional information would - // be required to pass a constraint, it may bail out early as - // a failure (e.g. shape mismatch in the first two items out - // of N shapes that must all match). - return std::nullopt; - } else if (is_sufficient) { - // The condition depends on dynamic parameters. In the - // future, this may be exposed to the user as a condition for - // optimization, or can be combined with the conditions - // provided from other constraints. - return std::nullopt; - } - } - } - - return new_match; -} - -static std::optional MatchTree( - const MatchState& current_match, size_t current_root_idx, - const std::unordered_map& pattern2node, - const std::unordered_map& var2node, DFPatternMatcher* matcher, - const std::vector& roots, const std::vector& validation_constraints, - const MatcherUseDefAnalysis& ud_analysis, arith::Analyzer* analyzer) { - auto get_next_root = [&](size_t root_idx) -> const PNode* { - // Look for the next unmatched root node. - for (; root_idx < roots.size(); ++root_idx) { - const auto& root = pattern2node.at(roots[root_idx].get()); - if (!current_match.matched(root)) { - return &root; - } - } - return nullptr; - }; - - const auto root = get_next_root(current_root_idx); - - if (!root) { - // All root nodes have been matched - return current_match; - } - - MatchState new_match = current_match; - - for (const auto& var : ud_analysis.vars) { - const RNode& r_node = var2node.at(var); - if (new_match.matched(r_node)) continue; - if (auto match = TryMatch(*root, r_node, new_match, matcher, ud_analysis)) { - // Recursively try to match the next subtree. - new_match.add(std::move(*match)); - if (auto validation = - TryValidate(new_match, pattern2node, validation_constraints, analyzer)) { - new_match.add(std::move(*validation)); - if (auto match_rec = - MatchTree(new_match, current_root_idx + 1, pattern2node, var2node, matcher, roots, - validation_constraints, ud_analysis, analyzer)) { - new_match.add(std::move(*match_rec)); - return new_match; - } - } - // Recursive matching has failed, backtrack. - new_match = current_match; - continue; - } - } - - return std::nullopt; -} - -Optional> MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb, - const Map& bindings) { - // TODO(@ganler): Handle non-may external use. - ICHECK(ctx->allow_extern_use == PatternContextNode::kMay) << "Only kMay is supported yet."; - DFPatternMatcher matcher(bindings); - - MatcherUseDefAnalysis ud_analysis; - ud_analysis.VisitBindingBlock_(dfb.get()); - - // First construct a graph of PNode and RNode. - std::unordered_map var2node; - var2node.reserve(dfb->bindings.size()); - - for (const VarNode* cur_var : ud_analysis.vars) { - const auto& uses = ud_analysis.def2use.at(cur_var); - RNode& cur_node = var2node[cur_var]; - cur_node.ptr = cur_var; - for (const VarNode* use : uses) { - auto& use_node = var2node[use]; - use_node.ptr = use; - cur_node.children.push_back(&use_node); - use_node.parents.push_back(&cur_node); - } - } - - std::unordered_map pattern2node; - pattern2node.reserve(ctx->edge_constraints.size()); - - for (const auto& def_pattern : ctx->src_ordered) { - PNode& def_node = pattern2node[def_pattern.get()]; - const auto& uses = ctx->edge_constraints.at(def_pattern); - def_node.ptr = def_pattern.get(); - def_node.children.reserve(uses.size()); - for (const auto& [use_pattern, cons] : uses) { - PNode& use_node = pattern2node[use_pattern.get()]; - use_node.ptr = use_pattern.get(); - use_node.parents.emplace_back(&def_node, std::ref(cons)); - def_node.children.emplace_back(&use_node, std::ref(cons)); - } - } - - std::vector roots; - for (const auto& pat : ctx->src_ordered) { - if (pattern2node[pat.get()].parents.empty()) { - roots.push_back(pat); - } - } - - if (roots.empty()) { - return NullOpt; - } - - arith::Analyzer analyzer; - auto match = MatchTree({}, 0, pattern2node, var2node, &matcher, roots, - ctx->validation_constraints, ud_analysis, &analyzer); - if (!match) { - return NullOpt; - } - - Map ret; - for (const auto& [pat, p_node] : pattern2node) { - ICHECK(match->matched(p_node)); - ret.Set(GetRef(pat), GetRef(match->matched(p_node))); - } - return ret; -} - -Optional> MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb) { - return MatchGraph(ctx, dfb, AnalyzeVar2Value(dfb)); -} - -TVM_REGISTER_GLOBAL("relax.dpl.match_dfb") - .set_body_typed([](const PatternContext& ctx, const DataflowBlock& dfb) { - return MatchGraph(ctx, dfb); - }); - -/*! - * \brief Apply pattern matching to each dataflow block, replacing matches - * with the output of a user-provided rewriter function. - */ -class BlockPatternRewriter : ExprMutator { - public: - using ExprMutator::VisitBindingBlock_; - using ExprMutator::VisitExpr_; - - BlockPatternRewriter( - const PatternContext& ctx, - TypedPackedFunc(Map, Map)> rewriter_func) - : ctx_(ctx), rewriter_func_(rewriter_func) {} - - template - static Function Run( - PatternType pat, - TypedPackedFunc(Map, Map)> rewriter_func, - Function func) { - BlockPatternRewriter rewriter(pat, rewriter_func); - - func = Downcast(rewriter(func)); - func = Downcast(RemoveAllUnused(func)); - return func; - } - - BindingBlock VisitBindingBlock_(const DataflowBlockNode* block_node) override { - return RewriteDataflowBlockFixedPoint(GetRef(block_node)); - } - - private: - void EmitUsedVars(Expr val, const Array& pending_bindings, - std::unordered_set* emitted_vars) { - std::unordered_set unemitted_vars; - PostOrderVisit(val, [=, &unemitted_vars](Expr e) { - if (auto v = e.as(); v && !emitted_vars->count(v)) { - unemitted_vars.insert(v); - } - }); - - if (unemitted_vars.empty()) { - return; - } - - size_t num_unemitted = unemitted_vars.size(); - for (size_t i = 0; i < pending_bindings.size(); ++i) { - const auto& binding = pending_bindings[i]; - if (auto var_bind = binding.as(); - var_bind && unemitted_vars.count(var_bind->var.get())) { - // var_bind->value may also depend on other unemitted vars in this range - Array prev_bindings(pending_bindings.begin(), pending_bindings.begin() + i); - EmitUsedVars(var_bind->value, prev_bindings, emitted_vars); - this->VisitBinding(binding); - emitted_vars->insert(var_bind->var.get()); - if (--num_unemitted == 0) { - return; - } - } - } - } - - // Repeat until all matchable subsets of bindings are rewritten. - BindingBlock RewriteDataflowBlockFixedPoint(BindingBlock block) { - auto df_block = Downcast(block); - Map bindings = AnalyzeVar2Value(df_block); - if (auto matches = MatchGraph(ctx_, df_block, bindings)) { - builder_->BeginDataflowBlock(); - Map replacements = rewriter_func_(matches.value(), bindings); - - std::unordered_set emitted_vars; - - bool changed = false; - for (size_t i = 0; i < block->bindings.size(); ++i) { - const auto& binding = block->bindings[i]; - if (auto var_bind = binding.as()) { - if (auto new_val = replacements.Get(var_bind->var).value_or(var_bind->value); - !StructuralEqual()(var_bind->value, new_val)) { - Array pending_bindings(block->bindings.begin() + i + 1, block->bindings.end()); - // Make sure there is no unbound variable used in the new value before it is emitted - EmitUsedVars(new_val, pending_bindings, &emitted_vars); - this->ReEmitBinding(var_bind, builder_->Normalize(new_val)); - changed = true; - } else if (!emitted_vars.count(var_bind->var.get())) { - this->VisitBinding(binding); - emitted_vars.insert(var_bind->var.get()); - } - } else { - this->VisitBinding(binding); - } - } - - auto new_block = builder_->EndBlock(); - - if (!changed) return new_block; - return RewriteDataflowBlockFixedPoint(new_block); - } - return block; - } - - /*! \brief The pattern constraint contexts for rewriting dataflow blocks */ - PatternContext ctx_; - /*! - * \brief The user-provided rewriter function. Its signature and semantics are: - * - * - (Map, Map) -> Map - * - * Given the map of patterns and corresponding variables (bound - * variables or parameters), it should return a map that - * specifies new values for matched bound variables. It can refer - * to the passed bindings to create the replacement expressions. - */ - TypedPackedFunc(Map, Map)> rewriter_func_; -}; - -/*! - * \brief Apply pattern matching to each expression, replacing - * matches with the output of a user-provided rewriter function. - */ -class ExprPatternRewriter : ExprMutator { - public: - using ExprMutator::VisitBindingBlock_; - using ExprMutator::VisitExpr_; - - ExprPatternRewriter(DFPattern pat, - TypedPackedFunc)> rewriter_func) - : pattern_(pat), rewriter_func_(rewriter_func) {} - - template - static Function Run(PatternType pat, - TypedPackedFunc)> rewriter_func, - Function func) { - ExprPatternRewriter rewriter(pat, rewriter_func); - func = Downcast(rewriter(func)); - func = Downcast(RemoveAllUnused(func)); - return func; - } - - Expr VisitExpr_(const SeqExprNode* seq) override { - auto cache = bindings_; - SeqExpr prev = GetRef(seq); - - StructuralEqual struct_equal; - - while (true) { - SeqExpr next = Downcast(builder_->Normalize(ExprMutator::VisitExpr_(prev.get()))); - if (struct_equal(prev, next)) { - return std::move(next); - } - - // Canonicalization may result in two previously-different - // expressions being recognized as identical. Elimination of - // common subexpressions may result in trival var-to-var - // bindings that can be canonicalized. Therefore, iterate the - // simplification steps until converged. - while (true) { - auto start_of_loop = next; - next = Downcast(CanonicalizeBindings(next)); - next = Downcast(EliminateCommonSubexpr(next)); - next = Downcast(RemoveAllUnused(next)); - if (struct_equal(start_of_loop, next)) { - break; - } - } - - if (struct_equal(prev, next)) { - return std::move(next); - } - - // Reset all knowledge of bindings that were collected from - // this SeqExpr. The collected bindings are only after - // the point where they were collected, and we are repeating - // the mutation of this SeqExpr. - bindings_ = cache; - prev = next; - } - } - - void VisitBinding_(const VarBindingNode* binding) override { - auto expr = VisitExpr(binding->value); - bindings_.Set(binding->var, expr); - ReEmitBinding(binding, expr); - } - - Expr VisitExpr(const Expr& expr) override { - auto node = ExprMutator::VisitExpr(expr); - - std::vector matches_top_level; - if (auto rewritten = TryRewrite(node, pattern_, &matches_top_level)) { - return builder_->Normalize(rewritten.value()); - } - - return node; - } - - private: - Optional TryRewrite(const Expr& expr, const DFPattern& pattern, - std::vector* matches_top_level) { - ICHECK(matches_top_level); - - // Special handling if the user-supplied pattern is a `OrPattern`. - // While the `ExtractMatchedExpr` can handle matching the - // `OrPattern`, it will return on the first match, even if the - // `rewriter_func_` doesn't apply a replacement. Unpacking the - // `OrPattern` here allows the match to be resumed if - // `rewriter_func_` returns the original function unmodified. - // This is only valid for a top-level match. - if (auto or_pattern = pattern.as()) { - matches_top_level->push_back(pattern); - Optional output = TryRewrite(expr, or_pattern->left, matches_top_level); - if (!output.defined()) { - output = TryRewrite(expr, or_pattern->right, matches_top_level); - } - matches_top_level->pop_back(); - return output; - } - - if (auto opt_matches = ExtractMatchedExpr(pattern, expr, bindings_)) { - auto matches = opt_matches.value(); - - // Append any additional matches that from the unwrapped - // `OrPattern`. When matching against `pat = pat_lhs | - // pat_rhs`, we call `ExtractMatchedExpr` on `pat_lhs` and - // `pat_rhs` separately. The top-level `pat` is never seen by - // `ExtractMatchedExpr`, and must be re-added afterward. - if (matches_top_level->size()) { - auto matched_expr = TryGetValOfVar(expr, bindings_); - for (const auto& pat : *matches_top_level) { - matches.Set(pat, matched_expr); - } - } - - Expr rewritten_expr = rewriter_func_(expr, matches); - if (!rewritten_expr.same_as(expr)) { - return builder_->Normalize(rewritten_expr); - } - } - - return NullOpt; - } - - /*! \brief The pattern for rewriting call nodes */ - DFPattern pattern_; - /*! - * \brief The user-provided rewriter function. Its signature and semantics are: - * - * - (Call, Map) -> Call - * - * Given the matched call node and the map of patterns and - * matched expressions, it should return a new call node to - * replace the original one or the original matched call node as - * is. - */ - TypedPackedFunc)> rewriter_func_; - - /*! \brief The known variable bindings - * - * The variable bindings whose value is known. This must be tracked - * separately from the block builder, so that it can be reset after - * each iteration of the mutate-until-converged loop applied to - * `SeqExpr`. - */ - Map bindings_; -}; - -Function RewriteBindings( - const PatternContext& ctx, - TypedPackedFunc(Map, Map)> rewriter, Function func) { - return BlockPatternRewriter::Run(ctx, rewriter, func); -} - -TVM_REGISTER_GLOBAL("relax.dpl.rewrite_bindings").set_body_typed(RewriteBindings); - -Function RewriteCall(const DFPattern& pat, - TypedPackedFunc)> rewriter, Function func) { - return ExprPatternRewriter::Run(pat, rewriter, func); -} - -TVM_REGISTER_GLOBAL("relax.dpl.rewrite_call").set_body_typed(RewriteCall); - } // namespace relax } // namespace tvm diff --git a/src/relax/ir/dataflow_matcher_impl.h b/src/relax/ir/dataflow_matcher.h similarity index 97% rename from src/relax/ir/dataflow_matcher_impl.h rename to src/relax/ir/dataflow_matcher.h index a0c35ac0deada..9036c7630a548 100644 --- a/src/relax/ir/dataflow_matcher_impl.h +++ b/src/relax/ir/dataflow_matcher.h @@ -45,6 +45,9 @@ class DFPatternMatcher : public DFPatternFunctor> GetMemo() { return Map>(memo_); } + /* \brief Unwrap trivial expressions/bindings */ + static Expr UnwrapBindings(Expr expr, const Map& bindings); + protected: bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override; bool VisitDFPattern_(const OrPatternNode* op, const Expr& expr) override;