Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax] Refactor PatternRewriter into separate Block/Expr mutators #16730

Merged
merged 1 commit into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion include/tvm/relax/dataflow_matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ TVM_DLL Optional<Map<DFPattern, Var>> MatchGraph(const PatternContext& ctx,
* \param f The function to rewrite
* \return The rewritten or the input function, depending on the pattern matching result.
*/
TVM_DLL Function RewriteBindings(const PatternContext& ctx, PackedFunc rewriter, Function f);
TVM_DLL Function RewriteBindings(
const PatternContext& ctx,
TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>, Map<Var, Expr>)> rewriter, Function f);

/**
* \brief Rewrite a function with the given pattern and the rewriter function.
Expand Down
238 changes: 137 additions & 101 deletions src/relax/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -973,102 +973,33 @@ TVM_REGISTER_GLOBAL("relax.dpl.match_dfb")
});

/*!
* \brief Apply pattern matching to each call node and dataflow block, and replace matching ones
* \brief Apply pattern matching to each dataflow block, replacing matches
* with the output of a user-provided rewriter function.
*/
class PatternRewriter : ExprMutator {
class BlockPatternRewriter : ExprMutator {
public:
using ExprMutator::VisitBindingBlock_;
using ExprMutator::VisitExpr_;

PatternRewriter(DFPattern pat, PackedFunc rewriter_func,
const std::unordered_set<const VarNode*>& params)
: pattern_(pat), rewriter_func_(rewriter_func), params_(params) {}

PatternRewriter(const PatternContext& ctx, PackedFunc rewriter_func,
const std::unordered_set<const VarNode*>& params)
: ctx_(ctx), rewriter_func_(rewriter_func), params_(params) {}
BlockPatternRewriter(
const PatternContext& ctx,
TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>, Map<Var, Expr>)> rewriter_func)
: ctx_(ctx), rewriter_func_(rewriter_func) {}

template <typename PatternType>
static Function Run(PatternType pat, PackedFunc rewriter_func, Function f) {
std::unordered_set<const VarNode*> params;
for (const auto& p : f->params) {
params.insert(p.get());
}
PatternRewriter rewriter(pat, rewriter_func, params);
return Downcast<Function>(RemoveAllUnused(rewriter.VisitExpr(f)));
}

Expr VisitExpr_(const SeqExprNode* seq) override {
if (ctx_) {
return ExprMutator::VisitExpr_(seq);
}

auto cache = bindings_;
SeqExpr prev = GetRef<SeqExpr>(seq);

StructuralEqual struct_equal;

while (true) {
SeqExpr next = Downcast<SeqExpr>(builder_->Normalize(ExprMutator::VisitExpr_(prev.get())));
if (struct_equal(prev, next)) {
return std::move(next);
}

// Canonicalization may result in two previously-different
// expressions being recognized as identical. Elimination of
// common subexpressions may result in trival var-to-var
// bindings that can be canonicalized. Therefore, iterate the
// simplification steps until converged.
while (true) {
auto start_of_loop = next;
next = Downcast<SeqExpr>(CanonicalizeBindings(next));
next = Downcast<SeqExpr>(EliminateCommonSubexpr(next));
next = Downcast<SeqExpr>(RemoveAllUnused(next));
if (struct_equal(start_of_loop, next)) {
break;
}
}

if (struct_equal(prev, next)) {
return std::move(next);
}

// Reset all knowledge of bindings that were collected from
// this DataflowBlock. The collected bindings are only after
// the point where they were collected, and we are repeating
// the mutation of this DataflowBlock.
bindings_ = cache;
prev = next;
}
static Function Run(
PatternType pat,
TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>, Map<Var, Expr>)> rewriter_func,
Function func) {
BlockPatternRewriter rewriter(pat, rewriter_func);

func = Downcast<Function>(rewriter(func));
func = Downcast<Function>(RemoveAllUnused(func));
return func;
}

BindingBlock VisitBindingBlock_(const DataflowBlockNode* block_node) override {
if (ctx_) {
return RewriteDataflowBlockFixedPoint(GetRef<DataflowBlock>(block_node));
} else {
return ExprMutator::VisitBindingBlock_(block_node);
}
}

void VisitBinding_(const VarBindingNode* binding) override {
auto expr = VisitExpr(binding->value);
bindings_.Set(binding->var, expr);
ReEmitBinding(binding, expr);
}

Expr VisitExpr(const Expr& expr) override {
auto node = ExprMutator::VisitExpr(expr);

if (pattern_) {
if (auto matches_opt = ExtractMatchedExpr(pattern_.value(), node, bindings_)) {
Expr rewritten_expr = rewriter_func_(node, matches_opt.value());
if (!rewritten_expr.same_as(node)) {
return builder_->Normalize(rewritten_expr);
}
}
}
return node;
return RewriteDataflowBlockFixedPoint(GetRef<DataflowBlock>(block_node));
}

private:
Expand Down Expand Up @@ -1106,7 +1037,7 @@ class PatternRewriter : ExprMutator {
BindingBlock RewriteDataflowBlockFixedPoint(BindingBlock block) {
auto df_block = Downcast<DataflowBlock>(block);
Map<Var, Expr> bindings = AnalyzeVar2Value(df_block);
if (auto matches = MatchGraph(ctx_.value(), df_block, bindings)) {
if (auto matches = MatchGraph(ctx_, df_block, bindings)) {
builder_->BeginDataflowBlock();
Map<Var, Expr> replacements = rewriter_func_(matches.value(), bindings);

Expand Down Expand Up @@ -1140,34 +1071,139 @@ class PatternRewriter : ExprMutator {
return block;
}

/*! \brief The pattern for rewriting call nodes */
Optional<DFPattern> pattern_;
/*! \brief The pattern constraint contexts for rewriting dataflow blocks */
Optional<PatternContext> ctx_;
PatternContext ctx_;
/*!
* \brief The user-provided rewriter function. Its signature and semantics are:
* - (Call, Map<DFPattern, Expr>) -> Call for call node rewriting. Given the matched
* call node and the map of patterns and matched expressions, it should return a new call node
* to replace the original one or the original matched call node as is.
* - (Map<DFPattern, Var>, Map<Var, Expr>) -> Map<Var, Expr> for dataflow block rewriting.
* Given the map of patterns and corresponding variables (bound variables or parameters),
* it should return a map that specifies new values for matched bound variables. It can refer
*
* - (Map<DFPattern, Var>, Map<Var, Expr>) -> Map<Var, Expr>
*
* Given the map of patterns and corresponding variables (bound
* variables or parameters), it should return a map that
* specifies new values for matched bound variables. It can refer
* to the passed bindings to create the replacement expressions.
*/
PackedFunc rewriter_func_;
std::unordered_set<const VarNode*> params_;
TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>, Map<Var, Expr>)> rewriter_func_;
};

/*!
* \brief Apply pattern matching to each expression, replacing
* matches with the output of a user-provided rewriter function.
*/
class ExprPatternRewriter : ExprMutator {
public:
using ExprMutator::VisitBindingBlock_;
using ExprMutator::VisitExpr_;

ExprPatternRewriter(DFPattern pat,
TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)> rewriter_func)
: pattern_(pat), rewriter_func_(rewriter_func) {}

template <typename PatternType>
static Function Run(PatternType pat,
TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)> rewriter_func,
Function func) {
ExprPatternRewriter rewriter(pat, rewriter_func);
func = Downcast<Function>(rewriter(func));
func = Downcast<Function>(RemoveAllUnused(func));
return func;
}

Expr VisitExpr_(const SeqExprNode* seq) override {
auto cache = bindings_;
SeqExpr prev = GetRef<SeqExpr>(seq);

StructuralEqual struct_equal;

while (true) {
SeqExpr next = Downcast<SeqExpr>(builder_->Normalize(ExprMutator::VisitExpr_(prev.get())));
if (struct_equal(prev, next)) {
return std::move(next);
}

// Canonicalization may result in two previously-different
// expressions being recognized as identical. Elimination of
// common subexpressions may result in trival var-to-var
// bindings that can be canonicalized. Therefore, iterate the
// simplification steps until converged.
while (true) {
auto start_of_loop = next;
next = Downcast<SeqExpr>(CanonicalizeBindings(next));
next = Downcast<SeqExpr>(EliminateCommonSubexpr(next));
next = Downcast<SeqExpr>(RemoveAllUnused(next));
if (struct_equal(start_of_loop, next)) {
break;
}
}

if (struct_equal(prev, next)) {
return std::move(next);
}

// Reset all knowledge of bindings that were collected from
// this SeqExpr. The collected bindings are only after
// the point where they were collected, and we are repeating
// the mutation of this SeqExpr.
bindings_ = cache;
prev = next;
}
}

void VisitBinding_(const VarBindingNode* binding) override {
auto expr = VisitExpr(binding->value);
bindings_.Set(binding->var, expr);
ReEmitBinding(binding, expr);
}

Expr VisitExpr(const Expr& expr) override {
auto node = ExprMutator::VisitExpr(expr);

if (auto matches_opt = ExtractMatchedExpr(pattern_, node, bindings_)) {
Expr rewritten_expr = rewriter_func_(node, matches_opt.value());
if (!rewritten_expr.same_as(node)) {
return builder_->Normalize(rewritten_expr);
}
}

return node;
}

private:
/*! \brief The pattern for rewriting call nodes */
DFPattern pattern_;
/*!
* \brief The user-provided rewriter function. Its signature and semantics are:
*
* - (Call, Map<DFPattern, Expr>) -> Call
*
* Given the matched call node and the map of patterns and
* matched expressions, it should return a new call node to
* replace the original one or the original matched call node as
* is.
*/
TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)> rewriter_func_;

/*! \brief The known variable bindings
*
* The variable bindings whose value is known. This must be tracked
* separately from the block builder, so that it can be reset after
* each iteration of the mutate-until-converged loop applied to
* `SeqExpr`.
*/
Map<Var, Expr> bindings_;
};

Function RewriteBindings(const PatternContext& ctx, PackedFunc rewriter, Function f) {
return PatternRewriter::Run(ctx, rewriter, f);
Function RewriteBindings(
const PatternContext& ctx,
TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>, Map<Var, Expr>)> rewriter, Function func) {
return BlockPatternRewriter::Run(ctx, rewriter, func);
}

TVM_REGISTER_GLOBAL("relax.dpl.rewrite_bindings").set_body_typed(RewriteBindings);

Function RewriteCall(const DFPattern& pat,
TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)> rewriter, Function f) {
return PatternRewriter::Run(pat, rewriter, f);
TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)> rewriter, Function func) {
return ExprPatternRewriter::Run(pat, rewriter, func);
}

TVM_REGISTER_GLOBAL("relax.dpl.rewrite_call").set_body_typed(RewriteCall);
Expand Down
Loading