diff --git a/python/tvm/relax/dpl/__init__.py b/python/tvm/relax/dpl/__init__.py index 6451238428c2..cda84424e5ab 100644 --- a/python/tvm/relax/dpl/__init__.py +++ b/python/tvm/relax/dpl/__init__.py @@ -19,4 +19,4 @@ from .pattern import * from .context import * -from .rewrite import rewrite_call, rewrite_bindings +from .rewrite import rewrite_call, rewrite_bindings, ExprRewriter, PatternRewriter, OrRewriter diff --git a/python/tvm/relax/dpl/rewrite.py b/python/tvm/relax/dpl/rewrite.py index 291061090fc2..f124c11f7077 100644 --- a/python/tvm/relax/dpl/rewrite.py +++ b/python/tvm/relax/dpl/rewrite.py @@ -15,16 +15,75 @@ # specific language governing permissions and limitations # under the License. """APIs for pattern-based rewriting.""" -from typing import Dict, Callable + +from typing import Dict, Callable, Union from .pattern import DFPattern from .context import PatternContext +from tvm.ir import IRModule +from tvm.runtime import Object +from tvm._ffi import register_object from ..expr import Expr, Function, Var from . import _ffi as ffi +@register_object("relax.dpl.ExprRewriter") +class ExprRewriter(Object): + @staticmethod + def from_pattern( + pattern: DFPattern, + func: Callable[[Expr, Dict[DFPattern, Expr]], Expr], + ) -> "ExprRewriter": + return ffi.ExprRewriterFromPattern( + pattern, + func, + ) # type: ignore + + @staticmethod + def from_module(mod: IRModule) -> "ExprRewriter": + return ffi.ExprRewriterFromModule(mod) # type: ignore + + def __call__(self, obj: Union[Expr, IRModule]) -> Union[Expr, IRModule]: + return ffi.ExprRewriterApply(self, obj) + + def __or__(self, other: "ExprRewriter") -> "ExprRewriter": + return OrRewriter(self, other) + + +@register_object("relax.dpl.PatternRewriter") +class PatternRewriter(ExprRewriter): + def __init__(self, pattern, func): + self.__init_handle_by_constructor__( + ffi.PatternRewriter, + pattern, + func, + ) # type: ignore + + +@register_object("relax.dpl.OrRewriter") +class OrRewriter(ExprRewriter): + def __init__(self, lhs, rhs): + self.__init_handle_by_constructor__( + ffi.OrRewriter, + lhs, + rhs, + ) # type: ignore + + +@register_object("relax.dpl.TupleRewriter") +class TupleRewriter(ExprRewriter): + def __init__(self, patterns, func): + self.__init_handle_by_constructor__( + ffi.TupleRewriter, + patterns, + func, + ) # type: ignore + + def rewrite_call( - pattern: DFPattern, rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr], func: Function + pattern: DFPattern, + rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr], + func: Function, ) -> Function: """ Rewrite a function with the given pattern and the rewriter function. diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index ef9ae775450b..e0beaeb9aade 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -20,11 +20,11 @@ import builtins import functools import inspect -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Type import tvm from tvm import DataType, relax -from tvm.ir import PrimExpr, VDevice +from tvm.ir import PrimExpr, VDevice, IRModule from tvm.relax import ( Call, Expr, @@ -35,6 +35,7 @@ VarBinding, const, ) +from tvm.relax.dpl import ExprRewriter ############################### Operators ############################### from tvm.relax.op import ( @@ -306,6 +307,48 @@ def func_ret_value(value: Expr) -> None: return _ffi_api.FuncRetValue(value) # type: ignore[attr-defined] # pylint: disable=no-member +def rewriter(rewriter_mod: Union[IRModule, Type]) -> ExprRewriter: + """Define a pattern-rewrite rule + + The IRModule must have two publicly-exposed functions, `pattern` + and `replacement`, where `pattern` and `replacement` have the same + function signature. + + .. code-block:: python + + @R.rewriter + class RewriteAddIntoMultiply: + @R.function + def pattern(A: R.Tensor): + B = A + A + return B + + @R.function + def replacement(A: R.Tensor): + B = A * 2 + return B + + Parameters + ---------- + rewriter_mod: Union[IRModule, Type] + + Either an IRModule that defines a rewrite pattern, or a + TVMScript class that can be parsed into an IRModule. + + Returns + ------- + rewriter: ExprRewriter + + A rewriter object, which can be applied either to a Relax + function or to an entire IRModule. + + """ + if not isinstance(rewriter_mod, IRModule): + rewriter_mod = tvm.script.ir_module(rewriter_mod) + + return ExprRewriter.from_module(rewriter_mod) + + ############################# BindingBlock ############################## @@ -765,6 +808,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "dequantize", "repeat", "reshape", + "rewriter", "tensor_to_shape", "shape_to_tensor", "rocm", diff --git a/python/tvm/script/parser/core/entry.py b/python/tvm/script/parser/core/entry.py index e7a7f98b7651..3d35416d941a 100644 --- a/python/tvm/script/parser/core/entry.py +++ b/python/tvm/script/parser/core/entry.py @@ -83,7 +83,8 @@ def parse( The parsed TVMScript program. """ if extra_vars is None: - extra_vars = _default_globals() + extra_vars = {} + extra_vars = {**extra_vars, **_default_globals()} ann = {} if inspect.isfunction(program): diff --git a/python/tvm/script/parser/core/utils.py b/python/tvm/script/parser/core/utils.py index 3edae3f25a33..8ad64f5dbc68 100644 --- a/python/tvm/script/parser/core/utils.py +++ b/python/tvm/script/parser/core/utils.py @@ -100,19 +100,29 @@ def is_defined_in_class(frames: List[FrameType], obj: Any) -> bool: res : bool The result if the object is defined in a class scope. """ + + def _is_tvmscript_class_annotator(line: str) -> bool: + """Checks if the line contains a TVMScript annotator for a class + + These match either `@I.ir_module` or `@R.rewriter`, or their + imported names `@ir_module` or `@rewriter`. + """ + + return line.startswith("@") and ("ir_module" in line or "rewriter" in line) + if len(frames) > 2: frame_info = frames[2] code_context = frame_info.code_context if code_context is None: return False line = code_context[0].strip() - if line.startswith("@") and "ir_module" in line: + if _is_tvmscript_class_annotator(line): return True if line.startswith("class"): lineno = frame_info.lineno if lineno >= 2: source, _ = findsource(obj) line = source[lineno - 2].strip() - if line.startswith("@") and "ir_module" in line: + if _is_tvmscript_class_annotator(line): return True return False diff --git a/src/relax/ir/dataflow_block_rewriter.cc b/src/relax/ir/dataflow_block_rewriter.cc index c5af0493a54c..d07fedd29715 100644 --- a/src/relax/ir/dataflow_block_rewriter.cc +++ b/src/relax/ir/dataflow_block_rewriter.cc @@ -37,6 +37,7 @@ #include #include "dataflow_matcher.h" +#include "dataflow_rewriter.h" namespace tvm { namespace relax { @@ -287,18 +288,21 @@ static std::optional MatchTree( return std::nullopt; } -Optional> MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb, +Optional> MatchGraph(const PatternContext& ctx, + const Array& binding_arr, const Map& bindings) { // TODO(@ganler): Handle non-may external use. ICHECK(ctx->allow_extern_use == PatternContextNode::kMay) << "Only kMay is supported yet."; DFPatternMatcher matcher(bindings); MatcherUseDefAnalysis ud_analysis; - ud_analysis.VisitBindingBlock_(dfb.get()); + for (const auto& binding : binding_arr) { + ud_analysis.VisitBinding(binding); + } // First construct a graph of PNode and RNode. std::unordered_map var2node; - var2node.reserve(dfb->bindings.size()); + var2node.reserve(bindings.size()); for (const VarNode* cur_var : ud_analysis.vars) { const auto& uses = ud_analysis.def2use.at(cur_var); @@ -355,7 +359,7 @@ Optional> MatchGraph(const PatternContext& ctx, const Datafl } Optional> MatchGraph(const PatternContext& ctx, const DataflowBlock& dfb) { - return MatchGraph(ctx, dfb, AnalyzeVar2Value(dfb)); + return MatchGraph(ctx, dfb->bindings, AnalyzeVar2Value(dfb)); } TVM_REGISTER_GLOBAL("relax.dpl.match_dfb") @@ -363,124 +367,82 @@ TVM_REGISTER_GLOBAL("relax.dpl.match_dfb") return MatchGraph(ctx, dfb); }); -/*! - * \brief Apply pattern matching to each dataflow block, replacing matches - * with the output of a user-provided rewriter function. - */ -class BlockPatternRewriter : ExprMutator { +class PatternContextRewriterNode : public ExprRewriterNode { public: - using ExprMutator::VisitBindingBlock_; - using ExprMutator::VisitExpr_; - - BlockPatternRewriter( - const PatternContext& ctx, - TypedPackedFunc(Map, Map)> rewriter_func) - : ctx_(ctx), rewriter_func_(rewriter_func) {} - - template - static Function Run( - PatternType pat, - TypedPackedFunc(Map, Map)> rewriter_func, - Function func) { - BlockPatternRewriter rewriter(pat, rewriter_func); - - func = Downcast(rewriter(func)); - func = Downcast(RemoveAllUnused(func)); - return func; - } + PatternContext pattern; + TypedPackedFunc(Map, Map)> rewriter_func; + + RewriteSpec RewriteBindings(const Array& bindings) const override; - BindingBlock VisitBindingBlock_(const DataflowBlockNode* block_node) override { - return RewriteDataflowBlockFixedPoint(GetRef(block_node)); + void VisitAttrs(AttrVisitor* visitor) { + visitor->Visit("pattern", &pattern); + PackedFunc untyped_func = rewriter_func; + visitor->Visit("rewriter_func", &untyped_func); } - private: - void EmitUsedVars(Expr val, const Array& pending_bindings, - std::unordered_set* emitted_vars) { - std::unordered_set unemitted_vars; - PostOrderVisit(val, [=, &unemitted_vars](Expr e) { - if (auto v = e.as(); v && !emitted_vars->count(v)) { - unemitted_vars.insert(v); - } - }); + static constexpr const char* _type_key = "relax.dpl.PatternContextRewriter"; + TVM_DECLARE_FINAL_OBJECT_INFO(PatternContextRewriterNode, ExprRewriterNode); - if (unemitted_vars.empty()) { - return; + private: + Optional> MatchBindings(const Array& bindings) const { + Map var_lookup; + for (const auto& binding : bindings) { + var_lookup.Set(binding->var, GetBoundValue(binding)); } - size_t num_unemitted = unemitted_vars.size(); - for (size_t i = 0; i < pending_bindings.size(); ++i) { - const auto& binding = pending_bindings[i]; - if (auto var_bind = binding.as(); - var_bind && unemitted_vars.count(var_bind->var.get())) { - // var_bind->value may also depend on other unemitted vars in this range - Array prev_bindings(pending_bindings.begin(), pending_bindings.begin() + i); - EmitUsedVars(var_bind->value, prev_bindings, emitted_vars); - this->VisitBinding(binding); - emitted_vars->insert(var_bind->var.get()); - if (--num_unemitted == 0) { - return; - } + if (auto matches = MatchGraph(pattern, bindings, var_lookup)) { + Map replacements = rewriter_func(matches.value(), var_lookup); + if (replacements.size()) { + return replacements; } } + + return NullOpt; } +}; - // Repeat until all matchable subsets of bindings are rewritten. - BindingBlock RewriteDataflowBlockFixedPoint(BindingBlock block) { - auto df_block = Downcast(block); - Map bindings = AnalyzeVar2Value(df_block); - if (auto matches = MatchGraph(ctx_, df_block, bindings)) { - builder_->BeginDataflowBlock(); - Map replacements = rewriter_func_(matches.value(), bindings); - - std::unordered_set emitted_vars; - - bool changed = false; - for (size_t i = 0; i < block->bindings.size(); ++i) { - const auto& binding = block->bindings[i]; - if (auto var_bind = binding.as()) { - if (auto new_val = replacements.Get(var_bind->var).value_or(var_bind->value); - !StructuralEqual()(var_bind->value, new_val)) { - Array pending_bindings(block->bindings.begin() + i + 1, block->bindings.end()); - // Make sure there is no unbound variable used in the new value before it is emitted - EmitUsedVars(new_val, pending_bindings, &emitted_vars); - this->ReEmitBinding(var_bind, builder_->Normalize(new_val)); - changed = true; - } else if (!emitted_vars.count(var_bind->var.get())) { - this->VisitBinding(binding); - emitted_vars.insert(var_bind->var.get()); - } - } else { - this->VisitBinding(binding); - } - } +class PatternContextRewriter : public ExprRewriter { + public: + PatternContextRewriter( + PatternContext pattern, + TypedPackedFunc(Map, Map)> rewriter_func); - auto new_block = builder_->EndBlock(); + TVM_DEFINE_OBJECT_REF_METHODS(PatternContextRewriter, ExprRewriter, PatternContextRewriterNode); +}; - if (!changed) return new_block; - return RewriteDataflowBlockFixedPoint(new_block); +RewriteSpec PatternContextRewriterNode::RewriteBindings(const Array& bindings) const { + std::vector remaining_bindings{bindings.begin(), bindings.end()}; + + Map variable_rewrites; + while (auto opt = MatchBindings(remaining_bindings)) { + auto new_rewrites = opt.value(); + remaining_bindings.erase(std::remove_if(remaining_bindings.begin(), remaining_bindings.end(), + [&new_rewrites](const Binding& binding) { + return new_rewrites.count(binding->var); + }), + remaining_bindings.end()); + for (const auto& [var, expr] : new_rewrites) { + variable_rewrites.Set(var, expr); } - return block; } - /*! \brief The pattern constraint contexts for rewriting dataflow blocks */ - PatternContext ctx_; - /*! - * \brief The user-provided rewriter function. Its signature and semantics are: - * - * - (Map, Map) -> Map - * - * Given the map of patterns and corresponding variables (bound - * variables or parameters), it should return a map that - * specifies new values for matched bound variables. It can refer - * to the passed bindings to create the replacement expressions. - */ - TypedPackedFunc(Map, Map)> rewriter_func_; -}; + return RewriteSpec{variable_rewrites, {}}; +} + +PatternContextRewriter::PatternContextRewriter( + PatternContext pattern, + TypedPackedFunc(Map, Map)> rewriter_func) { + auto node = make_object(); + node->pattern = std::move(pattern); + node->rewriter_func = std::move(rewriter_func); + data_ = std::move(node); +} Function RewriteBindings( const PatternContext& ctx, TypedPackedFunc(Map, Map)> rewriter, Function func) { - return BlockPatternRewriter::Run(ctx, rewriter, func); + // return BlockPatternRewriter::Run(ctx, rewriter, func); + return Downcast(PatternContextRewriter(ctx, rewriter)(func)); } TVM_REGISTER_GLOBAL("relax.dpl.rewrite_bindings").set_body_typed(RewriteBindings); diff --git a/src/relax/ir/dataflow_expr_rewriter.cc b/src/relax/ir/dataflow_expr_rewriter.cc index 4793d1d75a30..8acaec60c356 100644 --- a/src/relax/ir/dataflow_expr_rewriter.cc +++ b/src/relax/ir/dataflow_expr_rewriter.cc @@ -22,6 +22,7 @@ * \brief A transform to match a Relax Expr and rewrite */ +#include #include #include #include @@ -31,12 +32,754 @@ #include #include +#include + #include "../transform/utils.h" #include "dataflow_matcher.h" +#include "dataflow_rewriter.h" namespace tvm { namespace relax { +namespace { +class GlobalVarReplacer : public ExprMutator { + public: + GlobalVarReplacer(Map gvar_map) : gvar_map_(gvar_map) {} + + using ExprMutator::VisitExpr_; + Expr VisitExpr_(const GlobalVarNode* op) override { + auto gvar = GetRef(op); + if (auto opt = gvar_map_.Get(gvar)) { + gvar = opt.value(); + } + return gvar; + } + + private: + Map gvar_map_; +}; + +Array TopologicalSort(const Array& bindings) { + std::unordered_set remaining_bindings; + for (const auto& binding : bindings) { + remaining_bindings.insert(binding->var); + } + + // Utility structure used to track bindings that are moved later in + // the list. + struct DelayedBinding { + Binding binding; + std::unordered_set unmet_requirements; + bool emitted; + }; + std::vector delayed_bindings; + Array sorted_bindings; + + // Utility function to append the + auto push_sorted_binding = [&](Binding binding) { + sorted_bindings.push_back(binding); + remaining_bindings.erase(binding->var); + for (auto& delayed_binding : delayed_bindings) { + delayed_binding.unmet_requirements.erase(binding->var); + } + }; + + bool required_sorting = false; + for (const auto& binding : bindings) { + // Collect any variables used by this binding, but are emitted by + // a later binding. + std::unordered_set unmet_requirements; + for (auto free_var : FreeVars(GetBoundValue(binding))) { + if (remaining_bindings.count(free_var)) { + unmet_requirements.insert(free_var); + } + } + + if (unmet_requirements.empty()) { + push_sorted_binding(binding); + } else { + required_sorting = true; + delayed_bindings.push_back(DelayedBinding{binding, unmet_requirements, false}); + } + + bool requires_delayed_binding_check = true; + while (requires_delayed_binding_check) { + requires_delayed_binding_check = false; + for (auto& delayed_binding : delayed_bindings) { + if (!delayed_binding.emitted && delayed_binding.unmet_requirements.empty()) { + // If we find a delayed binding that can be emitted, mark it + // as emitted and push to the sorted list. This may + delayed_binding.emitted = true; + requires_delayed_binding_check = true; + push_sorted_binding(delayed_binding.binding); + + // The break is not necessary for a topological sort, but is + // necessary to minimize the amount of re-ordering that is + // performed. With this break, the next binding is always + // the earliest binding that is legal to emit at this point. + break; + } + } + } + + // Remove any delayed bindings that have been emitted, now that we + // are done iterating over the delayed bindings. + delayed_bindings.erase( + std::remove_if(delayed_bindings.begin(), delayed_bindings.end(), + [](const auto& delayed_binding) { return delayed_binding.emitted; }), + delayed_bindings.end()); + } + + // All bindings should be emitted by this point. If any remain, + // then there exists a circular dependency somewhere in the + // remaining bindings. + CHECK(delayed_bindings.empty()) << "ValueError: " + << "Bindings contain circular dependency"; + + if (required_sorting) { + return sorted_bindings; + } else { + return bindings; + } +} +} // namespace + +void RewriteSpec::Append(RewriteSpec other) { + if (variable_rewrites.empty()) { + *this = std::move(other); + return; + } + if (other.variable_rewrites.empty()) { + return; + } + + NameSupply gvar_name_supply(""); + for (const auto& [gvar, func] : new_subroutines) { + gvar_name_supply->ReserveName(gvar->name_hint); + } + + Map gvar_rewrites; + for (auto [gvar, func] : other.new_subroutines) { + if (auto it = new_subroutines.find(gvar); it != new_subroutines.end()) { + // The two rewrites provide the same GlobalVar. + // (e.g. Multiple rewrites of the same pattern.) Ensure that + // they are referring to the same underlying BaseFunc. + CHECK(func.same_as((*it).second)); + } else if (auto new_name = gvar_name_supply->FreshName(gvar->name_hint); + new_name != gvar->name_hint) { + // The two rewrites provide distinct GlobalVar subroutines, + // but with conflicting names. Because an IRModule must have + // enough names for each GlobalVar, even if they are not + // publicly exposed, one of the GlobalVars must be replaced. + // Replacing the GlobalVar here, when the conflict is first + // identified, minimizes the size of the `relax::Expr` that + // must be updated with `GlobalVarReplacer`. + GlobalVar new_gvar = gvar; + new_gvar.CopyOnWrite()->name_hint = new_name; + gvar_rewrites.Set(gvar, new_gvar); + new_subroutines.Set(new_gvar, func); + } else { + new_subroutines.Set(gvar, func); + } + } + + for (auto [var, expr] : other.variable_rewrites) { + if (gvar_rewrites.size()) { + expr = GlobalVarReplacer(gvar_rewrites)(expr); + } + variable_rewrites.Set(var, expr); + } +} + +TVM_REGISTER_NODE_TYPE(ExprRewriterNode); + +TVM_REGISTER_GLOBAL("relax.dpl.ExprRewriterFromPattern") + .set_body_typed([](DFPattern pattern, + TypedPackedFunc(Expr, Map)> func) { + return ExprRewriter::FromPattern(pattern, func); + }); + +TVM_REGISTER_GLOBAL("relax.dpl.ExprRewriterFromModule").set_body_typed([](IRModule mod) { + return ExprRewriter::FromModule(mod); +}); + +TVM_REGISTER_GLOBAL("relax.dpl.ExprRewriterApply") + .set_body_typed([](ExprRewriter rewriter, + Variant obj) -> Variant { + if (auto expr = obj.as()) { + return rewriter(expr.value()); + } else if (auto mod = obj.as()) { + return rewriter(mod.value()); + } else { + LOG(FATAL) << "Unreachable: object does not contain either variant type"; + } + }); + +TVM_REGISTER_NODE_TYPE(PatternRewriterNode); + +RewriteSpec PatternRewriterNode::RewriteBindings(const Array& bindings) const { + Map variable_rewrites; + Map binding_lookup; + for (const auto& binding : bindings) { + auto bound_value = GetBoundValue(binding); + if (auto new_expr = RewriteExpr(bound_value, binding_lookup)) { + variable_rewrites.Set(binding->var, new_expr.value()); + } else { + binding_lookup.Set(binding->var, bound_value); + } + } + if (variable_rewrites.size()) { + return RewriteSpec{variable_rewrites, new_subroutines}; + } else { + return RewriteSpec(); + } +} + +Optional PatternRewriterNode::RewriteExpr(const Expr& expr, + const Map& bindings) const { + if (auto opt_matches = ExtractMatchedExpr(pattern, expr, bindings)) { + auto matches = opt_matches.value(); + if (additional_bindings) { + // Append any additional matches that from the unwrapped + // `OrPattern`. When matching against `pat = pat_lhs | + // pat_rhs`, we call `ExtractMatchedExpr` on `pat_lhs` and + // `pat_rhs` separately. The top-level `pat` is never seen by + // `ExtractMatchedExpr`, and must be re-added afterward. + auto matched_expr = DFPatternMatcher::UnwrapBindings(expr, bindings); + for (const auto& pat : additional_bindings.value()) { + matches.Set(pat, matched_expr); + } + } + + Optional rewritten_expr = func(expr, matches); + if (rewritten_expr.defined() && !rewritten_expr.same_as(expr)) { + return rewritten_expr.value(); + } + } + return NullOpt; +} + +TVM_REGISTER_GLOBAL("relax.dpl.PatternRewriter") + .set_body_typed([](DFPattern pattern, + TypedPackedFunc(Expr, Map)> func) { + return PatternRewriter(pattern, func); + }); + +PatternRewriter::PatternRewriter(DFPattern pattern, + TypedPackedFunc(Expr, Map)> func, + Optional> additional_bindings, + Map new_subroutines) { + auto node = make_object(); + node->pattern = std::move(pattern); + node->func = std::move(func); + node->additional_bindings = std::move(additional_bindings); + node->new_subroutines = std::move(new_subroutines); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(OrRewriterNode); + +RewriteSpec OrRewriterNode::RewriteBindings(const Array& bindings) const { + auto lhs_match = lhs->RewriteBindings(bindings); + if (!lhs_match) { + // If no rewrites found on LHS, RHS is allowed to modify any + // variable binding. + return rhs->RewriteBindings(bindings); + } + + // The LHS matched some subset of the bindings. These + // replacements may not be normalized expressions, so the RHS may + // only replace variable bindings that haven't been modified by + // the LHS. Variable replacements from the RHS may still occur, + // but will need to wait for the next round of + // iterate-until-converged. + Array remaining_bindings; + for (const auto& binding : bindings) { + if (!lhs_match.variable_rewrites.count(binding->var)) { + remaining_bindings.push_back(binding); + } + } + + if (remaining_bindings.empty()) { + // Early bail-out, the RHS has no bindings available to rewrite. + return lhs_match; + } + + lhs_match.Append(rhs->RewriteBindings(remaining_bindings)); + return lhs_match; +} + +TVM_REGISTER_GLOBAL("relax.dpl.OrRewriter").set_body_typed([](ExprRewriter lhs, ExprRewriter rhs) { + return OrRewriter(lhs, rhs); +}); + +OrRewriter::OrRewriter(ExprRewriter lhs, ExprRewriter rhs) { + auto node = make_object(); + node->lhs = std::move(lhs); + node->rhs = std::move(rhs); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(TupleRewriterNode); + +RewriteSpec TupleRewriterNode::RewriteBindings(const Array& bindings) const { + CHECK_LE(patterns.size(), 3) << "For performance reasons, " + << "matching of implicit tuple patterns is currently limited" + << " to tuples with 3 elements or fewer."; + Map variable_rewrites = GenerateVariableRewrites(bindings); + + if (variable_rewrites.size()) { + return RewriteSpec{variable_rewrites, new_subroutines}; + } else { + return RewriteSpec(); + } +} + +Map TupleRewriterNode::GenerateVariableRewrites(const Array& bindings) const { + Map rewrites; + + Map binding_lookup; + + std::vector info_vec; + + std::unordered_map binding_index_lookup; + + // Initialize a vector of indices, each of which corresponds to a + // potential match for a tuple element. + // + // \param tuple_index_of_current_expr The index for the most recent + // binding. + // + // \param indices An output vector, into which indices will be + // generated. + // + // \returns bool True if the indices could be initialized to a + // potential match. False, otherwise. + auto initialize_indices = [&](size_t tuple_index_of_current_expr, + std::vector& indices) -> bool { + if (!info_vec.back().matches[tuple_index_of_current_expr]) { + return false; + } + + indices = std::vector(patterns.size(), info_vec.size()); + + indices[tuple_index_of_current_expr] = info_vec.size() - 1; + + for (size_t i_rev = 0; i_rev < indices.size(); i_rev++) { + size_t i = indices.size() - i_rev - 1; + if (indices[i] == info_vec.size() - 1) { + continue; + } + + auto binding_index = [&]() -> std::optional { + if (indices[i] == info_vec.size() - 1) { + return info_vec.size() - 1; + } + + for (size_t j_rev = 1; j_rev < info_vec.size(); j_rev++) { + size_t j = info_vec.size() - j_rev - 1; + if (info_vec[j].matches[i] && !info_vec[j].used && + std::all_of(indices.begin() + (j + 1), indices.end(), + [j](size_t prev_binding_index) { return j != prev_binding_index; })) { + return j; + } + } + + return std::nullopt; + }(); + + if (binding_index.has_value()) { + indices[i] = binding_index.value(); + } else { + return false; + } + } + + return true; + }; + + auto decrement_indices = [&](std::vector& indices) -> bool { + ICHECK_EQ(indices.size(), patterns.size()); + + // Step 1, find the first index that can be decremented, while + // still generating a valid set of indices. + size_t i_forward; + for (i_forward = 0; i_forward < indices.size(); i_forward++) { + if (indices[i_forward] == info_vec.size() - 1) { + continue; + } + + bool found_valid = false; + size_t& index = indices[i_forward]; + while (index) { + index--; + if (info_vec[index].matches[i_forward] && !info_vec[index].used && + std::all_of( + indices.begin() + (i_forward + 1), indices.end(), + [index](size_t later_binding_index) { return index != later_binding_index; })) { + found_valid = true; + break; + } + } + if (found_valid) { + break; + } + } + + // Step 2, if we reached the end, then all indices were + // decremented to zero without finding anything. Return false to + // indicate that we've reached the end. + if (i_forward == indices.size()) { + return false; + } + + // Step 3, refill all indices that were decremented to zero before from 0 to + for (size_t i = 0; i < i_forward; i++) { + size_t i_backward = i_forward - (i + 1); + if (indices[i_backward] == info_vec.size() - 1) { + continue; + } + + auto binding_index = [&]() -> std::optional { + for (size_t j_rev = 1; j_rev < info_vec.size(); j_rev++) { + size_t j = info_vec.size() - j_rev - 1; + if (info_vec[j].matches[i_backward] && !info_vec[j].used && + std::all_of(indices.begin() + (j + 1), indices.end(), + [j](size_t prev_binding_index) { return j != prev_binding_index; })) { + return j; + } + } + + return std::nullopt; + }(); + + if (binding_index.has_value()) { + indices[i_backward] = binding_index.value(); + } else { + return false; + } + } + + return true; + }; + + for (size_t i_binding = 0; i_binding < bindings.size(); i_binding++) { + const auto& binding = bindings[i_binding]; + + auto expr = GetBoundValue(binding); + + binding_index_lookup[binding->var] = i_binding; + + info_vec.push_back(VarInfo{ + binding->var, + expr, + patterns.Map( + [&](const DFPattern& pat) { return ExtractMatchedExpr(pat, expr, binding_lookup); }), + std::unordered_set(), + false, + }); + + auto new_match = [&]() -> std::optional, std::vector>> { + std::vector indices; + for (size_t i = 0; i < patterns.size(); i++) { + if (initialize_indices(patterns.size() - i - 1, indices)) { + do { + if (auto match = TryMatchByBindingIndex(info_vec, indices)) { + return std::pair{indices, match.value()}; + } + } while (decrement_indices(indices)); + } + } + return std::nullopt; + }(); + + if (new_match) { + const auto& [indices, exprs] = new_match.value(); + ICHECK_EQ(indices.size(), exprs.size()); + for (size_t i = 0; i < indices.size(); i++) { + ICHECK_LT(indices[i], info_vec.size()); + auto& info = info_vec[indices[i]]; + + ICHECK(!info.used) << "InternalError: " + << "Produced multiple replacements for variable " << info.var; + + rewrites.Set(info.var, exprs[i]); + binding_lookup.erase(info.var); + info.used = true; + } + } else { + binding_lookup.Set(binding->var, expr); + } + + for (const auto& prev_var : FreeVars(expr)) { + if (auto it = binding_index_lookup.find(prev_var); it != binding_index_lookup.end()) { + info_vec[it->second].downstream_usage.insert(binding->var); + } + } + } + + return rewrites; +} + +std::optional> TupleRewriterNode::TryMatchByBindingIndex( + const std::vector& info_vec, const std::vector& indices) const { + ICHECK_GE(indices.size(), 1); + + ICHECK_EQ(indices.size(), patterns.size()); + for (size_t i = 0; i < indices.size(); i++) { + const auto& info = info_vec[indices[i]]; + if (info.used || !info.matches[i]) { + return std::nullopt; + } + } + + Map merged_matches = info_vec[indices[0]].matches[0].value(); + for (size_t i = 1; i < indices.size(); i++) { + for (const auto& [pat, expr] : info_vec[indices[i]].matches[i].value()) { + if (auto it = merged_matches.find(pat); it != merged_matches.end()) { + if (!StructuralEqual()(expr, (*it).second)) { + return std::nullopt; + } + } else { + merged_matches.Set(pat, expr); + } + } + } + + bool tuple_element_is_already_used_outside_of_matched_tuple = [&]() -> bool { + std::unordered_set matched_vars; + for (const auto& [pat, expr] : merged_matches) { + if (auto opt = expr.as()) { + matched_vars.insert(opt.value()); + } + } + + for (size_t index : indices) { + const auto& downstream_of_rewritten_var = info_vec[index].downstream_usage; + + for (const auto& uses_matched_var : downstream_of_rewritten_var) { + if (!matched_vars.count(uses_matched_var)) { + return true; + } + } + } + + return false; + }(); + if (tuple_element_is_already_used_outside_of_matched_tuple) { + return std::nullopt; + } + + auto full_tuple = [&]() -> relax::Expr { + Array fields; + for (size_t index : indices) { + fields.push_back(info_vec[index].expr); + } + return relax::Tuple(fields); + }(); + + auto opt_rewritten = func(full_tuple, merged_matches); + if (!opt_rewritten) { + return std::nullopt; + } + auto rewritten = opt_rewritten.value(); + + if (rewritten.same_as(full_tuple)) { + return std::nullopt; + } + + std::vector rewrites; + if (auto inline_tuple = rewritten.as()) { + const auto& fields = inline_tuple->fields; + CHECK_EQ(fields.size(), indices.size()) + << "Expected to receive " << indices.size() << " values to replace TuplePattern with " + << indices.size() << " fields, but received " << fields.size() << " values"; + rewrites = {fields.begin(), fields.end()}; + } else { + for (size_t i = 0; i < indices.size(); i++) { + rewrites.push_back(TupleGetItem(rewritten, i)); + } + } + return rewrites; +} + +TVM_REGISTER_GLOBAL("relax.dpl.TupleRewriter") + .set_body_typed([](Array patterns, + TypedPackedFunc(Expr, Map)> func) { + return TupleRewriter(patterns, func); + }); + +TupleRewriter::TupleRewriter(Array patterns, + TypedPackedFunc(Expr, Map)> func, + Optional> additional_bindings, + Map new_subroutines) { + auto node = make_object(); + node->patterns = std::move(patterns); + node->func = std::move(func); + node->additional_bindings = std::move(additional_bindings); + node->new_subroutines = std::move(new_subroutines); + data_ = std::move(node); +} + +ExprRewriter ExprRewriter::FromPattern( + DFPattern pattern, TypedPackedFunc(Expr, Map)> func, + Optional> additional_bindings, Map new_subroutines) { + if (auto or_pattern = pattern.as()) { + auto new_additional_bindings = additional_bindings.value_or({}); + new_additional_bindings.push_back(pattern); + return OrRewriter( + ExprRewriter::FromPattern(or_pattern->left, func, new_additional_bindings, new_subroutines), + ExprRewriter::FromPattern(or_pattern->right, func, new_additional_bindings, + new_subroutines)); + } else if (auto tuple_pattern = pattern.as()) { + auto new_additional_bindings = additional_bindings.value_or({}); + new_additional_bindings.push_back(pattern); + // If the Tuple appears as a Relax binding, apply it first. As a + // fallback, also check for implicit tuples. + return OrRewriter( + PatternRewriter(pattern, func, additional_bindings, new_subroutines), + TupleRewriter(tuple_pattern->fields, func, new_additional_bindings, new_subroutines)); + } else { + return PatternRewriter(pattern, func, additional_bindings, new_subroutines); + } +} + +ExprRewriter ExprRewriter::FromModule(IRModule mod) { + Function func_pattern = [&]() { + CHECK(mod->ContainGlobalVar("pattern")) + << "KeyError: " + << "Expected module to contain 'pattern', " + << "a Relax function defining the pattern to be matched, " + << "but the module did not contain a 'pattern' function."; + auto base_func = mod->Lookup("pattern"); + CHECK(base_func->IsInstance()) + << "TypeError: " + << "Expected module to contain 'pattern', " + << "a Relax function defining the pattern to be matched, " + << "but the 'pattern' function was of type " << base_func->GetTypeKey() << "."; + return Downcast(base_func); + }(); + Function func_replacement = [&]() { + CHECK(mod->ContainGlobalVar("replacement")) + << "KeyError: " + + << "Expected module to contain 'replacement', " + << "a Relax function defining the replacement to be matched, " + << "but the module did not contain a 'replacement' function."; + auto base_func = mod->Lookup("replacement"); + CHECK(base_func->IsInstance()) + << "TypeError: " + << "Expected module to contain 'replacement', " + << "a Relax function defining the replacement to be made on a successful match, " + << "but the 'replacement' function was of type " << base_func->GetTypeKey() << "."; + return Downcast(base_func); + }(); + + Map new_subroutines; + for (const auto& [gvar, func] : mod->functions) { + if (gvar->name_hint != "pattern" && gvar->name_hint != "replacement") { + bool is_public = func->GetAttr(tvm::attr::kGlobalSymbol).defined(); + CHECK(!is_public) << "ValueError: " + << "Expected module to have no publicly-exposed functions " + << "other than 'pattern' and 'replacement'. " + << "However, function '" << gvar->name_hint << "' of type " + << func->GetTypeKey() << " is publicly exposed."; + new_subroutines.Set(gvar, func); + } + } + + auto sinfo_pattern = GetStructInfo(func_pattern); + auto sinfo_replacement = GetStructInfo(func_replacement); + CHECK(StructuralEqual()(sinfo_pattern, sinfo_replacement)) + << "ValueError: " + << "The pattern and replacement must have the same signature, " + << "but the pattern has struct info " << sinfo_pattern + << ", while the replacement has struct info " << sinfo_replacement; + + Array param_wildcards; + Map pattern_lookup; + for (const auto& param : func_pattern->params) { + WildcardPattern wildcard; + param_wildcards.push_back(wildcard); + pattern_lookup.Set(param, StructInfoPattern(wildcard, GetStructInfo(param))); + } + + std::function make_pattern = [&](Expr expr) -> DFPattern { + if (auto var = expr.as()) { + return pattern_lookup[var.value()]; + + } else if (auto call = expr.as()) { + auto op = make_pattern(call->op); + auto args = call->args.Map(make_pattern); + return CallPattern(op, args); + + } else if (auto tuple = expr.as()) { + auto fields = tuple->fields.Map(make_pattern); + return TuplePattern(fields); + + } else if (auto tuple_get_item = expr.as()) { + auto tuple = make_pattern(tuple_get_item->tuple); + return TupleGetItemPattern(tuple, tuple_get_item->index); + + } else if (auto op = expr.as()) { + return ExprPattern(op.value()); + + } else if (auto func = expr.as()) { + return ExternFuncPattern(func->global_symbol); + + } else { + LOG(FATAL) << "TypeError: " + << "Cannot convert Relax expression of type " << expr->GetTypeKey() + << " into pattern-matching rule."; + } + }; + + for (const auto& block : func_pattern->body->blocks) { + for (const auto& binding : block->bindings) { + auto value_pattern = make_pattern(GetBoundValue(binding)); + if (auto match_cast = binding.as()) { + value_pattern = StructInfoPattern(value_pattern, match_cast->struct_info); + } + pattern_lookup.Set(binding->var, value_pattern); + } + } + + DFPattern top_pattern = make_pattern(func_pattern->body->body); + + TypedPackedFunc(Expr, Map)> rewriter_func = + [param_wildcards = std::move(param_wildcards), + orig_func_replacement = std::move(func_replacement)]( + Expr expr, Map matches) -> Optional { + auto func_replacement = CopyWithNewVars(orig_func_replacement); + + Array new_blocks; + + Array wildcard_bindings; + ICHECK_EQ(param_wildcards.size(), func_replacement->params.size()); + for (size_t i = 0; i < param_wildcards.size(); i++) { + Expr matched_expr = matches[param_wildcards[i]]; + + // Introduce an intermediate variable, to ensure that the + // MatchCast's target will be a Var, even for expressions that + // wouldn't normally be normalized into a variable. + Var intermediate_var("intermediate_var", GetStructInfo(matched_expr)); + wildcard_bindings.push_back(VarBinding(intermediate_var, matched_expr)); + wildcard_bindings.push_back( + MatchCast(func_replacement->params[i], intermediate_var, GetStructInfo(matched_expr))); + } + + new_blocks.push_back(DataflowBlock(wildcard_bindings)); + + for (const auto& block : func_replacement->body->blocks) { + new_blocks.push_back(block); + } + + return SeqExpr(new_blocks, func_replacement->body->body); + }; + + return ExprRewriter::FromPattern(top_pattern, rewriter_func, NullOpt, new_subroutines); +} + Optional> ExtractMatchedExpr(DFPattern pattern, Expr expr, Optional> bindings_opt) { auto bindings = bindings_opt.value_or({}); @@ -46,12 +789,7 @@ Optional> ExtractMatchedExpr(DFPattern pattern, Expr expr, return NullOpt; } - Map matching; - for (const auto& [pat, matches] : matcher.GetMemo()) { - ICHECK_EQ(matches.size(), 1) << "More than one match for the pattern " << pat; - matching.Set(pat, matches[0]); - } - return matching; + return matcher.GetMemo(); } TVM_REGISTER_GLOBAL("relax.dpl.extract_matched_expr").set_body_typed(ExtractMatchedExpr); @@ -66,34 +804,23 @@ TVM_REGISTER_GLOBAL("relax.dpl.match_expr").set_body_typed(MatchExpr); * \brief Apply pattern matching to each expression, replacing * matches with the output of a user-provided rewriter function. */ -class ExprPatternRewriter : ExprMutator { +class ExprPatternRewriter : public ExprMutator { public: using ExprMutator::VisitExpr_; - ExprPatternRewriter(DFPattern pat, - TypedPackedFunc)> rewriter_func) - : pattern_(pat), rewriter_func_(rewriter_func) {} + ExprPatternRewriter(const ExprRewriterNode* rewriter) : rewriter_(rewriter) {} - template - static Function Run(PatternType pat, - TypedPackedFunc)> rewriter_func, - Function func) { - ExprPatternRewriter rewriter(pat, rewriter_func); - func = Downcast(rewriter(func)); - func = Downcast(RemoveAllUnused(func)); - return func; - } + Map GetNewSubroutines() const { return new_subroutines_; } Expr VisitExpr_(const SeqExprNode* seq) override { - auto cache = bindings_; - SeqExpr prev = GetRef(seq); + SeqExpr prev = Downcast(ExprMutator::VisitExpr_(seq)); StructuralEqual struct_equal; - while (true) { - SeqExpr next = Downcast(builder_->Normalize(ExprMutator::VisitExpr_(prev.get()))); + while (auto opt = TryRewriteSeqExpr(prev)) { + SeqExpr next = Downcast(builder_->Normalize(opt.value())); if (struct_equal(prev, next)) { - return std::move(next); + break; } // Canonicalization may result in two previously-different @@ -112,108 +839,235 @@ class ExprPatternRewriter : ExprMutator { } if (struct_equal(prev, next)) { - return std::move(next); + break; } - // Reset all knowledge of bindings that were collected from - // this SeqExpr. The collected bindings are only after - // the point where they were collected, and we are repeating - // the mutation of this SeqExpr. - bindings_ = cache; prev = next; } - } - void VisitBinding_(const VarBindingNode* binding) override { - auto expr = VisitExpr(binding->value); - bindings_.Set(binding->var, expr); - ReEmitBinding(binding, expr); + return prev; } - Expr VisitExpr(const Expr& expr) override { - auto node = ExprMutator::VisitExpr(expr); + Optional TryRewriteSeqExpr(const SeqExpr& seq) { + Array old_blocks = seq->blocks; + + // If the SeqExpr's output is not a variable, treat it as if it + // were the last variable binding of the last block. This + // simplifies the special handling of the SeqExpr's body. + Optional dummy_output_var = NullOpt; + if (!seq->body->IsInstance()) { + dummy_output_var = Var("dummy_output_var", GetStructInfo(seq->body)); + VarBinding dummy_binding(dummy_output_var.value(), seq->body); + + auto last_block = [&]() { + if (seq->blocks.size()) { + auto last_block = old_blocks.back(); + old_blocks.pop_back(); + return last_block; + } else { + return BindingBlock(Array{}); + } + }(); - std::vector matches_top_level; - if (auto rewritten = TryRewrite(node, pattern_, &matches_top_level)) { - return builder_->Normalize(rewritten.value()); + last_block.CopyOnWrite()->bindings.push_back(dummy_binding); + old_blocks.push_back(last_block); } - return node; - } + auto rewrite_block = [&](Array orig_bindings) -> Array { + auto rewrites = rewriter_->RewriteBindings(orig_bindings); + if (!rewrites) return orig_bindings; - private: - Optional TryRewrite(const Expr& expr, const DFPattern& pattern, - std::vector* matches_top_level) { - ICHECK(matches_top_level); - - // Special handling if the user-supplied pattern is a `OrPattern`. - // While the `ExtractMatchedExpr` can handle matching the - // `OrPattern`, it will return on the first match, even if the - // `rewriter_func_` doesn't apply a replacement. Unpacking the - // `OrPattern` here allows the match to be resumed if - // `rewriter_func_` returns the original function unmodified. - // This is only valid for a top-level match. - if (auto or_pattern = pattern.as()) { - matches_top_level->push_back(pattern); - Optional output = TryRewrite(expr, or_pattern->left, matches_top_level); - if (!output.defined()) { - output = TryRewrite(expr, or_pattern->right, matches_top_level); - } - matches_top_level->pop_back(); - return output; - } - - if (auto opt_matches = ExtractMatchedExpr(pattern, expr, bindings_)) { - auto matches = opt_matches.value(); + for (auto [gvar, func] : rewrites.new_subroutines) { + new_subroutines_.Set(gvar, func); + } - // Append any additional matches that from the unwrapped - // `OrPattern`. When matching against `pat = pat_lhs | - // pat_rhs`, we call `ExtractMatchedExpr` on `pat_lhs` and - // `pat_rhs` separately. The top-level `pat` is never seen by - // `ExtractMatchedExpr`, and must be re-added afterward. - if (matches_top_level->size()) { - auto matched_expr = DFPatternMatcher::UnwrapBindings(expr, bindings_); - for (const auto& pat : *matches_top_level) { - matches.Set(pat, matched_expr); + auto bindings = orig_bindings.Map([&](Binding binding) -> Binding { + if (auto new_expr = rewrites.variable_rewrites.Get(binding->var)) { + if (auto match_cast = binding.as()) { + return MatchCast(binding->var, new_expr.value(), match_cast->struct_info); + } else { + return VarBinding(binding->var, new_expr.value()); + } + } else { + return binding; } + }); + + if (bindings.same_as(orig_bindings)) { + return orig_bindings; } - Expr rewritten_expr = rewriter_func_(expr, matches); - if (!rewritten_expr.same_as(expr)) { - return builder_->Normalize(rewritten_expr); + // The rewriter may have introduced additional dependencies + // between computations. Since pattern-matching only occurs + // within blocks that may be re-ordered, these can be resolved + // by performing a topological sort. + bindings = TopologicalSort(bindings); + + return bindings; + }; + + // Utility function to return the rewrites that should be applied + // to a given block. + auto get_rewrites = [&](BindingBlock block) -> Array { + if (block.as()) { + // Early return for DataflowBlock. Since neither control flow + // nor impure functions are allowed within the dataflow block, + // all bindings may be considered at the same time. + return rewrite_block(block->bindings); } + + RewriteSpec rewrites; + + Array collected_bindings; + Array finalized_bindings; + + auto handle_collected_rewrites = [&]() { + if (collected_bindings.size()) { + auto bindings = rewrite_block(collected_bindings); + if (finalized_bindings.empty()) { + finalized_bindings = bindings; + } else { + for (const auto& binding : bindings) { + finalized_bindings.push_back(binding); + } + } + collected_bindings.clear(); + } + }; + + for (const auto& binding : block->bindings) { + auto value = GetBoundValue(binding); + bool is_dataflow = (!value.as()) && + (!(value.as() && IsImpureCall(Downcast(value)))); + if (is_dataflow) { + // This binding satisfies the dataflow constraints. + collected_bindings.push_back(binding); + } else { + // This binding does not satisfy the dataflow constraints. + // Any operations prior to this binding should be checked + // for pattern-match replacements. + handle_collected_rewrites(); + finalized_bindings.push_back(binding); + } + } + + // Check for rewrites in dataflow operations after the last + // non-dataflow segment. + handle_collected_rewrites(); + + return finalized_bindings; + }; + + // Utility function, check for and apply rewrites to a single + // block. + auto visit_block = [&](BindingBlock old_block) -> BindingBlock { + auto new_bindings = get_rewrites(old_block); + if (new_bindings.same_as(old_block->bindings)) { + return old_block; + } + + if (old_block.as()) { + builder_->BeginDataflowBlock(); + } else { + builder_->BeginBindingBlock(); + } + + for (const auto& binding : new_bindings) { + auto value = builder_->Normalize(GetBoundValue(binding)); + + if (binding.as()) { + builder_->EmitNormalized(VarBinding(binding->var, value)); + } else if (auto match_cast = binding.as()) { + builder_->EmitNormalized(MatchCast(binding->var, value, match_cast->struct_info)); + } else { + LOG(FATAL) << "Binding must be either VarBinding or MatchCast"; + } + } + return builder_->EndBlock(); + }; + + auto new_blocks = old_blocks.Map(visit_block); + if (old_blocks.same_as(new_blocks)) { + return NullOpt; } - return NullOpt; + // Restore the body of the SeqExpr, if needed. + auto new_body = [&]() -> Expr { + if (dummy_output_var) { + auto last_block = new_blocks.back(); + new_blocks.pop_back(); + + auto last_binding = last_block->bindings.back(); + last_block.CopyOnWrite()->bindings.pop_back(); + ICHECK(last_binding->var.same_as(dummy_output_var)); + + if (last_block->bindings.size()) { + new_blocks.push_back(last_block); + } + + return GetBoundValue(last_binding); + } else { + return seq->body; + } + }(); + + return SeqExpr(new_blocks, new_body); } - /*! \brief The pattern for rewriting call nodes */ - DFPattern pattern_; - /*! - * \brief The user-provided rewriter function. Its signature and semantics are: - * - * - (Call, Map) -> Call - * - * Given the matched call node and the map of patterns and - * matched expressions, it should return a new call node to - * replace the original one or the original matched call node as - * is. - */ - TypedPackedFunc)> rewriter_func_; - - /*! \brief The known variable bindings - * - * The variable bindings whose value is known. This must be tracked - * separately from the block builder, so that it can be reset after - * each iteration of the mutate-until-converged loop applied to - * `SeqExpr`. - */ - Map bindings_; + private: + const ExprRewriterNode* rewriter_; + Map new_subroutines_; }; +Expr ExprRewriter::operator()(Expr expr) { + ExprPatternRewriter mutator(get()); + auto new_expr = mutator(expr); + auto new_subroutines = mutator.GetNewSubroutines(); + CHECK_EQ(new_subroutines.size(), 0) + << "If ExprRewriter provides subroutines, " + << "then it must be applied to an entire IRModule. " + << "However, ExprRewriter produced subroutines " << [&]() -> Array { + std::vector vec; + for (const auto& [gvar, func] : new_subroutines) { + vec.push_back(gvar); + } + std::sort(vec.begin(), vec.end(), + [](const GlobalVar& a, const GlobalVar& b) { return a->name_hint < b->name_hint; }); + return vec; + }() << "when applied to " + << "Relax expression of type " << expr->GetTypeKey(); + return new_expr; +} + +IRModule ExprRewriterNode::operator()(IRModule mod, + const tvm::transform::PassContext& pass_ctx) const { + ExprPatternRewriter mutator(this); + + IRModule updates; + for (const auto& [gvar, base_func] : mod->functions) { + if (auto func = base_func.as()) { + auto rewritten = Downcast(mutator(func.value())); + if (!rewritten.same_as(base_func)) { + updates->Add(gvar, rewritten); + } + } + } + + if (updates->functions.size()) { + auto write_ptr = mod.CopyOnWrite(); + write_ptr->Update(updates); + write_ptr->Update(IRModule(mutator.GetNewSubroutines())); + } + + return mod; +} +tvm::transform::PassInfo ExprRewriterNode::Info() const { + return tvm::transform::PassInfo(0, "ExprRewriter", {}, false); +} + Function RewriteCall(const DFPattern& pat, TypedPackedFunc)> rewriter, Function func) { - return ExprPatternRewriter::Run(pat, rewriter, func); + return Downcast(ExprRewriter::FromPattern(pat, rewriter)(func)); } TVM_REGISTER_GLOBAL("relax.dpl.rewrite_call").set_body_typed(RewriteCall); diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 989c1174f41d..b6994f017466 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -101,14 +101,13 @@ bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr auto expr = UnwrapBindings(expr0, var2val_); if (memoize_ && memo_.count(pattern)) { - ICHECK_EQ(memo_[pattern].size(), 1); - return expr.same_as(memo_[pattern][0]); + return expr.same_as(memo_[pattern]); } else { PrimExpr cached_condition = symbolic_expr_condition_; size_t watermark = matched_nodes_.size(); bool out = DFPatternFunctor::VisitDFPattern(pattern, expr); if (out) { - memo_[pattern].push_back(expr); + memo_[pattern] = expr; matched_nodes_.push_back(pattern); } else { ClearMap(watermark); diff --git a/src/relax/ir/dataflow_matcher.h b/src/relax/ir/dataflow_matcher.h index 9036c7630a54..93141af81c7c 100644 --- a/src/relax/ir/dataflow_matcher.h +++ b/src/relax/ir/dataflow_matcher.h @@ -43,7 +43,7 @@ class DFPatternMatcher : public DFPatternFunctor> GetMemo() { return Map>(memo_); } + Map GetMemo() { return memo_; } /* \brief Unwrap trivial expressions/bindings */ static Expr UnwrapBindings(Expr expr, const Map& bindings); @@ -91,7 +91,7 @@ class DFPatternMatcher : public DFPatternFunctor, ObjectPtrHash, ObjectPtrEqual> memo_; + std::unordered_map memo_; var2val_t var2val_; std::vector matched_nodes_; PrimExpr symbolic_expr_condition_{Bool(true)}; diff --git a/src/relax/ir/dataflow_rewriter.h b/src/relax/ir/dataflow_rewriter.h new file mode 100644 index 000000000000..d26695a7ce52 --- /dev/null +++ b/src/relax/ir/dataflow_rewriter.h @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relax/dataflow_rewriter.h + * \brief Pattern match/rewriters for Relax + */ +#ifndef TVM_RELAX_IR_DATAFLOW_REWRITER_H_ +#define TVM_RELAX_IR_DATAFLOW_REWRITER_H_ + +#include +#include +#include +#include + +#include + +#include "dataflow_matcher.h" + +namespace tvm { +namespace relax { + +struct RewriteSpec { + Map variable_rewrites; + Map new_subroutines; + + explicit operator bool() const { return variable_rewrites.size(); } + + void Append(RewriteSpec other); +}; + +class ExprRewriterNode : public tvm::transform::PassNode { + public: + virtual RewriteSpec RewriteBindings(const Array& bindings) const { + return RewriteSpec(); + } + + void VisitAttrs(AttrVisitor* visitor) {} + + IRModule operator()(IRModule mod, const tvm::transform::PassContext& pass_ctx) const override; + tvm::transform::PassInfo Info() const override; + + static constexpr const char* _type_key = "relax.dpl.ExprRewriter"; + TVM_DECLARE_BASE_OBJECT_INFO(ExprRewriterNode, PassNode); +}; + +class ExprRewriter : public tvm::transform::Pass { + public: + static ExprRewriter FromPattern(DFPattern pattern, + TypedPackedFunc(Expr, Map)> func, + Optional> additional_bindings = NullOpt, + Map new_subroutines = {}); + + static ExprRewriter FromModule(IRModule mod); + + Expr operator()(Expr expr); + using Pass::operator(); + + TVM_DEFINE_OBJECT_REF_METHODS(ExprRewriter, Pass, ExprRewriterNode); +}; + +class PatternRewriterNode : public ExprRewriterNode { + public: + DFPattern pattern; + TypedPackedFunc(Expr, Map)> func; + Optional> additional_bindings; + Map new_subroutines; + + RewriteSpec RewriteBindings(const Array& bindings) const final; + + Optional RewriteExpr(const Expr& expr, const Map& bindings) const; + + void VisitAttrs(AttrVisitor* visitor) { + visitor->Visit("pattern", &pattern); + PackedFunc untyped_func = func; + visitor->Visit("func", &untyped_func); + } + + static constexpr const char* _type_key = "relax.dpl.PatternRewriter"; + TVM_DECLARE_BASE_OBJECT_INFO(PatternRewriterNode, ExprRewriterNode); +}; + +class PatternRewriter : public ExprRewriter { + public: + PatternRewriter(DFPattern pattern, + TypedPackedFunc(Expr, Map)> func, + Optional> additional_bindings = NullOpt, + Map new_subroutines = {}); + + TVM_DEFINE_OBJECT_REF_METHODS(PatternRewriter, ExprRewriter, PatternRewriterNode); +}; + +class OrRewriterNode : public ExprRewriterNode { + public: + ExprRewriter lhs; + ExprRewriter rhs; + + RewriteSpec RewriteBindings(const Array& bindings) const override; + + void VisitAttrs(AttrVisitor* visitor) { + visitor->Visit("lhs", &lhs); + visitor->Visit("rhs", &rhs); + } + + static constexpr const char* _type_key = "relax.dpl.OrRewriter"; + TVM_DECLARE_BASE_OBJECT_INFO(OrRewriterNode, ExprRewriterNode); +}; + +class OrRewriter : public ExprRewriter { + public: + OrRewriter(ExprRewriter lhs, ExprRewriter rhs); + + TVM_DEFINE_OBJECT_REF_METHODS(OrRewriter, ExprRewriter, OrRewriterNode); +}; + +class TupleRewriterNode : public ExprRewriterNode { + public: + Array patterns; + TypedPackedFunc(Expr, Map)> func; + Optional> additional_bindings; + Map new_subroutines; + + RewriteSpec RewriteBindings(const Array& bindings) const override; + + void VisitAttrs(AttrVisitor* visitor) { + visitor->Visit("patterns", &patterns); + PackedFunc untyped_func = func; + visitor->Visit("func", &untyped_func); + } + + static constexpr const char* _type_key = "relax.dpl.TupleRewriter"; + TVM_DECLARE_BASE_OBJECT_INFO(TupleRewriterNode, ExprRewriterNode); + + private: + struct VarInfo { + Var var; + Expr expr; + Array>> matches; + std::unordered_set downstream_usage; + bool used = false; + }; + + Map GenerateVariableRewrites(const Array& bindings) const; + + std::optional> TryMatchByBindingIndex(const std::vector& info_vec, + const std::vector& indices) const; +}; + +class TupleRewriter : public ExprRewriter { + public: + TupleRewriter(Array patterns, + TypedPackedFunc(Expr, Map)> func, + Optional> additional_bindings = NullOpt, + Map new_subroutines = {}); + + TVM_DEFINE_OBJECT_REF_METHODS(TupleRewriter, ExprRewriter, TupleRewriterNode); +}; + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_IR_DATAFLOW_REWRITER_H_ diff --git a/tests/python/relax/test_dataflow_rewriter.py b/tests/python/relax/test_dataflow_rewriter.py new file mode 100644 index 000000000000..1d917c59523b --- /dev/null +++ b/tests/python/relax/test_dataflow_rewriter.py @@ -0,0 +1,1388 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import tvm.testing +from tvm.relax.dpl import ExprRewriter +from tvm.script import ir as I, relax as R, tir as T + +import pytest + + +def test_rewrite_defined_by_ir_module(): + @R.rewriter + class Rewriter: + @R.function + def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.add(A, B) + return C + + @R.function + def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.call_pure_packed( + "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + return C + + @R.function + def before(x: R.Tensor([32], "float32")): + R.func_attr({"global_symbol": "main"}) + split = R.split(x, 2) + lhs = split[0] + rhs = split[1] + out = lhs + rhs + return out + + @R.function + def expected(x: R.Tensor([32], "float32")): + R.func_attr({"global_symbol": "main"}) + split = R.split(x, 2) + lhs = split[0] + rhs = split[1] + out = R.call_pure_packed( + "my_optimized_add_impl", lhs, rhs, sinfo_args=R.Tensor([16], "float32") + ) + return out + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_missing_pattern_raises_error(): + """The rewriter must define a pattern to be matched""" + + with pytest.raises(KeyError, match="pattern"): + + @R.rewriter + class Rewriter: + @R.function + def replacement(): + return R.tuple() + + +def test_incorrect_function_type_of_pattern_raises_error(): + """The rewriter's pattern must be a Relax function""" + + with pytest.raises(TypeError, match="pattern"): + + @R.rewriter + class Rewriter: + @T.prim_func + def pattern(): + pass + + @R.function + def replacement(): + return R.tuple() + + +def test_missing_replacement_raises_error(): + """The rewriter must define a replacement""" + + with pytest.raises(KeyError, match="replacement"): + + @R.rewriter + class Rewriter: + @R.function + def pattern(): + return R.tuple() + + +def test_incorrect_function_type_of_replacement_raises_error(): + """The rewriter's replacement must be a Relax function""" + + with pytest.raises(TypeError, match="replacement"): + + @R.rewriter + class Rewriter: + @R.function + def pattern(): + return R.tuple() + + @T.prim_func + def replacement(): + pass + + +def test_mismatch_of_static_shapes_raises_error(): + """The pattern and replacement must accept the same shapes""" + + with pytest.raises(ValueError, match="must have the same signature"): + + @R.rewriter + class Rewriter: + @R.function + def pattern(A: R.Tensor([32])): + return A + + @R.function + def replacement(A: R.Tensor([16])): + return A + + +def test_rewriter_may_be_applied_to_ir_module(): + """A rewriter may mutate an IRModule + + The `ExprRewriter.__call__` implementation may accept either a + single Relax function, or an entire IRModule. If it is passed an + IRModule, then all functions in the `IRModule` are updated. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.add(A, B) + return C + + @R.function + def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.call_pure_packed( + "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + return C + + @I.ir_module + class Before: + @R.function + def func_a(x: R.Tensor([32], "float32")): + split = R.split(x, 2) + lhs = split[0] + rhs = split[1] + out = lhs + rhs + return out + + @R.function + def func_b(x: R.Tensor([16], "float32")): + out = x + x + return out + + @I.ir_module + class Expected: + @R.function + def func_a(x: R.Tensor([32], "float32")): + split = R.split(x, 2) + lhs = split[0] + rhs = split[1] + out = R.call_pure_packed( + "my_optimized_add_impl", lhs, rhs, sinfo_args=R.Tensor([16], "float32") + ) + return out + + @R.function + def func_b(x: R.Tensor([16], "float32")): + out = R.call_pure_packed( + "my_optimized_add_impl", x, x, sinfo_args=R.Tensor([16], "float32") + ) + return out + + After = Rewriter(Before) + tvm.ir.assert_structural_equal(Expected, After) + + +def test_rewriter_may_be_used_as_ir_transform(): + """A rewriter may be used as a tvm.ir.transform.Pass""" + + @R.rewriter + class Rewriter: + @R.function + def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.add(A, B) + return C + + @R.function + def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.call_pure_packed( + "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + return C + + @I.ir_module + class Before: + @R.function + def main(x: R.Tensor([16], "float32")): + y = x + x + return y + + @I.ir_module + class Expected: + @R.function + def main(x: R.Tensor([16], "float32")): + out = R.call_pure_packed( + "my_optimized_add_impl", x, x, sinfo_args=R.Tensor([16], "float32") + ) + return out + + After = tvm.ir.transform.Sequential([Rewriter])(Before) + tvm.ir.assert_structural_equal(Expected, After) + + +def test_same_pattern_applied_multiple_times(): + """The pattern-match may apply multiple times""" + + @R.rewriter + class Rewriter: + @R.function + def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.add(A, B) + return C + + @R.function + def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.call_pure_packed( + "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + return C + + @R.function(private=True) + def before(x: R.Tensor([16], "float32")): + y = x + x + z = y + y + return z + + @R.function(private=True) + def expected(x: R.Tensor([16], "float32")): + y = R.call_pure_packed( + "my_optimized_add_impl", x, x, sinfo_args=R.Tensor([16], "float32") + ) + z = R.call_pure_packed( + "my_optimized_add_impl", y, y, sinfo_args=R.Tensor([16], "float32") + ) + return z + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_composition_of_rewrite_rules(): + """Rewrite rules may be composed together""" + + @R.rewriter + class RewriteAdd: + @R.function + def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = A + B + return C + + @R.function + def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.call_pure_packed( + "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + return C + + @R.rewriter + class RewriteMultiply: + @R.function + def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = A * B + return C + + @R.function + def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + C = R.call_pure_packed( + "my_optimized_mul_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + return C + + @R.function(private=True) + def before( + A: R.Tensor([16], "float32"), + B: R.Tensor([16], "float32"), + C: R.Tensor([16], "float32"), + ): + D = A + B + E = C * D + return E + + @R.function(private=True) + def expected( + A: R.Tensor([16], "float32"), + B: R.Tensor([16], "float32"), + C: R.Tensor([16], "float32"), + ): + D = R.call_pure_packed( + "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + E = R.call_pure_packed( + "my_optimized_mul_impl", C, D, sinfo_args=R.Tensor([16], "float32") + ) + return E + + rewriter = RewriteAdd | RewriteMultiply + + after = rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_recursive_rewrite_rules(): + """Rewrite rules are applied until convergence + + In this test, both the `RewriteAdd` and `RewriteMultiply` patterns + must be applied in order to produce the expected output. However, + the `RewriteMultiply` pattern relies on the expression produced by + the `RewriteAdd` pass. + + """ + + @R.rewriter + class RewriteAdd: + @R.function + def pattern(A: R.Tensor([16], "float32")): + return A + A + + @R.function + def replacement(A: R.Tensor([16], "float32")): + return A * R.const(2.0, "float32") + + @R.rewriter + class RewriteMultiply: + @R.function + def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([], "float32")): + C = A * B + return C + + @R.function + def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([], "float32")): + C = R.call_pure_packed( + "my_optimized_mul_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + return C + + @R.function(private=True) + def before(A: R.Tensor([16], "float32")): + B = A + A + return B + + @R.function(private=True) + def expected(A: R.Tensor([16], "float32")): + B = R.call_pure_packed( + "my_optimized_mul_impl", + A, + R.const(2.0, "float32"), + sinfo_args=R.Tensor([16], "float32"), + ) + return B + + rewriter = RewriteAdd | RewriteMultiply + + after = rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_rewrite_may_introduce_private_relax_subroutines(): + """The replacement may contain subroutines""" + + @R.rewriter + class Rewriter: + @R.function + def pattern(A: R.Tensor([16], "float32")): + return A + A + + @R.function + def replacement(A: R.Tensor([16], "float32")): + return Rewriter.subroutine(A) + + @R.function(private=True) + def subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + return A * R.const(2.0, "float32") + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor([16], "float32")): + B = A + A + C = B + B + return C + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor([16], "float32")): + B = Expected.subroutine(A) + C = Expected.subroutine(B) + return C + + @R.function(private=True) + def subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + return A * R.const(2.0, "float32") + + After = Rewriter(Before) + tvm.ir.assert_structural_equal(Expected, After) + + +def test_rewrite_only_introduces_private_subroutines_when_required(): + """Only subroutines that are used will be added to the module + + Like `test_rewrite_may_introduce_private_relax_subroutines`, but + the rewritten function only requires some of the subroutines + provided by the rewriter. + + """ + + @R.rewriter + class RewriteAdd: + @R.function + def pattern(A: R.Tensor([16], "float32")): + return A + A + + @R.function + def replacement(A: R.Tensor([16], "float32")): + return RewriteAdd.subroutine_add(A) + + @R.function(private=True) + def subroutine_add(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + return A * R.const(2.0, "float32") + + @R.rewriter + class RewriteMul: + @R.function + def pattern(A: R.Tensor([16], "float32")): + return A * A + + @R.function + def replacement(A: R.Tensor([16], "float32")): + return R.call_tir( + RewriteMul.subroutine_mul, [A], out_sinfo=R.Tensor([16], "float32") + ) + + @T.prim_func(private=True) + def subroutine_mul(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + for i in range(16): + B[i] = A[i] * A[i] + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor([16], "float32")): + B = A + A + C = B + B + return C + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor([16], "float32")): + B = Expected.subroutine_add(A) + C = Expected.subroutine_add(B) + return C + + @R.function(private=True) + def subroutine_add(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + return A * R.const(2.0, "float32") + + rewriter = RewriteAdd | RewriteMul + + After = rewriter(Before) + tvm.ir.assert_structural_equal(Expected, After) + + +def test_rewriter_may_not_introduce_public_subroutines(): + """The rewriter may only introduce private functions""" + + with pytest.raises(ValueError, match="is publicly exposed"): + + @R.rewriter + class Rewriter: + @R.function + def pattern(A: R.Tensor([16], "float32")): + return A + A + + @R.function + def replacement(A: R.Tensor([16], "float32")): + return Rewriter.subroutine(A) + + @R.function + def subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + return A * R.const(2.0, "float32") + + +def test_rewrite_branches_may_reuse_subroutine_name(): + """Each rewriter is independent, and may reuse subroutine names""" + + @R.rewriter + class RewriteAdd: + @R.function + def pattern(A: R.Tensor([16], "float32")): + return A + A + + @R.function + def replacement(A: R.Tensor([16], "float32")): + return RewriteAdd.subroutine(A) + + @R.function(private=True) + def subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + return A * R.const(2.0, "float32") + + @R.rewriter + class RewriteMul: + @R.function + def pattern(A: R.Tensor([16], "float32")): + return A * A + + @R.function + def replacement(A: R.Tensor([16], "float32")): + return R.call_tir( + RewriteMul.subroutine, [A], out_sinfo=R.Tensor([16], "float32") + ) + + @T.prim_func(private=True) + def subroutine(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + for i in range(16): + B[i] = A[i] * A[i] + + @I.ir_module + class Before: + @R.function + def main(A: R.Tensor([16], "float32")): + B = A + A + C = B * B + return C + + @I.ir_module + class Expected: + @R.function + def main(A: R.Tensor([16], "float32")): + B = Expected.subroutine(A) + C = R.call_tir( + Expected.subroutine_1, [B], out_sinfo=R.Tensor([16], "float32") + ) + return C + + @R.function(private=True) + def subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], "float32"): + return A * R.const(2.0, "float32") + + @T.prim_func(private=True) + def subroutine_1(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")): + for i in range(16): + B[i] = A[i] * A[i] + + rewriter = RewriteAdd | RewriteMul + + After = rewriter(Before) + tvm.ir.assert_structural_equal(Expected, After) + + +def test_rewrite_of_explicit_relax_tuple(): + """The rewriter function may return a tuple + + When it occurs explicitly within the Relax function, the tuple + pattern matches against the Relax tuple, and the Relax tuple is + replaced. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern( + lhs_A: R.Tensor([16, 16], "float32"), + lhs_B: R.Tensor([16, 16], "float32"), + rhs: R.Tensor([16], "float32"), + ): + proj_A = R.matmul(lhs_A, rhs) + proj_B = R.matmul(lhs_B, rhs) + proj_tuple = (proj_A, proj_B) + return proj_tuple + + @R.function + def replacement( + lhs_A: R.Tensor([16, 16], "float32"), + lhs_B: R.Tensor([16, 16], "float32"), + rhs: R.Tensor([16], "float32"), + ): + lhs = R.concat([lhs_A, lhs_B]) + proj_concat = R.matmul(lhs, rhs) + proj_tuple = R.split(proj_concat, 2) + return proj_tuple + + @R.function(private=True) + def before( + state: R.Tensor([16], "float32"), + A: R.Tensor([16, 16], "float32"), + B: R.Tensor([16, 16], "float32"), + ): + proj_A = R.matmul(A, state) + proj_B = R.matmul(B, state) + proj_tuple = (proj_A, proj_B) + out = proj_tuple[0] + proj_tuple[1] + return out + + @R.function(private=True) + def expected( + state: R.Tensor([16], "float32"), + A: R.Tensor([16, 16], "float32"), + B: R.Tensor([16, 16], "float32"), + ): + concat_AB = R.concat([A, B]) + proj_concat = R.matmul(concat_AB, state) + proj_tuple = R.split(proj_concat, 2) + out = proj_tuple[0] + proj_tuple[1] + return out + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_rewrite_of_output_relax_tuple(): + """The rewriter may update a tuple being returned + + Unlike most relax expressions, tuples may appear as nested + expressions. Pattern-matching should be aware of this option. + + Like `test_rewrite_of_explicit_relax_tuple`, but the tuple appears + as the return value in the function being modified. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern( + lhs_A: R.Tensor([16, 16], "float32"), + lhs_B: R.Tensor([16, 16], "float32"), + rhs: R.Tensor([16], "float32"), + ): + proj_A = R.matmul(lhs_A, rhs) + proj_B = R.matmul(lhs_B, rhs) + proj_tuple = (proj_A, proj_B) + return proj_tuple + + @R.function + def replacement( + lhs_A: R.Tensor([16, 16], "float32"), + lhs_B: R.Tensor([16, 16], "float32"), + rhs: R.Tensor([16], "float32"), + ): + lhs = R.concat([lhs_A, lhs_B]) + proj_concat = R.matmul(lhs, rhs) + proj_tuple = R.split(proj_concat, 2) + return proj_tuple + + @R.function(private=True) + def before( + state: R.Tensor([16], "float32"), + A: R.Tensor([16, 16], "float32"), + B: R.Tensor([16, 16], "float32"), + ): + proj_A = R.matmul(A, state) + proj_B = R.matmul(B, state) + return (proj_A, proj_B) + + @R.function(private=True) + def expected( + state: R.Tensor([16], "float32"), + A: R.Tensor([16, 16], "float32"), + B: R.Tensor([16, 16], "float32"), + ): + concat_AB = R.concat([A, B]) + proj_concat = R.matmul(concat_AB, state) + proj_tuple = R.split(proj_concat, 2) + return proj_tuple + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_rewrite_of_implicit_tuple(): + """The rewriter function may return a tuple + + The tuple being replaced does not need to explicitly exist within + the updated Relax function. So long as each element of the tuple + pattern matches a Relax expression, the pattern match can apply. + + This rule ensures that pattern-matching is never broken when + `CanonicalizeBindings` is applied. + + This test is identical to `test_rewrite_of_explicit_relax_tuple`, + except that the function does not contain the round trip of + packing `proj_A` and `proj_B` into a tuple, then immediately + unpacking them from the tuple. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern( + lhs_A: R.Tensor([16, 16], "float32"), + lhs_B: R.Tensor([16, 16], "float32"), + rhs: R.Tensor([16], "float32"), + ): + proj_A = R.matmul(lhs_A, rhs) + proj_B = R.matmul(lhs_B, rhs) + proj_tuple = (proj_A, proj_B) + return proj_tuple + + @R.function + def replacement( + lhs_A: R.Tensor([16, 16], "float32"), + lhs_B: R.Tensor([16, 16], "float32"), + rhs: R.Tensor([16], "float32"), + ): + lhs = R.concat([lhs_A, lhs_B]) + proj_concat = R.matmul(lhs, rhs) + proj_tuple = R.split(proj_concat, 2) + return proj_tuple + + @R.function(private=True) + def before( + state: R.Tensor([16], "float32"), + A: R.Tensor([16, 16], "float32"), + B: R.Tensor([16, 16], "float32"), + ): + proj_A = R.matmul(A, state) + proj_B = R.matmul(B, state) + out = proj_A + proj_B + return out + + @R.function(private=True) + def expected( + state: R.Tensor([16], "float32"), + A: R.Tensor([16, 16], "float32"), + B: R.Tensor([16, 16], "float32"), + ): + concat_AB = R.concat([A, B]) + proj_concat = R.matmul(concat_AB, state) + proj_tuple = R.split(proj_concat, 2) + out = proj_tuple[0] + proj_tuple[1] + return out + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_rewrite_of_implicit_tuple_with_shared_wildcard(): + """Tuple elements may depend on the same input + + Here, both elements of the tuple depend on `y`. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern( + x: R.Tensor([16], "float32"), + y: R.Tensor([16], "float32"), + z: R.Tensor([16], "float32"), + ): + lhs = x + y + rhs = y + z + return (lhs, rhs) + + @R.function + def replacement( + x: R.Tensor([16], "float32"), + y: R.Tensor([16], "float32"), + z: R.Tensor([16], "float32"), + ): + return R.call_pure_packed( + "optimized_impl", + x, + y, + z, + sinfo_args=R.Tuple( + [ + R.Tensor([16], "float32"), + R.Tensor([16], "float32"), + ] + ), + ) + + @R.function(private=True) + def before( + A: R.Tensor([16], "float32"), + B: R.Tensor([16], "float32"), + C: R.Tensor([16], "float32"), + ): + lhs = A + B + rhs = B + C + out = R.multiply(lhs, rhs) + return out + + @R.function(private=True) + def expected( + A: R.Tensor([16], "float32"), + B: R.Tensor([16], "float32"), + C: R.Tensor([16], "float32"), + ): + lhs_rhs = R.call_pure_packed( + "optimized_impl", + A, + B, + C, + sinfo_args=R.Tuple( + [ + R.Tensor([16], "float32"), + R.Tensor([16], "float32"), + ] + ), + ) + out = R.multiply(lhs_rhs[0], lhs_rhs[1]) + return out + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_no_rewrite_of_implicit_tuple_when_shared_wildcard_is_mismatched(): + """Tuple elements must match simultaneously + + Each element of the tuple matches individually, but the two + elements both depend on `B`. Because the first tuple element + would require `y = B`, while the second tuple element would + require `y = C`, the match fails. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern( + x: R.Tensor([16], "float32"), + y: R.Tensor([16], "float32"), + z: R.Tensor([16], "float32"), + ): + lhs = x + y + rhs = y + z + return (lhs, rhs) + + @R.function + def replacement( + A: R.Tensor([16], "float32"), + B: R.Tensor([16], "float32"), + C: R.Tensor([16], "float32"), + ): + return R.call_pure_packed( + "optimized_impl", + A, + B, + C, + sinfo_args=R.Tuple( + [ + R.Tensor([16], "float32"), + R.Tensor([16], "float32"), + ] + ), + ) + + @R.function(private=True) + def before( + A: R.Tensor([16], "float32"), + B: R.Tensor([16], "float32"), + C: R.Tensor([16], "float32"), + D: R.Tensor([16], "float32"), + ): + lhs = A + B + rhs = C + D + out = R.multiply(lhs, rhs) + return out + + expected = before + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_implicit_tuple_may_not_introduce_extra_compute(): + """Matching of implicit tuple may not cause extra compute + + Here, the `(proj_A, proj_B)` tuple could be an implcit tuple + match, but that would repeat the computation of `proj_A`. It + would be computed once on its own, to be used for `proj_A_on_B`, + and once for computing `(proj_A, proj_B)`. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern( + lhs_A: R.Tensor([16, 16], "float32"), + lhs_B: R.Tensor([16, 16], "float32"), + rhs: R.Tensor([16, 16], "float32"), + ): + proj_A = R.matmul(lhs_A, rhs) + proj_B = R.matmul(lhs_B, rhs) + proj_tuple = (proj_A, proj_B) + return proj_tuple + + @R.function + def replacement( + lhs_A: R.Tensor([16, 16], "float32"), + lhs_B: R.Tensor([16, 16], "float32"), + rhs: R.Tensor([16, 16], "float32"), + ): + lhs = R.concat([lhs_A, lhs_B]) + proj_concat = R.matmul(lhs, rhs) + proj_tuple = R.split(proj_concat, 2) + return proj_tuple + + @R.function(private=True) + def before( + state: R.Tensor([16, 16], "float32"), + A: R.Tensor([16, 16], "float32"), + B: R.Tensor([16, 16], "float32"), + ): + # This function has no location at which a tuple + # `(proj_A,proj_B)` could be constructed, then unpacked. + + proj_A = R.matmul(A, state) + + # A tuple `(proj_A, proj_B)` could not be constructed at this + # location, because `proj_B` has not yet been computed. + + proj_A_on_B = R.matmul(proj_A, B) + proj_B = R.matmul(proj_A_on_B, state) + + # A tuple `(proj_A, proj_B)` could be constructed here, but a + # use-site of `proj_A` has already occurred. Implicit + # matching of a tuple is only allowed if it would replace + # every use-site of a variable. + + out = proj_A + proj_B + return out + + expected = before + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_rewrite_of_implicit_tuple_with_three_elements(): + """Implicit tuples may contain three elements""" + + @R.rewriter + class Rewriter: + @R.function + def pattern(qkv: R.Tensor([12288], "float32")): + qkv_tuple = R.split(qkv, 3, axis=0) + q = qkv_tuple[0] + k = qkv_tuple[1] + v = qkv_tuple[2] + q_embed = R.call_pure_packed( + "rotary_embedding", [q], sinfo_args=R.Tensor([4096], "float32") + ) + k_embed = R.call_pure_packed( + "rotary_embedding", [k], sinfo_args=R.Tensor([4096], "float32") + ) + + return (q_embed, k_embed, v) + + @R.function + def replacement(qkv: R.Tensor([12288], "float32")): + return R.call_pure_packed( + "split_rotary_embedding", + [qkv], + sinfo_args=[ + R.Tensor([4096], "float32"), + R.Tensor([4096], "float32"), + R.Tensor([4096], "float32"), + ], + ) + + @R.function(private=True) + def before( + state: R.Tensor([4096], "float32"), + proj_qkv: R.Tensor([12288, 4096], "float32"), + kv_cache: R.Object, + ): + qkv = R.matmul(proj_qkv, state) + qkv_tuple = R.split(qkv, 3, axis=0) + q = qkv_tuple[0] + k = qkv_tuple[1] + v = qkv_tuple[2] + q_embed = R.call_pure_packed( + "rotary_embedding", [q], sinfo_args=R.Tensor([4096], "float32") + ) + k_embed = R.call_pure_packed( + "rotary_embedding", [k], sinfo_args=R.Tensor([4096], "float32") + ) + + attention = R.call_pure_packed( + "compute_self_attention", + [q_embed, k_embed, v, kv_cache], + sinfo_args=R.Tensor([4096]), + ) + + return attention + + @R.function(private=True) + def expected( + state: R.Tensor([4096], "float32"), + proj_qkv: R.Tensor([12288, 4096], "float32"), + kv_cache: R.Object, + ): + qkv = R.matmul(proj_qkv, state) + embedded_qkv_tuple = R.call_pure_packed( + "split_rotary_embedding", + [qkv], + sinfo_args=[ + R.Tensor([4096], "float32"), + R.Tensor([4096], "float32"), + R.Tensor([4096], "float32"), + ], + ) + + v = embedded_qkv_tuple[2] + q_embed = embedded_qkv_tuple[0] + k_embed = embedded_qkv_tuple[1] + + attention = R.call_pure_packed( + "compute_self_attention", + [q_embed, k_embed, v, kv_cache], + sinfo_args=R.Tensor([4096]), + ) + + return attention + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_pattern_matching_may_not_reorder_across_impure_functions(): + """Matched pattern must be ordered with respect to impure functions + + To ensure that debug printouts, memory management, performance + measurements, etc are not impacted by a pattern match, a pattern + must be entirely before, or entirely after an impure function. A + pattern match in which some parts of the matched expression are + performed before an impure function, while others are performed + afterwards, is not allowed. + + In this test, the matmul and the add may not be fused, because the + impure print statement occurs between them. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern( + state: R.Tensor([16], "float32"), + weights: R.Tensor([16, 16], "float32"), + bias: R.Tensor([16], "float32"), + ): + state = R.matmul(weights, state) + state = R.add(bias, state) + return state + + @R.function + def replacement( + state: R.Tensor([16], "float32"), + weights: R.Tensor([16, 16], "float32"), + bias: R.Tensor([16], "float32"), + ): + return R.call_pure_packed( + "my_optimized_fma_impl", + state, + weights, + bias, + sinfo_args=R.Tensor([16], "float32"), + ) + + @R.function(private=True, pure=False) + def before( + state: R.Tensor([16], "float32"), + weights: R.Tensor([16, 16], "float32"), + bias: R.Tensor([16], "float32"), + ): + R.print(format="Start of function") + state = R.matmul(weights, state) + R.print(format="After matmul, before add") + state = R.add(bias, state) + R.print(format="End of function") + return state + + expected = before + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_pattern_matching_may_occur_between_impure_functions(): + """Matched pattern may be adjacent to impure functions + + To ensure that debug printouts, memory management, performance + measurements, etc are not impacted by a pattern match, a pattern + must be entirely before, or entirely after an impure function. A + pattern match in which some parts of the matched expression are + performed before an impure function, while others are performed + afterwards, is not allowed. + + In this test, the matmul and the add may be fused, because the + pattern occurs without an impure print statement in-between. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern( + state: R.Tensor([16], "float32"), + weights: R.Tensor([16, 16], "float32"), + bias: R.Tensor([16], "float32"), + ): + state = R.matmul(weights, state) + state = R.add(bias, state) + return state + + @R.function + def replacement( + state: R.Tensor([16], "float32"), + weights: R.Tensor([16, 16], "float32"), + bias: R.Tensor([16], "float32"), + ): + return R.call_pure_packed( + "my_optimized_fma_impl", + state, + weights, + bias, + sinfo_args=R.Tensor([16], "float32"), + ) + + @R.function(private=True, pure=False) + def before( + state: R.Tensor([16], "float32"), + weights: R.Tensor([16, 16], "float32"), + bias: R.Tensor([16], "float32"), + ): + R.print(format="Start of function") + state = R.matmul(weights, state) + state = R.add(bias, state) + R.print(format="End of function") + return state + + @R.function(private=True, pure=False) + def expected( + state: R.Tensor([16], "float32"), + weights: R.Tensor([16, 16], "float32"), + bias: R.Tensor([16], "float32"), + ): + R.print(format="Start of function") + state = R.call_pure_packed( + "my_optimized_fma_impl", + state, + weights, + bias, + sinfo_args=R.Tensor([16], "float32"), + ) + R.print(format="End of function") + return state + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_rewrite_may_apply_within_conditional(): + """Rewrites may apply within to inner dataflow regions + + While dataflow regions may not contain conditionals, they may + occur within the body of conditionals. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + return A + B + + @R.function + def replacement(A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32")): + return R.call_pure_packed( + "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + + @R.function(private=True) + def before( + A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32"), cond: R.Prim("bool") + ): + if cond: + out = A + B + else: + C = A + B + out = C + B + return out + + @R.function(private=True) + def expected( + A: R.Tensor([16], "float32"), B: R.Tensor([16], "float32"), cond: R.Prim("bool") + ): + if cond: + out = R.call_pure_packed( + "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + else: + C = R.call_pure_packed( + "my_optimized_add_impl", A, B, sinfo_args=R.Tensor([16], "float32") + ) + out = R.call_pure_packed( + "my_optimized_add_impl", C, B, sinfo_args=R.Tensor([16], "float32") + ) + return out + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_match_dynamic_shape(): + """Pattern match/rewrites may be dynamic + + The tuple being replaced does not need to explicitly exist within + the updated Relax function. So long as each element of the tuple + pattern matches a Relax expression, the pattern match can apply. + + This rule ensures that pattern-matching is never broken when + `CanonicalizeBindings` is applied. + + This test is identical to `test_rewrite_of_explicit_relax_tuple`, + except that the function does not contain the round trip of + packing `proj_A` and `proj_B` into a tuple, then immediately + unpacking them from the tuple. + + """ + + @R.rewriter + class Rewriter: + @R.function + def pattern( + lhs_A: R.Tensor(["N1", "M"], "float32"), + lhs_B: R.Tensor(["N2", "M"], "float32"), + rhs: R.Tensor(["M"], "float32"), + ): + proj_A = R.matmul(lhs_A, rhs) + proj_B = R.matmul(lhs_B, rhs) + return (proj_A, proj_B) + + @R.function + def replacement( + lhs_A: R.Tensor(["N1", "M"], "float32"), + lhs_B: R.Tensor(["N2", "M"], "float32"), + rhs: R.Tensor(["M"], "float32"), + ): + N1 = T.int64() + N2 = T.int64() + + lhs = R.concat([lhs_A, lhs_B]) + proj_concat = R.matmul(lhs, rhs) + proj_A: R.Tensor([N1], "float32") = R.strided_slice( + proj_concat, axes=[0], begin=[0], end=[N1] + ) + proj_B: R.Tensor([N2], "float32") = R.strided_slice( + proj_concat, axes=[0], begin=[N1], end=[N2 + N1] + ) + return (proj_A, proj_B) + + @R.function(private=True) + def before( + state: R.Tensor([16], "float32"), + A: R.Tensor([16, 16], "float32"), + B: R.Tensor([16, 16], "float32"), + ): + proj_A = R.matmul(A, state) + proj_B = R.matmul(B, state) + out = proj_A + proj_B + return out + + @R.function(private=True) + def expected( + state: R.Tensor([16], "float32"), + A: R.Tensor([16, 16], "float32"), + B: R.Tensor([16, 16], "float32"), + ): + concat_AB = R.concat([A, B]) + proj_concat = R.matmul(concat_AB, state) + proj_A = R.strided_slice(proj_concat, axes=[0], begin=[0], end=[16]) + proj_B = R.strided_slice(proj_concat, axes=[0], begin=[16], end=[32]) + out = proj_A + proj_B + return out + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +def test_match_dynamic_pattern_against_dynamic_shape(): + """A dynamic pattern may match a static shape""" + + @R.rewriter + class Rewriter: + @R.function + def pattern( + A: R.Tensor(["M", "N"], "float32"), + B: R.Tensor(["N", "N"], "float32"), + ): + return R.matmul(A, B) + + @R.function + def replacement( + A: R.Tensor(["M", "N"], "float32"), + B: R.Tensor(["N", "N"], "float32"), + ): + M = T.int64() + N = T.int64() + return R.call_pure_packed( + "my_optimized_square_matmul", + A, + B, + sinfo_args=R.Tensor([M, N], "float32"), + ) + + @R.function(private=True) + def before( + A: R.Tensor(["N", "N*2"], "float32"), + B: R.Tensor(["N*2", "N*2"], "float32"), + C: R.Tensor(["N", "N"], "float32"), + ): + N = T.int64() + D: R.Tensor([N, N * 2], "float32") = R.matmul(A, B) + E: R.Tensor([N * 2, N], "float32") = R.permute_dims(D) + F: R.Tensor([N * 2, N], "float32") = R.matmul(E, C) + return F + + @R.function(private=True) + def expected( + A: R.Tensor(["N", "N*2"], "float32"), + B: R.Tensor(["N*2", "N*2"], "float32"), + C: R.Tensor(["N", "N"], "float32"), + ): + N = T.int64() + + D: R.Tensor([N, N * 2], "float32") = R.call_pure_packed( + "my_optimized_square_matmul", + A, + B, + sinfo_args=R.Tensor([N, N * 2], "float32"), + ) + E: R.Tensor([N * 2, N], "float32") = R.permute_dims(D) + F: R.Tensor([N * 2, N], "float32") = R.call_pure_packed( + "my_optimized_square_matmul", + E, + C, + sinfo_args=R.Tensor([N * 2, N], "float32"), + ) + return F + + after = Rewriter(before) + tvm.ir.assert_structural_equal(expected, after) + + +if __name__ == "__main__": + tvm.testing.main()