From 7423474ccb2939cad5ddde8ccc841e9288455149 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 5 Jul 2024 13:41:22 -0500 Subject: [PATCH 01/17] [TVMScript][Bugfix] Normalize relax::If with function's TIR var Prior to this commit, the branches of `relax::If` were normalized using `EraseToWellDefinedInScope`, using a fresh variable scope. While this had the intended behavior of preventing variables defined in a single branch from being usable outside of the conditional, it also caused the conditional's branches to treat function-scope symbolic variables as if they were undefined. This commit updates the `tvm::relax::Normalizer` so that `relax::If` is normalized within an inherited scope. This preserves the previous behavior for symbolic variables defined within a branch, but allows shapes within a branch to use symbolic variables defined outside of the branch. --- include/tvm/relax/block_builder.h | 35 ++++++++- include/tvm/relax/expr_functor.h | 21 ++++- include/tvm/script/ir_builder/relax/frame.h | 1 + src/relax/ir/block_builder.cc | 85 ++++++++++++++------- src/script/ir_builder/relax/frame.cc | 7 +- src/script/ir_builder/relax/ir.cc | 10 +-- tests/python/relax/test_tvmscript_parser.py | 46 +++++++++++ 7 files changed, 163 insertions(+), 42 deletions(-) diff --git a/include/tvm/relax/block_builder.h b/include/tvm/relax/block_builder.h index 7ca9aab6d5aa..17347fcfd84d 100644 --- a/include/tvm/relax/block_builder.h +++ b/include/tvm/relax/block_builder.h @@ -133,16 +133,47 @@ class BlockBuilderNode : public Object { * \brief Begin a new scope, with optional parameters that * are visible within the scope. * + * Symbolic variables from the parent scope are not available. + * * \param params Parameters that are visible within the scope. * * \note This function should be called when new scope is introduced - * (function, seq) to properly track the variable availability - * and help the best effort deduction. + * (e.g. function bodies) to properly track the variable + * availability and help the best effort deduction. * * \sa EndScope */ virtual void BeginScope(Optional> params) = 0; + /*! + * \brief Begin a new scope, which inherits visible parameters from + * its parent scope. + * + * Symbolic variables from the parent scope are available. + * + * \note This function should be called when an inner scope is + * introduced (e.g. conditional branches) to properly track + * the variable availability and help the best effort + * deduction. + * + * \sa EndScope + */ + virtual void BeginInnerScope() = 0; + + /*! + * \brief Append a definition to the cuurrent scope. + * + * \param Var A variable within the current scope. + * + * \note This function should be called when a new variable is + * defined that may impact struct inference (e.g. MatchCast) + * to properly track the variable availability and help the + * best effort deduction. + * + * \sa EndScope + */ + virtual void AddDefinitionToScope(Var var) = 0; + /*! \brief End the previously defined scope. */ virtual void EndScope() = 0; diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h index ce209ccd460f..c3aea24dcb50 100644 --- a/include/tvm/relax/expr_functor.h +++ b/include/tvm/relax/expr_functor.h @@ -494,7 +494,10 @@ class ExprMutator : public ExprMutatorBase { void ReEmitBinding(const VarBindingNode* binding, Expr new_value); /*! - * \brief Rewrite the expr with a new scope, used in a Function's body and the branches of If. + * \brief Rewrite the expr with a new scope, used in a Function's body. + * + * Visit an expression that may neither access variables from the + * current scope, nor may export definitions into the current scope. * * \param body_expr The body to be visited. * \param params Optional parameters that are visible within the scope. @@ -504,6 +507,22 @@ class ExprMutator : public ExprMutatorBase { */ Expr VisitWithNewScope(const Expr& body_expr, Optional> params = NullOpt); + /*! + * \brief Rewrite the expr with a new scope, used in the branches of If. + * + * Visit an expression that may access variables from the current + * scope, but may not export definitions into the current scope. + * + * \param body_expr The body to be visited. + * + * \return The expr after visiting. + * + * \sa VisitWithNewScope + * + * \note The body_expr must be an SeqExpr in the normal form. + */ + Expr VisitWithInnerScope(const Expr& body_expr); + /*! * \brief Look up the value bound to a variable. * \param var The var to be looked up. diff --git a/include/tvm/script/ir_builder/relax/frame.h b/include/tvm/script/ir_builder/relax/frame.h index 1ad681388912..0ee144f03e77 100644 --- a/include/tvm/script/ir_builder/relax/frame.h +++ b/include/tvm/script/ir_builder/relax/frame.h @@ -122,6 +122,7 @@ class FunctionFrameNode : public SeqExprFrameNode { TVM_DECLARE_FINAL_OBJECT_INFO(FunctionFrameNode, SeqExprFrameNode); public: + void EnterWithScope() final; void ExitWithScope() final; }; diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index e9a513c317d6..9d8cea6352fa 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -179,29 +179,51 @@ class BlockBuilderImpl : public BlockBuilderNode { // but can be further improved. // // TODO(relax-team): Add support for relax Var in struct info annotations. - Map shape_var_map; - for (const Var& var : params.value_or(Array())) { - const Map& var_map = StructInfoVarCollector::Collect(GetStructInfo(var)); - for (const auto& kv : var_map) { - const tir::Var& shape_var = kv.first; - const PrimExpr& shape_expr = kv.second; - auto it = shape_var_map.find(shape_var); - if (it == shape_var_map.end()) { - shape_var_map.Set(shape_var, shape_expr); - // Expose the shape variable as non-negative, for purposes - // of shape inference. In many cases, knowning that the - // shape variable is non-negative allows for simpler - // expressions for dynamic shapes. - analyzer_.MarkGlobalNonNegValue(shape_var); - } else { - const PrimExpr& old_shape_expr = (*it).second; - CHECK(analyzer_.CanProveEqual(old_shape_expr, shape_expr)) - << "Inconsistent shape var " << shape_var << " in scope: " << old_shape_expr << " vs " - << shape_expr; - } + + scope_stack_.emplace_back(ScopeFrame()); + if (params.defined()) { + for (const auto& param : params.value()) { + AddDefinitionToScope(param); + } + } + } + + void BeginInnerScope() final { + if (scope_stack_.size()) { + scope_stack_.emplace_back(scope_stack_.back()); + } else { + scope_stack_.emplace_back(ScopeFrame()); + } + } + + void AddDefinitionToScope(Var var) final { + ICHECK(scope_stack_.size()) << "Cannot add definition of " << var << " to current scope, " + << "because there is no current scope."; + auto& shape_var_map = CurrentScopeFrame()->shape_var_map; + + // The current implementation handles the collection of shape var + // defined in parameter struct info annotations. The implementation + // is correct (since we will simply erase all relax Vars in EraseToWellDefined), + // but can be further improved. + Map var_map = StructInfoVarCollector::Collect(GetStructInfo(var)); + for (const auto& kv : var_map) { + const tir::Var& shape_var = kv.first; + const PrimExpr& shape_expr = kv.second; + auto it = shape_var_map.find(shape_var); + if (it == shape_var_map.end()) { + shape_var_map.Set(shape_var, shape_expr); + // Expose the shape variable as non-negative, for purposes + // of shape inference. In many cases, knowning that the + // shape variable is non-negative allows for simpler + // expressions for dynamic shapes. + analyzer_.MarkGlobalNonNegValue(shape_var); + } else { + const PrimExpr& old_shape_expr = (*it).second; + CHECK(analyzer_.CanProveEqual(old_shape_expr, shape_expr)) + << "Inconsistent shape var " << shape_var << " in scope: " << old_shape_expr << " vs " + << shape_expr; } } - scope_stack_.emplace_back(ScopeFrame({std::move(shape_var_map)})); } void EndScope() final { scope_stack_.pop_back(); } @@ -844,15 +866,18 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor> params = NullOpt) { + if (params.defined()) { + this->BeginScope(params.value()); + } else { + this->BeginInnerScope(); + } + + Expr ret; + // SeqExpr do not need to prepare for normalization. if (expr.as()) { - this->BeginScope(params); - Expr ret = this->VisitExpr(expr); - this->EndScope(); - return ret; + ret = this->VisitExpr(expr); } else { - this->BeginScope(params); - this->BeginBindingBlock(); Expr post = this->NormalizeArgument(expr); BindingBlock prologue = this->EndBlock(); @@ -869,9 +894,11 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctorbody))); - this->EndScope(); - return seq; + ret = seq; } + + this->EndScope(); + return ret; } Array FlattenBlocks(const Array& blocks) { diff --git a/src/script/ir_builder/relax/frame.cc b/src/script/ir_builder/relax/frame.cc index 792331dda4c0..3153c0770e38 100644 --- a/src/script/ir_builder/relax/frame.cc +++ b/src/script/ir_builder/relax/frame.cc @@ -46,6 +46,11 @@ void SeqExprFrameNode::EnterWithScope() { BindingBlock()->EnterWithScope(); } +void FunctionFrameNode::EnterWithScope() { + this->block_builder->BeginScope(params); + SeqExprFrameNode::EnterWithScope(); +} + void FunctionFrameNode::ExitWithScope() { using ir::IRModuleFrame; using tvm::relax::Expr; @@ -54,7 +59,7 @@ void FunctionFrameNode::ExitWithScope() { // Step 1: Create the function. CHECK(output.defined()) << "ValueError: A Relax function must have a return value. Please use " "`return` to return an Expr"; - this->block_builder->BeginScope(params); + Expr body = this->block_builder->Normalize(tvm::relax::SeqExpr(binding_blocks, output.value())); // if the function is not private, add a global symbol to its attributes if (!is_private.value_or(Bool(false))->value && name.defined() && diff --git a/src/script/ir_builder/relax/ir.cc b/src/script/ir_builder/relax/ir.cc index 2e94ae420a97..453c7fdb5522 100644 --- a/src/script/ir_builder/relax/ir.cc +++ b/src/script/ir_builder/relax/ir.cc @@ -70,15 +70,7 @@ tvm::relax::Var Arg(const String& name, const tvm::relax::StructInfo& struct_inf FunctionFrame frame = FindFunctionFrame("R.Arg"); tvm::relax::Var var(name, struct_info); frame->params.push_back(var); - - // This constraint would normally be provided as part of - // `BlockBuilder::BeginScope`. However, because the frame and its - // scope are initialized before the arguments are known, the scope - // doesn't have access to these constraints. - auto* analyzer = frame->block_builder->GetAnalyzer(); - for (const auto& tir_var : DefinableTIRVarsInStructInfo(struct_info)) { - analyzer->MarkGlobalNonNegValue(tir_var); - } + frame->block_builder->AddDefinitionToScope(var); return var; } diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 64014d1c49be..4f41b662caf2 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -2317,5 +2317,51 @@ def expected(A: R.Tensor(["extent"])) -> R.Tensor(["extent-1"]): tvm.ir.assert_structural_equal(inferred_sinfo, expected) +def test_conditional_may_use_symbolic_variables_from_function_scope(): + """Symbolic variables from function scope may be used in branch + + This is a regression test. In earlier implementations, the + branches of `relax::If` were normalized with + `EraseToWellDefinedInScope`, using a fresh variable scope. While + this had the intended behavior of preventing variables defined in + a single branch from being usable outside of the conditional, it + also caused the conditional's branches to treat function-scope + symbolic variables as if they were undefined. + + """ + + @R.function(private=True) + def explicit_sinfo( + A: R.Tensor(["N"], "float32"), + B: R.Tensor(["N"], "float32"), + cond: R.Prim("bool"), + ) -> R.Tensor(["N"], "float32"): + + N = T.int64() + + if cond: + out: R.Tensor([N], "float32") = A + B + else: + out: R.Tensor([N], "float32") = A * B + + return out + + @R.function(private=True) + def inferred_sinfo( + A: R.Tensor(["N"], "float32"), + B: R.Tensor(["N"], "float32"), + cond: R.Prim("bool"), + ): + N = T.int64() + if cond: + out = A + B + else: + out = A * B + + return out + + tvm.ir.assert_structural_equal(explicit_sinfo, inferred_sinfo) + + if __name__ == "__main__": tvm.testing.main() From 23a5959bb813e5e2e1725704a2904410347331ee Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 5 Jul 2024 14:54:34 -0500 Subject: [PATCH 02/17] [Relax] Canonicalize known symbolic shapes in Relax expressions Prior to this commit, known constants in Relax functions would be inlined by the `CanonicalizeBindings` pass, but only if they appeared as Relax expressions (e.g. `R.const` or `R.prim_value`). Known constants that appeared as TIR variables (e.g. symbolic shapes) would be kept as dynamic parameters, even if they were known at compile time. This commit updates the `CanonicalizeBindings` pass to identify known values of symbolic shapes, and to use these known values in shape expressions. --- src/relax/ir/block_builder.cc | 10 +- src/relax/ir/expr_functor.cc | 53 +++- src/relax/transform/canonicalize_bindings.cc | 142 +++++++++- src/relax/transform/utils.h | 2 +- .../test_transform_canonicalize_bindings.py | 261 +++++++++++++++++- 5 files changed, 442 insertions(+), 26 deletions(-) diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 9d8cea6352fa..95d16c1abadf 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -197,8 +197,10 @@ class BlockBuilderImpl : public BlockBuilderNode { } void AddDefinitionToScope(Var var) final { - ICHECK(scope_stack_.size()) << "Cannot add definition of " << var << " to current scope, " - << "because there is no current scope."; + if (scope_stack_.empty()) { + return; + } + auto& shape_var_map = CurrentScopeFrame()->shape_var_map; // The current implementation handles the collection of shape var @@ -854,7 +856,9 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor Optional { diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index 63c74db7e33e..c2320de62a75 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -606,8 +606,8 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op) { Expr ExprMutator::VisitExpr_(const IfNode* op) { Expr guard = this->VisitExpr(op->cond); - Expr true_b = this->VisitWithNewScope(op->true_branch); - Expr false_b = this->VisitWithNewScope(op->false_branch); + Expr true_b = this->VisitWithInnerScope(op->true_branch); + Expr false_b = this->VisitWithInnerScope(op->false_branch); if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && op->false_branch.same_as(false_b) && VisitAndCheckStructInfoFieldUnchanged(op->struct_info_)) { @@ -696,20 +696,24 @@ void ExprMutator::VisitBinding_(const MatchCastNode* binding) { Var new_var = this->VisitVarDef(binding->var); - if (new_var.same_as(binding->var) && new_value.same_as(binding->value) && - new_struct_info.same_as(binding->struct_info)) { - // re-emit old binding if nothing changes - builder_->EmitNormalized(GetRef(binding)); - return; - } + MatchCast new_binding = [&]() -> MatchCast { + if (new_var.same_as(binding->var) && new_value.same_as(binding->value) && + new_struct_info.same_as(binding->struct_info)) { + // re-emit old binding if nothing changes + return GetRef(binding); + } else { + new_value = builder_->NormalizeArgument(new_value); + new_var = WithStructInfo(new_var, new_struct_info); - new_value = builder_->NormalizeArgument(new_value); - new_var = WithStructInfo(new_var, new_struct_info); + var_remap_[binding->var->vid] = new_var; + var_remap_[new_var->vid] = new_var; - var_remap_[binding->var->vid] = new_var; - var_remap_[new_var->vid] = new_var; + return MatchCast(new_var, new_value, new_struct_info, binding->span); + } + }(); - builder_->EmitNormalized(MatchCast(new_var, new_value, new_struct_info, binding->span)); + builder_->EmitNormalized(new_binding); + builder_->AddDefinitionToScope(new_binding->var); } BindingBlock ExprMutator::VisitBindingBlock_(const BindingBlockNode* block) { @@ -800,7 +804,30 @@ Expr ExprMutator::VisitWithNewScope(const Expr& expr, Optional> param } builder_->BeginScope(params); + // Outer scope only includes TIR variables that can be inferred from + // the function parameters. With context(builder_->GetAnalyzer(), constraint); + builder_->BeginInnerScope(); + // Inner scope also includes any TIR variables that are defined by + // MatchCast nodes, and are internal to the scope. + Expr ret = ExprFunctor::VisitExpr(expr); + builder_->EndScope(); + + // Normalization (and the resulting StructInfo inference) of the + // expr occurs outside of the body's parameters, but inside the + // function signature's scope. This keeps variables that are + // inferable based on the function signature, to allow callers to + // propagate StructInfo across the function. + ret = builder_->Normalize(ret); + builder_->EndScope(); + return ret; +} + +Expr ExprMutator::VisitWithInnerScope(const Expr& expr) { + ICHECK(expr->IsInstance()) + << "Normal form requires all new scope is stored as SeqExpr"; + + builder_->BeginInnerScope(); Expr ret = this->VisitExpr(expr); builder_->EndScope(); return ret; diff --git a/src/relax/transform/canonicalize_bindings.cc b/src/relax/transform/canonicalize_bindings.cc index 12eb81ac675d..d1a9f97337de 100644 --- a/src/relax/transform/canonicalize_bindings.cc +++ b/src/relax/transform/canonicalize_bindings.cc @@ -29,12 +29,119 @@ #include #include #include +#include namespace tvm { namespace relax { namespace { +class SymbolicVarCanonicalizer : public ExprMutator { + public: + Expr VisitExpr_(const FunctionNode* func) override { + auto cached = known_values_; + auto output = ExprMutator::VisitExpr_(func); + known_values_ = cached; + return output; + } + + void VisitBinding_(const MatchCastNode* binding) override { + auto tir_var_map = + InferSymbolicVarMap({{binding->var, binding->value}}, builder_->GetAnalyzer()); + for (const auto& [tir_var, prim_expr] : tir_var_map) { + if (auto it = known_values_.find(tir_var); it != known_values_.end()) { + CHECK(!builder_->GetAnalyzer()->CanProve(it->second.expr != prim_expr)) + << "ValueError: " + << "MatchCast statements must be consistent. " + << "However, the definition of Relax variable " << it->second.source->var + << " implies that TIR variable " << tir_var << " is " << it->second.expr + << ", while the later definition of Relax variable " << binding->var + << " instead implies that TIR variable " << tir_var << " is " << prim_expr; + } else { + known_values_[tir_var] = KnownValue{prim_expr, GetRef(binding)}; + } + } + ExprMutator::VisitBinding_(binding); + } + + Expr VisitExpr_(const IfNode* op) override { + Expr guard = this->VisitExpr(op->cond); + + auto cached = known_values_; + Expr true_b = this->VisitWithInnerScope(op->true_branch); + known_values_ = cached; + Expr false_b = this->VisitWithInnerScope(op->false_branch); + known_values_ = cached; + + if (op->cond.same_as(guard) && op->true_branch.same_as(true_b) && + op->false_branch.same_as(false_b)) { + return GetRef(op); + } + + // The two branches may have had different TIR variables inlined. + // For example, one branch has a dynamic implementation and + // produces `R.Tensor([M,N])`, while the other branch checks if + // `N==16` and produces `R.Tensor([M,16])`. After the branch, the + // output is `R.Tensor([M,N])`. However, the `GetStructLCA` would + // correctly return `R.Tensor(ndim=2)`, removing all shape + // information. + // + // Since we know the StructInfo prior to replacing TIR variables, + // this pass can provide a better StructInfo than the generic + // handling in ExprMutator, by restoring the symbolic variables + // within each branch. + auto new_sinfo = VisitExprDepStructInfoField(Downcast(op->struct_info_)); + + StructuralEqual struct_equal; + if (!struct_equal(new_sinfo, GetStructInfo(true_b))) { + auto output_var = Var("then_branch_with_dyn", new_sinfo); + + true_b = SeqExpr({BindingBlock({ + MatchCast(output_var, true_b, new_sinfo), + })}, + output_var); + } + + if (!struct_equal(new_sinfo, GetStructInfo(false_b))) { + auto output_var = Var("else_branch_with_dyn", new_sinfo); + + false_b = SeqExpr({BindingBlock({ + MatchCast(output_var, false_b, new_sinfo), + })}, + output_var); + } + + return If(guard, true_b, false_b, op->span); + } + + PrimExpr VisitPrimExpr(const PrimExpr& expr) override { + if (known_values_.empty()) { + return expr; + } + PrimExpr output = tir::Substitute(expr, [this](const tir::Var& var) -> Optional { + if (auto it = known_values_.find(var); it != known_values_.end()) { + return it->second.expr; + } else { + return NullOpt; + } + }); + if (output.same_as(expr)) { + return expr; + } + + output = builder_->GetAnalyzer()->Simplify(output); + return output; + } + + private: + struct KnownValue { + PrimExpr expr; + MatchCast source; + }; + + std::unordered_map known_values_; +}; + struct CanonicalizationPlan { Map replace_usage; Map replace_binding; @@ -377,16 +484,39 @@ class BindingCanonicalizer : public ExprMutator { }; } // namespace -Expr CanonicalizeBindings(const Expr& expr) { return BindingCanonicalizer::Apply(expr); } +Expr CanonicalizeTIRVariables(Expr expr) { return SymbolicVarCanonicalizer()(std::move(expr)); } + +Expr CanonicalizeRelaxBindings(Expr expr) { return BindingCanonicalizer::Apply(std::move(expr)); } + +Expr CanonicalizeBindings(Expr expr) { + expr = CanonicalizeTIRVariables(std::move(expr)); + expr = CanonicalizeRelaxBindings(std::move(expr)); + return expr; +} namespace transform { +Pass CanonicalizeTIRVariables() { + auto pass_func = [=](Function f, IRModule m, PassContext pc) { + return Downcast(CanonicalizeTIRVariables(f)); + }; + return CreateFunctionPass(pass_func, 1, "CanonicalizeTIRVariables", {}); +} + +Pass CanonicalizeRelaxBindings() { + auto pass_func = [=](Function f, IRModule m, PassContext pc) { + return Downcast(CanonicalizeBindings(f)); + }; + return CreateFunctionPass(pass_func, 1, "CanonicalizeRelaxBindings", {}); +} + Pass CanonicalizeBindings() { - runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(CanonicalizeBindings(f)); - }; - return CreateFunctionPass(pass_func, 1, "CanonicalizeBindings", {}); + return tvm::transform::Sequential( + { + CanonicalizeTIRVariables(), + CanonicalizeRelaxBindings(), + }, + "CanonicalizeBindings"); } TVM_REGISTER_GLOBAL("relax.transform.CanonicalizeBindings").set_body_typed(CanonicalizeBindings); diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h index 5755e118541f..932dca30a110 100644 --- a/src/relax/transform/utils.h +++ b/src/relax/transform/utils.h @@ -420,7 +420,7 @@ Expr EliminateCommonSubexpr(const Expr& expr, bool call_only = false); * * \ret The canonicalized expression */ -Expr CanonicalizeBindings(const Expr& expr); +Expr CanonicalizeBindings(Expr expr); /* \brief Remove use of trivial bindings * diff --git a/tests/python/relax/test_transform_canonicalize_bindings.py b/tests/python/relax/test_transform_canonicalize_bindings.py index d513c0cf6c6d..3255889960d4 100644 --- a/tests/python/relax/test_transform_canonicalize_bindings.py +++ b/tests/python/relax/test_transform_canonicalize_bindings.py @@ -198,9 +198,13 @@ def test_change_shape(): @I.ir_module class TestChangeShape: @R.function - def main(x: R.Tensor(("m", "n"))): + def main(x: R.Tensor(ndim=2)): y = x - # not trivial: introduces new shape vars + # The MatchCast is non-trivial, as it introduces new shape + # vars. Because the input tensor has an unknown shape + # rather than a symbolic shape, these new shape vars + # cannot be expressed in terms of previous variables. + # Therefore, the match cast must be retained. o, p = T.int64(), T.int64() z = R.match_cast(x, R.Tensor((o, p))) w = z @@ -210,7 +214,7 @@ def main(x: R.Tensor(("m", "n"))): @I.ir_module class Expected: @R.function - def main(x: R.Tensor(("m", "n"))): + def main(x: R.Tensor(ndim=2)): o, p = T.int64(), T.int64() z = R.match_cast(x, R.Tensor((o, p))) # the struct_info field on q will need to be updated @@ -220,6 +224,35 @@ def main(x: R.Tensor(("m", "n"))): verify(TestChangeShape, Expected) +def test_replace_symbolic_variable_and_remove_match_cast(): + @I.ir_module + class TestChangeShape: + @R.function + def main(x: R.Tensor(("m", "n"))): + y = x + # The MatchCast is non-trivial, as it introduces new shape + # vars. However, the new shape vars are redundant, and + # are replaced by canonicalization. After replacing the + # new shape vars, the MatchCast is trivial and may be + # removed. + o, p = T.int64(), T.int64() + z = R.match_cast(x, R.Tensor((o, p))) + w = z + q = R.add(w, y) + return R.add(q, w) + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", "n"))): + m = T.int64() + n = T.int64() + q: R.Tensor([m, n]) = R.add(x, x) + return R.add(q, x) + + verify(TestChangeShape, Expected) + + def test_unwrap_tuple(): @I.ir_module class Before: @@ -289,6 +322,228 @@ def main() -> R.Tensor((), "int32"): verify(Input, Expected) +def test_fold_variables_from_match_cast(): + """Symbolic variables in R.match_cast may be inferred + + If the argument to `R.match_cast` has known shape parameters, they + may be used to infer symbolic shape parameters. + + """ + + @I.ir_module + class Before: + @R.function + def main( + state: R.Tensor([16], dtype="float32"), + A: R.Tensor([16, 16], dtype="float32"), + B: R.Tensor([16, 16], dtype="float32"), + ): + N1 = T.int64() + M = T.int64() + N2 = T.int64() + + # The symbolic variables `N1`, `N2` and `M` are defined by + # these `R.match_cast` statements. Since the inputs have + # a known shape, the values of these symbolic variables + # may be inferred. + lhs_A = R.match_cast(A, R.Tensor([N1, M], dtype="float32")) + lhs_B = R.match_cast(B, R.Tensor([N2, M], dtype="float32")) + rhs = R.match_cast(state, R.Tensor([M], dtype="float32")) + + # The symbolic shapes propagate downstream. + lhs: R.Tensor([N1 + N2, M], "float32") = R.concat((lhs_A, lhs_B), axis=0) + proj_concat: R.Tensor([N1 + N2], "float32") = R.matmul( + lhs, rhs, out_dtype="void" + ) + proj_A = R.strided_slice( + proj_concat, + (R.prim_value(0),), + (R.prim_value(0),), + (R.prim_value(N1),), + assume_inbound=False, + ) + proj_B = R.strided_slice( + proj_concat, + [R.prim_value(0)], + [R.prim_value(N1)], + [R.prim_value(N1 + N2)], + assume_inbound=False, + ) + return (proj_A, proj_B) + + @I.ir_module + class Expected: + @R.function + def main( + state: R.Tensor([16], dtype="float32"), + A: R.Tensor([16, 16], dtype="float32"), + B: R.Tensor([16, 16], dtype="float32"), + ): + # The function no longer depends on symbolic variables. + # Shape inference is now propagated using the + # statically-known shapes. + + lhs: R.Tensor([32, 16], dtype="float32") = R.concat((A, B), axis=0) + proj_concat: R.Tensor([32], dtype="float32") = R.matmul( + lhs, state, out_dtype="void" + ) + proj_A: R.Tensor([16], dtype="float32") = R.strided_slice( + proj_concat, + [R.prim_value(0)], + [R.prim_value(0)], + [R.prim_value(16)], + assume_inbound=False, + ) + proj_B: R.Tensor([16], dtype="float32") = R.strided_slice( + proj_concat, + [R.prim_value(0)], + [R.prim_value(16)], + [R.prim_value(32)], + assume_inbound=False, + ) + return (proj_A, proj_B) + + verify(Before, Expected) + + +def test_inconsistent_match_cast_raises_error(): + """Symbolic variables from R.match_cast must be consistent + + All match cast statements must provide consistent definitions for + symbolic variables. In this test, the value of `M` would be + inferred as 16 from either `state` or `A`, but would be inferred + as 32 from `B`. + + """ + + @I.ir_module + class Before: + @R.function + def main( + state: R.Tensor([16], dtype="float32"), + A: R.Tensor([16, 16], dtype="float32"), + B: R.Tensor([32, 32], dtype="float32"), + ): + N1 = T.int64() + M = T.int64() + N2 = T.int64() + + # These R.match_cast statements define inconsistent values + # for the symbolic shape parameters. + lhs_A = R.match_cast(A, R.Tensor([N1, M], dtype="float32")) + lhs_B = R.match_cast(B, R.Tensor([N2, M], dtype="float32")) + rhs = R.match_cast(state, R.Tensor([M], dtype="float32")) + + lhs: R.Tensor([N1 + N2, M], "float32") = R.concat((lhs_A, lhs_B), axis=0) + proj_concat: R.Tensor([N1 + N2], "float32") = R.matmul( + lhs, rhs, out_dtype="void" + ) + proj_A = R.strided_slice( + proj_concat, + (R.prim_value(0),), + (R.prim_value(0),), + (R.prim_value(N1),), + assume_inbound=False, + ) + proj_B = R.strided_slice( + proj_concat, + [R.prim_value(0)], + [R.prim_value(N1)], + [R.prim_value(N1 + N2)], + assume_inbound=False, + ) + return (proj_A, proj_B) + + with pytest.raises(ValueError, match="MatchCast statements must be consistent"): + CanonicalizeBindings()(Before) + + +def test_match_cast_may_have_distinct_values_in_branches(): + """Conditional branches may have different values of symbolic variables + + Here, the value of `N` can be inferred as 16 within the `if` + branch and as 32 within the `else` branch. + + """ + + @I.ir_module + class Before: + @R.function + def main( + state: R.Tensor(["N"], dtype="float32"), + A: R.Tensor(["M", 16], dtype="float32"), + B: R.Tensor(["M", 32], dtype="float32"), + scale: R.Prim("float32"), + ): + N = T.int64() + M = T.int64() + + if N == 16: + weights: R.Tensor([M, 16], "float32") = A * scale + weights: R.Tensor([M, N], "float32") = R.match_cast( + weights, R.Tensor([M, N], "float32") + ) + weights: R.Tensor([M, N], "float32") = weights * scale + else: + weights: R.Tensor([M, 32], "float32") = B * scale + weights: R.Tensor([M, N], "float32") = R.match_cast( + weights, R.Tensor([M, N], "float32") + ) + weights: R.Tensor([M, N], "float32") = weights * scale + + weights: R.Tensor([M, N], "float32") = weights * scale + + out: R.Tensor([M], "float32") = R.matmul(weights, state) + + return out + + @I.ir_module + class Expected: + @R.function + def main( + state: R.Tensor(["N"], dtype="float32"), + A: R.Tensor(["M", 16], dtype="float32"), + B: R.Tensor(["M", 32], dtype="float32"), + scale: R.Prim("float32"), + ): + N = T.int64() + M = T.int64() + + if N == 16: + # Prior to the R.match_cast, the + weights: R.Tensor([M, 16], "float32") = A * scale + # The scaled weights within the branch may perform + # shape inference knowing that N==16. + weights: R.Tensor([M, 16], "float32") = weights * scale + # The match cast on exiting the if branch restores the + weights = R.match_cast(weights, R.Tensor([M, N], "float32")) + else: + # Prior to the R.match_cast, the + weights: R.Tensor([M, 32], "float32") = B * scale + # Within the else-branch, the R.match_cast implies + # that N==32. While this conflicts with the earlier + # definition, the two occur in separate branches, so + # this is legal. + # The scaled weights within the branch may perform + # shape inference knowing that N==32. + weights: R.Tensor([M, 32], "float32") = weights * scale + weights = R.match_cast(weights, R.Tensor([M, N], "float32")) + + # Outside of the conditional, we no longer have a known + # value for N, so this shape inference must be done using + # a dynamic shape for `N`. + weights: R.Tensor([M, N], "float32") = weights * scale + + # After the conditional branch, we no longer have a known + # value of N, so this shape inference must use the dynamic + # shape. + out: R.Tensor([M], "float32") = R.matmul(weights, state) + + return out + + verify(Before, Expected) + + def test_multiple_outputs(): @I.ir_module class Input: From 486bfca9836ec8a452e5c3848eafd443b4c98330 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 28 Jun 2024 14:59:47 -0500 Subject: [PATCH 03/17] [Relax][Refactor] Reorganize pattern-matching A follow-up to https://github.com/apache/tvm/pull/16730. Now that the implementations for `rewrite_call` and `rewrite_bindings` are in separate classes, they can be further split out into separate files. --- src/relax/ir/dataflow_block_rewriter.cc | 489 +++++++++++++ src/relax/ir/dataflow_expr_rewriter.cc | 222 ++++++ src/relax/ir/dataflow_matcher.cc | 663 +----------------- ...flow_matcher_impl.h => dataflow_matcher.h} | 3 + 4 files changed, 733 insertions(+), 644 deletions(-) create mode 100644 src/relax/ir/dataflow_block_rewriter.cc create mode 100644 src/relax/ir/dataflow_expr_rewriter.cc rename src/relax/ir/{dataflow_matcher_impl.h => dataflow_matcher.h} (97%) diff --git a/src/relax/ir/dataflow_block_rewriter.cc b/src/relax/ir/dataflow_block_rewriter.cc new file mode 100644 index 000000000000..c5af0493a54c --- /dev/null +++ b/src/relax/ir/dataflow_block_rewriter.cc @@ -0,0 +1,489 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/ir/dataflow_block_rewriter.cc + * \brief A transform to match a Relax DataflowBlock and rewrite + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "dataflow_matcher.h" + +namespace tvm { +namespace relax { + +class MatcherUseDefAnalysis : public relax::ExprVisitor { + public: + std::vector vars; + std::map> def2use; + // caller -> callee table. + std::map> caller2callees; + + const VarNode* cur_user_; + + void VisitBinding_(const VarBindingNode* binding) override { + // init + cur_user_ = binding->var.get(); + this->VisitVarDef(binding->var); + this->VisitExpr(binding->value); + cur_user_ = nullptr; + } + + void VisitExpr_(const VarNode* op) override { + if (nullptr == cur_user_) return; + + auto check_and_push = [](std::vector& vec, const VarNode* var) { + if (std::find(vec.begin(), vec.end(), var) == vec.end()) { + vec.push_back(var); + } + }; + + check_and_push(def2use[op], cur_user_); + check_and_push(vars, op); + + caller2callees[cur_user_].push_back(op); + } +}; + +struct PNode { + const DFPatternNode* ptr; + std::vector&>> children; + std::vector&>> parents; +}; + +struct RNode { + const VarNode* ptr; + std::vector children; + std::vector parents; +}; + +struct MatchState { + void add(const PNode* p, const RNode* r) { + match_p_r[p] = r; + match_r_p[r] = p; + } + + void add(const DFConstraintNode* constraint) { validated_constraints_.insert(constraint); } + + void add(MatchState&& other) { + match_p_r.merge(std::move(other.match_p_r)); + match_r_p.merge(std::move(other.match_r_p)); + validated_constraints_.merge(other.validated_constraints_); + } + + const VarNode* matched(const PNode* p) const { + if (auto it = match_p_r.find(p); it != match_p_r.end()) { + return it->second->ptr; + } + return nullptr; + } + + const DFPatternNode* matched(const RNode* r) const { + if (auto it = match_r_p.find(r); it != match_r_p.end()) { + return it->second->ptr; + } + return nullptr; + } + + const VarNode* matched(const PNode& p) const { return matched(&p); } + const DFPatternNode* matched(const RNode& r) const { return matched(&r); } + + bool is_validated(const DFConstraintNode* constraint) const { + return validated_constraints_.count(constraint); + } + + private: + std::unordered_map match_p_r; + std::unordered_map match_r_p; + std::unordered_set validated_constraints_; +}; + +/** + * \brief This method try to match a real node and a pattern node along with its neighbors. + */ +static std::optional TryMatch(const PNode& p, const RNode& r, + const MatchState& current_match, DFPatternMatcher* m, + const MatcherUseDefAnalysis& ud_analysis) { + if (!m->Match(GetRef(p.ptr), GetRef(r.ptr))) return std::nullopt; + + MatchState new_match; + + new_match.add(&p, &r); + + // forward matching; + for (const auto& [pchild, constraints] : p.children) { + bool any_cons_sat = false; + for (const auto& rchild : r.children) { + if (new_match.matched(rchild)) { + // The child variable is already matched to other child pattern in a previous iteration. + continue; + } + if (auto v = current_match.matched(pchild); v && v != rchild->ptr) { + // The child pattern is already matched to other variable in a earlier call to TryMatch. + continue; + } + + const auto& uses = ud_analysis.def2use.at(r.ptr); + + // check edge constraints. + bool all_cons_pass = true; + for (const auto& cons : constraints) { + if (cons.type == PairCons::kOnlyUsedBy && uses.size() != 1) { + all_cons_pass = false; + break; + } + + if (cons.index != -1) { + const auto& callees = ud_analysis.caller2callees.at(rchild->ptr); + if (callees.size() <= static_cast(cons.index) || callees[cons.index] != r.ptr) { + all_cons_pass = false; + break; + } + } + } + if (!all_cons_pass || new_match.matched(pchild)) continue; + any_cons_sat = true; + + if (auto match_rec = TryMatch(*pchild, *rchild, current_match, m, ud_analysis)) { + new_match.add(pchild, rchild); + new_match.add(std::move(*match_rec)); + } + } + if (!new_match.matched(pchild) || !any_cons_sat) return std::nullopt; + } + + return new_match; +} + +static std::optional TryValidate( + const MatchState& current_match, + const std::unordered_map& pattern2node, + const std::vector& validation_constraints, arith::Analyzer* analyzer) { + MatchState new_match; + + std::function(const DFPatternNode*)> query_match_state = + [&pattern2node, ¤t_match](const DFPatternNode* pattern) -> Optional { + auto it = pattern2node.find(pattern); + ICHECK(it != pattern2node.end()) + << "DFConstraint attempted to access DFPattern " << GetRef(pattern) + << ", which does not appear in the PatternContext"; + const auto& p_node = it->second; + if (auto ptr = current_match.matched(p_node)) { + return GetRef(ptr); + } else { + return NullOpt; + } + }; + + for (const auto& constraint : validation_constraints) { + if (!current_match.is_validated(constraint.get())) { + auto [necessary_condition, is_sufficient] = constraint->AsPrimExpr(query_match_state); + + necessary_condition = analyzer->Simplify(necessary_condition); + const auto* known = tir::as_const_int(necessary_condition); + + if (known && *known && is_sufficient) { + // The condition passes, and the expression provided is both + // necessary and sufficient for the constraint to pass. Mark + // the constraint as passing, to avoid re-checking it unless + // we backtrack. + new_match.add(constraint.get()); + } else if (known && !*known) { + // The condition fails. Even if additional information would + // be required to pass a constraint, it may bail out early as + // a failure (e.g. shape mismatch in the first two items out + // of N shapes that must all match). + return std::nullopt; + } else if (is_sufficient) { + // The condition depends on dynamic parameters. In the + // future, this may be exposed to the user as a condition for + // optimization, or can be combined with the conditions + // provided from other constraints. + return std::nullopt; + } + } + } + + return new_match; +} + +static std::optional MatchTree( + const MatchState& current_match, size_t current_root_idx, + const std::unordered_map& pattern2node, + const std::unordered_map& var2node, DFPatternMatcher* matcher, + const std::vector& roots, const std::vector& validation_constraints, + const MatcherUseDefAnalysis& ud_analysis, arith::Analyzer* analyzer) { + auto get_next_root = [&](size_t root_idx) -> const PNode* { + // Look for the next unmatched root node. + for (; root_idx < roots.size(); ++root_idx) { + const auto& root = pattern2node.at(roots[root_idx].get()); + if (!current_match.matched(root)) { + return &root; + } + } + return nullptr; + }; + + const auto root = get_next_root(current_root_idx); + + if (!root) { + // All root nodes have been matched + return current_match; + } + + MatchState new_match = current_match; + + for (const auto& var : ud_analysis.vars) { + const RNode& r_node = var2node.at(var); + if (new_match.matched(r_node)) continue; + if (auto match = TryMatch(*root, r_node, new_match, matcher, ud_analysis)) { + // Recursively try to match the next subtree. + new_match.add(std::move(*match)); + if (auto validation = + TryValidate(new_match, pattern2node, validation_constraints, analyzer)) { + new_match.add(std::move(*validation)); + if (auto match_rec = + MatchTree(new_match, current_root_idx + 1, pattern2node, var2node, matcher, roots, + validation_constraints, ud_analysis, analyzer)) { + new_match.add(std::move(*match_rec)); + return new_match; + } + } + // Recursive matching has failed, backtrack. + new_match = current_match; + continue; + } + } + + return std::nullopt; +} + +Optional> MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb, + const Map& bindings) { + // TODO(@ganler): Handle non-may external use. + ICHECK(ctx->allow_extern_use == PatternContextNode::kMay) << "Only kMay is supported yet."; + DFPatternMatcher matcher(bindings); + + MatcherUseDefAnalysis ud_analysis; + ud_analysis.VisitBindingBlock_(dfb.get()); + + // First construct a graph of PNode and RNode. + std::unordered_map var2node; + var2node.reserve(dfb->bindings.size()); + + for (const VarNode* cur_var : ud_analysis.vars) { + const auto& uses = ud_analysis.def2use.at(cur_var); + RNode& cur_node = var2node[cur_var]; + cur_node.ptr = cur_var; + for (const VarNode* use : uses) { + auto& use_node = var2node[use]; + use_node.ptr = use; + cur_node.children.push_back(&use_node); + use_node.parents.push_back(&cur_node); + } + } + + std::unordered_map pattern2node; + pattern2node.reserve(ctx->edge_constraints.size()); + + for (const auto& def_pattern : ctx->src_ordered) { + PNode& def_node = pattern2node[def_pattern.get()]; + const auto& uses = ctx->edge_constraints.at(def_pattern); + def_node.ptr = def_pattern.get(); + def_node.children.reserve(uses.size()); + for (const auto& [use_pattern, cons] : uses) { + PNode& use_node = pattern2node[use_pattern.get()]; + use_node.ptr = use_pattern.get(); + use_node.parents.emplace_back(&def_node, std::ref(cons)); + def_node.children.emplace_back(&use_node, std::ref(cons)); + } + } + + std::vector roots; + for (const auto& pat : ctx->src_ordered) { + if (pattern2node[pat.get()].parents.empty()) { + roots.push_back(pat); + } + } + + if (roots.empty()) { + return NullOpt; + } + + arith::Analyzer analyzer; + auto match = MatchTree({}, 0, pattern2node, var2node, &matcher, roots, + ctx->validation_constraints, ud_analysis, &analyzer); + if (!match) { + return NullOpt; + } + + Map ret; + for (const auto& [pat, p_node] : pattern2node) { + ICHECK(match->matched(p_node)); + ret.Set(GetRef(pat), GetRef(match->matched(p_node))); + } + return ret; +} + +Optional> MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb) { + return MatchGraph(ctx, dfb, AnalyzeVar2Value(dfb)); +} + +TVM_REGISTER_GLOBAL("relax.dpl.match_dfb") + .set_body_typed([](const PatternContext& ctx, const DataflowBlock& dfb) { + return MatchGraph(ctx, dfb); + }); + +/*! + * \brief Apply pattern matching to each dataflow block, replacing matches + * with the output of a user-provided rewriter function. + */ +class BlockPatternRewriter : ExprMutator { + public: + using ExprMutator::VisitBindingBlock_; + using ExprMutator::VisitExpr_; + + BlockPatternRewriter( + const PatternContext& ctx, + TypedPackedFunc(Map, Map)> rewriter_func) + : ctx_(ctx), rewriter_func_(rewriter_func) {} + + template + static Function Run( + PatternType pat, + TypedPackedFunc(Map, Map)> rewriter_func, + Function func) { + BlockPatternRewriter rewriter(pat, rewriter_func); + + func = Downcast(rewriter(func)); + func = Downcast(RemoveAllUnused(func)); + return func; + } + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block_node) override { + return RewriteDataflowBlockFixedPoint(GetRef(block_node)); + } + + private: + void EmitUsedVars(Expr val, const Array& pending_bindings, + std::unordered_set* emitted_vars) { + std::unordered_set unemitted_vars; + PostOrderVisit(val, [=, &unemitted_vars](Expr e) { + if (auto v = e.as(); v && !emitted_vars->count(v)) { + unemitted_vars.insert(v); + } + }); + + if (unemitted_vars.empty()) { + return; + } + + size_t num_unemitted = unemitted_vars.size(); + for (size_t i = 0; i < pending_bindings.size(); ++i) { + const auto& binding = pending_bindings[i]; + if (auto var_bind = binding.as(); + var_bind && unemitted_vars.count(var_bind->var.get())) { + // var_bind->value may also depend on other unemitted vars in this range + Array prev_bindings(pending_bindings.begin(), pending_bindings.begin() + i); + EmitUsedVars(var_bind->value, prev_bindings, emitted_vars); + this->VisitBinding(binding); + emitted_vars->insert(var_bind->var.get()); + if (--num_unemitted == 0) { + return; + } + } + } + } + + // Repeat until all matchable subsets of bindings are rewritten. + BindingBlock RewriteDataflowBlockFixedPoint(BindingBlock block) { + auto df_block = Downcast(block); + Map bindings = AnalyzeVar2Value(df_block); + if (auto matches = MatchGraph(ctx_, df_block, bindings)) { + builder_->BeginDataflowBlock(); + Map replacements = rewriter_func_(matches.value(), bindings); + + std::unordered_set emitted_vars; + + bool changed = false; + for (size_t i = 0; i < block->bindings.size(); ++i) { + const auto& binding = block->bindings[i]; + if (auto var_bind = binding.as()) { + if (auto new_val = replacements.Get(var_bind->var).value_or(var_bind->value); + !StructuralEqual()(var_bind->value, new_val)) { + Array pending_bindings(block->bindings.begin() + i + 1, block->bindings.end()); + // Make sure there is no unbound variable used in the new value before it is emitted + EmitUsedVars(new_val, pending_bindings, &emitted_vars); + this->ReEmitBinding(var_bind, builder_->Normalize(new_val)); + changed = true; + } else if (!emitted_vars.count(var_bind->var.get())) { + this->VisitBinding(binding); + emitted_vars.insert(var_bind->var.get()); + } + } else { + this->VisitBinding(binding); + } + } + + auto new_block = builder_->EndBlock(); + + if (!changed) return new_block; + return RewriteDataflowBlockFixedPoint(new_block); + } + return block; + } + + /*! \brief The pattern constraint contexts for rewriting dataflow blocks */ + PatternContext ctx_; + /*! + * \brief The user-provided rewriter function. Its signature and semantics are: + * + * - (Map, Map) -> Map + * + * Given the map of patterns and corresponding variables (bound + * variables or parameters), it should return a map that + * specifies new values for matched bound variables. It can refer + * to the passed bindings to create the replacement expressions. + */ + TypedPackedFunc(Map, Map)> rewriter_func_; +}; + +Function RewriteBindings( + const PatternContext& ctx, + TypedPackedFunc(Map, Map)> rewriter, Function func) { + return BlockPatternRewriter::Run(ctx, rewriter, func); +} + +TVM_REGISTER_GLOBAL("relax.dpl.rewrite_bindings").set_body_typed(RewriteBindings); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/dataflow_expr_rewriter.cc b/src/relax/ir/dataflow_expr_rewriter.cc new file mode 100644 index 000000000000..4793d1d75a30 --- /dev/null +++ b/src/relax/ir/dataflow_expr_rewriter.cc @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/ir/dataflow_expr_rewriter.cc + * \brief A transform to match a Relax Expr and rewrite + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../transform/utils.h" +#include "dataflow_matcher.h" + +namespace tvm { +namespace relax { + +Optional> ExtractMatchedExpr(DFPattern pattern, Expr expr, + Optional> bindings_opt) { + auto bindings = bindings_opt.value_or({}); + DFPatternMatcher matcher(bindings); + + if (!matcher.Match(pattern, expr)) { + return NullOpt; + } + + Map matching; + for (const auto& [pat, matches] : matcher.GetMemo()) { + ICHECK_EQ(matches.size(), 1) << "More than one match for the pattern " << pat; + matching.Set(pat, matches[0]); + } + return matching; +} + +TVM_REGISTER_GLOBAL("relax.dpl.extract_matched_expr").set_body_typed(ExtractMatchedExpr); + +bool MatchExpr(DFPattern pattern, Expr expr, Optional> bindings_opt) { + return static_cast(ExtractMatchedExpr(pattern, expr, bindings_opt)); +} + +TVM_REGISTER_GLOBAL("relax.dpl.match_expr").set_body_typed(MatchExpr); + +/*! + * \brief Apply pattern matching to each expression, replacing + * matches with the output of a user-provided rewriter function. + */ +class ExprPatternRewriter : ExprMutator { + public: + using ExprMutator::VisitExpr_; + + ExprPatternRewriter(DFPattern pat, + TypedPackedFunc)> rewriter_func) + : pattern_(pat), rewriter_func_(rewriter_func) {} + + template + static Function Run(PatternType pat, + TypedPackedFunc)> rewriter_func, + Function func) { + ExprPatternRewriter rewriter(pat, rewriter_func); + func = Downcast(rewriter(func)); + func = Downcast(RemoveAllUnused(func)); + return func; + } + + Expr VisitExpr_(const SeqExprNode* seq) override { + auto cache = bindings_; + SeqExpr prev = GetRef(seq); + + StructuralEqual struct_equal; + + while (true) { + SeqExpr next = Downcast(builder_->Normalize(ExprMutator::VisitExpr_(prev.get()))); + if (struct_equal(prev, next)) { + return std::move(next); + } + + // Canonicalization may result in two previously-different + // expressions being recognized as identical. Elimination of + // common subexpressions may result in trival var-to-var + // bindings that can be canonicalized. Therefore, iterate the + // simplification steps until converged. + while (true) { + auto start_of_loop = next; + next = Downcast(CanonicalizeBindings(next)); + next = Downcast(EliminateCommonSubexpr(next)); + next = Downcast(RemoveAllUnused(next)); + if (struct_equal(start_of_loop, next)) { + break; + } + } + + if (struct_equal(prev, next)) { + return std::move(next); + } + + // Reset all knowledge of bindings that were collected from + // this SeqExpr. The collected bindings are only after + // the point where they were collected, and we are repeating + // the mutation of this SeqExpr. + bindings_ = cache; + prev = next; + } + } + + void VisitBinding_(const VarBindingNode* binding) override { + auto expr = VisitExpr(binding->value); + bindings_.Set(binding->var, expr); + ReEmitBinding(binding, expr); + } + + Expr VisitExpr(const Expr& expr) override { + auto node = ExprMutator::VisitExpr(expr); + + std::vector matches_top_level; + if (auto rewritten = TryRewrite(node, pattern_, &matches_top_level)) { + return builder_->Normalize(rewritten.value()); + } + + return node; + } + + private: + Optional TryRewrite(const Expr& expr, const DFPattern& pattern, + std::vector* matches_top_level) { + ICHECK(matches_top_level); + + // Special handling if the user-supplied pattern is a `OrPattern`. + // While the `ExtractMatchedExpr` can handle matching the + // `OrPattern`, it will return on the first match, even if the + // `rewriter_func_` doesn't apply a replacement. Unpacking the + // `OrPattern` here allows the match to be resumed if + // `rewriter_func_` returns the original function unmodified. + // This is only valid for a top-level match. + if (auto or_pattern = pattern.as()) { + matches_top_level->push_back(pattern); + Optional output = TryRewrite(expr, or_pattern->left, matches_top_level); + if (!output.defined()) { + output = TryRewrite(expr, or_pattern->right, matches_top_level); + } + matches_top_level->pop_back(); + return output; + } + + if (auto opt_matches = ExtractMatchedExpr(pattern, expr, bindings_)) { + auto matches = opt_matches.value(); + + // Append any additional matches that from the unwrapped + // `OrPattern`. When matching against `pat = pat_lhs | + // pat_rhs`, we call `ExtractMatchedExpr` on `pat_lhs` and + // `pat_rhs` separately. The top-level `pat` is never seen by + // `ExtractMatchedExpr`, and must be re-added afterward. + if (matches_top_level->size()) { + auto matched_expr = DFPatternMatcher::UnwrapBindings(expr, bindings_); + for (const auto& pat : *matches_top_level) { + matches.Set(pat, matched_expr); + } + } + + Expr rewritten_expr = rewriter_func_(expr, matches); + if (!rewritten_expr.same_as(expr)) { + return builder_->Normalize(rewritten_expr); + } + } + + return NullOpt; + } + + /*! \brief The pattern for rewriting call nodes */ + DFPattern pattern_; + /*! + * \brief The user-provided rewriter function. Its signature and semantics are: + * + * - (Call, Map) -> Call + * + * Given the matched call node and the map of patterns and + * matched expressions, it should return a new call node to + * replace the original one or the original matched call node as + * is. + */ + TypedPackedFunc)> rewriter_func_; + + /*! \brief The known variable bindings + * + * The variable bindings whose value is known. This must be tracked + * separately from the block builder, so that it can be reset after + * each iteration of the mutate-until-converged loop applied to + * `SeqExpr`. + */ + Map bindings_; +}; + +Function RewriteCall(const DFPattern& pat, + TypedPackedFunc)> rewriter, Function func) { + return ExprPatternRewriter::Run(pat, rewriter, func); +} + +TVM_REGISTER_GLOBAL("relax.dpl.rewrite_call").set_body_typed(RewriteCall); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index c0b8d1e1df08..989c1174f41d 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -22,6 +22,8 @@ * \brief The dataflow pattern matcher for Relax. */ +#include "dataflow_matcher.h" + #include #include #include @@ -45,7 +47,6 @@ #include "../../arith/constraint_extract.h" #include "../transform/utils.h" -#include "dataflow_matcher_impl.h" namespace tvm { namespace relax { @@ -59,7 +60,7 @@ bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) { return VisitDFPattern(pattern, expr); } -static Expr TryGetValOfVar(Expr expr, const Map& var2val) { +Expr DFPatternMatcher::UnwrapBindings(Expr expr, const Map& var2val) { auto unwrap = [&](Expr expr) -> Optional { // Unwrap variables into the value to which they are bound. if (var2val.size()) { @@ -98,7 +99,7 @@ void DFPatternMatcher::ClearMap(size_t watermark) { bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr0) { CHECK(pattern.defined()) << "Null pattern found when matching against " << expr0; - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); if (memoize_ && memo_.count(pattern)) { ICHECK_EQ(memo_[pattern].size(), 1); return expr.same_as(memo_[pattern][0]); @@ -118,17 +119,17 @@ bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr } bool DFPatternMatcher::VisitDFPattern_(const OrPatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr); } bool DFPatternMatcher::VisitDFPattern_(const AndPatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); return VisitDFPattern(op->left, expr) && VisitDFPattern(op->right, expr); } bool DFPatternMatcher::VisitDFPattern_(const NotPatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); return !VisitDFPattern(op->reject, expr); } @@ -183,7 +184,7 @@ bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) { } bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); bool matches = VisitDFPattern(attr_pattern->pattern, expr); if (!matches) return matches; VLOG(1) << "considering AttrPatternNode at:\n" << expr; @@ -241,7 +242,7 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons } bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); // utilities auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* { if (op) { @@ -351,12 +352,12 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex } bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); return StructuralEqual()(op->expr, expr); } bool DFPatternMatcher::VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); bool matches = false; if (const auto* func = expr.as()) { matches = true; @@ -379,7 +380,7 @@ bool DFPatternMatcher::VisitDFPattern_(const FunctionPatternNode* op, const Expr } bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); if (const auto* tuple_get_item_node = expr.as()) { return (op->index == -1 || op->index == tuple_get_item_node->index) && VisitDFPattern(op->tuple, tuple_get_item_node->tuple); @@ -388,7 +389,7 @@ bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const } bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); bool matches = false; if (const auto* tuple_node = expr.as()) { matches = true; @@ -429,7 +430,7 @@ bool DFPatternMatcher::TryUnorderedMatch(size_t idx, const tvm::Array } bool DFPatternMatcher::VisitDFPattern_(const UnorderedTuplePatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); if (const auto* tuple_node = expr.as()) { if (op->fields.size() == tuple_node->fields.size()) { @@ -449,7 +450,7 @@ bool DFPatternMatcher::VisitDFPattern_(const StructInfoPatternNode* op, const Ex return false; } - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); auto expr_struct_info = GetStructInfo(expr); PrimExpr new_constraint = StructInfoBaseCheckPrecondition(op->struct_info, expr_struct_info); @@ -497,7 +498,7 @@ PrimExpr DFPatternMatcher::SimplifyCondition(PrimExpr condition) { } bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); auto expr_type = expr.as()->checked_type(); return (StructuralEqual()(op->type, expr_type)) && VisitDFPattern(op->pattern, expr); } @@ -584,7 +585,7 @@ std::tuple SameShapeConstraintNode::AsPrimExpr( } bool DFPatternMatcher::VisitDFPattern_(const PrimArrPatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); if (const ShapeExprNode* shape_expr = expr.as()) return ShapeEqual(&analyzer_, op->fields, shape_expr->values); return false; @@ -609,7 +610,7 @@ bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& exp } bool DFPatternMatcher::VisitDFPattern_(const ExternFuncPatternNode* op, const Expr& expr0) { - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); if (const auto* extern_fn = expr.as()) { return "" == op->global_symbol() || op->global_symbol() == extern_fn->global_symbol; } @@ -618,7 +619,7 @@ bool DFPatternMatcher::VisitDFPattern_(const ExternFuncPatternNode* op, const Ex bool DFPatternMatcher::VisitDFPattern_(const ConstantPatternNode* op, const Expr& expr0) { // constants can be binded to relax.Var as well. - auto expr = TryGetValOfVar(expr0, var2val_); + auto expr = UnwrapBindings(expr0, var2val_); return expr.as() != nullptr; } @@ -642,631 +643,5 @@ bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr return true; } -Optional> ExtractMatchedExpr(DFPattern pattern, Expr expr, - Optional> bindings_opt) { - auto bindings = bindings_opt.value_or({}); - DFPatternMatcher matcher(bindings); - - if (!matcher.Match(pattern, expr)) { - return NullOpt; - } - - Map matching; - for (const auto& [pat, matches] : matcher.GetMemo()) { - ICHECK_EQ(matches.size(), 1) << "More than one match for the pattern " << pat; - matching.Set(pat, matches[0]); - } - return matching; -} - -TVM_REGISTER_GLOBAL("relax.dpl.extract_matched_expr").set_body_typed(ExtractMatchedExpr); - -bool MatchExpr(DFPattern pattern, Expr expr, Optional> bindings_opt) { - return static_cast(ExtractMatchedExpr(pattern, expr, bindings_opt)); -} - -TVM_REGISTER_GLOBAL("relax.dpl.match_expr").set_body_typed(MatchExpr); - -class MatcherUseDefAnalysis : public relax::ExprVisitor { - public: - std::vector vars; - std::map> def2use; - // caller -> callee table. - std::map> caller2callees; - - const VarNode* cur_user_; - - void VisitBinding_(const VarBindingNode* binding) override { - // init - cur_user_ = binding->var.get(); - this->VisitVarDef(binding->var); - this->VisitExpr(binding->value); - cur_user_ = nullptr; - } - - void VisitExpr_(const VarNode* op) override { - if (nullptr == cur_user_) return; - - auto check_and_push = [](std::vector& vec, const VarNode* var) { - if (std::find(vec.begin(), vec.end(), var) == vec.end()) { - vec.push_back(var); - } - }; - - check_and_push(def2use[op], cur_user_); - check_and_push(vars, op); - - caller2callees[cur_user_].push_back(op); - } -}; - -struct PNode { - const DFPatternNode* ptr; - std::vector&>> children; - std::vector&>> parents; -}; - -struct RNode { - const VarNode* ptr; - std::vector children; - std::vector parents; -}; - -struct MatchState { - void add(const PNode* p, const RNode* r) { - match_p_r[p] = r; - match_r_p[r] = p; - } - - void add(const DFConstraintNode* constraint) { validated_constraints_.insert(constraint); } - - void add(MatchState&& other) { - match_p_r.merge(std::move(other.match_p_r)); - match_r_p.merge(std::move(other.match_r_p)); - validated_constraints_.merge(other.validated_constraints_); - } - - const VarNode* matched(const PNode* p) const { - if (auto it = match_p_r.find(p); it != match_p_r.end()) { - return it->second->ptr; - } - return nullptr; - } - - const DFPatternNode* matched(const RNode* r) const { - if (auto it = match_r_p.find(r); it != match_r_p.end()) { - return it->second->ptr; - } - return nullptr; - } - - const VarNode* matched(const PNode& p) const { return matched(&p); } - const DFPatternNode* matched(const RNode& r) const { return matched(&r); } - - bool is_validated(const DFConstraintNode* constraint) const { - return validated_constraints_.count(constraint); - } - - private: - std::unordered_map match_p_r; - std::unordered_map match_r_p; - std::unordered_set validated_constraints_; -}; - -/** - * \brief This method try to match a real node and a pattern node along with its neighbors. - */ -static std::optional TryMatch(const PNode& p, const RNode& r, - const MatchState& current_match, DFPatternMatcher* m, - const MatcherUseDefAnalysis& ud_analysis) { - if (!m->Match(GetRef(p.ptr), GetRef(r.ptr))) return std::nullopt; - - MatchState new_match; - - new_match.add(&p, &r); - - // forward matching; - for (const auto& [pchild, constraints] : p.children) { - bool any_cons_sat = false; - for (const auto& rchild : r.children) { - if (new_match.matched(rchild)) { - // The child variable is already matched to other child pattern in a previous iteration. - continue; - } - if (auto v = current_match.matched(pchild); v && v != rchild->ptr) { - // The child pattern is already matched to other variable in a earlier call to TryMatch. - continue; - } - - const auto& uses = ud_analysis.def2use.at(r.ptr); - - // check edge constraints. - bool all_cons_pass = true; - for (const auto& cons : constraints) { - if (cons.type == PairCons::kOnlyUsedBy && uses.size() != 1) { - all_cons_pass = false; - break; - } - - if (cons.index != -1) { - const auto& callees = ud_analysis.caller2callees.at(rchild->ptr); - if (callees.size() <= static_cast(cons.index) || callees[cons.index] != r.ptr) { - all_cons_pass = false; - break; - } - } - } - if (!all_cons_pass || new_match.matched(pchild)) continue; - any_cons_sat = true; - - if (auto match_rec = TryMatch(*pchild, *rchild, current_match, m, ud_analysis)) { - new_match.add(pchild, rchild); - new_match.add(std::move(*match_rec)); - } - } - if (!new_match.matched(pchild) || !any_cons_sat) return std::nullopt; - } - - return new_match; -} - -static std::optional TryValidate( - const MatchState& current_match, - const std::unordered_map& pattern2node, - const std::vector& validation_constraints, arith::Analyzer* analyzer) { - MatchState new_match; - - std::function(const DFPatternNode*)> query_match_state = - [&pattern2node, ¤t_match](const DFPatternNode* pattern) -> Optional { - auto it = pattern2node.find(pattern); - ICHECK(it != pattern2node.end()) - << "DFConstraint attempted to access DFPattern " << GetRef(pattern) - << ", which does not appear in the PatternContext"; - const auto& p_node = it->second; - if (auto ptr = current_match.matched(p_node)) { - return GetRef(ptr); - } else { - return NullOpt; - } - }; - - for (const auto& constraint : validation_constraints) { - if (!current_match.is_validated(constraint.get())) { - auto [necessary_condition, is_sufficient] = constraint->AsPrimExpr(query_match_state); - - necessary_condition = analyzer->Simplify(necessary_condition); - const auto* known = tir::as_const_int(necessary_condition); - - if (known && *known && is_sufficient) { - // The condition passes, and the expression provided is both - // necessary and sufficient for the constraint to pass. Mark - // the constraint as passing, to avoid re-checking it unless - // we backtrack. - new_match.add(constraint.get()); - } else if (known && !*known) { - // The condition fails. Even if additional information would - // be required to pass a constraint, it may bail out early as - // a failure (e.g. shape mismatch in the first two items out - // of N shapes that must all match). - return std::nullopt; - } else if (is_sufficient) { - // The condition depends on dynamic parameters. In the - // future, this may be exposed to the user as a condition for - // optimization, or can be combined with the conditions - // provided from other constraints. - return std::nullopt; - } - } - } - - return new_match; -} - -static std::optional MatchTree( - const MatchState& current_match, size_t current_root_idx, - const std::unordered_map& pattern2node, - const std::unordered_map& var2node, DFPatternMatcher* matcher, - const std::vector& roots, const std::vector& validation_constraints, - const MatcherUseDefAnalysis& ud_analysis, arith::Analyzer* analyzer) { - auto get_next_root = [&](size_t root_idx) -> const PNode* { - // Look for the next unmatched root node. - for (; root_idx < roots.size(); ++root_idx) { - const auto& root = pattern2node.at(roots[root_idx].get()); - if (!current_match.matched(root)) { - return &root; - } - } - return nullptr; - }; - - const auto root = get_next_root(current_root_idx); - - if (!root) { - // All root nodes have been matched - return current_match; - } - - MatchState new_match = current_match; - - for (const auto& var : ud_analysis.vars) { - const RNode& r_node = var2node.at(var); - if (new_match.matched(r_node)) continue; - if (auto match = TryMatch(*root, r_node, new_match, matcher, ud_analysis)) { - // Recursively try to match the next subtree. - new_match.add(std::move(*match)); - if (auto validation = - TryValidate(new_match, pattern2node, validation_constraints, analyzer)) { - new_match.add(std::move(*validation)); - if (auto match_rec = - MatchTree(new_match, current_root_idx + 1, pattern2node, var2node, matcher, roots, - validation_constraints, ud_analysis, analyzer)) { - new_match.add(std::move(*match_rec)); - return new_match; - } - } - // Recursive matching has failed, backtrack. - new_match = current_match; - continue; - } - } - - return std::nullopt; -} - -Optional> MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb, - const Map& bindings) { - // TODO(@ganler): Handle non-may external use. - ICHECK(ctx->allow_extern_use == PatternContextNode::kMay) << "Only kMay is supported yet."; - DFPatternMatcher matcher(bindings); - - MatcherUseDefAnalysis ud_analysis; - ud_analysis.VisitBindingBlock_(dfb.get()); - - // First construct a graph of PNode and RNode. - std::unordered_map var2node; - var2node.reserve(dfb->bindings.size()); - - for (const VarNode* cur_var : ud_analysis.vars) { - const auto& uses = ud_analysis.def2use.at(cur_var); - RNode& cur_node = var2node[cur_var]; - cur_node.ptr = cur_var; - for (const VarNode* use : uses) { - auto& use_node = var2node[use]; - use_node.ptr = use; - cur_node.children.push_back(&use_node); - use_node.parents.push_back(&cur_node); - } - } - - std::unordered_map pattern2node; - pattern2node.reserve(ctx->edge_constraints.size()); - - for (const auto& def_pattern : ctx->src_ordered) { - PNode& def_node = pattern2node[def_pattern.get()]; - const auto& uses = ctx->edge_constraints.at(def_pattern); - def_node.ptr = def_pattern.get(); - def_node.children.reserve(uses.size()); - for (const auto& [use_pattern, cons] : uses) { - PNode& use_node = pattern2node[use_pattern.get()]; - use_node.ptr = use_pattern.get(); - use_node.parents.emplace_back(&def_node, std::ref(cons)); - def_node.children.emplace_back(&use_node, std::ref(cons)); - } - } - - std::vector roots; - for (const auto& pat : ctx->src_ordered) { - if (pattern2node[pat.get()].parents.empty()) { - roots.push_back(pat); - } - } - - if (roots.empty()) { - return NullOpt; - } - - arith::Analyzer analyzer; - auto match = MatchTree({}, 0, pattern2node, var2node, &matcher, roots, - ctx->validation_constraints, ud_analysis, &analyzer); - if (!match) { - return NullOpt; - } - - Map ret; - for (const auto& [pat, p_node] : pattern2node) { - ICHECK(match->matched(p_node)); - ret.Set(GetRef(pat), GetRef(match->matched(p_node))); - } - return ret; -} - -Optional> MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb) { - return MatchGraph(ctx, dfb, AnalyzeVar2Value(dfb)); -} - -TVM_REGISTER_GLOBAL("relax.dpl.match_dfb") - .set_body_typed([](const PatternContext& ctx, const DataflowBlock& dfb) { - return MatchGraph(ctx, dfb); - }); - -/*! - * \brief Apply pattern matching to each dataflow block, replacing matches - * with the output of a user-provided rewriter function. - */ -class BlockPatternRewriter : ExprMutator { - public: - using ExprMutator::VisitBindingBlock_; - using ExprMutator::VisitExpr_; - - BlockPatternRewriter( - const PatternContext& ctx, - TypedPackedFunc(Map, Map)> rewriter_func) - : ctx_(ctx), rewriter_func_(rewriter_func) {} - - template - static Function Run( - PatternType pat, - TypedPackedFunc(Map, Map)> rewriter_func, - Function func) { - BlockPatternRewriter rewriter(pat, rewriter_func); - - func = Downcast(rewriter(func)); - func = Downcast(RemoveAllUnused(func)); - return func; - } - - BindingBlock VisitBindingBlock_(const DataflowBlockNode* block_node) override { - return RewriteDataflowBlockFixedPoint(GetRef(block_node)); - } - - private: - void EmitUsedVars(Expr val, const Array& pending_bindings, - std::unordered_set* emitted_vars) { - std::unordered_set unemitted_vars; - PostOrderVisit(val, [=, &unemitted_vars](Expr e) { - if (auto v = e.as(); v && !emitted_vars->count(v)) { - unemitted_vars.insert(v); - } - }); - - if (unemitted_vars.empty()) { - return; - } - - size_t num_unemitted = unemitted_vars.size(); - for (size_t i = 0; i < pending_bindings.size(); ++i) { - const auto& binding = pending_bindings[i]; - if (auto var_bind = binding.as(); - var_bind && unemitted_vars.count(var_bind->var.get())) { - // var_bind->value may also depend on other unemitted vars in this range - Array prev_bindings(pending_bindings.begin(), pending_bindings.begin() + i); - EmitUsedVars(var_bind->value, prev_bindings, emitted_vars); - this->VisitBinding(binding); - emitted_vars->insert(var_bind->var.get()); - if (--num_unemitted == 0) { - return; - } - } - } - } - - // Repeat until all matchable subsets of bindings are rewritten. - BindingBlock RewriteDataflowBlockFixedPoint(BindingBlock block) { - auto df_block = Downcast(block); - Map bindings = AnalyzeVar2Value(df_block); - if (auto matches = MatchGraph(ctx_, df_block, bindings)) { - builder_->BeginDataflowBlock(); - Map replacements = rewriter_func_(matches.value(), bindings); - - std::unordered_set emitted_vars; - - bool changed = false; - for (size_t i = 0; i < block->bindings.size(); ++i) { - const auto& binding = block->bindings[i]; - if (auto var_bind = binding.as()) { - if (auto new_val = replacements.Get(var_bind->var).value_or(var_bind->value); - !StructuralEqual()(var_bind->value, new_val)) { - Array pending_bindings(block->bindings.begin() + i + 1, block->bindings.end()); - // Make sure there is no unbound variable used in the new value before it is emitted - EmitUsedVars(new_val, pending_bindings, &emitted_vars); - this->ReEmitBinding(var_bind, builder_->Normalize(new_val)); - changed = true; - } else if (!emitted_vars.count(var_bind->var.get())) { - this->VisitBinding(binding); - emitted_vars.insert(var_bind->var.get()); - } - } else { - this->VisitBinding(binding); - } - } - - auto new_block = builder_->EndBlock(); - - if (!changed) return new_block; - return RewriteDataflowBlockFixedPoint(new_block); - } - return block; - } - - /*! \brief The pattern constraint contexts for rewriting dataflow blocks */ - PatternContext ctx_; - /*! - * \brief The user-provided rewriter function. Its signature and semantics are: - * - * - (Map, Map) -> Map - * - * Given the map of patterns and corresponding variables (bound - * variables or parameters), it should return a map that - * specifies new values for matched bound variables. It can refer - * to the passed bindings to create the replacement expressions. - */ - TypedPackedFunc(Map, Map)> rewriter_func_; -}; - -/*! - * \brief Apply pattern matching to each expression, replacing - * matches with the output of a user-provided rewriter function. - */ -class ExprPatternRewriter : ExprMutator { - public: - using ExprMutator::VisitBindingBlock_; - using ExprMutator::VisitExpr_; - - ExprPatternRewriter(DFPattern pat, - TypedPackedFunc)> rewriter_func) - : pattern_(pat), rewriter_func_(rewriter_func) {} - - template - static Function Run(PatternType pat, - TypedPackedFunc)> rewriter_func, - Function func) { - ExprPatternRewriter rewriter(pat, rewriter_func); - func = Downcast(rewriter(func)); - func = Downcast(RemoveAllUnused(func)); - return func; - } - - Expr VisitExpr_(const SeqExprNode* seq) override { - auto cache = bindings_; - SeqExpr prev = GetRef(seq); - - StructuralEqual struct_equal; - - while (true) { - SeqExpr next = Downcast(builder_->Normalize(ExprMutator::VisitExpr_(prev.get()))); - if (struct_equal(prev, next)) { - return std::move(next); - } - - // Canonicalization may result in two previously-different - // expressions being recognized as identical. Elimination of - // common subexpressions may result in trival var-to-var - // bindings that can be canonicalized. Therefore, iterate the - // simplification steps until converged. - while (true) { - auto start_of_loop = next; - next = Downcast(CanonicalizeBindings(next)); - next = Downcast(EliminateCommonSubexpr(next)); - next = Downcast(RemoveAllUnused(next)); - if (struct_equal(start_of_loop, next)) { - break; - } - } - - if (struct_equal(prev, next)) { - return std::move(next); - } - - // Reset all knowledge of bindings that were collected from - // this SeqExpr. The collected bindings are only after - // the point where they were collected, and we are repeating - // the mutation of this SeqExpr. - bindings_ = cache; - prev = next; - } - } - - void VisitBinding_(const VarBindingNode* binding) override { - auto expr = VisitExpr(binding->value); - bindings_.Set(binding->var, expr); - ReEmitBinding(binding, expr); - } - - Expr VisitExpr(const Expr& expr) override { - auto node = ExprMutator::VisitExpr(expr); - - std::vector matches_top_level; - if (auto rewritten = TryRewrite(node, pattern_, &matches_top_level)) { - return builder_->Normalize(rewritten.value()); - } - - return node; - } - - private: - Optional TryRewrite(const Expr& expr, const DFPattern& pattern, - std::vector* matches_top_level) { - ICHECK(matches_top_level); - - // Special handling if the user-supplied pattern is a `OrPattern`. - // While the `ExtractMatchedExpr` can handle matching the - // `OrPattern`, it will return on the first match, even if the - // `rewriter_func_` doesn't apply a replacement. Unpacking the - // `OrPattern` here allows the match to be resumed if - // `rewriter_func_` returns the original function unmodified. - // This is only valid for a top-level match. - if (auto or_pattern = pattern.as()) { - matches_top_level->push_back(pattern); - Optional output = TryRewrite(expr, or_pattern->left, matches_top_level); - if (!output.defined()) { - output = TryRewrite(expr, or_pattern->right, matches_top_level); - } - matches_top_level->pop_back(); - return output; - } - - if (auto opt_matches = ExtractMatchedExpr(pattern, expr, bindings_)) { - auto matches = opt_matches.value(); - - // Append any additional matches that from the unwrapped - // `OrPattern`. When matching against `pat = pat_lhs | - // pat_rhs`, we call `ExtractMatchedExpr` on `pat_lhs` and - // `pat_rhs` separately. The top-level `pat` is never seen by - // `ExtractMatchedExpr`, and must be re-added afterward. - if (matches_top_level->size()) { - auto matched_expr = TryGetValOfVar(expr, bindings_); - for (const auto& pat : *matches_top_level) { - matches.Set(pat, matched_expr); - } - } - - Expr rewritten_expr = rewriter_func_(expr, matches); - if (!rewritten_expr.same_as(expr)) { - return builder_->Normalize(rewritten_expr); - } - } - - return NullOpt; - } - - /*! \brief The pattern for rewriting call nodes */ - DFPattern pattern_; - /*! - * \brief The user-provided rewriter function. Its signature and semantics are: - * - * - (Call, Map) -> Call - * - * Given the matched call node and the map of patterns and - * matched expressions, it should return a new call node to - * replace the original one or the original matched call node as - * is. - */ - TypedPackedFunc)> rewriter_func_; - - /*! \brief The known variable bindings - * - * The variable bindings whose value is known. This must be tracked - * separately from the block builder, so that it can be reset after - * each iteration of the mutate-until-converged loop applied to - * `SeqExpr`. - */ - Map bindings_; -}; - -Function RewriteBindings( - const PatternContext& ctx, - TypedPackedFunc(Map, Map)> rewriter, Function func) { - return BlockPatternRewriter::Run(ctx, rewriter, func); -} - -TVM_REGISTER_GLOBAL("relax.dpl.rewrite_bindings").set_body_typed(RewriteBindings); - -Function RewriteCall(const DFPattern& pat, - TypedPackedFunc)> rewriter, Function func) { - return ExprPatternRewriter::Run(pat, rewriter, func); -} - -TVM_REGISTER_GLOBAL("relax.dpl.rewrite_call").set_body_typed(RewriteCall); - } // namespace relax } // namespace tvm diff --git a/src/relax/ir/dataflow_matcher_impl.h b/src/relax/ir/dataflow_matcher.h similarity index 97% rename from src/relax/ir/dataflow_matcher_impl.h rename to src/relax/ir/dataflow_matcher.h index a0c35ac0dead..9036c7630a54 100644 --- a/src/relax/ir/dataflow_matcher_impl.h +++ b/src/relax/ir/dataflow_matcher.h @@ -45,6 +45,9 @@ class DFPatternMatcher : public DFPatternFunctor> GetMemo() { return Map>(memo_); } + /* \brief Unwrap trivial expressions/bindings */ + static Expr UnwrapBindings(Expr expr, const Map& bindings); + protected: bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override; bool VisitDFPattern_(const OrPatternNode* op, const Expr& expr) override; From 7488ea146c606e6bc549e8f8d5dfd379ff673d07 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sat, 29 Jun 2024 07:35:13 -0500 Subject: [PATCH 04/17] [Relax][Refactor] Implement Rewriter class for pattern-rewrite Prior to this commit, the pattern to be matched and the rewrite to be performed were provided as separate arguments. This commit introduces a new class `ExprRewriter`, which contains both parts. This abstraction will make it easier to combine multiple different rewrite rules, applying them in a single pass. --- python/tvm/relax/dpl/__init__.py | 2 +- python/tvm/relax/dpl/rewrite.py | 63 +- python/tvm/script/ir_builder/relax/ir.py | 48 +- python/tvm/script/parser/core/entry.py | 3 +- python/tvm/script/parser/core/utils.py | 14 +- src/relax/ir/dataflow_block_rewriter.cc | 166 +-- src/relax/ir/dataflow_expr_rewriter.cc | 1058 +++++++++++-- src/relax/ir/dataflow_matcher.cc | 5 +- src/relax/ir/dataflow_matcher.h | 4 +- src/relax/ir/dataflow_rewriter.h | 178 +++ tests/python/relax/test_dataflow_rewriter.py | 1388 ++++++++++++++++++ 11 files changed, 2712 insertions(+), 217 deletions(-) create mode 100644 src/relax/ir/dataflow_rewriter.h create mode 100644 tests/python/relax/test_dataflow_rewriter.py diff --git a/python/tvm/relax/dpl/__init__.py b/python/tvm/relax/dpl/__init__.py index 6451238428c2..cda84424e5ab 100644 --- a/python/tvm/relax/dpl/__init__.py +++ b/python/tvm/relax/dpl/__init__.py @@ -19,4 +19,4 @@ from .pattern import * from .context import * -from .rewrite import rewrite_call, rewrite_bindings +from .rewrite import rewrite_call, rewrite_bindings, ExprRewriter, PatternRewriter, OrRewriter diff --git a/python/tvm/relax/dpl/rewrite.py b/python/tvm/relax/dpl/rewrite.py index 291061090fc2..f124c11f7077 100644 --- a/python/tvm/relax/dpl/rewrite.py +++ b/python/tvm/relax/dpl/rewrite.py @@ -15,16 +15,75 @@ # specific language governing permissions and limitations # under the License. """APIs for pattern-based rewriting.""" -from typing import Dict, Callable + +from typing import Dict, Callable, Union from .pattern import DFPattern from .context import PatternContext +from tvm.ir import IRModule +from tvm.runtime import Object +from tvm._ffi import register_object from ..expr import Expr, Function, Var from . import _ffi as ffi +@register_object("relax.dpl.ExprRewriter") +class ExprRewriter(Object): + @staticmethod + def from_pattern( + pattern: DFPattern, + func: Callable[[Expr, Dict[DFPattern, Expr]], Expr], + ) -> "ExprRewriter": + return ffi.ExprRewriterFromPattern( + pattern, + func, + ) # type: ignore + + @staticmethod + def from_module(mod: IRModule) -> "ExprRewriter": + return ffi.ExprRewriterFromModule(mod) # type: ignore + + def __call__(self, obj: Union[Expr, IRModule]) -> Union[Expr, IRModule]: + return ffi.ExprRewriterApply(self, obj) + + def __or__(self, other: "ExprRewriter") -> "ExprRewriter": + return OrRewriter(self, other) + + +@register_object("relax.dpl.PatternRewriter") +class PatternRewriter(ExprRewriter): + def __init__(self, pattern, func): + self.__init_handle_by_constructor__( + ffi.PatternRewriter, + pattern, + func, + ) # type: ignore + + +@register_object("relax.dpl.OrRewriter") +class OrRewriter(ExprRewriter): + def __init__(self, lhs, rhs): + self.__init_handle_by_constructor__( + ffi.OrRewriter, + lhs, + rhs, + ) # type: ignore + + +@register_object("relax.dpl.TupleRewriter") +class TupleRewriter(ExprRewriter): + def __init__(self, patterns, func): + self.__init_handle_by_constructor__( + ffi.TupleRewriter, + patterns, + func, + ) # type: ignore + + def rewrite_call( - pattern: DFPattern, rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr], func: Function + pattern: DFPattern, + rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr], + func: Function, ) -> Function: """ Rewrite a function with the given pattern and the rewriter function. diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index ef9ae775450b..e0beaeb9aade 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -20,11 +20,11 @@ import builtins import functools import inspect -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Type import tvm from tvm import DataType, relax -from tvm.ir import PrimExpr, VDevice +from tvm.ir import PrimExpr, VDevice, IRModule from tvm.relax import ( Call, Expr, @@ -35,6 +35,7 @@ VarBinding, const, ) +from tvm.relax.dpl import ExprRewriter ############################### Operators ############################### from tvm.relax.op import ( @@ -306,6 +307,48 @@ def func_ret_value(value: Expr) -> None: return _ffi_api.FuncRetValue(value) # type: ignore[attr-defined] # pylint: disable=no-member +def rewriter(rewriter_mod: Union[IRModule, Type]) -> ExprRewriter: + """Define a pattern-rewrite rule + + The IRModule must have two publicly-exposed functions, `pattern` + and `replacement`, where `pattern` and `replacement` have the same + function signature. + + .. code-block:: python + + @R.rewriter + class RewriteAddIntoMultiply: + @R.function + def pattern(A: R.Tensor): + B = A + A + return B + + @R.function + def replacement(A: R.Tensor): + B = A * 2 + return B + + Parameters + ---------- + rewriter_mod: Union[IRModule, Type] + + Either an IRModule that defines a rewrite pattern, or a + TVMScript class that can be parsed into an IRModule. + + Returns + ------- + rewriter: ExprRewriter + + A rewriter object, which can be applied either to a Relax + function or to an entire IRModule. + + """ + if not isinstance(rewriter_mod, IRModule): + rewriter_mod = tvm.script.ir_module(rewriter_mod) + + return ExprRewriter.from_module(rewriter_mod) + + ############################# BindingBlock ############################## @@ -765,6 +808,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "dequantize", "repeat", "reshape", + "rewriter", "tensor_to_shape", "shape_to_tensor", "rocm", diff --git a/python/tvm/script/parser/core/entry.py b/python/tvm/script/parser/core/entry.py index e7a7f98b7651..3d35416d941a 100644 --- a/python/tvm/script/parser/core/entry.py +++ b/python/tvm/script/parser/core/entry.py @@ -83,7 +83,8 @@ def parse( The parsed TVMScript program. """ if extra_vars is None: - extra_vars = _default_globals() + extra_vars = {} + extra_vars = {**extra_vars, **_default_globals()} ann = {} if inspect.isfunction(program): diff --git a/python/tvm/script/parser/core/utils.py b/python/tvm/script/parser/core/utils.py index 3edae3f25a33..8ad64f5dbc68 100644 --- a/python/tvm/script/parser/core/utils.py +++ b/python/tvm/script/parser/core/utils.py @@ -100,19 +100,29 @@ def is_defined_in_class(frames: List[FrameType], obj: Any) -> bool: res : bool The result if the object is defined in a class scope. """ + + def _is_tvmscript_class_annotator(line: str) -> bool: + """Checks if the line contains a TVMScript annotator for a class + + These match either `@I.ir_module` or `@R.rewriter`, or their + imported names `@ir_module` or `@rewriter`. + """ + + return line.startswith("@") and ("ir_module" in line or "rewriter" in line) + if len(frames) > 2: frame_info = frames[2] code_context = frame_info.code_context if code_context is None: return False line = code_context[0].strip() - if line.startswith("@") and "ir_module" in line: + if _is_tvmscript_class_annotator(line): return True if line.startswith("class"): lineno = frame_info.lineno if lineno >= 2: source, _ = findsource(obj) line = source[lineno - 2].strip() - if line.startswith("@") and "ir_module" in line: + if _is_tvmscript_class_annotator(line): return True return False diff --git a/src/relax/ir/dataflow_block_rewriter.cc b/src/relax/ir/dataflow_block_rewriter.cc index c5af0493a54c..d07fedd29715 100644 --- a/src/relax/ir/dataflow_block_rewriter.cc +++ b/src/relax/ir/dataflow_block_rewriter.cc @@ -37,6 +37,7 @@ #include #include "dataflow_matcher.h" +#include "dataflow_rewriter.h" namespace tvm { namespace relax { @@ -287,18 +288,21 @@ static std::optional MatchTree( return std::nullopt; } -Optional> MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb, +Optional> MatchGraph(const PatternContext& ctx, + const Array& binding_arr, const Map& bindings) { // TODO(@ganler): Handle non-may external use. ICHECK(ctx->allow_extern_use == PatternContextNode::kMay) << "Only kMay is supported yet."; DFPatternMatcher matcher(bindings); MatcherUseDefAnalysis ud_analysis; - ud_analysis.VisitBindingBlock_(dfb.get()); + for (const auto& binding : binding_arr) { + ud_analysis.VisitBinding(binding); + } // First construct a graph of PNode and RNode. std::unordered_map var2node; - var2node.reserve(dfb->bindings.size()); + var2node.reserve(bindings.size()); for (const VarNode* cur_var : ud_analysis.vars) { const auto& uses = ud_analysis.def2use.at(cur_var); @@ -355,7 +359,7 @@ Optional> MatchGraph(const PatternContext& ctx, const Datafl } Optional> MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb) { - return MatchGraph(ctx, dfb, AnalyzeVar2Value(dfb)); + return MatchGraph(ctx, dfb->bindings, AnalyzeVar2Value(dfb)); } TVM_REGISTER_GLOBAL("relax.dpl.match_dfb") @@ -363,124 +367,82 @@ TVM_REGISTER_GLOBAL("relax.dpl.match_dfb") return MatchGraph(ctx, dfb); }); -/*! - * \brief Apply pattern matching to each dataflow block, replacing matches - * with the output of a user-provided rewriter function. - */ -class BlockPatternRewriter : ExprMutator { +class PatternContextRewriterNode : public ExprRewriterNode { public: - using ExprMutator::VisitBindingBlock_; - using ExprMutator::VisitExpr_; - - BlockPatternRewriter( - const PatternContext& ctx, - TypedPackedFunc(Map, Map)> rewriter_func) - : ctx_(ctx), rewriter_func_(rewriter_func) {} - - template - static Function Run( - PatternType pat, - TypedPackedFunc(Map, Map)> rewriter_func, - Function func) { - BlockPatternRewriter rewriter(pat, rewriter_func); - - func = Downcast(rewriter(func)); - func = Downcast(RemoveAllUnused(func)); - return func; - } + PatternContext pattern; + TypedPackedFunc(Map, Map)> rewriter_func; + + RewriteSpec RewriteBindings(const Array& bindings) const override; - BindingBlock VisitBindingBlock_(const DataflowBlockNode* block_node) override { - return RewriteDataflowBlockFixedPoint(GetRef(block_node)); + void VisitAttrs(AttrVisitor* visitor) { + visitor->Visit("pattern", &pattern); + PackedFunc untyped_func = rewriter_func; + visitor->Visit("rewriter_func", &untyped_func); } - private: - void EmitUsedVars(Expr val, const Array& pending_bindings, - std::unordered_set* emitted_vars) { - std::unordered_set unemitted_vars; - PostOrderVisit(val, [=, &unemitted_vars](Expr e) { - if (auto v = e.as(); v && !emitted_vars->count(v)) { - unemitted_vars.insert(v); - } - }); + static constexpr const char* _type_key = "relax.dpl.PatternContextRewriter"; + TVM_DECLARE_FINAL_OBJECT_INFO(PatternContextRewriterNode, ExprRewriterNode); - if (unemitted_vars.empty()) { - return; + private: + Optional> MatchBindings(const Array& bindings) const { + Map var_lookup; + for (const auto& binding : bindings) { + var_lookup.Set(binding->var, GetBoundValue(binding)); } - size_t num_unemitted = unemitted_vars.size(); - for (size_t i = 0; i < pending_bindings.size(); ++i) { - const auto& binding = pending_bindings[i]; - if (auto var_bind = binding.as(); - var_bind && unemitted_vars.count(var_bind->var.get())) { - // var_bind->value may also depend on other unemitted vars in this range - Array prev_bindings(pending_bindings.begin(), pending_bindings.begin() + i); - EmitUsedVars(var_bind->value, prev_bindings, emitted_vars); - this->VisitBinding(binding); - emitted_vars->insert(var_bind->var.get()); - if (--num_unemitted == 0) { - return; - } + if (auto matches = MatchGraph(pattern, bindings, var_lookup)) { + Map replacements = rewriter_func(matches.value(), var_lookup); + if (replacements.size()) { + return replacements; } } + + return NullOpt; } +}; - // Repeat until all matchable subsets of bindings are rewritten. - BindingBlock RewriteDataflowBlockFixedPoint(BindingBlock block) { - auto df_block = Downcast(block); - Map bindings = AnalyzeVar2Value(df_block); - if (auto matches = MatchGraph(ctx_, df_block, bindings)) { - builder_->BeginDataflowBlock(); - Map replacements = rewriter_func_(matches.value(), bindings); - - std::unordered_set emitted_vars; - - bool changed = false; - for (size_t i = 0; i < block->bindings.size(); ++i) { - const auto& binding = block->bindings[i]; - if (auto var_bind = binding.as()) { - if (auto new_val = replacements.Get(var_bind->var).value_or(var_bind->value); - !StructuralEqual()(var_bind->value, new_val)) { - Array pending_bindings(block->bindings.begin() + i + 1, block->bindings.end()); - // Make sure there is no unbound variable used in the new value before it is emitted - EmitUsedVars(new_val, pending_bindings, &emitted_vars); - this->ReEmitBinding(var_bind, builder_->Normalize(new_val)); - changed = true; - } else if (!emitted_vars.count(var_bind->var.get())) { - this->VisitBinding(binding); - emitted_vars.insert(var_bind->var.get()); - } - } else { - this->VisitBinding(binding); - } - } +class PatternContextRewriter : public ExprRewriter { + public: + PatternContextRewriter( + PatternContext pattern, + TypedPackedFunc(Map, Map)> rewriter_func); - auto new_block = builder_->EndBlock(); + TVM_DEFINE_OBJECT_REF_METHODS(PatternContextRewriter, ExprRewriter, PatternContextRewriterNode); +}; - if (!changed) return new_block; - return RewriteDataflowBlockFixedPoint(new_block); +RewriteSpec PatternContextRewriterNode::RewriteBindings(const Array& bindings) const { + std::vector remaining_bindings{bindings.begin(), bindings.end()}; + + Map variable_rewrites; + while (auto opt = MatchBindings(remaining_bindings)) { + auto new_rewrites = opt.value(); + remaining_bindings.erase(std::remove_if(remaining_bindings.begin(), remaining_bindings.end(), + [&new_rewrites](const Binding& binding) { + return new_rewrites.count(binding->var); + }), + remaining_bindings.end()); + for (const auto& [var, expr] : new_rewrites) { + variable_rewrites.Set(var, expr); } - return block; } - /*! \brief The pattern constraint contexts for rewriting dataflow blocks */ - PatternContext ctx_; - /*! - * \brief The user-provided rewriter function. Its signature and semantics are: - * - * - (Map, Map) -> Map - * - * Given the map of patterns and corresponding variables (bound - * variables or parameters), it should return a map that - * specifies new values for matched bound variables. It can refer - * to the passed bindings to create the replacement expressions. - */ - TypedPackedFunc(Map, Map)> rewriter_func_; -}; + return RewriteSpec{variable_rewrites, {}}; +} + +PatternContextRewriter::PatternContextRewriter( + PatternContext pattern, + TypedPackedFunc(Map, Map)> rewriter_func) { + auto node = make_object(); + node->pattern = std::move(pattern); + node->rewriter_func = std::move(rewriter_func); + data_ = std::move(node); +} Function RewriteBindings( const PatternContext& ctx, TypedPackedFunc(Map, Map)> rewriter, Function func) { - return BlockPatternRewriter::Run(ctx, rewriter, func); + // return BlockPatternRewriter::Run(ctx, rewriter, func); + return Downcast(PatternContextRewriter(ctx, rewriter)(func)); } TVM_REGISTER_GLOBAL("relax.dpl.rewrite_bindings").set_body_typed(RewriteBindings); diff --git a/src/relax/ir/dataflow_expr_rewriter.cc b/src/relax/ir/dataflow_expr_rewriter.cc index 4793d1d75a30..8acaec60c356 100644 --- a/src/relax/ir/dataflow_expr_rewriter.cc +++ b/src/relax/ir/dataflow_expr_rewriter.cc @@ -22,6 +22,7 @@ * \brief A transform to match a Relax Expr and rewrite */ +#include #include #include #include @@ -31,12 +32,754 @@ #include #include +#include + #include "../transform/utils.h" #include "dataflow_matcher.h" +#include "dataflow_rewriter.h" namespace tvm { namespace relax { +namespace { +class GlobalVarReplacer : public ExprMutator { + public: + GlobalVarReplacer(Map gvar_map) : gvar_map_(gvar_map) {} + + using ExprMutator::VisitExpr_; + Expr VisitExpr_(const GlobalVarNode* op) override { + auto gvar = GetRef(op); + if (auto opt = gvar_map_.Get(gvar)) { + gvar = opt.value(); + } + return gvar; + } + + private: + Map gvar_map_; +}; + +Array TopologicalSort(const Array& bindings) { + std::unordered_set remaining_bindings; + for (const auto& binding : bindings) { + remaining_bindings.insert(binding->var); + } + + // Utility structure used to track bindings that are moved later in + // the list. + struct DelayedBinding { + Binding binding; + std::unordered_set unmet_requirements; + bool emitted; + }; + std::vector delayed_bindings; + Array sorted_bindings; + + // Utility function to append the + auto push_sorted_binding = [&](Binding binding) { + sorted_bindings.push_back(binding); + remaining_bindings.erase(binding->var); + for (auto& delayed_binding : delayed_bindings) { + delayed_binding.unmet_requirements.erase(binding->var); + } + }; + + bool required_sorting = false; + for (const auto& binding : bindings) { + // Collect any variables used by this binding, but are emitted by + // a later binding. + std::unordered_set unmet_requirements; + for (auto free_var : FreeVars(GetBoundValue(binding))) { + if (remaining_bindings.count(free_var)) { + unmet_requirements.insert(free_var); + } + } + + if (unmet_requirements.empty()) { + push_sorted_binding(binding); + } else { + required_sorting = true; + delayed_bindings.push_back(DelayedBinding{binding, unmet_requirements, false}); + } + + bool requires_delayed_binding_check = true; + while (requires_delayed_binding_check) { + requires_delayed_binding_check = false; + for (auto& delayed_binding : delayed_bindings) { + if (!delayed_binding.emitted && delayed_binding.unmet_requirements.empty()) { + // If we find a delayed binding that can be emitted, mark it + // as emitted and push to the sorted list. This may + delayed_binding.emitted = true; + requires_delayed_binding_check = true; + push_sorted_binding(delayed_binding.binding); + + // The break is not necessary for a topological sort, but is + // necessary to minimize the amount of re-ordering that is + // performed. With this break, the next binding is always + // the earliest binding that is legal to emit at this point. + break; + } + } + } + + // Remove any delayed bindings that have been emitted, now that we + // are done iterating over the delayed bindings. + delayed_bindings.erase( + std::remove_if(delayed_bindings.begin(), delayed_bindings.end(), + [](const auto& delayed_binding) { return delayed_binding.emitted; }), + delayed_bindings.end()); + } + + // All bindings should be emitted by this point. If any remain, + // then there exists a circular dependency somewhere in the + // remaining bindings. + CHECK(delayed_bindings.empty()) << "ValueError: " + << "Bindings contain circular dependency"; + + if (required_sorting) { + return sorted_bindings; + } else { + return bindings; + } +} +} // namespace + +void RewriteSpec::Append(RewriteSpec other) { + if (variable_rewrites.empty()) { + *this = std::move(other); + return; + } + if (other.variable_rewrites.empty()) { + return; + } + + NameSupply gvar_name_supply(""); + for (const auto& [gvar, func] : new_subroutines) { + gvar_name_supply->ReserveName(gvar->name_hint); + } + + Map gvar_rewrites; + for (auto [gvar, func] : other.new_subroutines) { + if (auto it = new_subroutines.find(gvar); it != new_subroutines.end()) { + // The two rewrites provide the same GlobalVar. + // (e.g. Multiple rewrites of the same pattern.) Ensure that + // they are referring to the same underlying BaseFunc. + CHECK(func.same_as((*it).second)); + } else if (auto new_name = gvar_name_supply->FreshName(gvar->name_hint); + new_name != gvar->name_hint) { + // The two rewrites provide distinct GlobalVar subroutines, + // but with conflicting names. Because an IRModule must have + // enough names for each GlobalVar, even if they are not + // publicly exposed, one of the GlobalVars must be replaced. + // Replacing the GlobalVar here, when the conflict is first + // identified, minimizes the size of the `relax::Expr` that + // must be updated with `GlobalVarReplacer`. + GlobalVar new_gvar = gvar; + new_gvar.CopyOnWrite()->name_hint = new_name; + gvar_rewrites.Set(gvar, new_gvar); + new_subroutines.Set(new_gvar, func); + } else { + new_subroutines.Set(gvar, func); + } + } + + for (auto [var, expr] : other.variable_rewrites) { + if (gvar_rewrites.size()) { + expr = GlobalVarReplacer(gvar_rewrites)(expr); + } + variable_rewrites.Set(var, expr); + } +} + +TVM_REGISTER_NODE_TYPE(ExprRewriterNode); + +TVM_REGISTER_GLOBAL("relax.dpl.ExprRewriterFromPattern") + .set_body_typed([](DFPattern pattern, + TypedPackedFunc(Expr, Map)> func) { + return ExprRewriter::FromPattern(pattern, func); + }); + +TVM_REGISTER_GLOBAL("relax.dpl.ExprRewriterFromModule").set_body_typed([](IRModule mod) { + return ExprRewriter::FromModule(mod); +}); + +TVM_REGISTER_GLOBAL("relax.dpl.ExprRewriterApply") + .set_body_typed([](ExprRewriter rewriter, + Variant obj) -> Variant { + if (auto expr = obj.as()) { + return rewriter(expr.value()); + } else if (auto mod = obj.as()) { + return rewriter(mod.value()); + } else { + LOG(FATAL) << "Unreachable: object does not contain either variant type"; + } + }); + +TVM_REGISTER_NODE_TYPE(PatternRewriterNode); + +RewriteSpec PatternRewriterNode::RewriteBindings(const Array& bindings) const { + Map variable_rewrites; + Map binding_lookup; + for (const auto& binding : bindings) { + auto bound_value = GetBoundValue(binding); + if (auto new_expr = RewriteExpr(bound_value, binding_lookup)) { + variable_rewrites.Set(binding->var, new_expr.value()); + } else { + binding_lookup.Set(binding->var, bound_value); + } + } + if (variable_rewrites.size()) { + return RewriteSpec{variable_rewrites, new_subroutines}; + } else { + return RewriteSpec(); + } +} + +Optional PatternRewriterNode::RewriteExpr(const Expr& expr, + const Map& bindings) const { + if (auto opt_matches = ExtractMatchedExpr(pattern, expr, bindings)) { + auto matches = opt_matches.value(); + if (additional_bindings) { + // Append any additional matches that from the unwrapped + // `OrPattern`. When matching against `pat = pat_lhs | + // pat_rhs`, we call `ExtractMatchedExpr` on `pat_lhs` and + // `pat_rhs` separately. The top-level `pat` is never seen by + // `ExtractMatchedExpr`, and must be re-added afterward. + auto matched_expr = DFPatternMatcher::UnwrapBindings(expr, bindings); + for (const auto& pat : additional_bindings.value()) { + matches.Set(pat, matched_expr); + } + } + + Optional rewritten_expr = func(expr, matches); + if (rewritten_expr.defined() && !rewritten_expr.same_as(expr)) { + return rewritten_expr.value(); + } + } + return NullOpt; +} + +TVM_REGISTER_GLOBAL("relax.dpl.PatternRewriter") + .set_body_typed([](DFPattern pattern, + TypedPackedFunc(Expr, Map)> func) { + return PatternRewriter(pattern, func); + }); + +PatternRewriter::PatternRewriter(DFPattern pattern, + TypedPackedFunc(Expr, Map)> func, + Optional> additional_bindings, + Map new_subroutines) { + auto node = make_object(); + node->pattern = std::move(pattern); + node->func = std::move(func); + node->additional_bindings = std::move(additional_bindings); + node->new_subroutines = std::move(new_subroutines); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(OrRewriterNode); + +RewriteSpec OrRewriterNode::RewriteBindings(const Array& bindings) const { + auto lhs_match = lhs->RewriteBindings(bindings); + if (!lhs_match) { + // If no rewrites found on LHS, RHS is allowed to modify any + // variable binding. + return rhs->RewriteBindings(bindings); + } + + // The LHS matched some subset of the bindings. These + // replacements may not be normalized expressions, so the RHS may + // only replace variable bindings that haven't been modified by + // the LHS. Variable replacements from the RHS may still occur, + // but will need to wait for the next round of + // iterate-until-converged. + Array remaining_bindings; + for (const auto& binding : bindings) { + if (!lhs_match.variable_rewrites.count(binding->var)) { + remaining_bindings.push_back(binding); + } + } + + if (remaining_bindings.empty()) { + // Early bail-out, the RHS has no bindings available to rewrite. + return lhs_match; + } + + lhs_match.Append(rhs->RewriteBindings(remaining_bindings)); + return lhs_match; +} + +TVM_REGISTER_GLOBAL("relax.dpl.OrRewriter").set_body_typed([](ExprRewriter lhs, ExprRewriter rhs) { + return OrRewriter(lhs, rhs); +}); + +OrRewriter::OrRewriter(ExprRewriter lhs, ExprRewriter rhs) { + auto node = make_object(); + node->lhs = std::move(lhs); + node->rhs = std::move(rhs); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(TupleRewriterNode); + +RewriteSpec TupleRewriterNode::RewriteBindings(const Array& bindings) const { + CHECK_LE(patterns.size(), 3) << "For performance reasons, " + << "matching of implicit tuple patterns is currently limited" + << " to tuples with 3 elements or fewer."; + Map variable_rewrites = GenerateVariableRewrites(bindings); + + if (variable_rewrites.size()) { + return RewriteSpec{variable_rewrites, new_subroutines}; + } else { + return RewriteSpec(); + } +} + +Map TupleRewriterNode::GenerateVariableRewrites(const Array& bindings) const { + Map rewrites; + + Map binding_lookup; + + std::vector info_vec; + + std::unordered_map binding_index_lookup; + + // Initialize a vector of indices, each of which corresponds to a + // potential match for a tuple element. + // + // \param tuple_index_of_current_expr The index for the most recent + // binding. + // + // \param indices An output vector, into which indices will be + // generated. + // + // \returns bool True if the indices could be initialized to a + // potential match. False, otherwise. + auto initialize_indices = [&](size_t tuple_index_of_current_expr, + std::vector& indices) -> bool { + if (!info_vec.back().matches[tuple_index_of_current_expr]) { + return false; + } + + indices = std::vector(patterns.size(), info_vec.size()); + + indices[tuple_index_of_current_expr] = info_vec.size() - 1; + + for (size_t i_rev = 0; i_rev < indices.size(); i_rev++) { + size_t i = indices.size() - i_rev - 1; + if (indices[i] == info_vec.size() - 1) { + continue; + } + + auto binding_index = [&]() -> std::optional { + if (indices[i] == info_vec.size() - 1) { + return info_vec.size() - 1; + } + + for (size_t j_rev = 1; j_rev < info_vec.size(); j_rev++) { + size_t j = info_vec.size() - j_rev - 1; + if (info_vec[j].matches[i] && !info_vec[j].used && + std::all_of(indices.begin() + (j + 1), indices.end(), + [j](size_t prev_binding_index) { return j != prev_binding_index; })) { + return j; + } + } + + return std::nullopt; + }(); + + if (binding_index.has_value()) { + indices[i] = binding_index.value(); + } else { + return false; + } + } + + return true; + }; + + auto decrement_indices = [&](std::vector& indices) -> bool { + ICHECK_EQ(indices.size(), patterns.size()); + + // Step 1, find the first index that can be decremented, while + // still generating a valid set of indices. + size_t i_forward; + for (i_forward = 0; i_forward < indices.size(); i_forward++) { + if (indices[i_forward] == info_vec.size() - 1) { + continue; + } + + bool found_valid = false; + size_t& index = indices[i_forward]; + while (index) { + index--; + if (info_vec[index].matches[i_forward] && !info_vec[index].used && + std::all_of( + indices.begin() + (i_forward + 1), indices.end(), + [index](size_t later_binding_index) { return index != later_binding_index; })) { + found_valid = true; + break; + } + } + if (found_valid) { + break; + } + } + + // Step 2, if we reached the end, then all indices were + // decremented to zero without finding anything. Return false to + // indicate that we've reached the end. + if (i_forward == indices.size()) { + return false; + } + + // Step 3, refill all indices that were decremented to zero before from 0 to + for (size_t i = 0; i < i_forward; i++) { + size_t i_backward = i_forward - (i + 1); + if (indices[i_backward] == info_vec.size() - 1) { + continue; + } + + auto binding_index = [&]() -> std::optional { + for (size_t j_rev = 1; j_rev < info_vec.size(); j_rev++) { + size_t j = info_vec.size() - j_rev - 1; + if (info_vec[j].matches[i_backward] && !info_vec[j].used && + std::all_of(indices.begin() + (j + 1), indices.end(), + [j](size_t prev_binding_index) { return j != prev_binding_index; })) { + return j; + } + } + + return std::nullopt; + }(); + + if (binding_index.has_value()) { + indices[i_backward] = binding_index.value(); + } else { + return false; + } + } + + return true; + }; + + for (size_t i_binding = 0; i_binding < bindings.size(); i_binding++) { + const auto& binding = bindings[i_binding]; + + auto expr = GetBoundValue(binding); + + binding_index_lookup[binding->var] = i_binding; + + info_vec.push_back(VarInfo{ + binding->var, + expr, + patterns.Map( + [&](const DFPattern& pat) { return ExtractMatchedExpr(pat, expr, binding_lookup); }), + std::unordered_set(), + false, + }); + + auto new_match = [&]() -> std::optional, std::vector>> { + std::vector indices; + for (size_t i = 0; i < patterns.size(); i++) { + if (initialize_indices(patterns.size() - i - 1, indices)) { + do { + if (auto match = TryMatchByBindingIndex(info_vec, indices)) { + return std::pair{indices, match.value()}; + } + } while (decrement_indices(indices)); + } + } + return std::nullopt; + }(); + + if (new_match) { + const auto& [indices, exprs] = new_match.value(); + ICHECK_EQ(indices.size(), exprs.size()); + for (size_t i = 0; i < indices.size(); i++) { + ICHECK_LT(indices[i], info_vec.size()); + auto& info = info_vec[indices[i]]; + + ICHECK(!info.used) << "InternalError: " + << "Produced multiple replacements for variable " << info.var; + + rewrites.Set(info.var, exprs[i]); + binding_lookup.erase(info.var); + info.used = true; + } + } else { + binding_lookup.Set(binding->var, expr); + } + + for (const auto& prev_var : FreeVars(expr)) { + if (auto it = binding_index_lookup.find(prev_var); it != binding_index_lookup.end()) { + info_vec[it->second].downstream_usage.insert(binding->var); + } + } + } + + return rewrites; +} + +std::optional> TupleRewriterNode::TryMatchByBindingIndex( + const std::vector& info_vec, const std::vector& indices) const { + ICHECK_GE(indices.size(), 1); + + ICHECK_EQ(indices.size(), patterns.size()); + for (size_t i = 0; i < indices.size(); i++) { + const auto& info = info_vec[indices[i]]; + if (info.used || !info.matches[i]) { + return std::nullopt; + } + } + + Map merged_matches = info_vec[indices[0]].matches[0].value(); + for (size_t i = 1; i < indices.size(); i++) { + for (const auto& [pat, expr] : info_vec[indices[i]].matches[i].value()) { + if (auto it = merged_matches.find(pat); it != merged_matches.end()) { + if (!StructuralEqual()(expr, (*it).second)) { + return std::nullopt; + } + } else { + merged_matches.Set(pat, expr); + } + } + } + + bool tuple_element_is_already_used_outside_of_matched_tuple = [&]() -> bool { + std::unordered_set matched_vars; + for (const auto& [pat, expr] : merged_matches) { + if (auto opt = expr.as()) { + matched_vars.insert(opt.value()); + } + } + + for (size_t index : indices) { + const auto& downstream_of_rewritten_var = info_vec[index].downstream_usage; + + for (const auto& uses_matched_var : downstream_of_rewritten_var) { + if (!matched_vars.count(uses_matched_var)) { + return true; + } + } + } + + return false; + }(); + if (tuple_element_is_already_used_outside_of_matched_tuple) { + return std::nullopt; + } + + auto full_tuple = [&]() -> relax::Expr { + Array fields; + for (size_t index : indices) { + fields.push_back(info_vec[index].expr); + } + return relax::Tuple(fields); + }(); + + auto opt_rewritten = func(full_tuple, merged_matches); + if (!opt_rewritten) { + return std::nullopt; + } + auto rewritten = opt_rewritten.value(); + + if (rewritten.same_as(full_tuple)) { + return std::nullopt; + } + + std::vector rewrites; + if (auto inline_tuple = rewritten.as()) { + const auto& fields = inline_tuple->fields; + CHECK_EQ(fields.size(), indices.size()) + << "Expected to receive " << indices.size() << " values to replace TuplePattern with " + << indices.size() << " fields, but received " << fields.size() << " values"; + rewrites = {fields.begin(), fields.end()}; + } else { + for (size_t i = 0; i < indices.size(); i++) { + rewrites.push_back(TupleGetItem(rewritten, i)); + } + } + return rewrites; +} + +TVM_REGISTER_GLOBAL("relax.dpl.TupleRewriter") + .set_body_typed([](Array patterns, + TypedPackedFunc(Expr, Map)> func) { + return TupleRewriter(patterns, func); + }); + +TupleRewriter::TupleRewriter(Array patterns, + TypedPackedFunc(Expr, Map)> func, + Optional> additional_bindings, + Map new_subroutines) { + auto node = make_object(); + node->patterns = std::move(patterns); + node->func = std::move(func); + node->additional_bindings = std::move(additional_bindings); + node->new_subroutines = std::move(new_subroutines); + data_ = std::move(node); +} + +ExprRewriter ExprRewriter::FromPattern( + DFPattern pattern, TypedPackedFunc(Expr, Map)> func, + Optional> additional_bindings, Map new_subroutines) { + if (auto or_pattern = pattern.as()) { + auto new_additional_bindings = additional_bindings.value_or({}); + new_additional_bindings.push_back(pattern); + return OrRewriter( + ExprRewriter::FromPattern(or_pattern->left, func, new_additional_bindings, new_subroutines), + ExprRewriter::FromPattern(or_pattern->right, func, new_additional_bindings, + new_subroutines)); + } else if (auto tuple_pattern = pattern.as()) { + auto new_additional_bindings = additional_bindings.value_or({}); + new_additional_bindings.push_back(pattern); + // If the Tuple appears as a Relax binding, apply it first. As a + // fallback, also check for implicit tuples. + return OrRewriter( + PatternRewriter(pattern, func, additional_bindings, new_subroutines), + TupleRewriter(tuple_pattern->fields, func, new_additional_bindings, new_subroutines)); + } else { + return PatternRewriter(pattern, func, additional_bindings, new_subroutines); + } +} + +ExprRewriter ExprRewriter::FromModule(IRModule mod) { + Function func_pattern = [&]() { + CHECK(mod->ContainGlobalVar("pattern")) + << "KeyError: " + << "Expected module to contain 'pattern', " + << "a Relax function defining the pattern to be matched, " + << "but the module did not contain a 'pattern' function."; + auto base_func = mod->Lookup("pattern"); + CHECK(base_func->IsInstance()) + << "TypeError: " + << "Expected module to contain 'pattern', " + << "a Relax function defining the pattern to be matched, " + << "but the 'pattern' function was of type " << base_func->GetTypeKey() << "."; + return Downcast(base_func); + }(); + Function func_replacement = [&]() { + CHECK(mod->ContainGlobalVar("replacement")) + << "KeyError: " + + << "Expected module to contain 'replacement', " + << "a Relax function defining the replacement to be matched, " + << "but the module did not contain a 'replacement' function."; + auto base_func = mod->Lookup("replacement"); + CHECK(base_func->IsInstance()) + << "TypeError: " + << "Expected module to contain 'replacement', " + << "a Relax function defining the replacement to be made on a successful match, " + << "but the 'replacement' function was of type " << base_func->GetTypeKey() << "."; + return Downcast(base_func); + }(); + + Map new_subroutines; + for (const auto& [gvar, func] : mod->functions) { + if (gvar->name_hint != "pattern" && gvar->name_hint != "replacement") { + bool is_public = func->GetAttr(tvm::attr::kGlobalSymbol).defined(); + CHECK(!is_public) << "ValueError: " + << "Expected module to have no publicly-exposed functions " + << "other than 'pattern' and 'replacement'. " + << "However, function '" << gvar->name_hint << "' of type " + << func->GetTypeKey() << " is publicly exposed."; + new_subroutines.Set(gvar, func); + } + } + + auto sinfo_pattern = GetStructInfo(func_pattern); + auto sinfo_replacement = GetStructInfo(func_replacement); + CHECK(StructuralEqual()(sinfo_pattern, sinfo_replacement)) + << "ValueError: " + << "The pattern and replacement must have the same signature, " + << "but the pattern has struct info " << sinfo_pattern + << ", while the replacement has struct info " << sinfo_replacement; + + Array param_wildcards; + Map pattern_lookup; + for (const auto& param : func_pattern->params) { + WildcardPattern wildcard; + param_wildcards.push_back(wildcard); + pattern_lookup.Set(param, StructInfoPattern(wildcard, GetStructInfo(param))); + } + + std::function make_pattern = [&](Expr expr) -> DFPattern { + if (auto var = expr.as()) { + return pattern_lookup[var.value()]; + + } else if (auto call = expr.as()) { + auto op = make_pattern(call->op); + auto args = call->args.Map(make_pattern); + return CallPattern(op, args); + + } else if (auto tuple = expr.as()) { + auto fields = tuple->fields.Map(make_pattern); + return TuplePattern(fields); + + } else if (auto tuple_get_item = expr.as()) { + auto tuple = make_pattern(tuple_get_item->tuple); + return TupleGetItemPattern(tuple, tuple_get_item->index); + + } else if (auto op = expr.as()) { + return ExprPattern(op.value()); + + } else if (auto func = expr.as()) { + return ExternFuncPattern(func->global_symbol); + + } else { + LOG(FATAL) << "TypeError: " + << "Cannot convert Relax expression of type " << expr->GetTypeKey() + << " into pattern-matching rule."; + } + }; + + for (const auto& block : func_pattern->body->blocks) { + for (const auto& binding : block->bindings) { + auto value_pattern = make_pattern(GetBoundValue(binding)); + if (auto match_cast = binding.as()) { + value_pattern = StructInfoPattern(value_pattern, match_cast->struct_info); + } + pattern_lookup.Set(binding->var, value_pattern); + } + } + + DFPattern top_pattern = make_pattern(func_pattern->body->body); + + TypedPackedFunc(Expr, Map)> rewriter_func = + [param_wildcards = std::move(param_wildcards), + orig_func_replacement = std::move(func_replacement)]( + Expr expr, Map matches) -> Optional { + auto func_replacement = CopyWithNewVars(orig_func_replacement); + + Array new_blocks; + + Array wildcard_bindings; + ICHECK_EQ(param_wildcards.size(), func_replacement->params.size()); + for (size_t i = 0; i < param_wildcards.size(); i++) { + Expr matched_expr = matches[param_wildcards[i]]; + + // Introduce an intermediate variable, to ensure that the + // MatchCast's target will be a Var, even for expressions that + // wouldn't normally be normalized into a variable. + Var intermediate_var("intermediate_var", GetStructInfo(matched_expr)); + wildcard_bindings.push_back(VarBinding(intermediate_var, matched_expr)); + wildcard_bindings.push_back( + MatchCast(func_replacement->params[i], intermediate_var, GetStructInfo(matched_expr))); + } + + new_blocks.push_back(DataflowBlock(wildcard_bindings)); + + for (const auto& block : func_replacement->body->blocks) { + new_blocks.push_back(block); + } + + return SeqExpr(new_blocks, func_replacement->body->body); + }; + + return ExprRewriter::FromPattern(top_pattern, rewriter_func, NullOpt, new_subroutines); +} + Optional> ExtractMatchedExpr(DFPattern pattern, Expr expr, Optional> bindings_opt) { auto bindings = bindings_opt.value_or({}); @@ -46,12 +789,7 @@ Optional> ExtractMatchedExpr(DFPattern pattern, Expr expr, return NullOpt; } - Map matching; - for (const auto& [pat, matches] : matcher.GetMemo()) { - ICHECK_EQ(matches.size(), 1) << "More than one match for the pattern " << pat; - matching.Set(pat, matches[0]); - } - return matching; + return matcher.GetMemo(); } TVM_REGISTER_GLOBAL("relax.dpl.extract_matched_expr").set_body_typed(ExtractMatchedExpr); @@ -66,34 +804,23 @@ TVM_REGISTER_GLOBAL("relax.dpl.match_expr").set_body_typed(MatchExpr); * \brief Apply pattern matching to each expression, replacing * matches with the output of a user-provided rewriter function. */ -class ExprPatternRewriter : ExprMutator { +class ExprPatternRewriter : public ExprMutator { public: using ExprMutator::VisitExpr_; - ExprPatternRewriter(DFPattern pat, - TypedPackedFunc)> rewriter_func) - : pattern_(pat), rewriter_func_(rewriter_func) {} + ExprPatternRewriter(const ExprRewriterNode* rewriter) : rewriter_(rewriter) {} - template - static Function Run(PatternType pat, - TypedPackedFunc)> rewriter_func, - Function func) { - ExprPatternRewriter rewriter(pat, rewriter_func); - func = Downcast(rewriter(func)); - func = Downcast(RemoveAllUnused(func)); - return func; - } + Map GetNewSubroutines() const { return new_subroutines_; } Expr VisitExpr_(const SeqExprNode* seq) override { - auto cache = bindings_; - SeqExpr prev = GetRef(seq); + SeqExpr prev = Downcast(ExprMutator::VisitExpr_(seq)); StructuralEqual struct_equal; - while (true) { - SeqExpr next = Downcast(builder_->Normalize(ExprMutator::VisitExpr_(prev.get()))); + while (auto opt = TryRewriteSeqExpr(prev)) { + SeqExpr next = Downcast(builder_->Normalize(opt.value())); if (struct_equal(prev, next)) { - return std::move(next); + break; } // Canonicalization may result in two previously-different @@ -112,108 +839,235 @@ class ExprPatternRewriter : ExprMutator { } if (struct_equal(prev, next)) { - return std::move(next); + break; } - // Reset all knowledge of bindings that were collected from - // this SeqExpr. The collected bindings are only after - // the point where they were collected, and we are repeating - // the mutation of this SeqExpr. - bindings_ = cache; prev = next; } - } - void VisitBinding_(const VarBindingNode* binding) override { - auto expr = VisitExpr(binding->value); - bindings_.Set(binding->var, expr); - ReEmitBinding(binding, expr); + return prev; } - Expr VisitExpr(const Expr& expr) override { - auto node = ExprMutator::VisitExpr(expr); + Optional TryRewriteSeqExpr(const SeqExpr& seq) { + Array old_blocks = seq->blocks; + + // If the SeqExpr's output is not a variable, treat it as if it + // were the last variable binding of the last block. This + // simplifies the special handling of the SeqExpr's body. + Optional dummy_output_var = NullOpt; + if (!seq->body->IsInstance()) { + dummy_output_var = Var("dummy_output_var", GetStructInfo(seq->body)); + VarBinding dummy_binding(dummy_output_var.value(), seq->body); + + auto last_block = [&]() { + if (seq->blocks.size()) { + auto last_block = old_blocks.back(); + old_blocks.pop_back(); + return last_block; + } else { + return BindingBlock(Array{}); + } + }(); - std::vector matches_top_level; - if (auto rewritten = TryRewrite(node, pattern_, &matches_top_level)) { - return builder_->Normalize(rewritten.value()); + last_block.CopyOnWrite()->bindings.push_back(dummy_binding); + old_blocks.push_back(last_block); } - return node; - } + auto rewrite_block = [&](Array orig_bindings) -> Array { + auto rewrites = rewriter_->RewriteBindings(orig_bindings); + if (!rewrites) return orig_bindings; - private: - Optional TryRewrite(const Expr& expr, const DFPattern& pattern, - std::vector* matches_top_level) { - ICHECK(matches_top_level); - - // Special handling if the user-supplied pattern is a `OrPattern`. - // While the `ExtractMatchedExpr` can handle matching the - // `OrPattern`, it will return on the first match, even if the - // `rewriter_func_` doesn't apply a replacement. Unpacking the - // `OrPattern` here allows the match to be resumed if - // `rewriter_func_` returns the original function unmodified. - // This is only valid for a top-level match. - if (auto or_pattern = pattern.as()) { - matches_top_level->push_back(pattern); - Optional output = TryRewrite(expr, or_pattern->left, matches_top_level); - if (!output.defined()) { - output = TryRewrite(expr, or_pattern->right, matches_top_level); - } - matches_top_level->pop_back(); - return output; - } - - if (auto opt_matches = ExtractMatchedExpr(pattern, expr, bindings_)) { - auto matches = opt_matches.value(); + for (auto [gvar, func] : rewrites.new_subroutines) { + new_subroutines_.Set(gvar, func); + } - // Append any additional matches that from the unwrapped - // `OrPattern`. When matching against `pat = pat_lhs | - // pat_rhs`, we call `ExtractMatchedExpr` on `pat_lhs` and - // `pat_rhs` separately. The top-level `pat` is never seen by - // `ExtractMatchedExpr`, and must be re-added afterward. - if (matches_top_level->size()) { - auto matched_expr = DFPatternMatcher::UnwrapBindings(expr, bindings_); - for (const auto& pat : *matches_top_level) { - matches.Set(pat, matched_expr); + auto bindings = orig_bindings.Map([&](Binding binding) -> Binding { + if (auto new_expr = rewrites.variable_rewrites.Get(binding->var)) { + if (auto match_cast = binding.as()) { + return MatchCast(binding->var, new_expr.value(), match_cast->struct_info); + } else { + return VarBinding(binding->var, new_expr.value()); + } + } else { + return binding; } + }); + + if (bindings.same_as(orig_bindings)) { + return orig_bindings; } - Expr rewritten_expr = rewriter_func_(expr, matches); - if (!rewritten_expr.same_as(expr)) { - return builder_->Normalize(rewritten_expr); + // The rewriter may have introduced additional dependencies + // between computations. Since pattern-matching only occurs + // within blocks that may be re-ordered, these can be resolved + // by performing a topological sort. + bindings = TopologicalSort(bindings); + + return bindings; + }; + + // Utility function to return the rewrites that should be applied + // to a given block. + auto get_rewrites = [&](BindingBlock block) -> Array { + if (block.as()) { + // Early return for DataflowBlock. Since neither control flow + // nor impure functions are allowed within the dataflow block, + // all bindings may be considered at the same time. + return rewrite_block(block->bindings); } + + RewriteSpec rewrites; + + Array collected_bindings; + Array finalized_bindings; + + auto handle_collected_rewrites = [&]() { + if (collected_bindings.size()) { + auto bindings = rewrite_block(collected_bindings); + if (finalized_bindings.empty()) { + finalized_bindings = bindings; + } else { + for (const auto& binding : bindings) { + finalized_bindings.push_back(binding); + } + } + collected_bindings.clear(); + } + }; + + for (const auto& binding : block->bindings) { + auto value = GetBoundValue(binding); + bool is_dataflow = (!value.as()) && + (!(value.as() && IsImpureCall(Downcast(value)))); + if (is_dataflow) { + // This binding satisfies the dataflow constraints. + collected_bindings.push_back(binding); + } else { + // This binding does not satisfy the dataflow constraints. + // Any operations prior to this binding should be checked + // for pattern-match replacements. + handle_collected_rewrites(); + finalized_bindings.push_back(binding); + } + } + + // Check for rewrites in dataflow operations after the last + // non-dataflow segment. + handle_collected_rewrites(); + + return finalized_bindings; + }; + + // Utility function, check for and apply rewrites to a single + // block. + auto visit_block = [&](BindingBlock old_block) -> BindingBlock { + auto new_bindings = get_rewrites(old_block); + if (new_bindings.same_as(old_block->bindings)) { + return old_block; + } + + if (old_block.as()) { + builder_->BeginDataflowBlock(); + } else { + builder_->BeginBindingBlock(); + } + + for (const auto& binding : new_bindings) { + auto value = builder_->Normalize(GetBoundValue(binding)); + + if (binding.as()) { + builder_->EmitNormalized(VarBinding(binding->var, value)); + } else if (auto match_cast = binding.as()) { + builder_->EmitNormalized(MatchCast(binding->var, value, match_cast->struct_info)); + } else { + LOG(FATAL) << "Binding must be either VarBinding or MatchCast"; + } + } + return builder_->EndBlock(); + }; + + auto new_blocks = old_blocks.Map(visit_block); + if (old_blocks.same_as(new_blocks)) { + return NullOpt; } - return NullOpt; + // Restore the body of the SeqExpr, if needed. + auto new_body = [&]() -> Expr { + if (dummy_output_var) { + auto last_block = new_blocks.back(); + new_blocks.pop_back(); + + auto last_binding = last_block->bindings.back(); + last_block.CopyOnWrite()->bindings.pop_back(); + ICHECK(last_binding->var.same_as(dummy_output_var)); + + if (last_block->bindings.size()) { + new_blocks.push_back(last_block); + } + + return GetBoundValue(last_binding); + } else { + return seq->body; + } + }(); + + return SeqExpr(new_blocks, new_body); } - /*! \brief The pattern for rewriting call nodes */ - DFPattern pattern_; - /*! - * \brief The user-provided rewriter function. Its signature and semantics are: - * - * - (Call, Map) -> Call - * - * Given the matched call node and the map of patterns and - * matched expressions, it should return a new call node to - * replace the original one or the original matched call node as - * is. - */ - TypedPackedFunc)> rewriter_func_; - - /*! \brief The known variable bindings - * - * The variable bindings whose value is known. This must be tracked - * separately from the block builder, so that it can be reset after - * each iteration of the mutate-until-converged loop applied to - * `SeqExpr`. - */ - Map bindings_; + private: + const ExprRewriterNode* rewriter_; + Map new_subroutines_; }; +Expr ExprRewriter::operator()(Expr expr) { + ExprPatternRewriter mutator(get()); + auto new_expr = mutator(expr); + auto new_subroutines = mutator.GetNewSubroutines(); + CHECK_EQ(new_subroutines.size(), 0) + << "If ExprRewriter provides subroutines, " + << "then it must be applied to an entire IRModule. " + << "However, ExprRewriter produced subroutines " << [&]() -> Array { + std::vector vec; + for (const auto& [gvar, func] : new_subroutines) { + vec.push_back(gvar); + } + std::sort(vec.begin(), vec.end(), + [](const GlobalVar& a, const GlobalVar& b) { return a->name_hint < b->name_hint; }); + return vec; + }() << "when applied to " + << "Relax expression of type " << expr->GetTypeKey(); + return new_expr; +} + +IRModule ExprRewriterNode::operator()(IRModule mod, + const tvm::transform::PassContext& pass_ctx) const { + ExprPatternRewriter mutator(this); + + IRModule updates; + for (const auto& [gvar, base_func] : mod->functions) { + if (auto func = base_func.as()) { + auto rewritten = Downcast(mutator(func.value())); + if (!rewritten.same_as(base_func)) { + updates->Add(gvar, rewritten); + } + } + } + + if (updates->functions.size()) { + auto write_ptr = mod.CopyOnWrite(); + write_ptr->Update(updates); + write_ptr->Update(IRModule(mutator.GetNewSubroutines())); + } + + return mod; +} +tvm::transform::PassInfo ExprRewriterNode::Info() const { + return tvm::transform::PassInfo(0, "ExprRewriter", {}, false); +} + Function RewriteCall(const DFPattern& pat, TypedPackedFunc)> rewriter, Function func) { - return ExprPatternRewriter::Run(pat, rewriter, func); + return Downcast(ExprRewriter::FromPattern(pat, rewriter)(func)); } TVM_REGISTER_GLOBAL("relax.dpl.rewrite_call").set_body_typed(RewriteCall); diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 989c1174f41d..b6994f017466 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -101,14 +101,13 @@ bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr auto expr = UnwrapBindings(expr0, var2val_); if (memoize_ && memo_.count(pattern)) { - ICHECK_EQ(memo_[pattern].size(), 1); - return expr.same_as(memo_[pattern][0]); + return expr.same_as(memo_[pattern]); } else { PrimExpr cached_condition = symbolic_expr_condition_; size_t watermark = matched_nodes_.size(); bool out = DFPatternFunctor::VisitDFPattern(pattern, expr); if (out) { - memo_[pattern].push_back(expr); + memo_[pattern] = expr; matched_nodes_.push_back(pattern); } else { ClearMap(watermark); diff --git a/src/relax/ir/dataflow_matcher.h b/src/relax/ir/dataflow_matcher.h index 9036c7630a54..93141af81c7c 100644 --- a/src/relax/ir/dataflow_matcher.h +++ b/src/relax/ir/dataflow_matcher.h @@ -43,7 +43,7 @@ class DFPatternMatcher : public DFPatternFunctor> GetMemo() { return Map>(memo_); } + Map GetMemo() { return memo_; } /* \brief Unwrap trivial expressions/bindings */ static Expr UnwrapBindings(Expr expr, const Map& bindings); @@ -91,7 +91,7 @@ class DFPatternMatcher : public DFPatternFunctor, ObjectPtrHash, ObjectPtrEqual> memo_; + std::unordered_map memo_; var2val_t var2val_; std::vector matched_nodes_; PrimExpr symbolic_expr_condition_{Bool(true)}; diff --git a/src/relax/ir/dataflow_rewriter.h b/src/relax/ir/dataflow_rewriter.h new file mode 100644 index 000000000000..d26695a7ce52 --- /dev/null +++ b/src/relax/ir/dataflow_rewriter.h @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/dataflow_rewriter.h + * \brief Pattern match/rewriters for Relax + */ +#ifndef TVM_RELAX_IR_DATAFLOW_REWRITER_H_ +#define TVM_RELAX_IR_DATAFLOW_REWRITER_H_ + +#include +#include +#include +#include + +#include + +#include "dataflow_matcher.h" + +namespace tvm { +namespace relax { + +struct RewriteSpec { + Map variable_rewrites; + Map new_subroutines; + + explicit operator bool() const { return variable_rewrites.size(); } + + void Append(RewriteSpec other); +}; + +class ExprRewriterNode : public tvm::transform::PassNode { + public: + virtual RewriteSpec RewriteBindings(const Array& bindings) const { + return RewriteSpec(); + } + + void VisitAttrs(AttrVisitor* visitor) {} + + IRModule operator()(IRModule mod, const tvm::transform::PassContext& pass_ctx) const override; + tvm::transform::PassInfo Info() const override; + + static constexpr const char* _type_key = "relax.dpl.ExprRewriter"; + TVM_DECLARE_BASE_OBJECT_INFO(ExprRewriterNode, PassNode); +}; + +class ExprRewriter : public tvm::transform::Pass { + public: + static ExprRewriter FromPattern(DFPattern pattern, + TypedPackedFunc(Expr, Map)> func, + Optional> additional_bindings = NullOpt, + Map new_subroutines = {}); + + static ExprRewriter FromModule(IRModule mod); + + Expr operator()(Expr expr); + using Pass::operator(); + + TVM_DEFINE_OBJECT_REF_METHODS(ExprRewriter, Pass, ExprRewriterNode); +}; + +class PatternRewriterNode : public ExprRewriterNode { + public: + DFPattern pattern; + TypedPackedFunc(Expr, Map)> func; + Optional> additional_bindings; + Map new_subroutines; + + RewriteSpec RewriteBindings(const Array& bindings) const final; + + Optional RewriteExpr(const Expr& expr, const Map& bindings) const; + + void VisitAttrs(AttrVisitor* visitor) { + visitor->Visit("pattern", &pattern); + PackedFunc untyped_func = func; + visitor->Visit("func", &untyped_func); + } + + static constexpr const char* _type_key = "relax.dpl.PatternRewriter"; + TVM_DECLARE_BASE_OBJECT_INFO(PatternRewriterNode, ExprRewriterNode); +}; + +class PatternRewriter : public ExprRewriter { + public: + PatternRewriter(DFPattern pattern, + TypedPackedFunc(Expr, Map)> func, + Optional> additional_bindings = NullOpt, + Map new_subroutines = {}); + + TVM_DEFINE_OBJECT_REF_METHODS(PatternRewriter, ExprRewriter, PatternRewriterNode); +}; + +class OrRewriterNode : public ExprRewriterNode { + public: + ExprRewriter lhs; + ExprRewriter rhs; + + RewriteSpec RewriteBindings(const Array& bindings) const override; + + void VisitAttrs(AttrVisitor* visitor) { + visitor->Visit("lhs", &lhs); + visitor->Visit("rhs", &rhs); + } + + static constexpr const char* _type_key = "relax.dpl.OrRewriter"; + TVM_DECLARE_BASE_OBJECT_INFO(OrRewriterNode, ExprRewriterNode); +}; + +class OrRewriter : public ExprRewriter { + public: + OrRewriter(ExprRewriter lhs, ExprRewriter rhs); + + TVM_DEFINE_OBJECT_REF_METHODS(OrRewriter, ExprRewriter, OrRewriterNode); +}; + +class TupleRewriterNode : public ExprRewriterNode { + public: + Array patterns; + TypedPackedFunc(Expr, Map)> func; + Optional> additional_bindings; + Map new_subroutines; + + RewriteSpec RewriteBindings(const Array& bindings) const override; + + void VisitAttrs(AttrVisitor* visitor) { + visitor->Visit("patterns", &patterns); + PackedFunc untyped_func = func; + visitor->Visit("func", &untyped_func); + } + + static constexpr const char* _type_key = "relax.dpl.TupleRewriter"; + TVM_DECLARE_BASE_OBJECT_INFO(TupleRewriterNode, ExprRewriterNode); + + private: + struct VarInfo { + Var var; + Expr expr; + Array>> matches; + std::unordered_set downstream_usage; + bool used = false; + }; + + Map GenerateVariableRewrites(const Array& bindings) const; + + std::optional> TryMatchByBindingIndex(const std::vector& info_vec, + const std::vector& indices) const; +}; + +class TupleRewriter : public ExprRewriter { + public: + TupleRewriter(Array patterns, + TypedPackedFunc(Expr, Map)> func, + Optional> additional_bindings = NullOpt, + Map new_subroutines = {}); + + TVM_DEFINE_OBJECT_REF_METHODS(TupleRewriter, ExprRewriter, TupleRewriterNode); +}; + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_IR_DATAFLOW_REWRITER_H_ diff --git a/tests/python/relax/test_dataflow_rewriter.py b/tests/python/relax/test_dataflow_rewriter.py new file mode 100644 index 000000000000..1d917c59523b --- /dev/null +++ b/tests/python/relax/test_dataflow_rewriter.py @@ -0,0 +1,1388 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import tvm.testing +from tvm.relax.dpl import ExprRewriter +from tvm.script import ir as I, relax as R, tir as T + +import pytest + + +def test_rewrite_defined_by_ir_module(): + @R.rewriter + class Rewriter: + @R.function + def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.add(A, B) + return C + + @R.function + def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.call_pure_packed( + "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + return C + + @R.function + def before(x: R.Tensor([32], "float32")): + R.func_attr({"global_symbol": "main"}) + split = R.split(x, 2) + lhs = split[0] + rhs = split[1] + out = lhs + rhs + return out + + @R.function + def expected(x: R.Tensor([32], "float32")): + R.func_attr({"global_symbol": "main"}) + split = R.split(x, 2) + lhs = split[0] + rhs = split[1] + out = R.call_pure_packed( + "my_optimized_add_impl", lhs, rhs, sinfo_args=R.Tensor([16], "float32") + ) + return out + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_missing_pattern_raises_error(): + """The rewriter must define a pattern to be matched""" + + with pytest.raises(KeyError, match="pattern"): + + @R.rewriter + class Rewriter: + @R.function + def replacement(): + return R.tuple() + + +def test_incorrect_function_type_of_pattern_raises_error(): + """The rewriter's pattern must be a Relax function""" + + with pytest.raises(TypeError, match="pattern"): + + @R.rewriter + class Rewriter: + @T.prim_func + def pattern(): + pass + + @R.function + def replacement(): + return R.tuple() + + +def test_missing_replacement_raises_error(): + """The rewriter must define a replacement""" + + with pytest.raises(KeyError, match="replacement"): + + @R.rewriter + class Rewriter: + @R.function + def pattern(): + return R.tuple() + + +def test_incorrect_function_type_of_replacement_raises_error(): + """The rewriter's replacement must be a Relax function""" + + with pytest.raises(TypeError, match="replacement"): + + @R.rewriter + class Rewriter: + @R.function + def pattern(): + return R.tuple() + + @T.prim_func + def replacement(): + pass + + +def test_mismatch_of_static_shapes_raises_error(): + """The pattern and replacement must accept the same shapes""" + + with pytest.raises(ValueError, match="must have the same signature"): + + @R.rewriter + class Rewriter: + @R.function + def pattern(A: R.Tensor([32])): + return A + + @R.function + def replacement(A: R.Tensor([16])): + return A + + +def test_rewriter_may_be_applied_to_ir_module(): + """A rewriter may mutate an IRModule + + The `ExprRewriter.__call__` implementation may accept either a + single Relax function, or an entire IRModule. If it is passed an + IRModule, then all functions in the `IRModule` are updated. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.add(A, B) + return C + + @R.function + def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.call_pure_packed( + "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + return C + + @I.ir_module + class Before: + @R.function + def func_a(x: R.Tensor([32], "float32")): + split = R.split(x, 2) + lhs = split[0] + rhs = split[1] + out = lhs + rhs + return out + + @R.function + def func_b(x: R.Tensor([16], "float32")): + out = x + x + return out + + @I.ir_module + class Expected: + @R.function + def func_a(x: R.Tensor([32], "float32")): + split = R.split(x, 2) + lhs = split[0] + rhs = split[1] + out = R.call_pure_packed( + "my_optimized_add_impl", lhs, rhs, sinfo_args=R.Tensor([16], "float32") + ) + return out + + @R.function + def func_b(x: R.Tensor([16], "float32")): + out = R.call_pure_packed( + "my_optimized_add_impl", x, x, sinfo_args=R.Tensor([16], "float32") + ) + return out + + After = Rewriter(Before) + tvm.ir.assert_structural_equal(Expected, After) + + +def test_rewriter_may_be_used_as_ir_transform(): + """A rewriter may be used as a tvm.ir.transform.Pass""" + + @R.rewriter + class Rewriter: + @R.function + def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.add(A, B) + return C + + @R.function + def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.call_pure_packed( + "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + return C + + @I.ir_module + class Before: + @R.function + def main(x: R.Tensor([16], "float32")): + y = x + x + return y + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor([16], "float32")): + out = R.call_pure_packed( + "my_optimized_add_impl", x, x, sinfo_args=R.Tensor([16], "float32") + ) + return out + + After = tvm.ir.transform.Sequential([Rewriter])(Before) + tvm.ir.assert_structural_equal(Expected, After) + + +def test_same_pattern_applied_multiple_times(): + """The pattern-match may apply multiple times""" + + @R.rewriter + class Rewriter: + @R.function + def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.add(A, B) + return C + + @R.function + def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.call_pure_packed( + "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + return C + + @R.function(private=True) + def before(x: R.Tensor([16], "float32")): + y = x + x + z = y + y + return z + + @R.function(private=True) + def expected(x: R.Tensor([16], "float32")): + y = R.call_pure_packed( + "my_optimized_add_impl", x, x, sinfo_args=R.Tensor([16], "float32") + ) + z = R.call_pure_packed( + "my_optimized_add_impl", y, y, sinfo_args=R.Tensor([16], "float32") + ) + return z + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_composition_of_rewrite_rules(): + """Rewrite rules may be composed together""" + + @R.rewriter + class RewriteAdd: + @R.function + def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = A + B + return C + + @R.function + def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.call_pure_packed( + "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + return C + + @R.rewriter + class RewriteMultiply: + @R.function + def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = A * B + return C + + @R.function + def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.call_pure_packed( + "my_optimized_mul_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + return C + + @R.function(private=True) + def before( + A: R.Tensor([16], "float32"), + B: R.Tensor([16], "float32"), + C: R.Tensor([16], "float32"), + ): + D = A + B + E = C * D + return E + + @R.function(private=True) + def expected( + A: R.Tensor([16], "float32"), + B: R.Tensor([16], "float32"), + C: R.Tensor([16], "float32"), + ): + D = R.call_pure_packed( + "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + E = R.call_pure_packed( + "my_optimized_mul_impl", C, D, sinfo_args=R.Tensor([16], "float32") + ) + return E + + rewriter = RewriteAdd | RewriteMultiply + + after = rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_recursive_rewrite_rules(): + """Rewrite rules are applied until convergence + + In this test, both the `RewriteAdd` and `RewriteMultiply` patterns + must be applied in order to produce the expected output. However, + the `RewriteMultiply` pattern relies on the expression produced by + the `RewriteAdd` pass. + + """ + + @R.rewriter + class RewriteAdd: + @R.function + def pattern(A: R.Tensor([16], "float32")): + return A + A + + @R.function + def replacement(A: R.Tensor([16], "float32")): + return A * R.const(2.0, "float32") + + @R.rewriter + class RewriteMultiply: + @R.function + def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([], "float32")): + C = A * B + return C + + @R.function + def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([], "float32")): + C = R.call_pure_packed( + "my_optimized_mul_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + return C + + @R.function(private=True) + def before(A: R.Tensor([16], "float32")): + B = A + A + return B + + @R.function(private=True) + def expected(A: R.Tensor([16], "float32")): + B = R.call_pure_packed( + "my_optimized_mul_impl", + A, + R.const(2.0, "float32"), + sinfo_args=R.Tensor([16], "float32"), + ) + return B + + rewriter = RewriteAdd | RewriteMultiply + + after = rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_rewrite_may_introduce_private_relax_subroutines(): + """The replacement may contain subroutines""" + + @R.rewriter + class Rewriter: + @R.function + def pattern(A: R.Tensor([16], "float32")): + return A + A + + @R.function + def replacement(A: R.Tensor([16], "float32")): + return Rewriter.subroutine(A) + + @R.function(private=True) + def subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + return A * R.const(2.0, "float32") + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor([16], "float32")): + B = A + A + C = B + B + return C + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor([16], "float32")): + B = Expected.subroutine(A) + C = Expected.subroutine(B) + return C + + @R.function(private=True) + def subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + return A * R.const(2.0, "float32") + + After = Rewriter(Before) + tvm.ir.assert_structural_equal(Expected, After) + + +def test_rewrite_only_introduces_private_subroutines_when_required(): + """Only subroutines that are used will be added to the module + + Like `test_rewrite_may_introduce_private_relax_subroutines`, but + the rewritten function only requires some of the subroutines + provided by the rewriter. + + """ + + @R.rewriter + class RewriteAdd: + @R.function + def pattern(A: R.Tensor([16], "float32")): + return A + A + + @R.function + def replacement(A: R.Tensor([16], "float32")): + return RewriteAdd.subroutine_add(A) + + @R.function(private=True) + def subroutine_add(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + return A * R.const(2.0, "float32") + + @R.rewriter + class RewriteMul: + @R.function + def pattern(A: R.Tensor([16], "float32")): + return A * A + + @R.function + def replacement(A: R.Tensor([16], "float32")): + return R.call_tir( + RewriteMul.subroutine_mul, [A], out_sinfo=R.Tensor([16], "float32") + ) + + @T.prim_func(private=True) + def subroutine_mul(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + for i in range(16): + B[i] = A[i] * A[i] + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor([16], "float32")): + B = A + A + C = B + B + return C + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor([16], "float32")): + B = Expected.subroutine_add(A) + C = Expected.subroutine_add(B) + return C + + @R.function(private=True) + def subroutine_add(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + return A * R.const(2.0, "float32") + + rewriter = RewriteAdd | RewriteMul + + After = rewriter(Before) + tvm.ir.assert_structural_equal(Expected, After) + + +def test_rewriter_may_not_introduce_public_subroutines(): + """The rewriter may only introduce private functions""" + + with pytest.raises(ValueError, match="is publicly exposed"): + + @R.rewriter + class Rewriter: + @R.function + def pattern(A: R.Tensor([16], "float32")): + return A + A + + @R.function + def replacement(A: R.Tensor([16], "float32")): + return Rewriter.subroutine(A) + + @R.function + def subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + return A * R.const(2.0, "float32") + + +def test_rewrite_branches_may_reuse_subroutine_name(): + """Each rewriter is independent, and may reuse subroutine names""" + + @R.rewriter + class RewriteAdd: + @R.function + def pattern(A: R.Tensor([16], "float32")): + return A + A + + @R.function + def replacement(A: R.Tensor([16], "float32")): + return RewriteAdd.subroutine(A) + + @R.function(private=True) + def subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + return A * R.const(2.0, "float32") + + @R.rewriter + class RewriteMul: + @R.function + def pattern(A: R.Tensor([16], "float32")): + return A * A + + @R.function + def replacement(A: R.Tensor([16], "float32")): + return R.call_tir( + RewriteMul.subroutine, [A], out_sinfo=R.Tensor([16], "float32") + ) + + @T.prim_func(private=True) + def subroutine(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + for i in range(16): + B[i] = A[i] * A[i] + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor([16], "float32")): + B = A + A + C = B * B + return C + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor([16], "float32")): + B = Expected.subroutine(A) + C = R.call_tir( + Expected.subroutine_1, [B], out_sinfo=R.Tensor([16], "float32") + ) + return C + + @R.function(private=True) + def subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + return A * R.const(2.0, "float32") + + @T.prim_func(private=True) + def subroutine_1(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + for i in range(16): + B[i] = A[i] * A[i] + + rewriter = RewriteAdd | RewriteMul + + After = rewriter(Before) + tvm.ir.assert_structural_equal(Expected, After) + + +def test_rewrite_of_explicit_relax_tuple(): + """The rewriter function may return a tuple + + When it occurs explicitly within the Relax function, the tuple + pattern matches against the Relax tuple, and the Relax tuple is + replaced. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern( + lhs_A: R.Tensor([16, 16], "float32"), + lhs_B: R.Tensor([16, 16], "float32"), + rhs: R.Tensor([16], "float32"), + ): + proj_A = R.matmul(lhs_A, rhs) + proj_B = R.matmul(lhs_B, rhs) + proj_tuple = (proj_A, proj_B) + return proj_tuple + + @R.function + def replacement( + lhs_A: R.Tensor([16, 16], "float32"), + lhs_B: R.Tensor([16, 16], "float32"), + rhs: R.Tensor([16], "float32"), + ): + lhs = R.concat([lhs_A, lhs_B]) + proj_concat = R.matmul(lhs, rhs) + proj_tuple = R.split(proj_concat, 2) + return proj_tuple + + @R.function(private=True) + def before( + state: R.Tensor([16], "float32"), + A: R.Tensor([16, 16], "float32"), + B: R.Tensor([16, 16], "float32"), + ): + proj_A = R.matmul(A, state) + proj_B = R.matmul(B, state) + proj_tuple = (proj_A, proj_B) + out = proj_tuple[0] + proj_tuple[1] + return out + + @R.function(private=True) + def expected( + state: R.Tensor([16], "float32"), + A: R.Tensor([16, 16], "float32"), + B: R.Tensor([16, 16], "float32"), + ): + concat_AB = R.concat([A, B]) + proj_concat = R.matmul(concat_AB, state) + proj_tuple = R.split(proj_concat, 2) + out = proj_tuple[0] + proj_tuple[1] + return out + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_rewrite_of_output_relax_tuple(): + """The rewriter may update a tuple being returned + + Unlike most relax expressions, tuples may appear as nested + expressions. Pattern-matching should be aware of this option. + + Like `test_rewrite_of_explicit_relax_tuple`, but the tuple appears + as the return value in the function being modified. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern( + lhs_A: R.Tensor([16, 16], "float32"), + lhs_B: R.Tensor([16, 16], "float32"), + rhs: R.Tensor([16], "float32"), + ): + proj_A = R.matmul(lhs_A, rhs) + proj_B = R.matmul(lhs_B, rhs) + proj_tuple = (proj_A, proj_B) + return proj_tuple + + @R.function + def replacement( + lhs_A: R.Tensor([16, 16], "float32"), + lhs_B: R.Tensor([16, 16], "float32"), + rhs: R.Tensor([16], "float32"), + ): + lhs = R.concat([lhs_A, lhs_B]) + proj_concat = R.matmul(lhs, rhs) + proj_tuple = R.split(proj_concat, 2) + return proj_tuple + + @R.function(private=True) + def before( + state: R.Tensor([16], "float32"), + A: R.Tensor([16, 16], "float32"), + B: R.Tensor([16, 16], "float32"), + ): + proj_A = R.matmul(A, state) + proj_B = R.matmul(B, state) + return (proj_A, proj_B) + + @R.function(private=True) + def expected( + state: R.Tensor([16], "float32"), + A: R.Tensor([16, 16], "float32"), + B: R.Tensor([16, 16], "float32"), + ): + concat_AB = R.concat([A, B]) + proj_concat = R.matmul(concat_AB, state) + proj_tuple = R.split(proj_concat, 2) + return proj_tuple + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_rewrite_of_implicit_tuple(): + """The rewriter function may return a tuple + + The tuple being replaced does not need to explicitly exist within + the updated Relax function. So long as each element of the tuple + pattern matches a Relax expression, the pattern match can apply. + + This rule ensures that pattern-matching is never broken when + `CanonicalizeBindings` is applied. + + This test is identical to `test_rewrite_of_explicit_relax_tuple`, + except that the function does not contain the round trip of + packing `proj_A` and `proj_B` into a tuple, then immediately + unpacking them from the tuple. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern( + lhs_A: R.Tensor([16, 16], "float32"), + lhs_B: R.Tensor([16, 16], "float32"), + rhs: R.Tensor([16], "float32"), + ): + proj_A = R.matmul(lhs_A, rhs) + proj_B = R.matmul(lhs_B, rhs) + proj_tuple = (proj_A, proj_B) + return proj_tuple + + @R.function + def replacement( + lhs_A: R.Tensor([16, 16], "float32"), + lhs_B: R.Tensor([16, 16], "float32"), + rhs: R.Tensor([16], "float32"), + ): + lhs = R.concat([lhs_A, lhs_B]) + proj_concat = R.matmul(lhs, rhs) + proj_tuple = R.split(proj_concat, 2) + return proj_tuple + + @R.function(private=True) + def before( + state: R.Tensor([16], "float32"), + A: R.Tensor([16, 16], "float32"), + B: R.Tensor([16, 16], "float32"), + ): + proj_A = R.matmul(A, state) + proj_B = R.matmul(B, state) + out = proj_A + proj_B + return out + + @R.function(private=True) + def expected( + state: R.Tensor([16], "float32"), + A: R.Tensor([16, 16], "float32"), + B: R.Tensor([16, 16], "float32"), + ): + concat_AB = R.concat([A, B]) + proj_concat = R.matmul(concat_AB, state) + proj_tuple = R.split(proj_concat, 2) + out = proj_tuple[0] + proj_tuple[1] + return out + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_rewrite_of_implicit_tuple_with_shared_wildcard(): + """Tuple elements may depend on the same input + + Here, both elements of the tuple depend on `y`. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern( + x: R.Tensor([16], "float32"), + y: R.Tensor([16], "float32"), + z: R.Tensor([16], "float32"), + ): + lhs = x + y + rhs = y + z + return (lhs, rhs) + + @R.function + def replacement( + x: R.Tensor([16], "float32"), + y: R.Tensor([16], "float32"), + z: R.Tensor([16], "float32"), + ): + return R.call_pure_packed( + "optimized_impl", + x, + y, + z, + sinfo_args=R.Tuple( + [ + R.Tensor([16], "float32"), + R.Tensor([16], "float32"), + ] + ), + ) + + @R.function(private=True) + def before( + A: R.Tensor([16], "float32"), + B: R.Tensor([16], "float32"), + C: R.Tensor([16], "float32"), + ): + lhs = A + B + rhs = B + C + out = R.multiply(lhs, rhs) + return out + + @R.function(private=True) + def expected( + A: R.Tensor([16], "float32"), + B: R.Tensor([16], "float32"), + C: R.Tensor([16], "float32"), + ): + lhs_rhs = R.call_pure_packed( + "optimized_impl", + A, + B, + C, + sinfo_args=R.Tuple( + [ + R.Tensor([16], "float32"), + R.Tensor([16], "float32"), + ] + ), + ) + out = R.multiply(lhs_rhs[0], lhs_rhs[1]) + return out + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_no_rewrite_of_implicit_tuple_when_shared_wildcard_is_mismatched(): + """Tuple elements must match simultaneously + + Each element of the tuple matches individually, but the two + elements both depend on `B`. Because the first tuple element + would require `y = B`, while the second tuple element would + require `y = C`, the match fails. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern( + x: R.Tensor([16], "float32"), + y: R.Tensor([16], "float32"), + z: R.Tensor([16], "float32"), + ): + lhs = x + y + rhs = y + z + return (lhs, rhs) + + @R.function + def replacement( + A: R.Tensor([16], "float32"), + B: R.Tensor([16], "float32"), + C: R.Tensor([16], "float32"), + ): + return R.call_pure_packed( + "optimized_impl", + A, + B, + C, + sinfo_args=R.Tuple( + [ + R.Tensor([16], "float32"), + R.Tensor([16], "float32"), + ] + ), + ) + + @R.function(private=True) + def before( + A: R.Tensor([16], "float32"), + B: R.Tensor([16], "float32"), + C: R.Tensor([16], "float32"), + D: R.Tensor([16], "float32"), + ): + lhs = A + B + rhs = C + D + out = R.multiply(lhs, rhs) + return out + + expected = before + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_implicit_tuple_may_not_introduce_extra_compute(): + """Matching of implicit tuple may not cause extra compute + + Here, the `(proj_A, proj_B)` tuple could be an implcit tuple + match, but that would repeat the computation of `proj_A`. It + would be computed once on its own, to be used for `proj_A_on_B`, + and once for computing `(proj_A, proj_B)`. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern( + lhs_A: R.Tensor([16, 16], "float32"), + lhs_B: R.Tensor([16, 16], "float32"), + rhs: R.Tensor([16, 16], "float32"), + ): + proj_A = R.matmul(lhs_A, rhs) + proj_B = R.matmul(lhs_B, rhs) + proj_tuple = (proj_A, proj_B) + return proj_tuple + + @R.function + def replacement( + lhs_A: R.Tensor([16, 16], "float32"), + lhs_B: R.Tensor([16, 16], "float32"), + rhs: R.Tensor([16, 16], "float32"), + ): + lhs = R.concat([lhs_A, lhs_B]) + proj_concat = R.matmul(lhs, rhs) + proj_tuple = R.split(proj_concat, 2) + return proj_tuple + + @R.function(private=True) + def before( + state: R.Tensor([16, 16], "float32"), + A: R.Tensor([16, 16], "float32"), + B: R.Tensor([16, 16], "float32"), + ): + # This function has no location at which a tuple + # `(proj_A,proj_B)` could be constructed, then unpacked. + + proj_A = R.matmul(A, state) + + # A tuple `(proj_A, proj_B)` could not be constructed at this + # location, because `proj_B` has not yet been computed. + + proj_A_on_B = R.matmul(proj_A, B) + proj_B = R.matmul(proj_A_on_B, state) + + # A tuple `(proj_A, proj_B)` could be constructed here, but a + # use-site of `proj_A` has already occurred. Implicit + # matching of a tuple is only allowed if it would replace + # every use-site of a variable. + + out = proj_A + proj_B + return out + + expected = before + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_rewrite_of_implicit_tuple_with_three_elements(): + """Implicit tuples may contain three elements""" + + @R.rewriter + class Rewriter: + @R.function + def pattern(qkv: R.Tensor([12288], "float32")): + qkv_tuple = R.split(qkv, 3, axis=0) + q = qkv_tuple[0] + k = qkv_tuple[1] + v = qkv_tuple[2] + q_embed = R.call_pure_packed( + "rotary_embedding", [q], sinfo_args=R.Tensor([4096], "float32") + ) + k_embed = R.call_pure_packed( + "rotary_embedding", [k], sinfo_args=R.Tensor([4096], "float32") + ) + + return (q_embed, k_embed, v) + + @R.function + def replacement(qkv: R.Tensor([12288], "float32")): + return R.call_pure_packed( + "split_rotary_embedding", + [qkv], + sinfo_args=[ + R.Tensor([4096], "float32"), + R.Tensor([4096], "float32"), + R.Tensor([4096], "float32"), + ], + ) + + @R.function(private=True) + def before( + state: R.Tensor([4096], "float32"), + proj_qkv: R.Tensor([12288, 4096], "float32"), + kv_cache: R.Object, + ): + qkv = R.matmul(proj_qkv, state) + qkv_tuple = R.split(qkv, 3, axis=0) + q = qkv_tuple[0] + k = qkv_tuple[1] + v = qkv_tuple[2] + q_embed = R.call_pure_packed( + "rotary_embedding", [q], sinfo_args=R.Tensor([4096], "float32") + ) + k_embed = R.call_pure_packed( + "rotary_embedding", [k], sinfo_args=R.Tensor([4096], "float32") + ) + + attention = R.call_pure_packed( + "compute_self_attention", + [q_embed, k_embed, v, kv_cache], + sinfo_args=R.Tensor([4096]), + ) + + return attention + + @R.function(private=True) + def expected( + state: R.Tensor([4096], "float32"), + proj_qkv: R.Tensor([12288, 4096], "float32"), + kv_cache: R.Object, + ): + qkv = R.matmul(proj_qkv, state) + embedded_qkv_tuple = R.call_pure_packed( + "split_rotary_embedding", + [qkv], + sinfo_args=[ + R.Tensor([4096], "float32"), + R.Tensor([4096], "float32"), + R.Tensor([4096], "float32"), + ], + ) + + v = embedded_qkv_tuple[2] + q_embed = embedded_qkv_tuple[0] + k_embed = embedded_qkv_tuple[1] + + attention = R.call_pure_packed( + "compute_self_attention", + [q_embed, k_embed, v, kv_cache], + sinfo_args=R.Tensor([4096]), + ) + + return attention + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_pattern_matching_may_not_reorder_across_impure_functions(): + """Matched pattern must be ordered with respect to impure functions + + To ensure that debug printouts, memory management, performance + measurements, etc are not impacted by a pattern match, a pattern + must be entirely before, or entirely after an impure function. A + pattern match in which some parts of the matched expression are + performed before an impure function, while others are performed + afterwards, is not allowed. + + In this test, the matmul and the add may not be fused, because the + impure print statement occurs between them. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern( + state: R.Tensor([16], "float32"), + weights: R.Tensor([16, 16], "float32"), + bias: R.Tensor([16], "float32"), + ): + state = R.matmul(weights, state) + state = R.add(bias, state) + return state + + @R.function + def replacement( + state: R.Tensor([16], "float32"), + weights: R.Tensor([16, 16], "float32"), + bias: R.Tensor([16], "float32"), + ): + return R.call_pure_packed( + "my_optimized_fma_impl", + state, + weights, + bias, + sinfo_args=R.Tensor([16], "float32"), + ) + + @R.function(private=True, pure=False) + def before( + state: R.Tensor([16], "float32"), + weights: R.Tensor([16, 16], "float32"), + bias: R.Tensor([16], "float32"), + ): + R.print(format="Start of function") + state = R.matmul(weights, state) + R.print(format="After matmul, before add") + state = R.add(bias, state) + R.print(format="End of function") + return state + + expected = before + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_pattern_matching_may_occur_between_impure_functions(): + """Matched pattern may be adjacent to impure functions + + To ensure that debug printouts, memory management, performance + measurements, etc are not impacted by a pattern match, a pattern + must be entirely before, or entirely after an impure function. A + pattern match in which some parts of the matched expression are + performed before an impure function, while others are performed + afterwards, is not allowed. + + In this test, the matmul and the add may be fused, because the + pattern occurs without an impure print statement in-between. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern( + state: R.Tensor([16], "float32"), + weights: R.Tensor([16, 16], "float32"), + bias: R.Tensor([16], "float32"), + ): + state = R.matmul(weights, state) + state = R.add(bias, state) + return state + + @R.function + def replacement( + state: R.Tensor([16], "float32"), + weights: R.Tensor([16, 16], "float32"), + bias: R.Tensor([16], "float32"), + ): + return R.call_pure_packed( + "my_optimized_fma_impl", + state, + weights, + bias, + sinfo_args=R.Tensor([16], "float32"), + ) + + @R.function(private=True, pure=False) + def before( + state: R.Tensor([16], "float32"), + weights: R.Tensor([16, 16], "float32"), + bias: R.Tensor([16], "float32"), + ): + R.print(format="Start of function") + state = R.matmul(weights, state) + state = R.add(bias, state) + R.print(format="End of function") + return state + + @R.function(private=True, pure=False) + def expected( + state: R.Tensor([16], "float32"), + weights: R.Tensor([16, 16], "float32"), + bias: R.Tensor([16], "float32"), + ): + R.print(format="Start of function") + state = R.call_pure_packed( + "my_optimized_fma_impl", + state, + weights, + bias, + sinfo_args=R.Tensor([16], "float32"), + ) + R.print(format="End of function") + return state + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_rewrite_may_apply_within_conditional(): + """Rewrites may apply within to inner dataflow regions + + While dataflow regions may not contain conditionals, they may + occur within the body of conditionals. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + return A + B + + @R.function + def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + return R.call_pure_packed( + "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + + @R.function(private=True) + def before( + A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32"), cond: R.Prim("bool") + ): + if cond: + out = A + B + else: + C = A + B + out = C + B + return out + + @R.function(private=True) + def expected( + A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32"), cond: R.Prim("bool") + ): + if cond: + out = R.call_pure_packed( + "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + else: + C = R.call_pure_packed( + "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + out = R.call_pure_packed( + "my_optimized_add_impl", C, B, sinfo_args=R.Tensor([16], "float32") + ) + return out + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_match_dynamic_shape(): + """Pattern match/rewrites may be dynamic + + The tuple being replaced does not need to explicitly exist within + the updated Relax function. So long as each element of the tuple + pattern matches a Relax expression, the pattern match can apply. + + This rule ensures that pattern-matching is never broken when + `CanonicalizeBindings` is applied. + + This test is identical to `test_rewrite_of_explicit_relax_tuple`, + except that the function does not contain the round trip of + packing `proj_A` and `proj_B` into a tuple, then immediately + unpacking them from the tuple. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern( + lhs_A: R.Tensor(["N1", "M"], "float32"), + lhs_B: R.Tensor(["N2", "M"], "float32"), + rhs: R.Tensor(["M"], "float32"), + ): + proj_A = R.matmul(lhs_A, rhs) + proj_B = R.matmul(lhs_B, rhs) + return (proj_A, proj_B) + + @R.function + def replacement( + lhs_A: R.Tensor(["N1", "M"], "float32"), + lhs_B: R.Tensor(["N2", "M"], "float32"), + rhs: R.Tensor(["M"], "float32"), + ): + N1 = T.int64() + N2 = T.int64() + + lhs = R.concat([lhs_A, lhs_B]) + proj_concat = R.matmul(lhs, rhs) + proj_A: R.Tensor([N1], "float32") = R.strided_slice( + proj_concat, axes=[0], begin=[0], end=[N1] + ) + proj_B: R.Tensor([N2], "float32") = R.strided_slice( + proj_concat, axes=[0], begin=[N1], end=[N2 + N1] + ) + return (proj_A, proj_B) + + @R.function(private=True) + def before( + state: R.Tensor([16], "float32"), + A: R.Tensor([16, 16], "float32"), + B: R.Tensor([16, 16], "float32"), + ): + proj_A = R.matmul(A, state) + proj_B = R.matmul(B, state) + out = proj_A + proj_B + return out + + @R.function(private=True) + def expected( + state: R.Tensor([16], "float32"), + A: R.Tensor([16, 16], "float32"), + B: R.Tensor([16, 16], "float32"), + ): + concat_AB = R.concat([A, B]) + proj_concat = R.matmul(concat_AB, state) + proj_A = R.strided_slice(proj_concat, axes=[0], begin=[0], end=[16]) + proj_B = R.strided_slice(proj_concat, axes=[0], begin=[16], end=[32]) + out = proj_A + proj_B + return out + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_match_dynamic_pattern_against_dynamic_shape(): + """A dynamic pattern may match a static shape""" + + @R.rewriter + class Rewriter: + @R.function + def pattern( + A: R.Tensor(["M", "N"], "float32"), + B: R.Tensor(["N", "N"], "float32"), + ): + return R.matmul(A, B) + + @R.function + def replacement( + A: R.Tensor(["M", "N"], "float32"), + B: R.Tensor(["N", "N"], "float32"), + ): + M = T.int64() + N = T.int64() + return R.call_pure_packed( + "my_optimized_square_matmul", + A, + B, + sinfo_args=R.Tensor([M, N], "float32"), + ) + + @R.function(private=True) + def before( + A: R.Tensor(["N", "N*2"], "float32"), + B: R.Tensor(["N*2", "N*2"], "float32"), + C: R.Tensor(["N", "N"], "float32"), + ): + N = T.int64() + D: R.Tensor([N, N * 2], "float32") = R.matmul(A, B) + E: R.Tensor([N * 2, N], "float32") = R.permute_dims(D) + F: R.Tensor([N * 2, N], "float32") = R.matmul(E, C) + return F + + @R.function(private=True) + def expected( + A: R.Tensor(["N", "N*2"], "float32"), + B: R.Tensor(["N*2", "N*2"], "float32"), + C: R.Tensor(["N", "N"], "float32"), + ): + N = T.int64() + + D: R.Tensor([N, N * 2], "float32") = R.call_pure_packed( + "my_optimized_square_matmul", + A, + B, + sinfo_args=R.Tensor([N, N * 2], "float32"), + ) + E: R.Tensor([N * 2, N], "float32") = R.permute_dims(D) + F: R.Tensor([N * 2, N], "float32") = R.call_pure_packed( + "my_optimized_square_matmul", + E, + C, + sinfo_args=R.Tensor([N * 2, N], "float32"), + ) + return F + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +if __name__ == "__main__": + tvm.testing.main() From 28a72e9580236a58451084edb9f4c847aa9c5bc5 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 10 Jul 2024 15:05:38 -0500 Subject: [PATCH 05/17] lint fixes --- tests/python/relax/test_dataflow_rewriter.py | 36 +++++-------------- .../test_transform_canonicalize_bindings.py | 12 ++----- 2 files changed, 12 insertions(+), 36 deletions(-) diff --git a/tests/python/relax/test_dataflow_rewriter.py b/tests/python/relax/test_dataflow_rewriter.py index 1d917c59523b..05bbe429bbcc 100644 --- a/tests/python/relax/test_dataflow_rewriter.py +++ b/tests/python/relax/test_dataflow_rewriter.py @@ -257,12 +257,8 @@ def before(x: R.Tensor([16], "float32")): @R.function(private=True) def expected(x: R.Tensor([16], "float32")): - y = R.call_pure_packed( - "my_optimized_add_impl", x, x, sinfo_args=R.Tensor([16], "float32") - ) - z = R.call_pure_packed( - "my_optimized_add_impl", y, y, sinfo_args=R.Tensor([16], "float32") - ) + y = R.call_pure_packed("my_optimized_add_impl", x, x, sinfo_args=R.Tensor([16], "float32")) + z = R.call_pure_packed("my_optimized_add_impl", y, y, sinfo_args=R.Tensor([16], "float32")) return z after = Rewriter(before) @@ -316,12 +312,8 @@ def expected( B: R.Tensor([16], "float32"), C: R.Tensor([16], "float32"), ): - D = R.call_pure_packed( - "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") - ) - E = R.call_pure_packed( - "my_optimized_mul_impl", C, D, sinfo_args=R.Tensor([16], "float32") - ) + D = R.call_pure_packed("my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32")) + E = R.call_pure_packed("my_optimized_mul_impl", C, D, sinfo_args=R.Tensor([16], "float32")) return E rewriter = RewriteAdd | RewriteMultiply @@ -457,9 +449,7 @@ def pattern(A: R.Tensor([16], "float32")): @R.function def replacement(A: R.Tensor([16], "float32")): - return R.call_tir( - RewriteMul.subroutine_mul, [A], out_sinfo=R.Tensor([16], "float32") - ) + return R.call_tir(RewriteMul.subroutine_mul, [A], out_sinfo=R.Tensor([16], "float32")) @T.prim_func(private=True) def subroutine_mul(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): @@ -537,9 +527,7 @@ def pattern(A: R.Tensor([16], "float32")): @R.function def replacement(A: R.Tensor([16], "float32")): - return R.call_tir( - RewriteMul.subroutine, [A], out_sinfo=R.Tensor([16], "float32") - ) + return R.call_tir(RewriteMul.subroutine, [A], out_sinfo=R.Tensor([16], "float32")) @T.prim_func(private=True) def subroutine(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): @@ -559,9 +547,7 @@ class Expected: @R.function def main(A: R.Tensor([16], "float32")): B = Expected.subroutine(A) - C = R.call_tir( - Expected.subroutine_1, [B], out_sinfo=R.Tensor([16], "float32") - ) + C = R.call_tir(Expected.subroutine_1, [B], out_sinfo=R.Tensor([16], "float32")) return C @R.function(private=True) @@ -1212,9 +1198,7 @@ def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): ) @R.function(private=True) - def before( - A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32"), cond: R.Prim("bool") - ): + def before(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32"), cond: R.Prim("bool")): if cond: out = A + B else: @@ -1223,9 +1207,7 @@ def before( return out @R.function(private=True) - def expected( - A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32"), cond: R.Prim("bool") - ): + def expected(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32"), cond: R.Prim("bool")): if cond: out = R.call_pure_packed( "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") diff --git a/tests/python/relax/test_transform_canonicalize_bindings.py b/tests/python/relax/test_transform_canonicalize_bindings.py index 3255889960d4..ea3b1c249b8b 100644 --- a/tests/python/relax/test_transform_canonicalize_bindings.py +++ b/tests/python/relax/test_transform_canonicalize_bindings.py @@ -352,9 +352,7 @@ def main( # The symbolic shapes propagate downstream. lhs: R.Tensor([N1 + N2, M], "float32") = R.concat((lhs_A, lhs_B), axis=0) - proj_concat: R.Tensor([N1 + N2], "float32") = R.matmul( - lhs, rhs, out_dtype="void" - ) + proj_concat: R.Tensor([N1 + N2], "float32") = R.matmul(lhs, rhs, out_dtype="void") proj_A = R.strided_slice( proj_concat, (R.prim_value(0),), @@ -384,9 +382,7 @@ def main( # statically-known shapes. lhs: R.Tensor([32, 16], dtype="float32") = R.concat((A, B), axis=0) - proj_concat: R.Tensor([32], dtype="float32") = R.matmul( - lhs, state, out_dtype="void" - ) + proj_concat: R.Tensor([32], dtype="float32") = R.matmul(lhs, state, out_dtype="void") proj_A: R.Tensor([16], dtype="float32") = R.strided_slice( proj_concat, [R.prim_value(0)], @@ -435,9 +431,7 @@ def main( rhs = R.match_cast(state, R.Tensor([M], dtype="float32")) lhs: R.Tensor([N1 + N2, M], "float32") = R.concat((lhs_A, lhs_B), axis=0) - proj_concat: R.Tensor([N1 + N2], "float32") = R.matmul( - lhs, rhs, out_dtype="void" - ) + proj_concat: R.Tensor([N1 + N2], "float32") = R.matmul(lhs, rhs, out_dtype="void") proj_A = R.strided_slice( proj_concat, (R.prim_value(0),), From 07935d19976a37af17853fb8ed47dc1c0b615677 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 11 Jul 2024 07:33:14 -0500 Subject: [PATCH 06/17] Remove unnecessary change which broke a unit test --- python/tvm/script/parser/core/entry.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/tvm/script/parser/core/entry.py b/python/tvm/script/parser/core/entry.py index 3d35416d941a..e7a7f98b7651 100644 --- a/python/tvm/script/parser/core/entry.py +++ b/python/tvm/script/parser/core/entry.py @@ -83,8 +83,7 @@ def parse( The parsed TVMScript program. """ if extra_vars is None: - extra_vars = {} - extra_vars = {**extra_vars, **_default_globals()} + extra_vars = _default_globals() ann = {} if inspect.isfunction(program): From 0c038be6f703f203e704e98d30e3364d0ee66fe4 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 11 Jul 2024 11:54:42 -0500 Subject: [PATCH 07/17] lint fix for import order --- python/tvm/relax/dpl/rewrite.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/dpl/rewrite.py b/python/tvm/relax/dpl/rewrite.py index f124c11f7077..47d151e36a2e 100644 --- a/python/tvm/relax/dpl/rewrite.py +++ b/python/tvm/relax/dpl/rewrite.py @@ -17,12 +17,13 @@ """APIs for pattern-based rewriting.""" from typing import Dict, Callable, Union -from .pattern import DFPattern -from .context import PatternContext from tvm.ir import IRModule from tvm.runtime import Object from tvm._ffi import register_object + +from .pattern import DFPattern +from .context import PatternContext from ..expr import Expr, Function, Var from . import _ffi as ffi From c3144ca499f0ae995d14a3ec3d31bb2aec00b556 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 11 Jul 2024 12:54:10 -0500 Subject: [PATCH 08/17] Add docstrings --- python/tvm/relax/dpl/rewrite.py | 78 +++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/python/tvm/relax/dpl/rewrite.py b/python/tvm/relax/dpl/rewrite.py index 47d151e36a2e..9cd14e2c4cd8 100644 --- a/python/tvm/relax/dpl/rewrite.py +++ b/python/tvm/relax/dpl/rewrite.py @@ -30,11 +30,38 @@ @register_object("relax.dpl.ExprRewriter") class ExprRewriter(Object): + """A pattern-matching rewriter for Relax""" + @staticmethod def from_pattern( pattern: DFPattern, func: Callable[[Expr, Dict[DFPattern, Expr]], Expr], ) -> "ExprRewriter": + """Construct from a pattern and rewriter-function + + The replacements performed by the rewriter will be equivalent + to using the `pattern` and `func` as arguments to + `rewrite_call`. + + Parameters + ---------- + pattern: DFPattern + + The pattern to be matched against. + + func: Callable[[Expr, Dict[DFPattern, Expr]], Expr] + + A function that returns the rewritten expression. See + `rewrite_call` for details and examples. + + + Returns + ------- + rewriter_obj: ExprRewriter + + The rewriter object + + """ return ffi.ExprRewriterFromPattern( pattern, func, @@ -42,12 +69,63 @@ def from_pattern( @staticmethod def from_module(mod: IRModule) -> "ExprRewriter": + """Construct a rewriter from an IRModule + + Parameters + ---------- + mod: IRModule + + A module with `pattern` and `replacement` functions, + defining a rewrite rule. + + + Returns + ------- + rewriter_obj: ExprRewriter + + The rewriter object + + """ return ffi.ExprRewriterFromModule(mod) # type: ignore def __call__(self, obj: Union[Expr, IRModule]) -> Union[Expr, IRModule]: + """Apply the rewriter + + Parameters + ---------- + obj: Union[Expr, IRModule]) + + The object to be rewritten. May be applied to either a + relax expression, or an IRModule. + + Returns + ------- + updated: Union[Expr, IRModule] + + The rewritten object + + """ return ffi.ExprRewriterApply(self, obj) def __or__(self, other: "ExprRewriter") -> "ExprRewriter": + """Compose two rewriters + + Composing two rewrite rules together allows them to be applied + in a single Relax-level transformation. + + Parameters + ---------- + other: ExprRewriter + + Another rewrite rule + + Returns + ------- + ExprRewriter + + A rewriter that will apply either rewrite pattern + + """ return OrRewriter(self, other) From 9e7bd0be4fb520e564860615593687f9d387431f Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 11 Jul 2024 14:59:55 -0500 Subject: [PATCH 09/17] lint fix --- include/tvm/relax/block_builder.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/relax/block_builder.h b/include/tvm/relax/block_builder.h index 17347fcfd84d..e70b14c6082a 100644 --- a/include/tvm/relax/block_builder.h +++ b/include/tvm/relax/block_builder.h @@ -163,7 +163,7 @@ class BlockBuilderNode : public Object { /*! * \brief Append a definition to the cuurrent scope. * - * \param Var A variable within the current scope. + * \param var A variable within the current scope. * * \note This function should be called when a new variable is * defined that may impact struct inference (e.g. MatchCast) From df86a8197925ca0e8e668615b2bbf7f64c391088 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 11 Jul 2024 20:16:10 -0500 Subject: [PATCH 10/17] Lint fix --- src/relax/ir/dataflow_matcher.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/relax/ir/dataflow_matcher.h b/src/relax/ir/dataflow_matcher.h index 93141af81c7c..c5d58db5b9d0 100644 --- a/src/relax/ir/dataflow_matcher.h +++ b/src/relax/ir/dataflow_matcher.h @@ -18,11 +18,11 @@ */ /*! - * \file src/tvm/relax/dataflow_matcher_impl.h + * \file src/tvm/relax/dataflow_matcher.h * \brief The auxiliary data structure for dataflow matcher. */ -#ifndef TVM_RELAX_IR_DATAFLOW_MATCHER_IMPL_H_ -#define TVM_RELAX_IR_DATAFLOW_MATCHER_IMPL_H_ +#ifndef TVM_RELAX_IR_DATAFLOW_MATCHER_H_ +#define TVM_RELAX_IR_DATAFLOW_MATCHER_H_ #include #include @@ -102,4 +102,4 @@ class DFPatternMatcher : public DFPatternFunctor Date: Fri, 12 Jul 2024 12:52:18 -0500 Subject: [PATCH 11/17] lint fixes --- src/relax/ir/dataflow_expr_rewriter.cc | 2 +- src/relax/ir/dataflow_rewriter.h | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/relax/ir/dataflow_expr_rewriter.cc b/src/relax/ir/dataflow_expr_rewriter.cc index 8acaec60c356..e0dca16b2e91 100644 --- a/src/relax/ir/dataflow_expr_rewriter.cc +++ b/src/relax/ir/dataflow_expr_rewriter.cc @@ -44,7 +44,7 @@ namespace relax { namespace { class GlobalVarReplacer : public ExprMutator { public: - GlobalVarReplacer(Map gvar_map) : gvar_map_(gvar_map) {} + explicit GlobalVarReplacer(Map gvar_map) : gvar_map_(gvar_map) {} using ExprMutator::VisitExpr_; Expr VisitExpr_(const GlobalVarNode* op) override { diff --git a/src/relax/ir/dataflow_rewriter.h b/src/relax/ir/dataflow_rewriter.h index d26695a7ce52..c9878343aac4 100644 --- a/src/relax/ir/dataflow_rewriter.h +++ b/src/relax/ir/dataflow_rewriter.h @@ -30,6 +30,9 @@ #include #include +#include +#include +#include #include "dataflow_matcher.h" From 46bbdb041827981fd30f548e01f3662e9051dbac Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 12 Jul 2024 13:56:07 -0500 Subject: [PATCH 12/17] lint fix --- src/relax/ir/dataflow_matcher.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index b6994f017466..417a78f0d04b 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -39,6 +39,7 @@ #include #include #include +#include #include #include #include From 826d270899bd80603be96854fda5dfd3e8bab48f Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 15 Jul 2024 09:15:50 -0500 Subject: [PATCH 13/17] Update based on review comments --- include/tvm/relax/block_builder.h | 2 +- python/tvm/relax/dpl/rewrite.py | 42 +++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/include/tvm/relax/block_builder.h b/include/tvm/relax/block_builder.h index e70b14c6082a..ad2b9820707a 100644 --- a/include/tvm/relax/block_builder.h +++ b/include/tvm/relax/block_builder.h @@ -161,7 +161,7 @@ class BlockBuilderNode : public Object { virtual void BeginInnerScope() = 0; /*! - * \brief Append a definition to the cuurrent scope. + * \brief Append a definition to the current scope. * * \param var A variable within the current scope. * diff --git a/python/tvm/relax/dpl/rewrite.py b/python/tvm/relax/dpl/rewrite.py index 9cd14e2c4cd8..d059119ace00 100644 --- a/python/tvm/relax/dpl/rewrite.py +++ b/python/tvm/relax/dpl/rewrite.py @@ -71,6 +71,48 @@ def from_pattern( def from_module(mod: IRModule) -> "ExprRewriter": """Construct a rewriter from an IRModule + The IRModule must have two publicly-exposed functions, + `pattern` and `replacement`, where `pattern` and `replacement` + have the same function signature, as shown in the example + below. + + .. code-block:: python + + @I.ir_module + class RewriteAddIntoMultiply: + @R.function + def pattern(A: R.Tensor): + B = A + A + return B + + @R.function + def replacement(A: R.Tensor): + B = A * 2 + return B + + rewriter = ExprRewriter.from_module(RewriteAddIntoMultiply) + rewritten_ir_module = rewriter(ir_module) + + To support the common case of defining an IRModule with + TVMScript, then immediately turning it into a rewriter, the + `@R.rewriter` annotation can be used. + + .. code-block:: python + + @R.rewriter + class RewriteAddIntoMultiply: + @R.function + def pattern(A: R.Tensor): + B = A + A + return B + + @R.function + def replacement(A: R.Tensor): + B = A * 2 + return B + + rewritten_ir_module = RewriteAddIntoMultiply(ir_module) + Parameters ---------- mod: IRModule From 234ddde3adc9dfe1854f165c3e0150b064c78de7 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 15 Jul 2024 10:35:33 -0500 Subject: [PATCH 14/17] Add test case for matching against arbitrary dtype --- src/relax/ir/block_builder.cc | 6 +- src/relax/ir/dataflow_expr_rewriter.cc | 3 + src/relax/ir/expr.cc | 15 ++ src/relax/ir/expr_functor.cc | 3 +- src/relax/utils.cc | 16 +-- tests/python/relax/test_dataflow_rewriter.py | 142 +++++++++++++++++++ 6 files changed, 171 insertions(+), 14 deletions(-) diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 95d16c1abadf..a123a3e45805 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -221,7 +221,8 @@ class BlockBuilderImpl : public BlockBuilderNode { analyzer_.MarkGlobalNonNegValue(shape_var); } else { const PrimExpr& old_shape_expr = (*it).second; - CHECK(analyzer_.CanProveEqual(old_shape_expr, shape_expr)) + CHECK(old_shape_expr.same_as(shape_expr) || + analyzer_.CanProveEqual(old_shape_expr, shape_expr)) << "Inconsistent shape var " << shape_var << " in scope: " << old_shape_expr << " vs " << shape_expr; } @@ -261,6 +262,8 @@ class BlockBuilderImpl : public BlockBuilderNode { cur_frame->bindings.push_back(match_cast); // NOTE match shape do not follow simple binding rule // as a result should not appear in binding table. + + AddDefinitionToScope(var); return var; } @@ -296,6 +299,7 @@ class BlockBuilderImpl : public BlockBuilderNode { // NOTE match shape do not follow simple binding rule // as a result should not appear in binding table. cur_frame->bindings.push_back(binding); + AddDefinitionToScope(match_cast->var); } else { LOG(FATAL) << "Unsupported binding type: " << binding->GetTypeKey(); } diff --git a/src/relax/ir/dataflow_expr_rewriter.cc b/src/relax/ir/dataflow_expr_rewriter.cc index e0dca16b2e91..81d26cb4730c 100644 --- a/src/relax/ir/dataflow_expr_rewriter.cc +++ b/src/relax/ir/dataflow_expr_rewriter.cc @@ -727,6 +727,9 @@ ExprRewriter ExprRewriter::FromModule(IRModule mod) { } else if (auto func = expr.as()) { return ExternFuncPattern(func->global_symbol); + } else if (auto prim = expr.as()) { + return StructInfoPattern(WildcardPattern(), PrimStructInfo(prim->value)); + } else { LOG(FATAL) << "TypeError: " << "Cannot convert Relax expression of type " << expr->GetTypeKey() diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index a14ba1d9aaa1..4850d52546b2 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -21,6 +21,8 @@ #include #include +#include + namespace tvm { namespace relax { @@ -589,6 +591,19 @@ Function::Function(Array params, Expr body, Optional ret_struct ret_struct_info = body_sinfo; } + auto f_shape_var_map = [&] { + auto tir_vars = DefinableTIRVarsInStructInfo(TupleStructInfo(params.Map(GetStructInfo))); + std::unordered_set lookup(tir_vars.begin(), tir_vars.end()); + return [lookup = std::move(lookup)](const tir::Var& var) -> Optional { + if (lookup.count(var)) { + return var; + } else { + return NullOpt; + } + }; + }(); + ret_struct_info = EraseToWellDefined(ret_struct_info.value(), f_shape_var_map); + FuncStructInfo func_sinfo(param_sinfo, ret_struct_info.value(), is_pure); // set the fields diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc index c2320de62a75..3ee403a25cda 100644 --- a/src/relax/ir/expr_functor.cc +++ b/src/relax/ir/expr_functor.cc @@ -810,7 +810,8 @@ Expr ExprMutator::VisitWithNewScope(const Expr& expr, Optional> param builder_->BeginInnerScope(); // Inner scope also includes any TIR variables that are defined by // MatchCast nodes, and are internal to the scope. - Expr ret = ExprFunctor::VisitExpr(expr); + Expr ret = this->VisitExpr(expr); + builder_->EndScope(); // Normalization (and the resulting StructInfo inference) of the diff --git a/src/relax/utils.cc b/src/relax/utils.cc index f0239e424f30..77416dc92b1d 100644 --- a/src/relax/utils.cc +++ b/src/relax/utils.cc @@ -122,11 +122,7 @@ tvm::Map InferSymbolicVarMap( if (!var_sinfo) return; auto expr_sinfo = expr.as(); - CHECK(expr_sinfo) << "Cannot bind expression with struct type " << expr - << " to variable with struct type " << var; - CHECK_EQ(var_sinfo->dtype, expr_sinfo->dtype) - << "Cannot bind expression with struct type " << expr << " to variable with struct type " - << var << ", due to conflicting PrimExpr DataType"; + if (!expr_sinfo) return; if (!var_sinfo->value.defined() || !expr_sinfo->value.defined()) return; @@ -139,15 +135,12 @@ tvm::Map InferSymbolicVarMap( if (!var_shape->values.defined()) return; auto expr_shape = expr.as(); - CHECK(expr_shape) << "Cannot bind expression with struct type " << expr - << " to variable with struct type " << var; + if (!expr_shape) return; if (!expr_shape->values.defined()) return; auto var_shape_arr = var_shape->values.value(); auto expr_shape_arr = expr_shape->values.value(); - CHECK_EQ(var_shape_arr.size(), expr_shape_arr.size()) - << "Cannot bind shape " << expr_shape_arr << " of dimension " << expr_shape_arr.size() - << " to variable with shape " << var_shape_arr << " of dimension " << var_shape_arr.size(); + if (var_shape_arr.size() != expr_shape_arr.size()) return; for (size_t i = 0; i < var_shape_arr.size(); i++) { bind_from_prim_expr(var_shape_arr[i], expr_shape_arr[i]); } @@ -159,8 +152,7 @@ tvm::Map InferSymbolicVarMap( if (!var_tensor->shape.defined()) return; auto expr_tensor = expr.as(); - CHECK(expr_tensor) << "Cannot bind expression with struct type " << expr - << " to variable with struct type " << var; + if (!expr_tensor) return; if (!expr_tensor->shape.defined()) return; bind_from_shape(GetStructInfo(var_tensor->shape.value()), diff --git a/tests/python/relax/test_dataflow_rewriter.py b/tests/python/relax/test_dataflow_rewriter.py index 05bbe429bbcc..1377bd0c1498 100644 --- a/tests/python/relax/test_dataflow_rewriter.py +++ b/tests/python/relax/test_dataflow_rewriter.py @@ -377,6 +377,148 @@ def expected(A: R.Tensor([16], "float32")): tvm.ir.assert_structural_equal(expected, after) +def test_rewrite_of_arbitrary_dtype(): + """A pattern-match may apply to a tensor with unknown dtype + + In this test case, a pattern identifies `R.strided_slice` usage + which returns the last slice of an array, and replaces it with a + view into the input array. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern(A: R.Tensor(["M", "N"])) -> R.Tensor(["N"]): + M = T.int64() + N = T.int64() + last_slice_2d: R.Tensor([1, N]) = R.strided_slice(A, axes=[0], begin=[M - 1], end=[M]) + last_slice_1d: R.Tensor([N]) = R.squeeze(last_slice_2d, axis=0) + return last_slice_1d + + @R.function + def replacement(A: R.Tensor(["M", "N"])) -> R.Tensor(["N"]): + M = T.int64() + N = T.int64() + + # TODO(Lunderberg): Improve this syntax. A Relax + # PrimValue (e.g. `A.dtype.bits`) should be usable in any + # Relax context that accepts a `PrimExpr`. Currently, + # this requires `R.match_cast` to produce a TIR symbolic + # variable from the Relax PrimValue. + bits_per_element = T.uint8() + _ = R.match_cast( + A.dtype.bits, + R.Prim(value=bits_per_element), + ) + lanes_per_element = T.uint16() + _ = R.match_cast( + A.dtype.lanes, + R.Prim(value=lanes_per_element), + ) + + last_slice = R.memory.view( + A, + [N], + relative_byte_offset=(M - 1) + * N + * T.ceildiv( + bits_per_element.astype("int64") * lanes_per_element.astype("int64"), 8 + ), + ) + return last_slice + + @I.ir_module + class Before: + @R.function + def main( + A: R.Tensor([32, 16], "float16"), + B: R.Tensor(["P", "Q"], "int4x8"), + C: R.Tensor([16, 32]), + ): + P = T.int64() + Q = T.int64() + + A_slice_2d = R.strided_slice(A, axes=[0], begin=[31], end=[32]) + A_slice_1d = R.squeeze(A_slice_2d, axis=0) + + B_slice_2d = R.strided_slice(B, axes=[0], begin=[P - 1], end=[P]) + B_slice_1d = R.squeeze(B_slice_2d, axis=0) + + C_slice_2d = R.strided_slice(C, axes=[0], begin=[15], end=[16]) + C_slice_1d = R.squeeze(C_slice_2d, axis=0) + + return (A_slice_1d, B_slice_1d, C_slice_1d) + + @I.ir_module + class Expected: + @R.function + def main( + A: R.Tensor([32, 16], "float16"), + B: R.Tensor(["P", "Q"], "int4x8"), + C: R.Tensor([16, 32]), + ): + P = T.int64() + Q = T.int64() + + # The pattern matches any 2-d tensor, with any data type. + # When the match's shape and dtype are both known, + # normalization and canonicalization produces a statically + # known value for `relative_byte_offset`. + # + # Relative offset is `(31 rows) * + # (16 elements/row) * + # (2 bytes/element)` + A_slice_1d = R.memory.view(A, shape=[16], relative_byte_offset=992) + + # The pattern can also match a 2-d tensor with dynamic + # shape. The `relative_byte_offset` uses the known + # datatype (4 bytes for each int4x8), but with dynamic + # shape variables substituted in where required. + # + # Relative offset is `((P-1) rows) * + # (Q elements/row) * + # (4 bytes/element)` + B_slice_1d = R.memory.view(B, shape=[Q], relative_byte_offset=(P - 1) * Q * 4) + + # The pattern can also match a 2-d tensor with static + # shape, but unknown data type. The + # `relative_byte_offset` is determined based on the known + # number of elements, and the dynamic size of each + # element. + # + # Relative offset is `(15 rows) * + # (32 elements/row) * + # (ceildiv(bits*lanes,8) bytes/element)` + C_bits_per_element = T.uint8() + C_bits_prim_value = C.dtype.bits + _ = R.match_cast( + C_bits_prim_value, + R.Prim(value=C_bits_per_element), + ) + C_lanes_per_element = T.uint16() + C_lanes_prim_value = C.dtype.lanes + _ = R.match_cast( + C_lanes_prim_value, + R.Prim(value=C_lanes_per_element), + ) + + C_slice_1d = R.memory.view( + C, + shape=[32], + relative_byte_offset=( + (C_bits_per_element.astype("int64") * C_lanes_per_element.astype("int64") + 7) + // 8 + ) + * 480, + ) + + return (A_slice_1d, B_slice_1d, C_slice_1d) + + after = Rewriter(Before) + tvm.ir.assert_structural_equal(Expected, after) + + def test_rewrite_may_introduce_private_relax_subroutines(): """The replacement may contain subroutines""" From 7f62f70860623e0644d85413a9fd34c5551091b1 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 16 Jul 2024 09:04:20 -0500 Subject: [PATCH 15/17] Fix breakage in unit tests One unit test that had been relying on invalid shape propagation. Another unit test that required constructed an ill-formed output to test against. --- src/relax/ir/expr.cc | 53 ++++++++++--------- .../test_transform_legalize_ops_manipulate.py | 2 +- 2 files changed, 30 insertions(+), 25 deletions(-) diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index 4850d52546b2..6ace974985a5 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -578,32 +578,37 @@ Function::Function(Array params, Expr body, Optional ret_struct body_sinfo = GetStructInfo(body); } - if (ret_struct_info.defined()) { - // allow body to override ret if body is more fine-grained. - if (body_sinfo.defined()) { - if (IsBaseOf(ret_struct_info.value(), body_sinfo.value())) { - ret_struct_info = body_sinfo; - } - } - } else { - CHECK(body_sinfo.defined()) - << "Function do not have a return signature and body is not normalized"; - ret_struct_info = body_sinfo; + CHECK(body_sinfo.defined() || ret_struct_info.defined()) + << "Function must be constructed with either " + << "an explicit struct info for the return type, " + << "or a normalized body with struct info."; + + // Use the body's struct info if there is no explicit return type, + // or if the body may provide a more granular return type. + bool use_body_struct_info = + !ret_struct_info.defined() || + (body_sinfo && ret_struct_info && IsBaseOf(ret_struct_info.value(), body_sinfo.value())); + + if (use_body_struct_info) { + // MatchCast nodes within the body may introduce new symbolic + // variables. These are in-scope for the function body, but not + // for the function's return type. When hoisting the body's type + // to the function return type, symbolic variables may only be + // used if they were defined by the function's parameters. + auto f_shape_var_map = [&] { + auto tir_vars = DefinableTIRVarsInStructInfo(TupleStructInfo(params.Map(GetStructInfo))); + std::unordered_set lookup(tir_vars.begin(), tir_vars.end()); + return [lookup = std::move(lookup)](const tir::Var& var) -> Optional { + if (lookup.count(var)) { + return var; + } else { + return NullOpt; + } + }; + }(); + ret_struct_info = EraseToWellDefined(body_sinfo.value(), f_shape_var_map); } - auto f_shape_var_map = [&] { - auto tir_vars = DefinableTIRVarsInStructInfo(TupleStructInfo(params.Map(GetStructInfo))); - std::unordered_set lookup(tir_vars.begin(), tir_vars.end()); - return [lookup = std::move(lookup)](const tir::Var& var) -> Optional { - if (lookup.count(var)) { - return var; - } else { - return NullOpt; - } - }; - }(); - ret_struct_info = EraseToWellDefined(ret_struct_info.value(), f_shape_var_map); - FuncStructInfo func_sinfo(param_sinfo, ret_struct_info.value(), is_pure); // set the fields diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py b/tests/python/relax/test_transform_legalize_ops_manipulate.py index dd0208f5db07..ba5d4d7d1219 100644 --- a/tests/python/relax/test_transform_legalize_ops_manipulate.py +++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py @@ -720,7 +720,7 @@ def reshape( T_reshape[v_ax0] = rxplaceholder[v_ax0 % T.int64(3)] @R.function - def main(x: R.Tensor((3,), dtype="int64")) -> R.Tensor((3,), dtype="int64"): + def main(x: R.Tensor((3,), dtype="int64")) -> R.Tensor(ndim=1, dtype="int64"): x_1 = T.int64() gv: R.Shape([3]) = R.call_pure_packed("vm.builtin.tensor_to_shape", x, sinfo_args=(R.Shape([3]),)) y: R.Shape([x_1]) = R.match_cast(gv, R.Shape([x_1])) From 68fbeed1233f0f7de2354c7e3609905263d2b544 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 16 Jul 2024 13:23:30 -0500 Subject: [PATCH 16/17] Updated base class name from ExprRewriter to PatternMatchingRewriter --- python/tvm/relax/dpl/__init__.py | 8 +- python/tvm/relax/dpl/rewrite.py | 34 ++++---- python/tvm/script/ir_builder/relax/ir.py | 17 ++-- src/relax/ir/dataflow_block_rewriter.cc | 9 +- src/relax/ir/dataflow_expr_rewriter.cc | 86 ++++++++++---------- src/relax/ir/dataflow_rewriter.h | 61 +++++++------- tests/python/relax/test_dataflow_rewriter.py | 8 +- 7 files changed, 118 insertions(+), 105 deletions(-) diff --git a/python/tvm/relax/dpl/__init__.py b/python/tvm/relax/dpl/__init__.py index cda84424e5ab..a4f3f4063e90 100644 --- a/python/tvm/relax/dpl/__init__.py +++ b/python/tvm/relax/dpl/__init__.py @@ -19,4 +19,10 @@ from .pattern import * from .context import * -from .rewrite import rewrite_call, rewrite_bindings, ExprRewriter, PatternRewriter, OrRewriter +from .rewrite import ( + rewrite_call, + rewrite_bindings, + PatternMatchingRewriter, + ExprPatternRewriter, + OrRewriter, +) diff --git a/python/tvm/relax/dpl/rewrite.py b/python/tvm/relax/dpl/rewrite.py index d059119ace00..96c69e9266a2 100644 --- a/python/tvm/relax/dpl/rewrite.py +++ b/python/tvm/relax/dpl/rewrite.py @@ -28,15 +28,15 @@ from . import _ffi as ffi -@register_object("relax.dpl.ExprRewriter") -class ExprRewriter(Object): +@register_object("relax.dpl.PatternMatchingRewriter") +class PatternMatchingRewriter(Object): """A pattern-matching rewriter for Relax""" @staticmethod def from_pattern( pattern: DFPattern, func: Callable[[Expr, Dict[DFPattern, Expr]], Expr], - ) -> "ExprRewriter": + ) -> "PatternMatchingRewriter": """Construct from a pattern and rewriter-function The replacements performed by the rewriter will be equivalent @@ -57,18 +57,18 @@ def from_pattern( Returns ------- - rewriter_obj: ExprRewriter + rewriter_obj: PatternMatchingRewriter The rewriter object """ - return ffi.ExprRewriterFromPattern( + return ffi.PatternMatchingRewriterFromPattern( pattern, func, ) # type: ignore @staticmethod - def from_module(mod: IRModule) -> "ExprRewriter": + def from_module(mod: IRModule) -> "PatternMatchingRewriter": """Construct a rewriter from an IRModule The IRModule must have two publicly-exposed functions, @@ -90,7 +90,7 @@ def replacement(A: R.Tensor): B = A * 2 return B - rewriter = ExprRewriter.from_module(RewriteAddIntoMultiply) + rewriter = PatternMatchingRewriter.from_module(RewriteAddIntoMultiply) rewritten_ir_module = rewriter(ir_module) To support the common case of defining an IRModule with @@ -123,12 +123,12 @@ def replacement(A: R.Tensor): Returns ------- - rewriter_obj: ExprRewriter + rewriter_obj: PatternMatchingRewriter The rewriter object """ - return ffi.ExprRewriterFromModule(mod) # type: ignore + return ffi.PatternMatchingRewriterFromModule(mod) # type: ignore def __call__(self, obj: Union[Expr, IRModule]) -> Union[Expr, IRModule]: """Apply the rewriter @@ -147,9 +147,9 @@ def __call__(self, obj: Union[Expr, IRModule]) -> Union[Expr, IRModule]: The rewritten object """ - return ffi.ExprRewriterApply(self, obj) + return ffi.PatternMatchingRewriterApply(self, obj) - def __or__(self, other: "ExprRewriter") -> "ExprRewriter": + def __or__(self, other: "PatternMatchingRewriter") -> "PatternMatchingRewriter": """Compose two rewriters Composing two rewrite rules together allows them to be applied @@ -157,13 +157,13 @@ def __or__(self, other: "ExprRewriter") -> "ExprRewriter": Parameters ---------- - other: ExprRewriter + other: PatternMatchingRewriter Another rewrite rule Returns ------- - ExprRewriter + PatternMatchingRewriter A rewriter that will apply either rewrite pattern @@ -171,8 +171,8 @@ def __or__(self, other: "ExprRewriter") -> "ExprRewriter": return OrRewriter(self, other) -@register_object("relax.dpl.PatternRewriter") -class PatternRewriter(ExprRewriter): +@register_object("relax.dpl.ExprPatternRewriter") +class ExprPatternRewriter(PatternMatchingRewriter): def __init__(self, pattern, func): self.__init_handle_by_constructor__( ffi.PatternRewriter, @@ -182,7 +182,7 @@ def __init__(self, pattern, func): @register_object("relax.dpl.OrRewriter") -class OrRewriter(ExprRewriter): +class OrRewriter(PatternMatchingRewriter): def __init__(self, lhs, rhs): self.__init_handle_by_constructor__( ffi.OrRewriter, @@ -192,7 +192,7 @@ def __init__(self, lhs, rhs): @register_object("relax.dpl.TupleRewriter") -class TupleRewriter(ExprRewriter): +class TupleRewriter(PatternMatchingRewriter): def __init__(self, patterns, func): self.__init_handle_by_constructor__( ffi.TupleRewriter, diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index e0beaeb9aade..fc86be6ab881 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -35,7 +35,7 @@ VarBinding, const, ) -from tvm.relax.dpl import ExprRewriter +from tvm.relax.dpl import PatternMatchingRewriter ############################### Operators ############################### from tvm.relax.op import ( @@ -307,7 +307,7 @@ def func_ret_value(value: Expr) -> None: return _ffi_api.FuncRetValue(value) # type: ignore[attr-defined] # pylint: disable=no-member -def rewriter(rewriter_mod: Union[IRModule, Type]) -> ExprRewriter: +def rewriter(rewriter_mod: Union[IRModule, Type]) -> PatternMatchingRewriter: """Define a pattern-rewrite rule The IRModule must have two publicly-exposed functions, `pattern` @@ -337,7 +337,7 @@ def replacement(A: R.Tensor): Returns ------- - rewriter: ExprRewriter + rewriter: PatternMatchingRewriter A rewriter object, which can be applied either to a Relax function or to an entire IRModule. @@ -346,7 +346,7 @@ def replacement(A: R.Tensor): if not isinstance(rewriter_mod, IRModule): rewriter_mod = tvm.script.ir_module(rewriter_mod) - return ExprRewriter.from_module(rewriter_mod) + return PatternMatchingRewriter.from_module(rewriter_mod) ############################# BindingBlock ############################## @@ -438,7 +438,10 @@ def _convert_tensor_type(args): new_args = [_convert_tensor_type(x) for x in args] return type(args)(new_args) if isinstance(args, dict): - return {_convert_tensor_type(k): _convert_tensor_type(v) for k, v in args.items()} + return { + _convert_tensor_type(k): _convert_tensor_type(v) + for k, v in args.items() + } if inspect.isfunction(args): args = args() if isinstance(args, ObjectGeneric): @@ -506,7 +509,9 @@ def emit_te(func: Callable, *args: Any, **kwargs: Any) -> Call: A newly created call that calls into a tir function. """ primfunc_name_hint = kwargs.pop("primfunc_name_hint", None) - tir_func, call_args, out_sinfo, tir_vars = gen_call_tir_inputs(func, *args, **kwargs) + tir_func, call_args, out_sinfo, tir_vars = gen_call_tir_inputs( + func, *args, **kwargs + ) if not primfunc_name_hint: primfunc_name_hint = func.__name__ gvar = decl_function(primfunc_name_hint, tir_func) # type: ignore diff --git a/src/relax/ir/dataflow_block_rewriter.cc b/src/relax/ir/dataflow_block_rewriter.cc index d07fedd29715..fb08dfe96a17 100644 --- a/src/relax/ir/dataflow_block_rewriter.cc +++ b/src/relax/ir/dataflow_block_rewriter.cc @@ -367,7 +367,7 @@ TVM_REGISTER_GLOBAL("relax.dpl.match_dfb") return MatchGraph(ctx, dfb); }); -class PatternContextRewriterNode : public ExprRewriterNode { +class PatternContextRewriterNode : public PatternMatchingRewriterNode { public: PatternContext pattern; TypedPackedFunc(Map, Map)> rewriter_func; @@ -381,7 +381,7 @@ class PatternContextRewriterNode : public ExprRewriterNode { } static constexpr const char* _type_key = "relax.dpl.PatternContextRewriter"; - TVM_DECLARE_FINAL_OBJECT_INFO(PatternContextRewriterNode, ExprRewriterNode); + TVM_DECLARE_FINAL_OBJECT_INFO(PatternContextRewriterNode, PatternMatchingRewriterNode); private: Optional> MatchBindings(const Array& bindings) const { @@ -401,13 +401,14 @@ class PatternContextRewriterNode : public ExprRewriterNode { } }; -class PatternContextRewriter : public ExprRewriter { +class PatternContextRewriter : public PatternMatchingRewriter { public: PatternContextRewriter( PatternContext pattern, TypedPackedFunc(Map, Map)> rewriter_func); - TVM_DEFINE_OBJECT_REF_METHODS(PatternContextRewriter, ExprRewriter, PatternContextRewriterNode); + TVM_DEFINE_OBJECT_REF_METHODS(PatternContextRewriter, PatternMatchingRewriter, + PatternContextRewriterNode); }; RewriteSpec PatternContextRewriterNode::RewriteBindings(const Array& bindings) const { diff --git a/src/relax/ir/dataflow_expr_rewriter.cc b/src/relax/ir/dataflow_expr_rewriter.cc index 81d26cb4730c..514116c5cadf 100644 --- a/src/relax/ir/dataflow_expr_rewriter.cc +++ b/src/relax/ir/dataflow_expr_rewriter.cc @@ -191,20 +191,20 @@ void RewriteSpec::Append(RewriteSpec other) { } } -TVM_REGISTER_NODE_TYPE(ExprRewriterNode); +TVM_REGISTER_NODE_TYPE(PatternMatchingRewriterNode); -TVM_REGISTER_GLOBAL("relax.dpl.ExprRewriterFromPattern") +TVM_REGISTER_GLOBAL("relax.dpl.PatternMatchingRewriterFromPattern") .set_body_typed([](DFPattern pattern, TypedPackedFunc(Expr, Map)> func) { - return ExprRewriter::FromPattern(pattern, func); + return PatternMatchingRewriter::FromPattern(pattern, func); }); -TVM_REGISTER_GLOBAL("relax.dpl.ExprRewriterFromModule").set_body_typed([](IRModule mod) { - return ExprRewriter::FromModule(mod); +TVM_REGISTER_GLOBAL("relax.dpl.PatternMatchingRewriterFromModule").set_body_typed([](IRModule mod) { + return PatternMatchingRewriter::FromModule(mod); }); -TVM_REGISTER_GLOBAL("relax.dpl.ExprRewriterApply") - .set_body_typed([](ExprRewriter rewriter, +TVM_REGISTER_GLOBAL("relax.dpl.PatternMatchingRewriterApply") + .set_body_typed([](PatternMatchingRewriter rewriter, Variant obj) -> Variant { if (auto expr = obj.as()) { return rewriter(expr.value()); @@ -215,9 +215,9 @@ TVM_REGISTER_GLOBAL("relax.dpl.ExprRewriterApply") } }); -TVM_REGISTER_NODE_TYPE(PatternRewriterNode); +TVM_REGISTER_NODE_TYPE(ExprPatternRewriterNode); -RewriteSpec PatternRewriterNode::RewriteBindings(const Array& bindings) const { +RewriteSpec ExprPatternRewriterNode::RewriteBindings(const Array& bindings) const { Map variable_rewrites; Map binding_lookup; for (const auto& binding : bindings) { @@ -235,8 +235,8 @@ RewriteSpec PatternRewriterNode::RewriteBindings(const Array& bindings) } } -Optional PatternRewriterNode::RewriteExpr(const Expr& expr, - const Map& bindings) const { +Optional ExprPatternRewriterNode::RewriteExpr(const Expr& expr, + const Map& bindings) const { if (auto opt_matches = ExtractMatchedExpr(pattern, expr, bindings)) { auto matches = opt_matches.value(); if (additional_bindings) { @@ -262,14 +262,13 @@ Optional PatternRewriterNode::RewriteExpr(const Expr& expr, TVM_REGISTER_GLOBAL("relax.dpl.PatternRewriter") .set_body_typed([](DFPattern pattern, TypedPackedFunc(Expr, Map)> func) { - return PatternRewriter(pattern, func); + return ExprPatternRewriter(pattern, func); }); -PatternRewriter::PatternRewriter(DFPattern pattern, - TypedPackedFunc(Expr, Map)> func, - Optional> additional_bindings, - Map new_subroutines) { - auto node = make_object(); +ExprPatternRewriter::ExprPatternRewriter( + DFPattern pattern, TypedPackedFunc(Expr, Map)> func, + Optional> additional_bindings, Map new_subroutines) { + auto node = make_object(); node->pattern = std::move(pattern); node->func = std::move(func); node->additional_bindings = std::move(additional_bindings); @@ -309,11 +308,12 @@ RewriteSpec OrRewriterNode::RewriteBindings(const Array& bindings) cons return lhs_match; } -TVM_REGISTER_GLOBAL("relax.dpl.OrRewriter").set_body_typed([](ExprRewriter lhs, ExprRewriter rhs) { - return OrRewriter(lhs, rhs); -}); +TVM_REGISTER_GLOBAL("relax.dpl.OrRewriter") + .set_body_typed([](PatternMatchingRewriter lhs, PatternMatchingRewriter rhs) { + return OrRewriter(lhs, rhs); + }); -OrRewriter::OrRewriter(ExprRewriter lhs, ExprRewriter rhs) { +OrRewriter::OrRewriter(PatternMatchingRewriter lhs, PatternMatchingRewriter rhs) { auto node = make_object(); node->lhs = std::move(lhs); node->rhs = std::move(rhs); @@ -621,30 +621,30 @@ TupleRewriter::TupleRewriter(Array patterns, data_ = std::move(node); } -ExprRewriter ExprRewriter::FromPattern( +PatternMatchingRewriter PatternMatchingRewriter::FromPattern( DFPattern pattern, TypedPackedFunc(Expr, Map)> func, Optional> additional_bindings, Map new_subroutines) { if (auto or_pattern = pattern.as()) { auto new_additional_bindings = additional_bindings.value_or({}); new_additional_bindings.push_back(pattern); - return OrRewriter( - ExprRewriter::FromPattern(or_pattern->left, func, new_additional_bindings, new_subroutines), - ExprRewriter::FromPattern(or_pattern->right, func, new_additional_bindings, - new_subroutines)); + return OrRewriter(PatternMatchingRewriter::FromPattern( + or_pattern->left, func, new_additional_bindings, new_subroutines), + PatternMatchingRewriter::FromPattern( + or_pattern->right, func, new_additional_bindings, new_subroutines)); } else if (auto tuple_pattern = pattern.as()) { auto new_additional_bindings = additional_bindings.value_or({}); new_additional_bindings.push_back(pattern); // If the Tuple appears as a Relax binding, apply it first. As a // fallback, also check for implicit tuples. return OrRewriter( - PatternRewriter(pattern, func, additional_bindings, new_subroutines), + ExprPatternRewriter(pattern, func, additional_bindings, new_subroutines), TupleRewriter(tuple_pattern->fields, func, new_additional_bindings, new_subroutines)); } else { - return PatternRewriter(pattern, func, additional_bindings, new_subroutines); + return ExprPatternRewriter(pattern, func, additional_bindings, new_subroutines); } } -ExprRewriter ExprRewriter::FromModule(IRModule mod) { +PatternMatchingRewriter PatternMatchingRewriter::FromModule(IRModule mod) { Function func_pattern = [&]() { CHECK(mod->ContainGlobalVar("pattern")) << "KeyError: " @@ -780,7 +780,7 @@ ExprRewriter ExprRewriter::FromModule(IRModule mod) { return SeqExpr(new_blocks, func_replacement->body->body); }; - return ExprRewriter::FromPattern(top_pattern, rewriter_func, NullOpt, new_subroutines); + return PatternMatchingRewriter::FromPattern(top_pattern, rewriter_func, NullOpt, new_subroutines); } Optional> ExtractMatchedExpr(DFPattern pattern, Expr expr, @@ -807,11 +807,11 @@ TVM_REGISTER_GLOBAL("relax.dpl.match_expr").set_body_typed(MatchExpr); * \brief Apply pattern matching to each expression, replacing * matches with the output of a user-provided rewriter function. */ -class ExprPatternRewriter : public ExprMutator { +class PatternMatchingMutator : public ExprMutator { public: using ExprMutator::VisitExpr_; - ExprPatternRewriter(const ExprRewriterNode* rewriter) : rewriter_(rewriter) {} + PatternMatchingMutator(const PatternMatchingRewriterNode* rewriter) : rewriter_(rewriter) {} Map GetNewSubroutines() const { return new_subroutines_; } @@ -1018,18 +1018,18 @@ class ExprPatternRewriter : public ExprMutator { } private: - const ExprRewriterNode* rewriter_; + const PatternMatchingRewriterNode* rewriter_; Map new_subroutines_; }; -Expr ExprRewriter::operator()(Expr expr) { - ExprPatternRewriter mutator(get()); +Expr PatternMatchingRewriter::operator()(Expr expr) { + PatternMatchingMutator mutator(get()); auto new_expr = mutator(expr); auto new_subroutines = mutator.GetNewSubroutines(); CHECK_EQ(new_subroutines.size(), 0) - << "If ExprRewriter provides subroutines, " + << "If PatternMatchingRewriter provides subroutines, " << "then it must be applied to an entire IRModule. " - << "However, ExprRewriter produced subroutines " << [&]() -> Array { + << "However, PatternMatchingRewriter produced subroutines " << [&]() -> Array { std::vector vec; for (const auto& [gvar, func] : new_subroutines) { vec.push_back(gvar); @@ -1042,9 +1042,9 @@ Expr ExprRewriter::operator()(Expr expr) { return new_expr; } -IRModule ExprRewriterNode::operator()(IRModule mod, - const tvm::transform::PassContext& pass_ctx) const { - ExprPatternRewriter mutator(this); +IRModule PatternMatchingRewriterNode::operator()( + IRModule mod, const tvm::transform::PassContext& pass_ctx) const { + PatternMatchingMutator mutator(this); IRModule updates; for (const auto& [gvar, base_func] : mod->functions) { @@ -1064,13 +1064,13 @@ IRModule ExprRewriterNode::operator()(IRModule mod, return mod; } -tvm::transform::PassInfo ExprRewriterNode::Info() const { - return tvm::transform::PassInfo(0, "ExprRewriter", {}, false); +tvm::transform::PassInfo PatternMatchingRewriterNode::Info() const { + return tvm::transform::PassInfo(0, "PatternMatchingRewriter", {}, false); } Function RewriteCall(const DFPattern& pat, TypedPackedFunc)> rewriter, Function func) { - return Downcast(ExprRewriter::FromPattern(pat, rewriter)(func)); + return Downcast(PatternMatchingRewriter::FromPattern(pat, rewriter)(func)); } TVM_REGISTER_GLOBAL("relax.dpl.rewrite_call").set_body_typed(RewriteCall); diff --git a/src/relax/ir/dataflow_rewriter.h b/src/relax/ir/dataflow_rewriter.h index c9878343aac4..53f934982c59 100644 --- a/src/relax/ir/dataflow_rewriter.h +++ b/src/relax/ir/dataflow_rewriter.h @@ -48,7 +48,7 @@ struct RewriteSpec { void Append(RewriteSpec other); }; -class ExprRewriterNode : public tvm::transform::PassNode { +class PatternMatchingRewriterNode : public tvm::transform::PassNode { public: virtual RewriteSpec RewriteBindings(const Array& bindings) const { return RewriteSpec(); @@ -59,26 +59,26 @@ class ExprRewriterNode : public tvm::transform::PassNode { IRModule operator()(IRModule mod, const tvm::transform::PassContext& pass_ctx) const override; tvm::transform::PassInfo Info() const override; - static constexpr const char* _type_key = "relax.dpl.ExprRewriter"; - TVM_DECLARE_BASE_OBJECT_INFO(ExprRewriterNode, PassNode); + static constexpr const char* _type_key = "relax.dpl.PatternMatchingRewriter"; + TVM_DECLARE_BASE_OBJECT_INFO(PatternMatchingRewriterNode, PassNode); }; -class ExprRewriter : public tvm::transform::Pass { +class PatternMatchingRewriter : public tvm::transform::Pass { public: - static ExprRewriter FromPattern(DFPattern pattern, - TypedPackedFunc(Expr, Map)> func, - Optional> additional_bindings = NullOpt, - Map new_subroutines = {}); + static PatternMatchingRewriter FromPattern( + DFPattern pattern, TypedPackedFunc(Expr, Map)> func, + Optional> additional_bindings = NullOpt, + Map new_subroutines = {}); - static ExprRewriter FromModule(IRModule mod); + static PatternMatchingRewriter FromModule(IRModule mod); Expr operator()(Expr expr); using Pass::operator(); - TVM_DEFINE_OBJECT_REF_METHODS(ExprRewriter, Pass, ExprRewriterNode); + TVM_DEFINE_OBJECT_REF_METHODS(PatternMatchingRewriter, Pass, PatternMatchingRewriterNode); }; -class PatternRewriterNode : public ExprRewriterNode { +class ExprPatternRewriterNode : public PatternMatchingRewriterNode { public: DFPattern pattern; TypedPackedFunc(Expr, Map)> func; @@ -95,24 +95,25 @@ class PatternRewriterNode : public ExprRewriterNode { visitor->Visit("func", &untyped_func); } - static constexpr const char* _type_key = "relax.dpl.PatternRewriter"; - TVM_DECLARE_BASE_OBJECT_INFO(PatternRewriterNode, ExprRewriterNode); + static constexpr const char* _type_key = "relax.dpl.ExprPatternRewriter"; + TVM_DECLARE_BASE_OBJECT_INFO(ExprPatternRewriterNode, PatternMatchingRewriterNode); }; -class PatternRewriter : public ExprRewriter { +class ExprPatternRewriter : public PatternMatchingRewriter { public: - PatternRewriter(DFPattern pattern, - TypedPackedFunc(Expr, Map)> func, - Optional> additional_bindings = NullOpt, - Map new_subroutines = {}); + ExprPatternRewriter(DFPattern pattern, + TypedPackedFunc(Expr, Map)> func, + Optional> additional_bindings = NullOpt, + Map new_subroutines = {}); - TVM_DEFINE_OBJECT_REF_METHODS(PatternRewriter, ExprRewriter, PatternRewriterNode); + TVM_DEFINE_OBJECT_REF_METHODS(ExprPatternRewriter, PatternMatchingRewriter, + ExprPatternRewriterNode); }; -class OrRewriterNode : public ExprRewriterNode { +class OrRewriterNode : public PatternMatchingRewriterNode { public: - ExprRewriter lhs; - ExprRewriter rhs; + PatternMatchingRewriter lhs; + PatternMatchingRewriter rhs; RewriteSpec RewriteBindings(const Array& bindings) const override; @@ -122,17 +123,17 @@ class OrRewriterNode : public ExprRewriterNode { } static constexpr const char* _type_key = "relax.dpl.OrRewriter"; - TVM_DECLARE_BASE_OBJECT_INFO(OrRewriterNode, ExprRewriterNode); + TVM_DECLARE_BASE_OBJECT_INFO(OrRewriterNode, PatternMatchingRewriterNode); }; -class OrRewriter : public ExprRewriter { +class OrRewriter : public PatternMatchingRewriter { public: - OrRewriter(ExprRewriter lhs, ExprRewriter rhs); + OrRewriter(PatternMatchingRewriter lhs, PatternMatchingRewriter rhs); - TVM_DEFINE_OBJECT_REF_METHODS(OrRewriter, ExprRewriter, OrRewriterNode); + TVM_DEFINE_OBJECT_REF_METHODS(OrRewriter, PatternMatchingRewriter, OrRewriterNode); }; -class TupleRewriterNode : public ExprRewriterNode { +class TupleRewriterNode : public PatternMatchingRewriterNode { public: Array patterns; TypedPackedFunc(Expr, Map)> func; @@ -148,7 +149,7 @@ class TupleRewriterNode : public ExprRewriterNode { } static constexpr const char* _type_key = "relax.dpl.TupleRewriter"; - TVM_DECLARE_BASE_OBJECT_INFO(TupleRewriterNode, ExprRewriterNode); + TVM_DECLARE_BASE_OBJECT_INFO(TupleRewriterNode, PatternMatchingRewriterNode); private: struct VarInfo { @@ -165,14 +166,14 @@ class TupleRewriterNode : public ExprRewriterNode { const std::vector& indices) const; }; -class TupleRewriter : public ExprRewriter { +class TupleRewriter : public PatternMatchingRewriter { public: TupleRewriter(Array patterns, TypedPackedFunc(Expr, Map)> func, Optional> additional_bindings = NullOpt, Map new_subroutines = {}); - TVM_DEFINE_OBJECT_REF_METHODS(TupleRewriter, ExprRewriter, TupleRewriterNode); + TVM_DEFINE_OBJECT_REF_METHODS(TupleRewriter, PatternMatchingRewriter, TupleRewriterNode); }; } // namespace relax diff --git a/tests/python/relax/test_dataflow_rewriter.py b/tests/python/relax/test_dataflow_rewriter.py index 1377bd0c1498..828aa92bda28 100644 --- a/tests/python/relax/test_dataflow_rewriter.py +++ b/tests/python/relax/test_dataflow_rewriter.py @@ -17,7 +17,6 @@ import tvm.testing -from tvm.relax.dpl import ExprRewriter from tvm.script import ir as I, relax as R, tir as T import pytest @@ -137,9 +136,10 @@ def replacement(A: R.Tensor([16])): def test_rewriter_may_be_applied_to_ir_module(): """A rewriter may mutate an IRModule - The `ExprRewriter.__call__` implementation may accept either a - single Relax function, or an entire IRModule. If it is passed an - IRModule, then all functions in the `IRModule` are updated. + The `PatternMatchingRewriter.__call__` implementation may accept + either a single Relax function, or an entire IRModule. If it is + passed an IRModule, then all functions in the `IRModule` are + updated. """ From a7bf78f7708433ebeebfe524dd54976cf0a32ba1 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 16 Jul 2024 13:59:15 -0500 Subject: [PATCH 17/17] lint fix --- python/tvm/script/ir_builder/relax/ir.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index fc86be6ab881..c4be8afac4d2 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -438,10 +438,7 @@ def _convert_tensor_type(args): new_args = [_convert_tensor_type(x) for x in args] return type(args)(new_args) if isinstance(args, dict): - return { - _convert_tensor_type(k): _convert_tensor_type(v) - for k, v in args.items() - } + return {_convert_tensor_type(k): _convert_tensor_type(v) for k, v in args.items()} if inspect.isfunction(args): args = args() if isinstance(args, ObjectGeneric): @@ -509,9 +506,7 @@ def emit_te(func: Callable, *args: Any, **kwargs: Any) -> Call: A newly created call that calls into a tir function. """ primfunc_name_hint = kwargs.pop("primfunc_name_hint", None) - tir_func, call_args, out_sinfo, tir_vars = gen_call_tir_inputs( - func, *args, **kwargs - ) + tir_func, call_args, out_sinfo, tir_vars = gen_call_tir_inputs(func, *args, **kwargs) if not primfunc_name_hint: primfunc_name_hint = func.__name__ gvar = decl_function(primfunc_name_hint, tir_func) # type: ignore