diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index 63d6fa375c..dce9736ade 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -96,22 +96,20 @@ TVM_DLL Array UndefinedVars(const PrimExpr& expr); TVM_DLL CallEffectKind SideEffect(const PrimExpr& expr); /*! - * \brief Whether e expression used any var in variable set. - * \param expr The expression to be checked. - * \param vset_contains The check function to see if var is in the vset. - * \return Whether e uses vset. + * \brief Whether the given Stmt uses any var in the given variable set. + * \param stmt The Stmt to be checked. + * \param vset_contains The check function to see if a var is in the variable set. + * \return Whether `stmt` uses any var in the given variable set. */ -TVM_DLL bool ExprUseVar(const PrimExpr& expr, std::function vset_contains); +TVM_DLL bool UsesVar(const Stmt& stmt, std::function vset_contains); /*! - * \brief Whether e expression used var. - * \param expr The expression to be checked. - * \param var The variable. - * \return Whether e uses v. + * \brief Whether the given PrimExpr uses any var in the given variable set. + * \param expr The PrimExpr to be checked. + * \param vset_contains The check function to see if var is in the variable set. + * \return Whether `expr` uses any var in the given variable set. */ -inline bool ExprUseVar(const PrimExpr& expr, const Var& var) { - return ExprUseVar(expr, [&](const VarNode* node) { return var.get() == node; }); -} +TVM_DLL bool UsesVar(const PrimExpr& expr, std::function vset_contains); /*! * \brief Verifies whether the IR stmt or Expr is in SSA form. diff --git a/include/tvm/tir/schedule/block_scope.h b/include/tvm/tir/schedule/block_scope.h index fb08583b77..5756eecbed 100644 --- a/include/tvm/tir/schedule/block_scope.h +++ b/include/tvm/tir/schedule/block_scope.h @@ -255,14 +255,14 @@ class BlockScopeNode : public Object { class BlockScope : public ObjectRef { public: /*! \brief The constructor creating an empty block scope with on dependency information */ - TVM_DLL BlockScope(); + TVM_DLL explicit BlockScope(); /*! * \brief Create the object with the specific leaf blocks, and compute the dependency information * between the leaf blocks. * \param child_block_srefs The srefs to the leaf blocks * \note We assume the leaf blocks are given in pre-DFS order */ - TVM_DLL BlockScope(const Array& child_block_srefs); + TVM_DLL explicit BlockScope(const Array& child_block_srefs); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockScope, ObjectRef, BlockScopeNode); }; diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 38a15a8143..4bae8d77d5 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -254,6 +254,17 @@ class ScheduleNode : public runtime::Object { /******** Schedule: loop binding/annotation ********/ /******** Schedule: cache read/write ********/ /******** Schedule: reduction ********/ + /*! + * \brief Factor a reduction block by the specified loop + * \details See python/tvm/tir/schedule/schedule.py + * \param loop_rv The loop outside block we want to do rfactor + * \param factor_axis The position where the new dimension is placed in the new introduced rfactor + * buffer. Suppose the original reduction block writes to buffer `B` with + * ndim(B) dimensions, then `factor_axis` should be in range `[-ndim(B) - 1, + * ndim(B)]`, and the negative index will be normalized to a non-negative one + * \return The rfactor block + */ + virtual BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) = 0; /******** Schedule: blockize & tensorize ********/ }; diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index cc10c218c8..beba189bdc 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -860,6 +860,7 @@ class For : public Stmt { Map annotations = Map(), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(ForNode); }; /*! @@ -1356,6 +1357,23 @@ TVM_DLL PrimExpr TypeAnnotation(DataType dtype, Span span = Span()); // overload printing of for type. TVM_DLL std::ostream& operator<<(std::ostream& os, ForKind kind); +// inline implementations +inline const char* IterVarType2String(ForKind t) { + switch (t) { + case ForKind::kSerial: + return "Serial"; + case ForKind::kParallel: + return "Parallel"; + case ForKind::kVectorized: + return "Vectorized"; + case ForKind::kUnrolled: + return "Unrolled"; + case ForKind::kThreadBinding: + return "ThreadBinding"; + } + return "Unknown"; +} + } // namespace tir } // namespace tvm #endif // TVM_TIR_STMT_H_ diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 67350bd109..9fc656cd08 100755 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -424,7 +424,7 @@ def before_inline(a: ty.handle, c: ty.handle) -> None: .. code-block:: python - sch = tir.Schedule(before_inline, debug_mode=True) + sch = tir.Schedule(before_inline) sch.compute_inline(sch.get_block("B")) print(tvm.script.asscript(sch.mod["main"])) @@ -484,7 +484,7 @@ def before_inline(a: ty.handle, c: ty.handle) -> None: .. code-block:: python - sch = tir.Schedule(before_inline, debug_mode=True) + sch = tir.Schedule(before_inline) sch.reverse_compute_inline(sch.get_block("C")) print(tvm.script.asscript(sch.mod["main"])) @@ -505,6 +505,150 @@ def after_inline(a: ty.handle, c: ty.handle) -> None: ########## Schedule: loop binding/annotation ########## ########## Schedule: cache read/write ########## ########## Schedule: reduction ########## + def rfactor(self, loop: LoopRV, factor_axis: int) -> LoopRV: + """Factorize an associative reduction block by the specified loop. + + An associative reduction cannot be parallelized directly, + because it leads to potential race condition during accumulation. + Alternatively, the reduction could be factorized on a loop with the following steps: + - Step 1: evenly slice the reduction into `n` separate chunks, where `n` is the loop extent + - Step 2: compute the chunks separately and write the result into `n` intermediate buffers; + - Step 3: accumulate the `n` separate buffer into the result buffer. + Note that the Step 2 above introduces opportunities for parallelization. + + RFactor is a schedule primitive that implements the transformation described above: + Given a block that writes to buffer `B`, it factorizes a loop of extent `n`. + + For example, the pesudocode below accumulates `B[i] = sum(A[i, : , : ])`: + + + .. code-block:: python + + for i in range(128): # loop i is a data parallel loop + for j in range(128): # loop j is a reduction loop + for k in range(128): # loop j is a reduction loop + B[i] = B[i] + A[i, j, k] + + + Suppose RFactor is applied on the innermost loop `k` and `factor_axis = 1`. + RFactor then creates an intermediate buffer and two blocks. + + - The intermediate buffer, or "rf-buffer" is a buffer of rank `ndim(B) + 1` and + size `size(B) * n`, whose shape expands from `shape(B)` by adding an axis of `n` + at the position specified by `factor_axis`. For example, + + * shape(B) = [1, 2, 3], factor_axis = 0 => shape(B_rf) = [n, 1, 2, 3] + * shape(B) = [1, 2, 3], factor_axis = 1 => shape(B_rf) = [1, n, 2, 3] + * shape(B) = [1, 2, 3], factor_axis = 2 => shape(B_rf) = [1, 2, n, 3] + * shape(B) = [1, 2, 3], factor_axis = 3 => shape(B_rf) = [1, 2, 3, n] + + - The rfactor block, or "rf-block", is a block that writes to the `rf-buffer` without + accumulating over the loop `k`, i.e. the loop `k` is converted from a reduction loop + to a data parallel loop. In our example, the rf-block is: + + + .. code-block:: python + + B_rf = np.zeros((128, 128)) # the rf-buffer + for k in range(128): # loop k is converted to a data parallel loop + for i in range(128): # loop i is a data parallel loop (unchanged) + for j in range(128): # loop j is a reduction loop (unchanged) + B_rf[i, k] = B_rf[i, k] + A[i, j, k] + + + - The write-back block, or `wb-block`, is a block that accumulates the rf-buffer into + the result buffer. All the reduction loops are removed except the loop `k` for accumulation. + In our example, the wb-block is: + + .. code-block:: python + + for i in range(128): # loop i is a data parallel loop (unchanged) + # loop j is removed because it is a reduction loop + for k in range(128): # loop k is a reduction loop (unchanged) + B[i] = B[i] + B_rf[i, k] + + Parameters + ---------- + loop : LoopRV + The loop outside block for which we want to do rfactor + factor_axis : int + The position where the new dimension is placed in the new introduced rfactor buffer + + Returns + ------- + rf_block : BlockRV + The block which computes partial results over each slices (i.e., the first block + as described in the above illustration) + + Examples + -------- + + Before rfactor, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_rfactor(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128), "float32") + B = tir.match_buffer(b, (128,), "float32") + with tir.block([128, tir.reduce_axis(0, 128), + tir.reduce_axis(0, 128)], "B") as [vii, vi, vj]: + with tir.init(): + B[vii] = 0.0 + B[vii] = B[vii] + A[vii, vi, vj] + + Create the schedule and do rfactor: + + .. code-block:: python + + sch = tir.Schedule(before_rfactor) + _, _, k = sch.get_loops(sch.get_block("B")) + sch.rfactor(k, 0) + print(tvm.script.asscript(sch.mod["main"])) + + After applying rfactor, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_rfactor(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128, 128]) + B = tir.match_buffer(b, [128]) + B_rf = tir.alloc_buffer([128, 128]) + with tir.block([128, 128, tir.reduce_axis(0, 128)], "B_rf") as [vi2, vii, vi]: + with tir.init(): + B_rf[vi2, vii] = 0.0 + B_rf[vi2, vii] = (B_rf[vi2, vii] + A[vii, vi, vi2]) + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vii_1, vi2_1]: + with tir.init(): + B[vii_1] = 0.0 + B[vii_1] = (B[vii_1] + B_rf[vi2_1, vii_1]) + + + Note + ---- + + Rfactor requires: + 1) `loop` has only one child block, and it is a reduction block; + 2) `loop` is a reduction loop, i.e. the loop variable is bound to only reduction variables + in the block binding; + 3) `loop` is not parallelized, vectorized, unrolled or bound to any thread axis; + 4) The block scope that `loop` is in is a staged-pipeline; + 5) The outermost loop outside the reduction block should has the reduction block as its first child block; + 6) The outermost reduction loop should have only one child block; + 7) An unary extent loop that is not bound to any reduction or data parallel variables in the block binding + should not appear under some reduction loop; + 8) The reduction block should write to only one buffer, and its init and body block only is + a simple `BufferStore`, and the pattern is registered as associative reducer. + The pre-defined patterns include: plus, multiplication, min and max; + 9) Each of the loops on top of the block cannot be bound to a data parallel and a reduction + block binding at the same time; + 10) `factor_axis` should be in range `[-ndim(B) - 1, ndim(B)]`, + where `B` is the buffer that the reduction block writes to. + Negative indexing is normalized according to numpy convention. + """ + return _ffi_api_schedule.ScheduleRFactor(self, loop, factor_axis) # pylint: disable=no-member + ########## Schedule: blockize & tensorize ########## diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index ba549959ac..94db659e25 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -1137,8 +1137,10 @@ PrimExpr CanonicalSimplifier::Impl::SimplifyReduceCombiner(const ReduceNode* op) // and recursively mark the corresponding components for (size_t i = 0; i < simplified_result.size(); ++i) if (!used[i]) { - if (ExprUseVar(simplified_result[idx], op->combiner->lhs[i]) || - ExprUseVar(simplified_result[idx], op->combiner->rhs[i])) + if (UsesVar(simplified_result[idx], + [v = op->combiner->lhs[i].get()](const VarNode* var) { return var == v; }) || + UsesVar(simplified_result[idx], + [v = op->combiner->rhs[i].get()](const VarNode* var) { return var == v; })) mark_used(i); } }; diff --git a/src/arith/detect_linear_equation.cc b/src/arith/detect_linear_equation.cc index f0634feac0..d81159bf05 100644 --- a/src/arith/detect_linear_equation.cc +++ b/src/arith/detect_linear_equation.cc @@ -108,7 +108,7 @@ class LinearEqDetector : public ExprFunctor DetectLinearEquation(const PrimExpr& e, const Array& vars) for (size_t i = vars.size(); i > 1; --i) { vset.insert(vars[i - 1].get()); // The previous coeff contains the variable - if (ExprUseVar(coeff[i - 2], vset_contains)) { + if (UsesVar(coeff[i - 2], vset_contains)) { return Array(); } } diff --git a/src/te/autodiff/ad_simplify.cc b/src/te/autodiff/ad_simplify.cc index 76fed053fd..240adf14b3 100644 --- a/src/te/autodiff/ad_simplify.cc +++ b/src/te/autodiff/ad_simplify.cc @@ -834,7 +834,7 @@ std::pair ImplicationNotContainingVars( return {pair_a.first || pair_b.first, (pair_a.first || pair_b.second) && (pair_b.first || pair_a.second) && (pair_a.second || pair_b.second)}; - } else if (!tir::ExprUseVar(cond, [&vars](const VarNode* var) { return vars.count(var); })) { + } else if (!tir::UsesVar(cond, [&vars](const VarNode* var) { return vars.count(var); })) { return {cond, const_true()}; } else { return {const_true(), cond}; @@ -1014,7 +1014,7 @@ PrimExpr TrySimplifyCompute(const PrimExpr& expr, const PrimExpr& cond, // Keep only those variables of the new vars which are used in the new_expr Array used_res_variables; for (const Var& var : res->dst->variables) { - if (ExprUseVar(new_expr, var)) { + if (tir::UsesVar(new_expr, [&var](const VarNode* var_) { return var_ == var.get(); })) { ICHECK(res->dst->ranges.count(var)) << "Range of " << var << " cannot be inferred."; used_res_variables.push_back(var); } diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 9a4eadb356..84fd745598 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -591,7 +591,7 @@ Stmt TransformUpdate(const Stage& stage, const std::unordered_map -#include #include namespace tvm { namespace tir { -class VarTouchVisitor : public ExprVisitor { +class VarTouchVisitor : public StmtExprVisitor { public: - explicit VarTouchVisitor(std::function var_set) : var_set_(var_set) {} + explicit VarTouchVisitor(std::function var_set) + : var_set_(std::move(var_set)) {} + + void VisitStmt(const Stmt& stmt) final { + if (use_var_) return; + StmtExprVisitor::VisitStmt(stmt); + } void VisitExpr(const PrimExpr& e) final { if (use_var_) return; - ExprVisitor::VisitExpr(e); + StmtExprVisitor::VisitExpr(e); } void VisitExpr_(const VarNode* op) final { Handle(op); } + void VisitStmt_(const StoreNode* op) final { + Handle(op->buffer_var.get()); + StmtVisitor::VisitStmt_(op); + } + void VisitExpr_(const LoadNode* op) final { Handle(op->buffer_var.get()); ExprVisitor::VisitExpr_(op); @@ -54,9 +64,15 @@ class VarTouchVisitor : public ExprVisitor { std::function var_set_; }; -bool ExprUseVar(const PrimExpr& e, std::function var_set) { - VarTouchVisitor visitor(var_set); - visitor(e); +bool UsesVar(const Stmt& stmt, std::function var_set) { + VarTouchVisitor visitor(std::move(var_set)); + visitor(stmt); + return visitor.use_var_; +} + +bool UsesVar(const PrimExpr& expr, std::function var_set) { + VarTouchVisitor visitor(std::move(var_set)); + visitor(expr); return visitor.use_var_; } diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 352f75abdf..231261e4b5 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -21,13 +21,10 @@ * \file expr.cc */ #include -#include +#include #include #include -#include -#include - #include "../../support/str_escape.h" namespace tvm { diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 0d713707a5..518a821c16 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -41,30 +41,35 @@ void VerifySRefTree(const ScheduleState& self); */ void VerifyCachedFlags(const ScheduleState& self); -/******** Scope ********/ +/******** IR Module ********/ /*! - * \brief Gets the sref to the scope root block, exclusive - * \param sref The block or loop sref to be retrieved - * \return The sref to the scope root block. NullOpt if `sref` is the root block of the IR + * \brief Get PrimFunc and GlobalVar that the root block belongs to + * \param mod The IRModule + * \param root_block The root block of the PrimFunc + * \param result_g_var The result GlobalVar + * \return The result PrimFunc where the root block belongs to + * \note This function returns the pointer instead of ObjectRef to avoid later copy-on-write */ -Optional GetScopeRoot(const StmtSRef& sref); +const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block, + GlobalVar* result_g_var); +/******** Scope ********/ /*! * \brief Checks if scope the specified sref is in is a stage-pipeline and return it - * \param prim The name of the schedule primitive * \param self The schedule state * \param sref The sref whose scope is to be checked + * \param require_stage_pipeline A boolean indicating whether to check stage pipeline * \throw ScheduleError if the sref has been the root of the AST (so it has no scope root), or its * scope root is not a stage pipeline * \return The block sref to the scope root */ -StmtSRef GetScopeRootAndCheckStagePipeline(const ScheduleState& self, const StmtSRef& sref); +StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, bool require_stage_pipeline); /*! * \brief Checks whether the block is a complete block under the scope * \param self The schedule state * \param block_sref The block to be checked - * \param scope_root The sref to the root block of the scope that `block_sref` is in + * \param scope_root_sref The sref to the root block of the scope that `block_sref` is in * \return A boolean indicating if the block is a complete block * \note Definition of a complete block: * 1) All block vars are data parallel @@ -73,10 +78,10 @@ StmtSRef GetScopeRootAndCheckStagePipeline(const ScheduleState& self, const Stmt * 3) No overlap between the buffers the block reads and writes */ bool IsCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, - const StmtSRef& scope_root); + const StmtSRef& scope_root_sref); /*! - * \brief Checks if the block is a complete block + * \brief Check if the block is a complete block under the scope * \param self The schedule state * \param block_sref The sref to the block whose completeness is to be checked * \param scope_root_sref The scope root of the block @@ -85,6 +90,33 @@ bool IsCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref); +/*! + * \brief Check whether the block is a reduction block under the scope + * \param self The schedule state + * \param block_sref The block to be checked + * \param scope_root_sref The sref to the root block of the scope that `block_sref` is in + * \return A boolean indicating if the block is a reduction block + * \note Definition of a reduction block: + * 1) The block has the `init` statement + * 2) All the block bindings are quasi-affine expressions + * 3) All block vars are either data parallel block vars or reduction block vars + * 4) Dominant: the block is the only writer of its output, dominating the reader of its output + * buffers + * 5) The reduction block vars are not used to index the output buffers + */ +bool IsReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root_sref); + +/*! + * \brief Check if the block is a reduction block under the scope + * \param self The schedule state + * \param block_sref The sref of the block to be checked + * \param scope_root_sref The scope root of the block + * \throw ScheduleError If the block is not a reduction block + */ +void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root_sref); + /******** Binding ********/ /*! * \brief Verifies if the block binding in a specific BlockRealize is an affine binding. @@ -119,6 +151,19 @@ Map LoopDomainOfSRefTreePath(const StmtSRef& low_inclusive, */ Map GetBindings(const BlockRealize& realize); +/*! + * \brief Get the vars involved in the bindings of data parallel block vars and reduction block + * vars, respectively + * \param block_realize The BlockRealize to be analyzed + * \param data_par_vars The vars that appear in the binding of any data parallel block iter + * \param reduce_vars The vars that appear in the binding of any reduction block iter + * \return A boolean indicating whether the block has block iters that is neither a data parallel + * block iter nor a reduction block iter + */ +bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize, + std::unordered_set* data_par_vars, + std::unordered_set* reduce_vars); + /******** Block-loop relation ********/ /*! * \brief Retrieves blocks in a specific function with its name @@ -128,6 +173,7 @@ Map GetBindings(const BlockRealize& realize); * \return A list of blocks with the specific name */ Array GetBlocks(const ScheduleState& self, const String& name, const String& func_name); + /*! * \brief Gets the parent loops of the block in its scope, from outer to inner * \param self The schedule state @@ -135,13 +181,22 @@ Array GetBlocks(const ScheduleState& self, const String& name, const S * \return A list of loops above the given block in its scope, from outer to inner */ Array GetLoops(const StmtSRef& block_sref); + /*! - * \brief Gets the leaf blocks of a scope where a specific block/loop is in + * \brief Gets StmtSRefs of leaf blocks of a scope where a specific block/loop is in * \param self The schedule state * \param parent_sref The StmtSRef that points to the parent block/loop - * \return A list of leaf blocks + * \return A list of StmtSRefs of leaf block + */ +Array GetChildBlockSRefOnSRefTree(const ScheduleState& self, const StmtSRef& parent_sref); + +/*! + * \brief Gets the BlockRealize of the leaf blocks of a scope where a specific block/loop is in + * \param parent_sref The StmtSRef that points to the parent block/loop + * \return A list of leaf BlockRealize */ -Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref); +Array GetChildBlockRealizeOnSRefTree(const StmtSRef& parent_sref); + /*! * \brief Get the direct child Schedulable Stmt (Block and For) * \param stmt the parent stmt. @@ -149,6 +204,47 @@ Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent */ Array GetChildren(const Stmt& stmt); +/*! + * \brief Get the BlockRealize of the single child block of the block or loop specified by + * `parent_sref` on SRef tree, or throw an exception if there is 0 or multiple child blocks + * \param self The schedule state + * \param parent_sref The StmtSRef that points to the parent block/loop + * \return The BlockRealize of the single child block + * \throw ScheduleError If there is 0 or multiple child blocks + */ +BlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self, + const StmtSRef& parent_sref); + +/*! + * \brief Get the BlockRealize of the input block + * \param self The schedule state + * \param block_sref The StmtSRef of the queried block + * \return The BlockRealize of the input block + */ +BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sref); + +/******** Commutative Reducer ********/ + +/*! + * \brief Get the list of the registered reducer-getter functions + * \return The list of the registered reducer-getter functions + * \sa ReducerRegistry + */ +std::vector> GetReducerGetters(); + +/*! + * \brief Given the input identity and the combiner BufferStore of a reduction, extract the + * corresponding commutative reducer and its lhs, rhs if possible. + * \param identity The identity of the reduction + * \param combiner The combiner of the reduction + * \param result_reducer The extracted CommReducer + * \param lhs The extracted lhs of the reducer + * \param rhs The extracted rhs of the reducer + * \return A boolean indicating whether a corresponding commutative reducer is found + */ +bool FromIdentityCombiner(const PrimExpr& identity, const BufferStore& combiner, + CommReducer* result_reducer, PrimExpr* lhs, PrimExpr* rhs); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 7584d36a65..a730f80e68 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -21,8 +21,37 @@ namespace tvm { namespace tir { +/******** IR Module ********/ + +const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block, + GlobalVar* result_g_var) { + for (const auto& kv : mod->functions) { + const GlobalVar& g_var = kv.first; + const BaseFunc& base_func = kv.second; + if (const auto* func = base_func.as()) { + if (const auto* realize = func->body.as()) { + if (realize->block.get() == root_block) { + if (result_g_var != nullptr) { + *result_g_var = g_var; + } + return func; + } + } + } + } + LOG(FATAL) << "IndexError: Could not get the corresponding function in the schedule state of the " + "statement:\n" + << GetRef(root_block); + throw; +} + /******** Scope ********/ +/*! + * \brief Gets the sref to the scope root block, exclusive + * \param sref The block or loop sref to be retrieved + * \return The sref to the scope root block. NullOpt if `sref` is the root block of the IR + */ Optional GetScopeRoot(const StmtSRef& sref) { for (const StmtSRefNode* p = sref->parent; p != nullptr; p = p->parent) { if (p->stmt->IsInstance()) { @@ -32,7 +61,8 @@ Optional GetScopeRoot(const StmtSRef& sref) { return NullOpt; } -StmtSRef GetScopeRootAndCheckStagePipeline(const ScheduleState& self, const StmtSRef& sref) { +StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, + bool require_stage_pipeline) { class RootBlockError : public ScheduleError { public: explicit RootBlockError(IRModule mod) : mod_(mod) {} @@ -75,7 +105,7 @@ Definition of a scope that is a stage pipeline: throw RootBlockError(self->mod); } bool stage_pipeline = self->GetBlockInfo(scope_root_sref).scope->stage_pipeline; - if (stage_pipeline == false) { + if (require_stage_pipeline && stage_pipeline == false) { const BlockNode* block = TVM_SREF_TO_BLOCK(block, scope_root_sref); throw NotStagePipelineError(self->mod, GetRef(block)); } @@ -106,20 +136,29 @@ bool IsDominantBlock(const BlockScope& self, const StmtSRef& block_sref) { return true; } -bool IsCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, - const StmtSRef& scope_root) { - BlockScope scope = self->GetBlockScope(scope_root); +/*! + * \brief A helper function that checks whether a given block is a complete block under the scope, + * or return the condition it violates if it is not a complete block + * \param self The schedule state + * \param block_sref The block to be checked + * \param scope_root_sref The sref to the root block of the scope that `block_sref` is in + * \return 0 if the block is a complete block, or a positive integer indicating which condition is + * first violated + */ +int CheckCompleteBlockErrorCode(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root_sref) { + BlockScope scope = self->GetBlockScope(scope_root_sref); // Cond 1. All block vars are data parallel - const auto* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); for (const IterVar& iter_var : block->iter_vars) { if (iter_var->iter_type != kDataPar) { - return false; + return 1; } } // Cond 2. Dominant: the block is the only writer of its output, // dominating the reader of its output buffers if (!IsDominantBlock(scope, block_sref)) { - return false; + return 2; } // Cond 3. No overlap between the buffers the block reads and writes std::unordered_set written_buffers; @@ -129,35 +168,150 @@ bool IsCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, } for (const BufferRegion& read : block->reads) { if (written_buffers.count(read->buffer.get())) { - return false; + return 3; } } - return true; + return 0; +} + +bool IsCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root_sref) { + return CheckCompleteBlockErrorCode(self, block_sref, scope_root_sref) == 0; } void CheckCompleteBlock(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref) { class IncompleteBlockError : public ScheduleError { public: - explicit IncompleteBlockError(IRModule mod, Block block) : mod_(mod), block_(block) {} + explicit IncompleteBlockError(IRModule mod, Block block, int violated_cond) + : mod_(std::move(mod)), block_(std::move(block)), violated_cond_(violated_cond) {} String FastErrorString() const final { return "ScheduleError: Incomplete block"; } String DetailRenderTemplate() const final { - return R"(The block {0} is not a complete block. -Definition of a complete block: + std::ostringstream os; + os << "The block {0} is not a complete block - it violates condition #" << violated_cond_ + << ".\n" + << R"(Definition of a complete block: 1) All block vars are data parallel 2) Dominant: the block is the only writer of its output, dominating the reader of its output buffers 3) No overlap between the buffers the block reads and writes)"; + return os.str(); } IRModule mod() const final { return mod_; } Array LocationsOfInterest() const final { return {block_}; } IRModule mod_; Block block_; + int violated_cond_; }; - bool result = IsCompleteBlock(self, block_sref, scope_root_sref); - if (result == false) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, scope_root_sref); - throw IncompleteBlockError(self->mod, GetRef(block)); + int error_code = CheckCompleteBlockErrorCode(self, block_sref, scope_root_sref); + if (error_code != 0) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + throw IncompleteBlockError(self->mod, GetRef(block), error_code); + } +} + +/*! + * \brief A helper function that checks whether a given block is a reduction block under the scope, + * or return the condition it violates if it is not a reduction block + * \param self The schedule state + * \param block_sref The block to be checked + * \param scope_root_sref The sref to the root block of the scope that `block_sref` is in + * \return 0 if the block is a reduction block, or a positive integer indicating which condition is + * first violated + */ +int CheckReductionBlockErrorCode(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root_sref) { + BlockScope scope = self->GetBlockScope(scope_root_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + // Cond 1. The block has the `init` statement. + if (!block->init.defined()) { + return 1; + } + // Cond 2. All the block bindings are quasi-affine expressions. + if (!self->IsAffineBlockBinding(block_sref)) { + return 2; + } + // Cond 3. All block vars are either data parallel block vars or reduction block vars. Meanwhile, + // we collect all the reduction block vars. + std::unordered_set reduction_block_vars; + reduction_block_vars.reserve(block->iter_vars.size()); + for (const IterVar& iter_var : block->iter_vars) { + if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) { + return 3; + } else if (iter_var->iter_type == kCommReduce) { + reduction_block_vars.insert(iter_var->var.get()); + } + } + // Cond 4. Dominant: the block is the only writer of its output, dominating the reader of its + // output buffers. + if (!IsDominantBlock(scope, block_sref)) { + return 4; + } + // Cond 5. The reduction block vars are not used to index the output buffers. + std::unordered_set buffer_written; + buffer_written.reserve(block->writes.size()); + for (const BufferRegion& write_region : block->writes) { + buffer_written.insert(write_region->buffer.get()); + } + bool affected = false; + PreOrderVisit(block->body, [&](const ObjectRef& obj) { + if (affected) { + return false; + } + if (const auto* store = obj.as()) { + ICHECK(buffer_written.count(store->buffer.get())) + << "ValueError: The buffer \"" << store->buffer + << "\" is written in the block but is not in the block's signature"; + for (const PrimExpr& index : store->indices) { + if (UsesVar(index, [&reduction_block_vars](const VarNode* var) { + return reduction_block_vars.count(var); + })) { + affected = true; + return false; + } + } + return false; + } + return true; + }); + return !affected ? 0 : 5; +} + +bool IsReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root_sref) { + return CheckReductionBlockErrorCode(self, block_sref, scope_root_sref) == 0; +} + +void CheckReductionBlock(const ScheduleState& self, const StmtSRef& block_sref, + const StmtSRef& scope_root_sref) { + class NotReductionBlockError : public ScheduleError { + public: + explicit NotReductionBlockError(IRModule mod, Block block, int violated_cond) + : mod_(std::move(mod)), block_(std::move(block)), violated_cond_(violated_cond) {} + String FastErrorString() const final { return "ScheduleError: Not a reduction block"; } + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The block {0} is not a reduction block - it violates condition #" << violated_cond_ + << ".\n" + << R"(Definition of a reduction block: +1) The block has the `init` statement +2) All the block bindings are quasi-affine expressions +3) All block vars are either data parallel block vars or reduction block vars +4) Dominant: the block is the only writer of its output, dominating the reader of its output buffers +5) The reduction block vars are not used to index the output buffers)"; + return os.str(); + } + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + IRModule mod_; + Block block_; + int violated_cond_; + }; + + int error_code = CheckReductionBlockErrorCode(self, block_sref, scope_root_sref); + if (error_code != 0) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + throw NotReductionBlockError(self->mod, GetRef(block), error_code); } } @@ -229,6 +383,38 @@ Map GetBindings(const BlockRealize& realize) { return result; } +bool GetVarsTouchedByBlockIters(const BlockRealize& block_realize, + std::unordered_set* data_par_vars, + std::unordered_set* reduce_vars) { + Block block = block_realize->block; + ICHECK(block_realize->block.same_as(block)) + << "ValueError: The input `block_realize` is required to be the exact BlockRealize of the " + "input block"; + + bool has_block_vars_of_other_types = false; + ICHECK_EQ(block->iter_vars.size(), block_realize->iter_values.size()); + int n = static_cast(block->iter_vars.size()); + for (int i = 0; i < n; ++i) { + const IterVar& iter_var = block->iter_vars[i]; + const PrimExpr& iter_value = block_realize->iter_values[i]; + std::unordered_set* set = nullptr; + if (iter_var->iter_type == IterVarType::kDataPar) { + set = data_par_vars; + } else if (iter_var->iter_type == IterVarType::kCommReduce) { + set = reduce_vars; + } else { + has_block_vars_of_other_types = true; + } + + Array vars_in_binding = UndefinedVars(iter_value); + for (const Var& var : vars_in_binding) { + set->insert(var.get()); + } + } + + return has_block_vars_of_other_types; +} + /******** Block-loop relation ********/ Array GetBlocks(const ScheduleState& self, const String& name, const String& func_name) { @@ -265,34 +451,38 @@ Array GetLoops(const StmtSRef& block_sref) { return {result.rbegin(), result.rend()}; } -Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref) { +Array GetChildBlockSRefOnSRefTree(const ScheduleState& self, const StmtSRef& parent_sref) { + Array child_block_realize = GetChildBlockRealizeOnSRefTree(parent_sref); + Array child_block_srefs; + child_block_srefs.reserve(child_block_realize.size()); + + for (BlockRealize realize : child_block_realize) { + child_block_srefs.push_back(self->stmt2ref.at(realize->block.get())); + } + return child_block_srefs; +} + +Array GetChildBlockRealizeOnSRefTree(const StmtSRef& parent_sref) { struct Collector : public StmtVisitor { - public: - static Array Collect(const ScheduleState& self, const Stmt& stmt) { - Collector collector(self); + static Array Collect(const Stmt& stmt) { + Collector collector; collector(stmt); return std::move(collector.result_); } - private: - explicit Collector(const ScheduleState& self) : self_(self) {} - - void VisitStmt_(const BlockNode* block) final { - auto it = self_->stmt2ref.find(block); - ICHECK(it != self_->stmt2ref.end()); - result_.push_back(it->second); + void VisitStmt_(const BlockRealizeNode* block_realize) final { + result_.push_back(GetRef(block_realize)); } - const ScheduleState& self_; - Array result_; + Array result_; }; if (parent_sref->stmt->IsInstance()) { const auto* loop = static_cast(parent_sref->stmt); - return Collector::Collect(self, loop->body); + return Collector::Collect(loop->body); } else if (parent_sref->stmt->IsInstance()) { const auto* block = static_cast(parent_sref->stmt); - return Collector::Collect(self, block->body); + return Collector::Collect(block->body); } ICHECK(false) << "Unreachable"; throw; @@ -328,5 +518,393 @@ Array GetChildren(const Stmt& stmt) { } } +BlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self, + const StmtSRef& parent_sref) { + class NonSingleChildBlockError : public ScheduleError { + public: + explicit NonSingleChildBlockError(IRModule mod, const StmtSRef& sref) + : mod_(std::move(mod)), stmt_(GetRef(sref->stmt)) { + sref_type_ = stmt_.as() != nullptr ? "block" : "loop"; + } + + String FastErrorString() const final { + std::ostringstream os; + os << "ScheduleError: The " << sref_type_ << " is required to have only one child block"; + return os.str(); + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "The " << sref_type_ << " {0} is required to have only one child block"; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {stmt_}; } + + IRModule mod_; + Stmt stmt_; + String sref_type_; + }; + + Array child_block_realize = GetChildBlockRealizeOnSRefTree(parent_sref); + if (child_block_realize.size() != 1) { + throw NonSingleChildBlockError(self->mod, parent_sref); + } + return child_block_realize[0]; +} + +BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sref) { + struct BlockRealizeFinder : public StmtVisitor { + explicit BlockRealizeFinder(const BlockNode* target_block) + : target_block(target_block), result(nullptr) {} + + void VisitStmt(const Stmt& stmt) final { + if (result != nullptr) { + return; + } + StmtVisitor::VisitStmt(stmt); + } + + void VisitStmt_(const BlockRealizeNode* block_realize) final { + if (block_realize->block.get() == target_block) { + result = block_realize; + } + // No need to visit recursively, since the deeper BlockRealizes must not be the result. + } + + const BlockNode* target_block; + const BlockRealizeNode* result; + }; + + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + if (block_sref->parent == nullptr) { + const PrimFuncNode* func = GetRootPrimFunc(self->mod, block, nullptr); + return Downcast(func->body); + } else { + BlockRealizeFinder finder(block); + finder(GetRef(block_sref->parent->stmt)); + ICHECK(finder.result != nullptr) + << "InternalError: Cannot find the BlockRealize of block " << GetRef(block); + return GetRef(finder.result); + } +} + +/******** Pattern Matcher ********/ + +/*! + * \brief PrimExpr pattern matcher. + * + * It is different from the pattern matcher in arith/pattern_match.h, which is dedicated + * for compile-time constant patterns. This pattern matcher can work on dynamic user-specific + * patterns. + * + * The code below shows how to use the pattern matcher. + * + * \code + * + * Var x("x"), y("y"); + * // use PrimExpr to declare patterns, x, y are holes that can be filled with + * PatternMatcher pattern_matcher(x + y); + * // expr = C[i, j] + A[i, k] * B[k, j], which is the expr we want to match + * pattern_matcher.Match(expr); + * + * if (pattern_matcher.Success()) { + * pattern_matcher.Eval(x) // C[i, j] + * pattern_matcher.Eval(y) // A[i, k] * B[k, j] + * } + * + * \endcode + */ +class PatternMatcher : public ExprVisitor { + public: + explicit PatternMatcher(PrimExpr pattern) : pattern_(std::move(pattern)) {} + + void VisitExpr_(const VarNode* op) final { + auto it = filled_map_.find(op); + if (it == filled_map_.end()) { + filled_map_[op] = expr_to_match_; + } else { + ExprDeepEqual equal; + if (it->second.same_as(expr_to_match_) || equal(it->second, expr_to_match_)) return; + match_success_ = false; + } + } + + void VisitExpr_(const LoadNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + if (!op->buffer_var.same_as(ptr->buffer_var)) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + expr_to_match_ = ptr->predicate; + VisitExpr(op->predicate); + expr_to_match_ = ptr->index; + VisitExpr(op->index); + std::swap(expr_to_match_, tmp); + } + } + } + + void VisitExpr_(const LetNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + expr_to_match_ = ptr->var; + VisitExpr(op->var); + expr_to_match_ = ptr->value; + VisitExpr(op->value); + expr_to_match_ = ptr->body; + VisitExpr(op->body); + std::swap(expr_to_match_, tmp); + } + } + + void VisitExpr_(const CallNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + if (!op->op.same_as(ptr->op)) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + for (size_t i = 0; i < op->args.size(); ++i) { + expr_to_match_ = ptr->args[i]; + VisitExpr(op->args[i]); + } + std::swap(expr_to_match_, tmp); + } + } + } + +#define TVM_DECLARE_PATTERN_MATCHER_BIN_OP(OpName) \ + void VisitExpr_(const OpName* op) { \ + const auto* ptr = expr_to_match_.as(); \ + if (ptr == nullptr) { \ + match_success_ = false; \ + } else { \ + PrimExpr current = expr_to_match_; \ + expr_to_match_ = ptr->a; \ + VisitExpr(op->a); \ + expr_to_match_ = ptr->b; \ + VisitExpr(op->b); \ + std::swap(expr_to_match_, current); \ + } \ + } + + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(AddNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(SubNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(MulNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(DivNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(ModNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(FloorDivNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(FloorModNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(MinNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(MaxNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(EQNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(NENode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(LTNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(LENode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(GTNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(GENode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(AndNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(OrNode); + + void VisitExpr_(const CastNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + if (!runtime::TypeEqual(op->dtype, ptr->dtype)) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + expr_to_match_ = ptr->value; + VisitExpr(op->value); + std::swap(expr_to_match_, tmp); + } + } + } + + void VisitExpr_(const NotNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + expr_to_match_ = ptr->a; + VisitExpr(op->a); + std::swap(expr_to_match_, tmp); + } + } + + void VisitExpr_(const SelectNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + expr_to_match_ = ptr->condition; + VisitExpr(op->condition); + expr_to_match_ = ptr->true_value; + VisitExpr(op->true_value); + expr_to_match_ = ptr->false_value; + VisitExpr(op->false_value); + std::swap(expr_to_match_, tmp); + } + } + + void VisitExpr_(const RampNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + if (op->lanes != ptr->lanes) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + expr_to_match_ = ptr->base; + VisitExpr(op->base); + expr_to_match_ = ptr->stride; + VisitExpr(op->stride); + std::swap(expr_to_match_, tmp); + } + } + } + + void VisitExpr_(const BroadcastNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + if (op->lanes != ptr->lanes) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + expr_to_match_ = ptr->value; + VisitExpr(op->value); + std::swap(expr_to_match_, tmp); + } + } + } + + void VisitExpr_(const ShuffleNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + if (op->vectors.size() != ptr->vectors.size() || op->indices.size() != ptr->indices.size()) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + for (size_t i = 0; i < op->indices.size(); ++i) { + expr_to_match_ = ptr->indices[i]; + VisitExpr(op->indices[i]); + } + for (size_t i = 0; i < op->vectors.size(); ++i) { + expr_to_match_ = ptr->vectors[i]; + VisitExpr(op->vectors[i]); + } + std::swap(expr_to_match_, tmp); + } + } + } + + void VisitExpr_(const IntImmNode* op) final { + const auto* ptr = expr_to_match_.as(); + match_success_ = ptr != nullptr && op->value == ptr->value; + } + + void VisitExpr_(const FloatImmNode* op) final { + const auto* ptr = expr_to_match_.as(); + match_success_ = ptr != nullptr && op->value == ptr->value; + } + + void VisitExpr_(const StringImmNode* op) final { + const auto* ptr = expr_to_match_.as(); + match_success_ = ptr != nullptr && op->value == ptr->value; + } + + void VisitExpr_(const BufferLoadNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + if (!op->buffer.same_as(ptr->buffer) || op->indices.size() != ptr->indices.size()) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + for (size_t i = 0; i < op->indices.size(); ++i) { + expr_to_match_ = ptr->indices[i]; + VisitExpr(op->indices[i]); + } + std::swap(expr_to_match_, tmp); + } + } + } + + void Match(const PrimExpr& expr_to_match) { + this->match_success_ = true; + this->filled_map_.clear(); + this->expr_to_match_ = expr_to_match; + this->operator()(pattern_); + } + + PrimExpr Eval(const Var& var) { + auto it = filled_map_.find(var.operator->()); + ICHECK(it != filled_map_.end()) << "Unknown pattern variable"; + ICHECK(match_success_) << "Match failed"; + return it->second; + } + + bool Success() const { return match_success_; } + + private: + bool match_success_{true}; + PrimExpr pattern_, expr_to_match_; + std::unordered_map filled_map_; +}; + +/******** Commutative Reducer ********/ + +bool MatchReducer(const CommReducer& reducer, const PrimExpr& identity, const PrimExpr& combiner, + const BufferLoad& load, PrimExpr* lhs, PrimExpr* rhs) { + if (!ExprDeepEqual()(reducer->identity_element[0], identity)) { + return false; + } + PatternMatcher pattern_matcher(reducer->result[0]); + pattern_matcher.Match(combiner); + if (pattern_matcher.Success()) { + PrimExpr lhs_tmp = pattern_matcher.Eval(reducer->lhs[0]); + PrimExpr rhs_tmp = pattern_matcher.Eval(reducer->rhs[0]); + if (ExprDeepEqual()(load, lhs_tmp)) { + *lhs = std::move(lhs_tmp); + *rhs = std::move(rhs_tmp); + } + return true; + } + return false; +} + +bool FromIdentityCombiner(const PrimExpr& identity, const BufferStore& combiner, + CommReducer* result_reducer, PrimExpr* lhs, PrimExpr* rhs) { + BufferLoad load(combiner->buffer, combiner->indices); + // Check reduction patterns. + for (const TypedPackedFunc& reducer_getter : GetReducerGetters()) { + CommReducer reducer = reducer_getter(identity.dtype()); + if (MatchReducer(reducer, identity, combiner->value, load, lhs, rhs)) { + *result_reducer = std::move(reducer); + return true; + } + } + return false; +} + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index a180bd7613..6436fd1487 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -305,6 +305,16 @@ void ConcreteScheduleNode::ReverseComputeInline(const BlockRV& block_rv) { /******** Schedule: loop binding/annotation ********/ /******** Schedule: cache read/write ********/ /******** Schedule: reduction ********/ + +BlockRV ConcreteScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::RFactor(state_, this->GetSRef(loop_rv), factor_axis); + TVM_TIR_SCHEDULE_END("rfactor", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(result); +} + /******** Schedule: blockize & tensorize ********/ /******** FFI ********/ diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 250246a01e..1d11ca2f5f 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -90,6 +90,9 @@ class ConcreteScheduleNode : public ScheduleNode { /******** Schedule: reduction ********/ /******** Schedule: blockize & tensorize ********/ + /******** Schedule: reduction ********/ + BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) override; + /******** Utility functions ********/ protected: /*! @@ -136,13 +139,13 @@ class ConcreteScheduleNode : public ScheduleNode { inline Block ConcreteScheduleNode::Get(const BlockRV& block_rv) const { StmtSRef sref = this->GetSRef(block_rv); - const auto* block = TVM_SREF_TO_BLOCK(block, sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, sref); return GetRef(block); } inline For ConcreteScheduleNode::Get(const LoopRV& loop_rv) const { StmtSRef sref = this->GetSRef(loop_rv); - const auto* loop = TVM_SREF_TO_FOR(loop, sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop, sref); return GetRef(loop); } diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 4f36910989..058234ed2d 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -78,6 +78,17 @@ TVM_DLL void ReverseComputeInline(ScheduleState self, const StmtSRef& block_sref /******** Schedule: cache read/write ********/ /******** Schedule: reduction ********/ +/*! + * \brief Factor a reduction block by the specified loop + * \details See python/tvm/tir/schedule/schedule.py + * \param loop_sref The loop outside block for which we want to do rfactor + * \param factor_axis The position where the new dimension is placed in the new introduced rfactor + * buffer. Suppose the original reduction block writes to buffer `B` with + * ndim(B) dimensions, then `factor_axis` should be in range `[-ndim(B) - 1, + * ndim(B)]`, and the negative index will be normalized to a non-negative one + * \return The sref of the rfactor block + */ +TVM_DLL StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int factor_axis); /******** Schedule: blockize & tensorize ********/ diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index 6bd6388faf..3892f358e0 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -622,7 +622,8 @@ void ComputeInline(ScheduleState self, const StmtSRef& producer_block_sref) { Block producer_block = GetRef(_producer_block); Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleWrite(self, producer_block); // Step 1. Get the scope block - StmtSRef scope_root_sref = GetScopeRootAndCheckStagePipeline(self, producer_block_sref); + StmtSRef scope_root_sref = + GetScopeRoot(self, producer_block_sref, /*require_stage_pipeline=*/true); // Step 2. Check completeness CheckCompleteBlock(self, producer_block_sref, scope_root_sref); // Step 3. Analyze the block body @@ -649,7 +650,8 @@ void ReverseComputeInline(ScheduleState self, const StmtSRef& consumer_block_sre Block consumer_block = GetRef(_consumer_block); Buffer inlined_buffer = NotSingleReadWriteBuffer::GetSingleRead(self, consumer_block); // Step 1. Get the scope block - StmtSRef scope_root_sref = GetScopeRootAndCheckStagePipeline(self, consumer_block_sref); + StmtSRef scope_root_sref = + GetScopeRoot(self, consumer_block_sref, /*require_stage_pipeline=*/true); // Step 2. Check completeness CheckCompleteBlock(self, consumer_block_sref, scope_root_sref); // Step 3. Check if the consumer has a single complete producer diff --git a/src/tir/schedule/primitive/fuse_split.cc b/src/tir/schedule/primitive/fuse_split.cc index 02a8774f91..9e69155532 100644 --- a/src/tir/schedule/primitive/fuse_split.cc +++ b/src/tir/schedule/primitive/fuse_split.cc @@ -314,7 +314,7 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, // - The execution order has not changed. (The block executes with the same args and the same // order with before. // Step 1. Check correctness - GetScopeRootAndCheckStagePipeline(self, loop_sref); + GetScopeRoot(self, loop_sref, /*require_stage_pipeline=*/true); const auto* loop = loop_sref->StmtAs(); if (loop == nullptr) { throw NotLoopError(self->mod, loop_sref->stmt->GetTypeKey()); @@ -411,7 +411,7 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs) { const ForNode* outer_loop = nullptr; arith::Analyzer analyzer; // Step 1. check correctness - GetScopeRootAndCheckStagePipeline(self, loop_srefs[0]); + GetScopeRoot(self, loop_srefs[0], /*require_stage_pipeline=*/true); for (const StmtSRef& sref : loop_srefs) { const auto* loop = sref->StmtAs(); if (loop == nullptr) { diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc new file mode 100644 index 0000000000..1d89e61b8d --- /dev/null +++ b/src/tir/schedule/primitive/reduction.cc @@ -0,0 +1,964 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/******** Commutative Reducer ********/ + +/*! + * \brief A structure used for registering new commutative reducers, and store all the registered + * reducers. The reducers are preserved in a list, in the form of "reducer-getter function". When + * invoking a reducer-getter function with a specific datatype, the reducer-getter will return the + * CommReducer of the corresponding reduction pattern and the specific datatype + */ +struct ReducerRegistry { + ReducerRegistry() + : reducer_getters{CreateReducerGetter([](const Var& x, const Var& y) { return x + y; }, + [](DataType dtype) { return make_const(dtype, 0); }), + CreateReducerGetter([](const Var& x, const Var& y) { return x * y; }, + [](DataType dtype) { return make_const(dtype, 1); }), + CreateReducerGetter([](const Var& x, const Var& y) { return min(x, y); }, + [](DataType dtype) { return max_value(dtype); }), + CreateReducerGetter([](const Var& x, const Var& y) { return max(x, y); }, + [](DataType dtype) { return min_value(dtype); })} {} + + static void RegisterReducer(TypedPackedFunc combiner_getter, + TypedPackedFunc identity_getter) { + ReducerRegistry::Global()->reducer_getters.push_back(ReducerRegistry::CreateReducerGetter( + std::move(combiner_getter), std::move(identity_getter))); + } + + static TypedPackedFunc CreateReducerGetter( + TypedPackedFunc combiner_getter, + TypedPackedFunc identity_getter) { + return [combiner_getter = std::move(combiner_getter), + identity_getter = std::move(identity_getter)](DataType dtype) -> CommReducer { + Var lhs("x", dtype); + Var rhs("y", dtype); + return CommReducer({lhs}, {rhs}, {combiner_getter(lhs, rhs)}, {identity_getter(dtype)}); + }; + } + + static ReducerRegistry* Global() { + static ReducerRegistry instance; + return &instance; + } + + std::vector> reducer_getters; +}; + +std::vector> GetReducerGetters() { + return ReducerRegistry::Global()->reducer_getters; +} + +class NotSerialLoopKindError : public ScheduleError { + public: + explicit NotSerialLoopKindError(IRModule mod, For loop) + : mod_(std::move(mod)), loop_(std::move(loop)) {} + + String FastErrorString() const final { + return "ScheduleError: The input loop of rfactor is required to be `kSerial`"; + } + + String DetailRenderTemplate() const final { + String str_kind = IterVarType2String(loop_->kind); + ICHECK_NE(str_kind, "Unknown"); + std::ostringstream os; + os << "ScheduleError: The input loop {0} of rfactor is required to be `Serial`. However, the " + "kind of {0} is `" + << str_kind << "`"; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_}; } + + IRModule mod_; + For loop_; +}; + +class InitBodyNotBufferStoreError : public ScheduleError { + public: + explicit InitBodyNotBufferStoreError(IRModule mod, Block block, bool init_is_bufferstore, + bool body_is_bufferstore) + : mod_(std::move(mod)), + block_(std::move(block)), + init_is_bufferstore_(init_is_bufferstore), + body_is_bufferstore_(body_is_bufferstore) {} + + String FastErrorString() const final { + return "ScheduleError: The `init` and `body` of reduction block are required to be both " + "BufferStore"; + } + + String DetailRenderTemplate() const final { + if (!init_is_bufferstore_ && !body_is_bufferstore_) { + return "The `init` and `body` of block {0} are required to be BufferStore so that rfactor " + "can be applied"; + } else if (!init_is_bufferstore_) { + return "The `init` of block {0} is required to be BufferStore so that rfactor can be applied"; + } else { + ICHECK(!body_is_bufferstore_); + return "The `body` of block {0} is required to be BufferStore so that rfactor can be applied"; + } + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + IRModule mod_; + Block block_; + bool init_is_bufferstore_; + bool body_is_bufferstore_; +}; + +class InitBodyNotSameBufferAccessError : public ScheduleError { + public: + explicit InitBodyNotSameBufferAccessError(IRModule mod, Block block) + : mod_(std::move(mod)), block_(std::move(block)) {} + + String FastErrorString() const final { + return "ScheduleError: The `init` and `body` of the reduction block are required to have the " + "same buffer access pattern"; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + const auto* init = block_->init.as(); + const auto* update = block_->body.as(); + os << "The `init` and `body` of the block {0} is required to have the same buffer access " + "pattern. However, in block {0} the `init` writes to " + << init->buffer->name << init->indices << ", and the `body` writes to " + << update->buffer->name << update->indices; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + IRModule mod_; + Block block_; +}; + +class FactorAxisOutOfRangeError : public ScheduleError { + public: + explicit FactorAxisOutOfRangeError(IRModule mod, Buffer buffer, int factor_axis) + : mod_(std::move(mod)), buffer_(std::move(buffer)), factor_axis_(factor_axis) {} + + String FastErrorString() const final { + return "ScheduleError: The input `factor_axis` is out of range. It is required to be in range " + "[-(ndim + 1), ndim] where `ndim` is the number of dimensions of the write buffer"; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + int ndim = static_cast(buffer_->shape.size()); + os << "The write buffer " << buffer_->name << " has " << ndim + << " dimension(s), so `factor_axis` is required to be in [" << -(ndim + 1) << ", " << ndim + << "] for rfactor. However, the input `factor_axis` is " << factor_axis_ + << ", which is out of the expected range"; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {}; } + + static int CheckAndUpdate(const IRModule& mod, const Buffer& buffer, int factor_axis) { + int ndim = static_cast(buffer->shape.size()); + if (factor_axis < -(ndim + 1) || factor_axis > ndim) { + throw FactorAxisOutOfRangeError(mod, buffer, factor_axis); + } + // If factor_axis is negative, convert it to a non-negative one. + if (factor_axis < 0) { + factor_axis += ndim + 1; + } + return factor_axis; + } + + IRModule mod_; + Buffer buffer_; + int factor_axis_; +}; + +class NoMatchedReducerError : public ScheduleError { + public: + explicit NoMatchedReducerError(IRModule mod, PrimExpr identity, BufferStore combiner) + : mod_(std::move(mod)), identity_(std::move(identity)), combiner_(std::move(combiner)) {} + + String FastErrorString() const final { + return "ScheduleError: No matched reducer for the identity and the combiner of this reduction " + "block. So rfactor cannot be applied."; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "No matched reducer for identity " << identity_ << " and combiner " << combiner_ + << "In this case rfactor cannot be applied. You can check tvm::tir::ReducerRegistry for " + "default reducers or registering new reducers."; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {}; } + + IRModule mod_; + PrimExpr identity_; + BufferStore combiner_; +}; + +class LoopPropertyError : public ScheduleError { + public: + enum ErrorType { + kDataParIterTouchRFactorLoop = 0, + kLoopTouchedByBothKindsOfBlockIters = 1, + kNotFirstChildBlockOfOutermostLoop = 2, + kUnboundLoopUnderReductionLoop = 3 + }; + + explicit LoopPropertyError(IRModule mod, For loop, ErrorType error_type) + : mod_(std::move(mod)), loop_(std::move(loop)), error_type_(error_type) {} + + String FastErrorString() const final { + switch (error_type_) { + case kDataParIterTouchRFactorLoop: + return "ScheduleError: The loop to be applied rfactor is required not to be touched by any " + "data parallel block iter of the block"; + case kLoopTouchedByBothKindsOfBlockIters: + return "ScheduleError: The loops outside of the reduction block are required not to be " + "touched by both data parallel block iters and reduction block iters"; + case kNotFirstChildBlockOfOutermostLoop: + return "ScheduleError: The reduction block should be the first child block of the " + "outermost loop outside of it"; + case kUnboundLoopUnderReductionLoop: + return "ScheduleError: A loop who has extent greater than one and is not bound to any " + "block iter should not appear under a reduction loop"; + } + ICHECK(false) << "Unreachable"; + throw; + } + + String DetailRenderTemplate() const final { + switch (error_type_) { + case kDataParIterTouchRFactorLoop: + return "The loop to be applied rfactor is {0}, which is required not to be touched by any " + "data parallel block iter of the block below. However, some of the block's data " + "parallel block iters touch this loop"; + case kLoopTouchedByBothKindsOfBlockIters: + return "It is not allowed that the loop {0} is touched by both some data parallel block " + "iters and some reduction block iters"; + case kNotFirstChildBlockOfOutermostLoop: + return "The first child block of the outermost loop {0} is not the reduction block."; + case kUnboundLoopUnderReductionLoop: + return "The loop {0} has extent greater than one, and is not bound to any block iter. " + "Therefore it shouldn't appear under a reduction loop"; + } + ICHECK(false) << "Unreachable"; + throw; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_}; } + + static void CheckLoopProperty(const ScheduleState& self, const Array& loops, + const ForNode* rf_loop, const Block& block, + const std::unordered_set& data_par_loop_vars, + const std::unordered_set& reduce_loop_vars) { + Array children_of_outermost_loop = + GetChildBlockRealizeOnSRefTree(self->stmt2ref.at(loops[0].get())); + if (!children_of_outermost_loop[0]->block.same_as(block)) { + throw LoopPropertyError(self->mod, loops[0], kNotFirstChildBlockOfOutermostLoop); + } + + bool meet_reduction_loop = false; + for (const For& loop : loops) { + bool data_par_touched = data_par_loop_vars.count(loop->loop_var.get()); + bool reduction_touched = reduce_loop_vars.count(loop->loop_var.get()); + + if (data_par_touched && reduction_touched) { + throw LoopPropertyError(self->mod, loop, kLoopTouchedByBothKindsOfBlockIters); + } else if (data_par_touched) { + if (loop.get() == rf_loop) { + throw LoopPropertyError(self->mod, loop, kDataParIterTouchRFactorLoop); + } + continue; + } else if (reduction_touched) { + if (!meet_reduction_loop) { + CheckGetSingleChildBlockRealizeOnSRefTree(self, self->stmt2ref.at(loop.get())); + meet_reduction_loop = true; + } + continue; + } else if (meet_reduction_loop && !is_one(loop->extent)) { + throw LoopPropertyError(self->mod, loop, kUnboundLoopUnderReductionLoop); + } + } + } + + IRModule mod_; + For loop_; + ErrorType error_type_; +}; + +/*! + * \brief Convert the `init` and `body` of the input block to BufferStores + * \param self The schedule state + * \param block The block to be analyzed + * \return The BufferStores of the `init` and `body` of the input block + * \throw ScheduleError If the `init` or `body` is not BufferStore, or they don't write to the same + * buffer + */ +std::pair GetBufferStoreNodes(const ScheduleState& self, + const Block& block) { + const auto* init = block->init.as(); + const auto* body = block->body.as(); + if (!(init && body)) { + throw InitBodyNotBufferStoreError(self->mod, block, init != nullptr, body != nullptr); + } + if (!init->buffer.same_as(body->buffer)) { + throw InitBodyNotSameBufferAccessError(self->mod, block); + } + int ndim = static_cast(init->buffer->shape.size()); + for (int i = 0; i < ndim; ++i) { + if (!ExprDeepEqual()(init->indices[i], body->indices[i])) { + throw InitBodyNotSameBufferAccessError(self->mod, block); + } + } + return {GetRef(init), GetRef(body)}; +} + +/*! + * \brief Given a reduction identity and a reduction combiner, detect the corresponding commutative + * reducer, and extract the combiner lhs and combiner rhs + * \param self The schedule state + * \param identity The reduction identity to be analyzed + * \param combiner The reduction combiner to be analyzed + * \return The corresponding CommReducer, the combiner lhs and the combiner rhs + * \throw ScheduleError If no corresponding commutative reducer can be matched + */ +std::tuple GetReducerAndCombinerLhsRhs( + const ScheduleState& self, const PrimExpr& identity, const BufferStore& combiner) { + CommReducer reducer{nullptr}; + PrimExpr combiner_lhs{nullptr}, combiner_rhs{nullptr}; + bool matched = FromIdentityCombiner(identity, combiner, &reducer, &combiner_lhs, &combiner_rhs); + if (!matched) { + throw NoMatchedReducerError(self->mod, identity, combiner); + } + return {std::move(reducer), std::move(combiner_lhs), std::move(combiner_rhs)}; +} + +/*! + * \brief For each loop in the given array of loop, associate its loop var with the loop itself + * using a mapping + * \param loops The loops to be analyzed + * \return A mapping from loops to their corresponding loop vars + */ +std::unordered_map GetLoopVar2LoopMap(const Array& loops) { + std::unordered_map loop_vars2loop; + loop_vars2loop.reserve(loops.size()); + for (const For& loop : loops) { + loop_vars2loop[loop->loop_var.get()] = loop; + } + return loop_vars2loop; +} + +/*! + * \brief Create the intermediate rfactor buffer, which the rfactor block writes to and the + * write-back block reads from + * \param buffer The buffer written by the reduction block + * \param factor_axis The `factor_axis` parameter of rfactor + * \param rf_loop The rfactor loop + * \return The new created intermediate rfactor buffer + */ +Buffer CreateRFactorBuffer(const Buffer& buffer, int factor_axis, const ForNode* rf_loop) { + Array rf_shape = buffer->shape; + rf_shape.insert(rf_shape.begin() + factor_axis, rf_loop->extent); + + ObjectPtr n = make_object(*buffer.get()); + n->shape = rf_shape; + n->name = buffer->name + ".rf"; + n->data = buffer->data.copy_with_suffix(".rf"); + return Buffer(n); +} + +/*! + * \brief The base class of the rfactor/write-back block creator, which creates the blocks in four + * steps: + * 1) Create the new block iters and the their iter bindings + * 2) Create the reduction update of the new block + * 3) Create the read/write regions of the new block + * 4) Create the new block and the new block-realize + */ +class BaseBlockCreator { + public: + explicit BaseBlockCreator(BlockRealize old_block_realize, For rf_loop, + BufferStore old_reduction_update, CommReducer reducer, Buffer rf_buffer, + bool is_rf_block) + : old_block_realize_(std::move(old_block_realize)), + rf_loop_(std::move(rf_loop)), + old_reduction_update_(std::move(old_reduction_update)), + reducer_(std::move(reducer)), + rf_buffer_(std::move(rf_buffer)), + is_rf_block_(is_rf_block) { + n_block_iters_ = static_cast(old_block_realize_->iter_values.size()); + } + + void CreateBlock() { + CreateAdditionalIter(); + for (int i = 0; i < n_block_iters_; ++i) { + CreateNormalIters(i); + } + CreateReductionUpdate(); + CreateReadWriteRegions(); + + String new_block_name = old_block_realize_->block->name_hint; + PrimExpr predicate = Bool(true); + if (is_rf_block_) { + new_block_name = new_block_name + "_rf"; + predicate = old_block_realize_->predicate; + } + new_block_ = Block( + /*iter_vars=*/iter_vars_, + /*reads=*/read_regions_, + /*writes=*/write_regions_, + /*name_hint=*/new_block_name, + /*body=*/new_reduction_update_, + /*init=*/ + BufferStore(new_reduction_update_->buffer, reducer_->identity_element[0], + new_reduction_update_->indices)); + new_block_realize_ = BlockRealize(iter_values_, predicate, new_block_); + } + + private: + virtual void CreateAdditionalIter() = 0; + virtual void CreateNormalIters(int idx) = 0; + virtual void CreateReductionUpdate() = 0; + virtual void CreateReadWriteRegions() = 0; + + public: + /*! \brief The new created block */ + Block new_block_; + /*! \brief The new created block-realize */ + BlockRealize new_block_realize_; + /*! \brief The indices used to access the intermediate rfactor buffer */ + Array rf_buf_access_indices_; + + protected: + /*! \brief The old block-realize */ + BlockRealize old_block_realize_; + /*! \brief The number of block iters in the old block */ + int n_block_iters_; + /*! \brief The rfactor loop */ + For rf_loop_; + /*! \brief The update BufferStore of the old block */ + BufferStore old_reduction_update_; + /*! \brief The matched commutative reducer */ + CommReducer reducer_; + /*! \brief The intermediate rfactor buffer */ + Buffer rf_buffer_; + + /*! \brief Whether we are creating the rfactor block or the write-back block */ + bool is_rf_block_; + /*! \brief The new block iters of the new created block */ + std::vector iter_vars_; + /*! \brief The new block iter bindings of the new created block-realize */ + std::vector iter_values_; + /*! + * \brief A mapping which maps old block iters to new expressions. The old iters will be replaced + * by the expressions in future substitution for the two blocks + */ + Map var_map_; + /*! \brief The update BufferStore of the new created block */ + BufferStore new_reduction_update_; + /*! \brief The read regions of the new created block */ + Array read_regions_; + /*! \brief The write regions of the new created block */ + Array write_regions_; +}; + +/*! + * \brief The derived class of the rfactor block creator, which implements all virtual methods in + * the base creator + * \details Start constructing the rfactor block. The main difficulty to construct the rfactor block + * is to create its block iters. So here we introduce the algorithm to create the block iters. + * 1. Create a block iter for the rfactor loop. The block binding of this iter is the loop var, and + * the block iter is data parallel. + * 2. For all the old block's block iters, there are two cases: + * (a) If it is data parallel block iter, or a reduction block iter which doesn't touch the + * rfactor loop, we keep it and its block binding in the rfactor block. + * (b) Otherwise it is a reduction block iter which touches the rfactor loop. In this case, we + * "split" the block iter into one or more new block iters and do not keep the old block + * var. More specifically, we create a new reduction block iter for each loop var that + * appears in the reduction block iter's binding (except for the rfactor loop), and the + * binding of the new block iter is exactly the loop var. (Note that for each loop var, we + * create at most one block iter, even if there are multiple old block iters which touch + * both this loop and the rfactor loop). + * Then we substitute the appearances of the old block iter with the new created block + * iters by recording two mappings: one maps loops vars to new created block iters which + * is used for binding substitution, and another maps old block iters to new expressions + * which is used for substitutions of the old block iters. + */ +class RFactorBlockCreator : public BaseBlockCreator { + public: + explicit RFactorBlockCreator(BlockRealize old_block_realize, For rf_loop, + BufferStore old_reduction_update, CommReducer reducer, + Buffer rf_buffer, + std::unordered_map loop_vars2loop, + int factor_axis, PrimExpr combiner_rhs) + : BaseBlockCreator(std::move(old_block_realize), std::move(rf_loop), + std::move(old_reduction_update), std::move(reducer), std::move(rf_buffer), + true), + loop_vars2loop_(std::move(loop_vars2loop)), + factor_axis_(factor_axis), + combiner_rhs_(std::move(combiner_rhs)) {} + + private: + void CreateAdditionalIter() final { + // Create a new data parallel block iter for the rfactor loop. + additional_iter_ = + IterVarFromLoop(rf_loop_, "v" + rf_loop_->loop_var->name_hint, IterVarType::kDataPar); + loop_var2block_binding_[rf_loop_->loop_var.get()] = additional_iter_->var; + iter_vars_.push_back(additional_iter_); + iter_values_.push_back(rf_loop_->loop_var); + } + + void CreateNormalIters(int idx) final { + IterVar old_iter = old_block_realize_->block->iter_vars[idx]; + PrimExpr old_binding = old_block_realize_->iter_values[idx]; + if (old_iter->iter_type == IterVarType::kDataPar || + !UsesVar(old_binding, + [v = rf_loop_->loop_var.get()](const VarNode* var) { return var == v; })) { + // The old block iter is either a data parallel block iter, or a reduction block iter that + // doesn't touch the rfactor loop. In this case reuse the old reduction block iter and its + // corresponding binding. + iter_vars_.push_back(old_iter); + iter_values_.push_back(old_binding); + return; + } + ICHECK(old_iter->iter_type == kCommReduce); + // This block iter is a reduction block iter that touches the rfactor loop. So next we try to + // create a new block iter for all loop vars that appear in the old binding. + Array vars_in_old_binding = UndefinedVars(old_binding); + for (const Var& var : vars_in_old_binding) { + auto it = loop_vars2loop_.find(var.get()); + if (it == loop_vars2loop_.end()) { + // `var` is not a loop var. So skip. + continue; + } + const For& loop = it->second; + if (loop_var2block_binding_.find(var.get()) == loop_var2block_binding_.end()) { + // We haven't created the new block iter for `var`. So here we create it, append it + // and its binding to `rf_block_iter_vars` and `rf_block_iter_values` respectively. + IterVar new_iter_var = + IterVarFromLoop(loop, "v" + loop->loop_var->name_hint, IterVarType::kCommReduce); + loop_var2block_binding_[var.get()] = new_iter_var->var; + iter_vars_.push_back(new_iter_var); + iter_values_.push_back(var); + } + } + // Substitute the original binding with new block iters. Store the result expression + // in `rf_var_map` for future substitution. + var_map_.Set(old_iter->var, Substitute(old_binding, loop_var2block_binding_)); + } + + void CreateReductionUpdate() final { + rf_buf_access_indices_ = old_reduction_update_->indices; + rf_buf_access_indices_.insert(rf_buf_access_indices_.begin() + factor_axis_, + additional_iter_->var); + new_reduction_update_ = BufferStore( + rf_buffer_, + (*reducer_.get())({BufferLoad(rf_buffer_, rf_buf_access_indices_)}, {combiner_rhs_})[0], + rf_buf_access_indices_); + new_reduction_update_ = Downcast(Substitute(new_reduction_update_, var_map_)); + } + + void CreateReadWriteRegions() final { + const Block& old_block = old_block_realize_->block; + read_regions_ = CreateRegions(old_block->reads); + write_regions_ = CreateRegions(old_block->writes); + } + + Array CreateRegions(const Array& old_regions) { + Array new_regions; + new_regions.reserve(old_regions.size()); + for (const BufferRegion& buffer_region : old_regions) { + if (buffer_region->buffer.same_as(old_reduction_update_->buffer)) { + Array region = buffer_region->region; + region.insert(region.begin() + factor_axis_, + Range::FromMinExtent(additional_iter_->var, 1)); + new_regions.push_back(BufferRegion(rf_buffer_, Substitute(region, var_map_))); + } else { + new_regions.push_back( + BufferRegion(buffer_region->buffer, Substitute(buffer_region->region, var_map_))); + } + } + return new_regions; + } + + public: + /*! \brief The generated additional block iter in rfactor block for the rfactor loop */ + IterVar additional_iter_; + + private: + /*! + * \brief A mapping which maps a loop var to its corresponding For loop for all the reduction + * block's outer loops + */ + std::unordered_map loop_vars2loop_; + /*! \brief The factor_axis specified for rfactor */ + int factor_axis_; + /*! \brief The rhs of the combiner in the reduction update of the old block */ + PrimExpr combiner_rhs_; + /*! + * \brief A mapping which maps loop vars to new created block iters. This map is used to + * substitute the loop vars which appear in the bindings of some old block iters with the new + * created block iters + */ + std::unordered_map loop_var2block_binding_; +}; + +/*! + * \brief The derived class of the write-back block creator, which implements all virtual methods in + * the base creator + */ +class WriteBackBlockCreator : public BaseBlockCreator { + public: + explicit WriteBackBlockCreator(BlockRealize old_block_realize, For rf_loop, + BufferStore old_reduction_update, CommReducer reducer, + Buffer rf_buffer, IterVar rf_additional_iter, + PrimExpr combiner_lhs, Array rf_buf_access_indices) + : BaseBlockCreator(std::move(old_block_realize), std::move(rf_loop), + std::move(old_reduction_update), std::move(reducer), std::move(rf_buffer), + false), + rf_additional_iter_(std::move(rf_additional_iter)), + combiner_lhs_(std::move(combiner_lhs)) { + iter_vars_.reserve(n_block_iters_); + iter_values_.reserve(n_block_iters_); + rf_buf_access_indices_ = std::move(rf_buf_access_indices); + } + + private: + void CreateAdditionalIter() final { + // Create a new reduction block iter for the rfactor loop. + IterVar wb_new_block_iter = + IterVarFromLoop(rf_loop_, "v" + rf_loop_->loop_var->name_hint, kCommReduce); + iter_vars_.push_back(wb_new_block_iter); + iter_values_.push_back(rf_loop_->loop_var); + var_map_.Set(rf_additional_iter_->var, wb_new_block_iter->var); + } + + void CreateNormalIters(int idx) final { + IterVar old_block_iter = old_block_realize_->block->iter_vars[idx]; + if (old_block_iter->iter_type == IterVarType::kDataPar) { + iter_vars_.emplace_back(old_block_iter->dom, old_block_iter->var.copy_with_suffix(""), + kDataPar); + iter_values_.push_back(old_block_realize_->iter_values[idx]); + var_map_.Set(old_block_iter->var, iter_vars_.back()); + } + } + + void CreateReductionUpdate() final { + wb_lhs_ = Downcast(Substitute(combiner_lhs_, var_map_)); + wb_rhs_ = + Downcast(Substitute(BufferLoad(rf_buffer_, rf_buf_access_indices_), var_map_)); + new_reduction_update_ = + BufferStore(old_reduction_update_->buffer, (*reducer_.get())({wb_lhs_}, {wb_rhs_})[0], + old_reduction_update_->indices); + new_reduction_update_ = Downcast(Substitute(new_reduction_update_, var_map_)); + } + + void CreateReadWriteRegions() final { + read_regions_.push_back(CreateRegion(wb_lhs_)); + read_regions_.push_back(CreateRegion(wb_rhs_)); + write_regions_.push_back(read_regions_[0]); + } + + static BufferRegion CreateRegion(const BufferLoad& load) { + Array region; + region.reserve(load->indices.size()); + for (const PrimExpr& index : load->indices) { + region.push_back(Range::FromMinExtent(index, 1)); + } + return BufferRegion(load->buffer, std::move(region)); + } + + private: + /*! \brief The new created additional block iter of the rfactor block */ + IterVar rf_additional_iter_; + /*! \brief The lhs of the combiner in the reduction update of the old block */ + PrimExpr combiner_lhs_; + /*! \brief The lhs of the combiner of the write-back block */ + BufferLoad wb_lhs_; + /*! \brief The rhs of the combiner of the write-back block */ + BufferLoad wb_rhs_; +}; + +/*! + * \brief Create new outer loops for the rfactor block, meanwhile update the rfactor block's iter + * bindings to use the new created loop vars + * \param rf_block_realize The BlockRealize of the rfactor block + * \param loops The loops to be wrapped over the rfactor block + * \return A Stmt which is the wrapping result + */ +Stmt CreateLoopOutsideRfactorBlock(BlockRealize rf_block_realize, const Array& loops) { + int n_loops = static_cast(loops.size()); + + // Step 1. Create new loop vars. + Array new_loops; + std::unordered_map new_loop_var_map; + new_loops.reserve(n_loops); + new_loop_var_map.reserve(n_loops); + for (const For& old_loop : loops) { + Var new_loop_var = old_loop->loop_var.copy_with_suffix(""); + new_loop_var_map[old_loop->loop_var.get()] = new_loop_var; + } + + // Step 2. Update the iter bindings of the rfactor block. + Array new_bindings; + new_bindings.reserve(rf_block_realize->iter_values.size()); + for (const PrimExpr& old_binding : rf_block_realize->iter_values) { + new_bindings.push_back(Substitute(old_binding, new_loop_var_map)); + } + rf_block_realize.CopyOnWrite()->iter_values = new_bindings; + + // Step 3. Wrap `rf_block_realize` with outer loops. + Stmt rf_body = rf_block_realize; + for (int i = n_loops - 1; i >= 0; --i) { + ObjectPtr p_loop = make_object(*loops[i].get()); + p_loop->loop_var = Downcast(new_loop_var_map[loops[i]->loop_var.get()]); + p_loop->body = rf_body; + rf_body = For(std::move(p_loop)); + } + + return rf_body; +} + +class BlockReplacer : public StmtMutator { + public: + /*! + * \brief The replace takes the old scope root block as input, and does four things: + * 1) replace the reduction block with the write-back block, + * 2) remove loops outside the write-back block that are touched by reduction block iters, except + * for the rfactor loop + * 3) combine the rfactor block (wrapped with outer loops) and the transformed outermost loop + * into a SeqStmt, and + * 4) insert the rfactor buffer into the scope root block's `alloc_buffers` + * After transformation, the function returns the new scope root block + * \param scope_root_block The old scope root block + * \param rf_body The rfactor block, which is already wrapped with outer loops + * \param outermost_loop The loop that is outermost among all loops outside the reduction block + * \param wb_block_realize The new created BlockRealize of the write-back block + * \param old_block_realize The BlockRealize of the reduction block + * \param rf_loop The rfactor loop, which should be kept outside the write-back block + * \param reduce_loop_vars The loops that are touched by reduction block iters, used to remove + * loops outside the write-back block + * \param loop_vars2loop The mapping from loop vars to loops that are outside the reduction block, + * which is used to reduce redundant recursive visits + * \param rf_buffer The rfactor buffer to be added into the scope root's `alloc_buffers` + * \return The transformed new scope root block + */ + static Block Replace(Block scope_root_block, Stmt rf_body, For outermost_loop, + BlockRealize wb_block_realize, BlockRealize old_block_realize, For rf_loop, + std::unordered_set reduce_loop_vars, + std::unordered_map loop_vars2loop, + const Buffer& rf_buffer) { + BlockReplacer replacer(std::move(rf_body), std::move(outermost_loop), + std::move(wb_block_realize), std::move(old_block_realize), + std::move(rf_loop), std::move(reduce_loop_vars), + std::move(loop_vars2loop)); + Block new_scope_root = Downcast(replacer(std::move(scope_root_block))); + BlockNode* p = new_scope_root.CopyOnWrite(); + p->alloc_buffers.push_back(rf_buffer); + return new_scope_root; + } + + private: + explicit BlockReplacer(Stmt rf_body, For outermost_loop, BlockRealize wb_block_realize, + BlockRealize old_block_realize, For rf_loop, + std::unordered_set reduce_loop_vars, + std::unordered_map loop_vars2loop) + : rf_body_(std::move(rf_body)), + outermost_loop_(std::move(outermost_loop)), + wb_block_realize_(std::move(wb_block_realize)), + old_block_realize_(std::move(old_block_realize)), + rf_loop_(std::move(rf_loop)), + reduce_loop_vars_(std::move(reduce_loop_vars)), + loop_vars2loop_(std::move(loop_vars2loop)) {} + + Stmt VisitStmt_(const ForNode* loop) final { + // Step 1. Check whether this loop is outside the reduction block. Given that we've made sure + // that the scope root block has stage-pipeline property, if this loop is not outside the + // reduction block, there's no need to recursively mutate. + if (!loop_vars2loop_.count(loop->loop_var.get())) { + return GetRef(loop); + } + + // Step 2. Recursively mutate. + Stmt body = StmtMutator::VisitStmt(loop->body); + + // Step 3. If this loop is the rfactor loop and isn't touched by any reduction block iter, it + // should be kept outside the write-back block. Otherwise it shouldn't. + if (loop == rf_loop_.get() || !reduce_loop_vars_.count(loop->loop_var.get())) { + ObjectPtr p_loop = CopyOnWrite(loop); + p_loop->body = body; + body = Stmt(p_loop); + } + + // Step 4. If this loop is the outermost loop of the reduction block, return the combination of + // `rf_body_` and the mutation result `body`. Otherwise return the mutation result. + return loop == outermost_loop_.get() ? SeqStmt({rf_body_, body}) : body; + } + + Stmt VisitStmt_(const BlockRealizeNode* block_realize) final { + // Due to the visitor's behavior on ForNode, this block-realize must be the reduction block's + // block-realize. And we directly return the new `wb_block_realize`. + ICHECK_EQ(block_realize, old_block_realize_.get()); + return wb_block_realize_; + } + + Stmt VisitStmt_(const SeqStmtNode* seq) final { + Array new_stmts; + new_stmts.reserve(static_cast(seq->seq.size())); + + for (const Stmt old_stmt : seq->seq) { + new_stmts.push_back(VisitStmt(old_stmt)); + } + return SeqStmt::Flatten(new_stmts); + } + + private: + Stmt rf_body_; + For outermost_loop_; + BlockRealize wb_block_realize_; + BlockRealize old_block_realize_; + For rf_loop_; + std::unordered_set reduce_loop_vars_; + std::unordered_map loop_vars2loop_; +}; + +StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_axis) { + // ***************************************************** + // * Condition Checks and Information Collection * + // ***************************************************** + + // Step 1. Check some basic conditions for rfactor. Get the block and block-realize. + BlockRealize block_realize = CheckGetSingleChildBlockRealizeOnSRefTree(self, rf_loop_sref); + const StmtSRef& block_sref = self->stmt2ref.at(block_realize->block.get()); + const Block& block = block_realize->block; + StmtSRef scope_root = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/true); + CheckReductionBlock(self, block_sref, scope_root); + const ForNode* rf_loop = TVM_SREF_TO_FOR(rf_loop, rf_loop_sref); + if (rf_loop->kind != ForKind::kSerial) { + throw NotSerialLoopKindError(self->mod, GetRef(rf_loop)); + } + + // Step 2. Collect loop vars that are touched by data parallel block iters and reduction block + // iters, respectively. + std::unordered_set data_par_loop_vars; + std::unordered_set reduce_loop_vars; + GetVarsTouchedByBlockIters(block_realize, &data_par_loop_vars, &reduce_loop_vars); + + // Step 3. Collect the loops of the reduction block. Construct a mapping from loops to + // corresponding loop vars. + Array loops = LoopSRefs2Loops(GetLoops(block_sref)); + std::unordered_map loop_vars2loop = GetLoopVar2LoopMap(loops); + + // Step 4. Check four properties that the loops should have: + // - the rfactor loop cannot be touched by any data parallel block iter; + // - all the loops cannot be touched by both data parallel block iters and reduction block iters; + // - the outermost loop should have the reduction block as its first child block; + // - the outermost loop that is touched by some reduction block iters can only have one child + // block. + LoopPropertyError::CheckLoopProperty(self, loops, rf_loop, block, data_par_loop_vars, + reduce_loop_vars); + + // Step 5. Get the `init` identity and the `update` combiner of the reduction. Extract the + // commutative reducer, combiner lhs and combiner rhs from the reduction identity and the + // reduction combiner. The lhs will be used when constructing the write-back block, and the rhs + // will be used when constructing the rfactor block. + BufferStore init; + BufferStore update; + CommReducer reducer; + PrimExpr combiner_lhs, combiner_rhs; + std::tie(init, update) = GetBufferStoreNodes(self, block); + std::tie(reducer, combiner_lhs, combiner_rhs) = + GetReducerAndCombinerLhsRhs(self, init->value, update); + + // Step 6. Check whether `factor_axis` is in a correct range, and convert it to non-negative if it + // is negative. + factor_axis = FactorAxisOutOfRangeError::CheckAndUpdate(self->mod, update->buffer, factor_axis); + + // ***************************************************** + // * IR Manipulation * + // ***************************************************** + // Since rfactor splits the reduction block into two, we call the first one "rfactor block", and + // the latter one "write-back block", and the intermediate buffer is called "rfactor buffer". + + // Step 1. Create the intermediate buffer (a.k.a. rfactor buffer), which has an additional + // dimension that specified by `factor_axis` and `rf_loop`. + Buffer rf_buffer = CreateRFactorBuffer(update->buffer, factor_axis, rf_loop); + + // Step 2. Create the rfactor block. + RFactorBlockCreator rf_block_creator(block_realize, GetRef(rf_loop), update, reducer, + rf_buffer, loop_vars2loop, factor_axis, + std::move(combiner_rhs)); + rf_block_creator.CreateBlock(); + + // Step 3. Create the write-back block. + WriteBackBlockCreator wb_block_creator(block_realize, GetRef(rf_loop), update, reducer, + rf_buffer, std::move(rf_block_creator.additional_iter_), + std::move(combiner_lhs), + std::move(rf_block_creator.rf_buf_access_indices_)); + wb_block_creator.CreateBlock(); + + // Step 4. Wrap the rfactor block with loops. + Stmt rf_body = CreateLoopOutsideRfactorBlock(rf_block_creator.new_block_realize_, loops); + + // ***************************************************** + // * Schedule Replacement & Update * + // ***************************************************** + + // Step 1. Substitute the old scope root block with the new scope root block. + Block old_scope_root_block = GetRef(scope_root->StmtAs()); + Block new_scope_root_block = BlockReplacer::Replace( + old_scope_root_block, rf_body, loops[0], wb_block_creator.new_block_realize_, block_realize, + GetRef(rf_loop), reduce_loop_vars, loop_vars2loop, rf_buffer); + self->Replace(scope_root, new_scope_root_block, {{old_scope_root_block, new_scope_root_block}}); + + // Step 2. Update scope information. + std::vector new_block_srefs{self->stmt2ref.at(rf_block_creator.new_block_.get()), + self->stmt2ref.at(wb_block_creator.new_block_.get())}; + for (const StmtSRef& new_block_sref : new_block_srefs) { + BlockInfo& info = self->block_info[new_block_sref]; + info.affine_binding = true; + info.region_cover = true; + info.scope->stage_pipeline = true; + } + return new_block_srefs[0]; +} + +/******** FFI ********/ + +TVM_REGISTER_GLOBAL("tir.schedule.RegisterReducer") + .set_body_typed([](PackedFunc combiner_getter, PackedFunc identity_getter) { + ReducerRegistry::RegisterReducer(std::move(combiner_getter), std::move(identity_getter)); + }); + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 77d17c9dc6..eae04bc76d 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -135,5 +135,10 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReverseComputeInline") /******** (FFI) reduction ********/ /******** (FFI) blockize & tensorize ********/ +/******** (FFI) reduction ********/ + +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRFactor") + .set_body_method(&ScheduleNode::RFactor); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index ca61dfea27..41bc644155 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -161,34 +161,6 @@ void UpdateSRef(ScheduleStateNode* self, StmtSRefNode* sref, const StmtNode* new sref->stmt = new_stmt; } -/*! - * \brief Get PrimFunc and GlobalVar that the root block belongs to - * \param mod The IRModule - * \param root_block The root block of the PrimFunc - * \param result_g_var The result GlobalVar - * \return The result PrimFunc where the root block belongs to - * \note This function returns the pointer instead of ObjectRef to avoid later copy-on-write - */ -const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block, - GlobalVar* result_g_var) { - for (const auto& kv : mod->functions) { - const GlobalVar& g_var = kv.first; - const BaseFunc& base_func = kv.second; - if (const auto* func = base_func.as()) { - if (const auto* realize = func->body.as()) { - if (realize->block.get() == root_block) { - *result_g_var = g_var; - return func; - } - } - } - } - LOG(FATAL) << "IndexError: Could not get the correpsonding function in the schedule state of the " - "statement:\n" - << GetRef(root_block); - throw; -} - /**************** Creation ****************/ /*! \brief A helper class to create a new ScheduleStateNode from an IRModule */ @@ -737,7 +709,8 @@ class SRefUpdater : public StmtVisitor { void UpdateBlockInfo(const StmtSRef& block_sref) { using TIter = std::unordered_map::iterator; // The caller is responsible for correcting the flags - BlockInfo new_info(BlockScope(GetChildBlocks(self_, block_sref))); + Array child_block_srefs = GetChildBlockSRefOnSRefTree(self_, block_sref); + BlockInfo new_info((BlockScope(child_block_srefs))); std::pair insert_result = self_->block_info.emplace(block_sref, new_info); bool inserted = insert_result.second; BlockInfo& info = insert_result.first->second; @@ -1045,7 +1018,7 @@ void ScheduleStateNode::DebugVerify() const { /**************** BlockInfo-related ****************/ BlockInfo ScheduleStateNode::GetBlockInfo(const StmtSRef& block_sref) const { - const auto* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); auto it = this->block_info.find(block_sref); CHECK(it != this->block_info.end()) << "IndexError: Cannot find the corresponding BlockScope to the block sref:\n" diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index 19ed995ac8..d31cea5781 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -98,6 +98,21 @@ namespace tir { << "TypeError: Expects `" << #From << "` to have type `" << Type::_type_key \ << "`, but gets: " << (From.defined() ? From->GetTypeKey() : "None") +/*! + * \brief Convert an array of loop StmtSRefs to an array of loops + * \param loop_srefs The loop StmtSRefs to be converted + * \return The conversion result loops + */ +inline Array LoopSRefs2Loops(const Array& loop_srefs) { + Array loops; + loops.reserve(loop_srefs.size()); + for (StmtSRef loop_sref : loop_srefs) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + loops.push_back(GetRef(loop)); + } + return loops; +} + /******** Storage scope ********/ /*! @@ -143,6 +158,18 @@ inline Stmt RemoveFromSeqStmt(const SeqStmt& seq, const Stmt& to_remove) { return SeqStmt::Flatten(new_stmts); } +/*! + * \brief Create a new IterVar for the input For loop, with specified name and type + * \param loop The loop to be created from + * \param name The name of the new IterVar + * \param iter_var_type The type of the new IterVar + * \return The newly created IterVar + */ +inline IterVar IterVarFromLoop(const For& loop, String name, IterVarType iter_var_type) { + return IterVar(Range::FromMinExtent(loop->min, loop->extent), + Var(std::move(name), loop->loop_var.dtype()), iter_var_type); +} + /******** Integer set ********/ /*! diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc index f1d816f0ba..6ccc2b18ff 100644 --- a/src/tir/transforms/loop_partition.cc +++ b/src/tir/transforms/loop_partition.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -84,19 +85,6 @@ using Partition = std::unordered_map; -bool ExprUseVars(PrimExpr expr, const std::unordered_set& vars) { - bool success = false; - PostOrderVisit(expr, [&vars, &success](const ObjectRef& node) { - if (const VarNode* v = node.as()) { - if (vars.count(v)) { - success = true; - return; - } - } - }); - return success; -} - // Select potential candidate IRs that can be partitioned. // Rule: // - the range should not be const @@ -200,7 +188,8 @@ class PartitionFinder : public StmtExprVisitor { } void VisitStmt_(const ForNode* op) final { - if (ExprUseVars(op->min, out_vars_) || ExprUseVars(op->extent, out_vars_)) return; + auto f_vset_contains = [this](const VarNode* var) { return out_vars_.count(var); }; + if (UsesVar(op->min, f_vset_contains) || UsesVar(op->extent, f_vset_contains)) return; const VarNode* var = op->loop_var.get(); hint_map_.insert({var, IntSet::Interval(op->min, op->min + op->extent - 1)}); @@ -230,7 +219,7 @@ class PartitionFinder : public StmtExprVisitor { void VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::likely())) { PrimExpr cond = op->args[0]; - if (ExprUseVars(cond, std::unordered_set({current_var_.get()}))) { + if (UsesVar(cond, [this](const VarNode* var) {return var == current_var_.get(); })) { // For cond, find out the interval, if exists, in which we can prove that cond is // true. Also find the interval, if exists, in which we can prove that cond is // false. diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index b95681a936..31a6b0f541 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -250,7 +250,7 @@ class WarpAccessRewriter : protected StmtExprMutator { PrimExpr local_index, group; std::tie(local_index, group) = SplitIndexByGroup(op->index); // invariance: local index must do not contain warp id - ICHECK(!ExprUseVar(local_index, warp_index_)) + ICHECK(!UsesVar(local_index, [this](const VarNode* var) { return var == warp_index_.get(); })) << "LowerWarpMemory failed to rewrite load to shuffle for index " << op->index << " local_index=" << local_index; PrimExpr load_value = Load(op->dtype, op->buffer_var, local_index, op->predicate); diff --git a/tests/python/unittest/test_te_schedule.py b/tests/python/unittest/test_te_schedule.py index 60a324727f..8b504df120 100644 --- a/tests/python/unittest/test_te_schedule.py +++ b/tests/python/unittest/test_te_schedule.py @@ -218,7 +218,7 @@ def test_rfactor(): assert set(BF.op.body[0].axis) == set([k2]) assert s[B].op.body[0].axis[0].dom.extent == n assert len(s[B].all_iter_vars) == 2 - # schedule with splot + # schedule with split s = te.create_schedule(B.op) ko, ki = s[B].split(k1, factor=4) xo, xi = s[B].split(B.op.axis[0], factor=8) diff --git a/tests/python/unittest/test_tir_schedule_reduction.py b/tests/python/unittest/test_tir_schedule_reduction.py new file mode 100644 index 0000000000..ebf03d102b --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_reduction.py @@ -0,0 +1,661 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import pytest + +import numpy as np +import tvm +import tvm.testing +from tvm import tir +from tvm.script import ty + +# pylint: disable=no-member,invalid-name,unused-variable + + +@tvm.script.tir +def transformed_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + + for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in tir.grid(128, 128, 4, 8, 4): + with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + tir.bind(vi, i0) + tir.bind(vj, i1) + tir.bind(vk, (((i2_outer*32) + (i2_inner_outer*4)) + i2_inner_inner)) + tir.reads([C[vi, vj], A[vi, vk], B[vj, vk]]) + tir.writes([C[vi, vj]]) + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = (C[vi, vj] + (A[vi, vk]*B[vj, vk])) + + +@tvm.script.tir +def matmul_rfactor(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + C_rf = tir.alloc_buffer([4, 128, 128]) + + for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in tir.grid(128, 128, 4, 8, 4): + with tir.block([4, 128, 128, tir.reduce_axis(0, 4), tir.reduce_axis(0, 8)], "update_rf") as [vi2_inner_inner, vi, vj, vi2_outer, vi2_inner_outer]: + tir.bind(vi2_inner_inner, i2_inner_inner) + tir.bind(vi, i0) + tir.bind(vj, i1) + tir.bind(vi2_outer, i2_outer) + tir.bind(vi2_inner_outer, i2_inner_outer) + with tir.init(): + C_rf[vi2_inner_inner, vi, vj] = 0.0 + C_rf[vi2_inner_inner, vi, vj] = (C_rf[vi2_inner_inner, vi, vj] + (A[vi, (((vi2_outer*32) + (vi2_inner_outer*4)) + vi2_inner_inner)]*B[vj, (((vi2_outer*32) + (vi2_inner_outer*4)) + vi2_inner_inner)])) + + for i0_1, i1_1, i2_inner_inner_1 in tir.grid(128, 128, 4): + with tir.block([tir.reduce_axis(0, 4), 128, 128], "update") as [vi2_inner_inner_1, vi_1, vj_1]: + tir.bind(vi2_inner_inner_1, i2_inner_inner_1) + tir.bind(vi_1, i0_1) + tir.bind(vj_1, i1_1) + with tir.init(): + C[vi_1, vj_1] = 0.0 + C[vi_1, vj_1] = (C[vi_1, vj_1] + C_rf[vi2_inner_inner_1, vi_1, vj_1]) + + +@tvm.script.tir +def matmul_not_stage_pipeline(a: ty.handle, b: ty.handle, d: ty.handle) -> None: + A = tir.match_buffer(a, [256, 256]) + B = tir.match_buffer(b, [256, 256]) + D = tir.match_buffer(d, [256, 256]) + C = tir.alloc_buffer([256, 256]) + + with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + with tir.block([256, 256], "D") as [vi, vj]: + D[vi, vj] = C[vi, vj] + + +@tvm.script.tir +def matmul_not_same_buffer_access(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + C = tir.match_buffer(c, (128, 128)) + + with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vj, vi] = C[vj, vi] + A[vi, vk] * B[vk, vj] + + +@tvm.script.tir +def matmul_loop_multiple_children(a: ty.handle, b:ty.handle, c: ty.handle, d: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + D = tir.match_buffer(d, [128, 128]) + + for k, i, j in tir.grid(128, 128, 128): + with tir.block([tir.reduce_axis(0, 128), 128, 128], "C") as [ck, ci, cj]: + tir.bind(ck, k) + tir.bind(ci, i) + tir.bind(cj, j) + with tir.init(): + C[ci, cj] = 0.0 + C[ci, cj] = C[ci, cj] + A[ci, ck] * B[ck, cj] + with tir.block([tir.reduce_axis(0, 128), 128, 128], "D") as [dk, di, dj]: + tir.bind(dk, k) + tir.bind(di, i) + tir.bind(dj, j) + with tir.init(): + D[di, dj] = 0.0 + D[di, dj] = D[di, dj] + B[di, dk] * A[dk, dj] + + +@tvm.script.tir +def square_sum(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [16, 256, 256]) + C = tir.match_buffer(c, [16]) + + with tir.block([16, tir.reduce_axis(0, 256), tir.reduce_axis(0, 256)], "C") as [b, i, j]: + with tir.init(): + C[b] = 0.0 + C[b] = C[b] + A[b, i, j] * A[b, i, j] + + +@tvm.script.tir +def square_sum_rfactor(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [16, 256, 256]) + C = tir.match_buffer(c, [16]) + C_rf = tir.alloc_buffer([16, 256]) + + for i0, i1, i2 in tir.grid(16, 256, 256): + with tir.block([256, 16, tir.reduce_axis(0, 256)], "C_rf") as [vi2, b, i]: + tir.bind(vi2, i2) + tir.bind(b, i0) + tir.bind(i, i1) + with tir.init(): + C_rf[b, vi2] = 0.0 + C_rf[b, vi2] = (C_rf[b, vi2] + (A[b, i, vi2]*A[b, i, vi2])) + + for i0_1, i2_1 in tir.grid(16, 256): + with tir.block([tir.reduce_axis(0, 256), 16], "C") as [vi2_1, b_1]: + tir.bind(vi2_1, i2_1) + tir.bind(b_1, i0_1) + with tir.init(): + C[b_1] = 0.0 + C[b_1] = (C[b_1] + C_rf[b_1, vi2_1]) + + +@tvm.script.tir +def transformed_square_sum_square_root(a: ty.handle, d: ty.handle) -> None: + A = tir.match_buffer(a, [16, 256, 256]) + D = tir.match_buffer(d, [16]) + C = tir.alloc_buffer([16]) + + for i0, i1_i2_fused_outer, i1_i2_fused_inner in tir.grid(16, 65536, 1): + with tir.block([16, tir.reduce_axis(0, 256), tir.reduce_axis(0, 256)], "C") as [b, i, j]: + tir.bind(b, i0) + tir.bind(i, tir.floordiv(i1_i2_fused_outer, 256)) + tir.bind(j, tir.floormod(i1_i2_fused_outer, 256)) + tir.reads([C[b], A[b, i, j]]) + tir.writes([C[b]]) + with tir.init(): + C[b] = 0.0 + C[b] = (C[b] + (A[b, i, j]*A[b, i, j])) + for i0_1 in tir.serial(0, 16): + with tir.block([16], "D") as [b_1]: + tir.bind(b_1, i0_1) + tir.reads([C[b_1]]) + tir.writes([D[b_1]]) + D[b_1] = tir.sqrt(C[b_1], dtype="float32") + + +@tvm.script.tir +def square_sum_square_root_rfactor(a: ty.handle, d: ty.handle) -> None: + A = tir.match_buffer(a, [16, 256, 256]) + D = tir.match_buffer(d, [16]) + C = tir.alloc_buffer([16]) + C_rf = tir.alloc_buffer([1, 16]) + + for i0, i1_i2_fused_outer, i1_i2_fused_inner in tir.grid(16, 65536, 1): + with tir.block([1, 16, tir.reduce_axis(0, 256), tir.reduce_axis(0, 256)], "C_rf") as [vi1_i2_fused_inner, b, i, j]: + tir.bind(vi1_i2_fused_inner, i1_i2_fused_inner) + tir.bind(b, i0) + tir.bind(i, tir.floordiv(i1_i2_fused_outer, 256)) + tir.bind(j, tir.floormod(i1_i2_fused_outer, 256)) + with tir.init(): + C_rf[vi1_i2_fused_inner, b] = 0.0 + C_rf[vi1_i2_fused_inner, b] = (C_rf[vi1_i2_fused_inner, b] + (A[b, i, j]*A[b, i, j])) + + for i0_1, i1_i2_fused_inner_1 in tir.grid(16, 1): + with tir.block([tir.reduce_axis(0, 1), 16], "C") as [vi1_i2_fused_inner_1, b_1]: + tir.bind(vi1_i2_fused_inner_1, i1_i2_fused_inner_1) + tir.bind(b_1, i0_1) + with tir.init(): + C[b_1] = 0.0 + C[b_1] = (C[b_1] + C_rf[vi1_i2_fused_inner_1, b_1]) + + for i0_2 in tir.serial(0, 16): + with tir.block([16], "D") as [b_2]: + tir.bind(b_2, i0_2) + D[b_2] = tir.sqrt(C[b_2], dtype="float32") + + +@tvm.script.tir +def element_wise(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + +@tvm.script.tir +def rowsum(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128,)) + + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: + with tir.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] + + +@tvm.script.tir +def rowsum_not_quasi_affine(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128,)) + + for i, k in tir.grid(128, 16): + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: + tir.bind(vi, i) + tir.bind(vk, tir.floordiv(k * k, 2)) + with tir.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] + + +@tvm.script.tir +def rowsum_not_dominant(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: + with tir.init(): + B[vi, vk] = 0.0 + B[vi, vk] = B[vi, vk] + A[vi, vk] + + +@tvm.script.tir +def rowsum_not_serial(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128,)) + + for i in tir.serial(0, 128): + for k in tir.parallel(0, 128): + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: + tir.bind(vi, i) + tir.bind(vk, k) + with tir.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] + + +@tvm.script.tir +def rowsum_wrong_reduce_pattern1(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128,)) + + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: + with tir.init(): + B[vi] = 1.0 + B[vi] = B[vi] + A[vi, vk] + + +@tvm.script.tir +def rowsum_wrong_reduce_pattern2(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128,)) + + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: + with tir.init(): + B[vi] = 0.0 + B[vi] = B[vi] - A[vi, vk] + + +@tvm.script.tir +def rowsum_transformed(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128,)) + + for io, ii_ko_fused, ki in tir.grid(32, 128, 4): + with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vi, vk]: + tir.bind(vi, io * 4 + tir.floordiv(ii_ko_fused, 32)) + tir.bind(vk, tir.floormod(ii_ko_fused, 32) * 4 + ki) + with tir.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] + + +@tvm.script.tir +def rowsum_zero_dim(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [128]) + B = tir.match_buffer(b, []) + + with tir.block([tir.reduce_axis(0, 128)], "B") as [k]: + with tir.init(): + B[()] = 0.0 + B[()] = B[()] + A[k] + + +@tvm.script.tir +def rowsum_zero_dim_rfactor(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [128]) + B = tir.match_buffer(b, []) + B_rf = tir.alloc_buffer([128]) + + with tir.block([128], "B_rf") as [vi0]: + with tir.init(): + B_rf[vi0] = 0.0 + B_rf[vi0] = (B_rf[vi0] + A[vi0]) + + with tir.block([tir.reduce_axis(0, 128)], "B") as [vi0_1]: + with tir.init(): + B[()] = 0.0 + B[()] = (B[()] + B_rf[vi0_1]) + + +@tvm.script.tir +def multiple_reduction_blocks(a: ty.handle, f: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16, 16)) + C = tir.alloc_buffer((16, 16)) + D = tir.alloc_buffer((16, 16)) + E = tir.alloc_buffer((16, 16)) + F = tir.match_buffer(f, (16, 16)) + + for i in tir.serial(0, 16): + for j1 in tir.serial(0, 16): + for k1o, k1i in tir.grid(4, 4): + with tir.block([16, 16, tir.reduce_axis(0, 16)], "C") as [ci, cj, ck]: + tir.bind(ci, i) + tir.bind(cj, j1) + tir.bind(ck, k1o * 4 + k1i) + with tir.init(): + C[ci, cj] = 0.0 + C[ci, cj] = C[ci, cj] + A[ci, cj, ck] + for k2o, k2i in tir.grid(4, 4): + with tir.block([16, 16, tir.reduce_axis(0, 16)], "D") as [di, dj, dk]: + tir.bind(di, i) + tir.bind(dj, j1) + tir.bind(dk, k2o * 4 + k2i) + with tir.init(): + D[di, dj] = 0.0 + D[di, dj] = D[di, dj] + A[di, dj, dk] + C[di, dj] + for j2 in tir.serial(0, 16): + for k3o, k3i in tir.grid(4, 4): + with tir.block([16, 16, tir.reduce_axis(0, 16)], "E") as [ei, ej, ek]: + tir.bind(ei, i) + tir.bind(ej, j2) + tir.bind(ek, k3o * 4 + k3i) + with tir.init(): + E[ei, ej] = 0.0 + E[ei, ej] = E[ei, ej] + A[ei, ej, ek] + D[ei, ej] + for k4o, k4i in tir.grid(4, 4): + with tir.block([16, 16, tir.reduce_axis(0, 16)], "F") as [fi, fj, fk]: + tir.bind(fi, i) + tir.bind(fj, j2) + tir.bind(fk, k4o * 4 + k4i) + with tir.init(): + F[fi, fj] = 0.0 + F[fi, fj] = F[fi, fj] + A[fi, fj, fk] + E[fi, fj] + + +@tvm.script.tir +def multiple_reduction_blocks_rfactor(a: ty.handle, f: ty.handle) -> None: + A = tir.match_buffer(a, [16, 16, 16]) + C = tir.alloc_buffer([16, 16]) + D = tir.alloc_buffer([16, 16]) + E = tir.alloc_buffer([16, 16]) + F = tir.match_buffer(f, [16, 16]) + C_rf = tir.alloc_buffer([16, 16, 4]) + + for i, j1, k1o, k1i in tir.grid(16, 16, 4, 4): + with tir.block([4, 16, 16, tir.reduce_axis(0, 4)], "C_rf") as [vk1o, ci, cj, vk1i]: + tir.bind(vk1o, k1o) + tir.bind(ci, i) + tir.bind(cj, j1) + tir.bind(vk1i, k1i) + with tir.init(): + C_rf[ci, cj, vk1o] = 0.0 + C_rf[ci, cj, vk1o] = (C_rf[ci, cj, vk1o] + A[ci, cj, ((vk1o*4) + vk1i)]) + for i_1 in tir.serial(0, 16): + for j1_1 in tir.serial(0, 16): + for k1o_1 in tir.serial(0, 4): + with tir.block([tir.reduce_axis(0, 4), 16, 16], "C") as [vk1o_1, ci_1, cj_1]: + tir.bind(vk1o_1, k1o_1) + tir.bind(ci_1, i_1) + tir.bind(cj_1, j1_1) + with tir.init(): + C[ci_1, cj_1] = 0.0 + C[ci_1, cj_1] = (C[ci_1, cj_1] + C_rf[ci_1, cj_1, vk1o_1]) + for k2o, k2i in tir.grid(4, 4): + with tir.block([16, 16, tir.reduce_axis(0, 16)], "D") as [di, dj, dk]: + tir.bind(di, i_1) + tir.bind(dj, j1_1) + tir.bind(dk, ((k2o*4) + k2i)) + with tir.init(): + D[di, dj] = 0.0 + D[di, dj] = ((D[di, dj] + A[di, dj, dk]) + C[di, dj]) + for j2 in tir.serial(0, 16): + for k3o, k3i in tir.grid(4, 4): + with tir.block([16, 16, tir.reduce_axis(0, 16)], "E") as [ei, ej, ek]: + tir.bind(ei, i_1) + tir.bind(ej, j2) + tir.bind(ek, ((k3o*4) + k3i)) + with tir.init(): + E[ei, ej] = 0.0 + E[ei, ej] = ((E[ei, ej] + A[ei, ej, ek]) + D[ei, ej]) + for k4o, k4i in tir.grid(4, 4): + with tir.block([16, 16, tir.reduce_axis(0, 16)], "F") as [fi, fj, fk]: + tir.bind(fi, i_1) + tir.bind(fj, j2) + tir.bind(fk, ((k4o*4) + k4i)) + with tir.init(): + F[fi, fj] = 0.0 + F[fi, fj] = ((F[fi, fj] + A[fi, fj, fk]) + E[fi, fj]) + + +# pylint: enable=no-member,invalid-name,unused-variable + + +def test_reduction_rfactor_matmul(): + s = tir.Schedule(transformed_matmul, debug_mode=True) + C = s.get_block("update") + _, _, _, _, kii = s.get_loops(C) + rf_block = s.rfactor(kii, 0) + tvm.ir.assert_structural_equal(s.mod["main"], matmul_rfactor) + assert s.get(rf_block).same_as(s.get(s.get_block("update_rf"))) + + func = tvm.build(s.mod["main"], target="llvm") + a_np = np.random.uniform(size=(128, 128)).astype("float32") + b_np = np.random.uniform(size=(128, 128)).astype("float32") + a = tvm.nd.array(a_np) + b = tvm.nd.array(b_np) + c = tvm.nd.array(np.zeros((128, 128), dtype="float32")) + func(a, b, c) + c_np = np.matmul(a_np, b_np.T) + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-4, atol=1e-4) + + +def test_reduction_rfactor_square_sum(): + s = tir.Schedule(square_sum, debug_mode=True) + C = s.get_block("C") + _, _, j = s.get_loops(C) + rf_block = s.rfactor(j, 1) + tvm.ir.assert_structural_equal(s.mod["main"], square_sum_rfactor) + assert s.get(rf_block).same_as(s.get(s.get_block("C_rf"))) + + func = tvm.build(s.mod["main"], target="llvm") + a_np = np.random.uniform(size=(16, 256, 256)).astype("float32") + a = tvm.nd.array(a_np) + c = tvm.nd.array(np.zeros((16,), dtype="float32")) + func(a, c) + c_np = np.sum(a_np * a_np, axis=(1, 2)) + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-4, atol=1e-4) + + +def test_reduction_rfactor_square_sum_square_root(): + s = tir.Schedule(transformed_square_sum_square_root, debug_mode=True) + C = s.get_block("C") + _, _, fi = s.get_loops(C) + rf_block = s.rfactor(fi, 0) + tvm.ir.assert_structural_equal(s.mod["main"], square_sum_square_root_rfactor) + assert s.get(rf_block).same_as(s.get(s.get_block("C_rf"))) + + func = tvm.build(s.mod["main"], target="llvm") + a_np = np.random.uniform(size=(16, 256, 256)).astype("float32") + a = tvm.nd.array(a_np) + d = tvm.nd.array(np.zeros((16,), dtype="float32")) + func(a, d) + d_np = np.sqrt(np.sum(a_np * a_np, axis=(1, 2))) + tvm.testing.assert_allclose(d.numpy(), d_np, rtol=1e-4, atol=1e-4) + + +def test_reduction_rfactor_loop_multiple_children(): + s = tir.Schedule(matmul_loop_multiple_children, debug_mode=True) + C = s.get_block("C") + k, _, _ = s.get_loops(C) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k, 0) + + +def test_reduction_rfactor_not_stage_pipeline(): + s = tir.Schedule(matmul_not_stage_pipeline, debug_mode=True) + C = s.get_block("C") + _, _, k = s.get_loops(C) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k, 0) + + +def test_reduction_rfactor_not_reduction_block1(): + s = tir.Schedule(element_wise, debug_mode=True) + B = s.get_block("B") + i, _ = s.get_loops(B) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(i, 0) + + +def test_reduction_rfactor_not_reduction_block2(): + s = tir.Schedule(rowsum_not_quasi_affine, debug_mode=True) + B = s.get_block("B") + _, k = s.get_loops(B) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k, 0) + + +def test_reduction_rfactor_not_reduction_block3(): + s = tir.Schedule(rowsum_not_dominant, debug_mode=True) + B = s.get_block("B") + _, k = s.get_loops(B) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k, 0) + + +def test_reduction_rfactor_not_serial_loop(): + s = tir.Schedule(rowsum_not_serial, debug_mode=True) + B = s.get_block("B") + _, k = s.get_loops(B) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k, 0) + + +def test_reduction_rfactor_not_same_buffer_access(): + s = tir.Schedule(matmul_not_same_buffer_access, debug_mode=True) + C = s.get_block("C") + _, _, k = s.get_loops(C) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k, 0) + + +def test_reduction_rfactor_factor_axis_range(): + s = tir.Schedule(transformed_matmul, debug_mode=True) + C = s.get_block("update") + _, _, _, _, kii = s.get_loops(C) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(kii, 3) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(kii, -4) + + rf_block = s.rfactor(kii, -3) + tvm.ir.assert_structural_equal(s.mod["main"], matmul_rfactor) + assert s.get(rf_block).same_as(s.get(s.get_block("update_rf"))) + + func = tvm.build(s.mod["main"], target="llvm") + a_np = np.random.uniform(size=(128, 128)).astype("float32") + b_np = np.random.uniform(size=(128, 128)).astype("float32") + a = tvm.nd.array(a_np) + b = tvm.nd.array(b_np) + c = tvm.nd.array(np.zeros((128, 128), dtype="float32")) + func(a, b, c) + c_np = np.matmul(a_np, b_np.T) + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-4, atol=1e-4) + + +def test_reduction_rfactor_wrong_reduce_pattern1(): + s = tir.Schedule(rowsum_wrong_reduce_pattern1, debug_mode=True) + B = s.get_block("B") + _, k = s.get_loops(B) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k, 0) + + +def test_reduction_rfactor_wrong_reduce_pattern2(): + s = tir.Schedule(rowsum_wrong_reduce_pattern2, debug_mode=True) + B = s.get_block("B") + _, k = s.get_loops(B) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k, 0) + + +def test_reduction_rfactor_wrong_loops1(): + s = tir.Schedule(rowsum, debug_mode=True) + B = s.get_block("B") + i, _ = s.get_loops(B) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(i, 0) + + +def test_reduction_rfactor_wrong_loops2(): + s = tir.Schedule(rowsum_transformed, debug_mode=True) + B = s.get_block("B") + _, _, ki = s.get_loops(B) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(ki, 0) + + +def test_reduction_rfactor_zero_dim(): + s = tir.Schedule(rowsum_zero_dim, debug_mode=True) + B = s.get_block("B") + k, = s.get_loops(B) + s.rfactor(k, 0) + tvm.ir.assert_structural_equal(s.mod["main"], rowsum_zero_dim_rfactor) + + func = tvm.build(s.mod["main"], target="llvm") + a_np = np.random.uniform(size=(128, )).astype("float32") + a = tvm.nd.array(a_np) + b = tvm.nd.array(np.array(1, dtype="float32")) + func(a, b) + b_np = np.array(np.sum(a_np)) + tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-4, atol=1e-4) + + +def test_reduction_rfactor_outermost_loop_multiple_children(): + s = tir.Schedule(multiple_reduction_blocks, debug_mode=True) + D = s.get_block("D") + E = s.get_block("E") + F = s.get_block("F") + _, _, k2o, k2i = s.get_loops(D) + _, _, k3o, k3i = s.get_loops(E) + _, _, k4o, k4i = s.get_loops(F) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k2o, 0) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k2i, 0) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k3o, 0) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k3i, 0) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k4o, 0) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k4i, 0) + + C = s.get_block("C") + i, j1, k1o, k1i = s.get_loops(C) + s.rfactor(k1o, 2) + tvm.ir.assert_structural_equal(s.mod["main"], multiple_reduction_blocks_rfactor) + + func = tvm.build(s.mod["main"], target="llvm") + a_np = np.random.uniform(size=(16, 16, 16)).astype("float32") + a = tvm.nd.array(a_np) + f = tvm.nd.array(np.zeros((16, 16), dtype="float32")) + func(a, f) + f_np = np.sum(a_np, axis=2) * 4369 + tvm.testing.assert_allclose(f.numpy(), f_np, rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + pytest.main([__file__])