diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index d671339fb6..6c72cbeafd 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -282,6 +282,18 @@ class IterSumExpr : public IterMapExpr { Array DetectIterMap(const Array& indices, const Map& input_iters, const PrimExpr& predicate, bool require_bijective, arith::Analyzer* analyzer); +/*! + * \brief Use IterVarMap detector to rewrite and simplify the indices + * + * \param indices The indices to detect pattern for. + * \param input_iters Map from variable to iterator's range. + * \param input_pred The predicate constraints on the input iterators + * \param require_bijective A boolean flag that indicates whether the mapping should be bijective. + * + * \return The indices after rewrite + */ +Array IterMapSimplify(const Array& indices, const Map& input_iters, + const PrimExpr& input_pred, bool require_bijective); /*! * \brief Apply the inverse of the affine transformation to the outputs. diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 9a09d0ad21..38a15a8143 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -151,6 +151,18 @@ class ScheduleNode : public runtime::Object { * \return The corresponding loop sref */ virtual StmtSRef GetSRef(const LoopRV& loop_rv) const = 0; + /*! + * \brief Get the block srefs corresponding to an array of BlockRVs + * \param block_rvs The BlockRVs to be looked up + * \return The corresponding block srefs + */ + virtual Array GetSRefs(const Array& block_rvs) const = 0; + /*! + * \brief Get the loop srefs corresponding to an array of LoopRVs + * \param loop_rvs The LoopRVs to be looked up + * \return The corresponding loop srefs + */ + virtual Array GetSRefs(const Array& loop_rvs) const = 0; /*! * \brief Get the block/loop sref corresponding to the specific statement * \param stmt The statement to be looked up @@ -196,6 +208,25 @@ class ScheduleNode : public runtime::Object { */ virtual Array GetLoops(const BlockRV& block_rv) = 0; /******** Schedule: loops manipulation ********/ + /*! + * \brief Fuse consecutive loops into one. It requires: + * 1) The loops can't have annotations or thread bindings. + * 2) The (i+1)-th loop must be the only child of the i-th loop. + * 3) All loops must start with 0. + * \param loop_rvs The loops to be fused + * \return The fused loop + */ + virtual LoopRV Fuse(const Array& loop_rvs) = 0; + /*! + * \brief Split a specified loop into two or more with the specific factor.It requires: + * 1) The loop can't have annotation or thread binding. + * 2) The loop must start with 0. + * \param loop_rv The loop to be split + * \param factors The tiling factors, and at most one of which is -1, which means that + * factor is inferred. + * \return The loops after splitting + */ + virtual Array Split(const LoopRV& loop_rv, const Array& factors) = 0; /******** Schedule: compute location ********/ /*! * \brief Inline a block into its consumer(s). It requires: diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py old mode 100644 new mode 100755 index 2091f4d80a..67350bd109 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=unused-import """The TensorIR schedule class""" -from typing import List, Optional, Union +from typing import List, Optional, Union, Tuple from tvm._ffi import register_object as _register_object from tvm.error import TVMError, register_error @@ -43,7 +43,7 @@ class BlockRV(Object): """A random variable that refers to a block""" -ExprRV = PrimExpr # A random variable that evaluates to an integer +ExprRV = Union[PrimExpr] # A random variable that evaluates to an integer RAND_VAR_TYPE = Union[ExprRV, BlockRV, LoopRV] # type: ignore # pylint: disable=invalid-name @@ -257,6 +257,133 @@ def get_loops(self, block: BlockRV) -> List[LoopRV]: return _ffi_api_schedule.ScheduleGetLoops(self, block) # type: ignore # pylint: disable=no-member ########## Schedule: loops manipulation ########## + def fuse(self, *loops: List[LoopRV]) -> LoopRV: + """Fuse a list of consecutive loops into one. It requires: + 1) The loops can't have annotations or thread bindings. + 2) The (i+1)-th loop must be the only child of the i-th loop. + 3) All loops must start with 0. + + Parameters + ---------- + *loops : List[LoopRV] + The loops to be fused + + Returns + ---------- + fused_loop : LoopRV + The new loop after fusion + + Examples + -------- + + Before fuse, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_fuse(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and do fuse: + + .. code-block:: python + + sch = tir.Schedule(before_fuse, debug_mode=True) + i, j = sch.get_loops(sch.get_block("B")) + sch.fuse(i, j) + print(tvm.script.asscript(sch.mod["main"])) + + After applying fuse, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_fuse(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, [128, 128]) + for i0_i1_fused in tir.serial(0, 16384): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, tir.floordiv(i0_i1_fused, 128)) + tir.bind(vj, tir.floormod(i0_i1_fused, 128)) + tir.reads([A[vi, vj]]) + tir.writes([B[vi, vj]]) + B[vi, vj] = A[vi, vj] * 2.0 + + """ + return _ffi_api_schedule.ScheduleFuse(self, loops) # type: ignore # pylint: disable=no-member + + def split( + self, + loop: LoopRV, + factors: List[Optional[ExprRV]], + ) -> List[LoopRV]: + """Split a loop into a list of consecutive loops. It requires: + 1) The loop can't have annotation or thread binding. + 2) The loop must start with 0. + Predicates may be added to ensure the total loop numbers keeps unchanged. + In `factors`, at most one of the factors can be None or -1, + which will be automatically inferred. + Parameters + ---------- + loop : LoopRV + The loop to be split + + factors: List[Optional[ExprRV]] + The splitting factors + + Returns + ---------- + split_loops : List[LoopRV] + The new loops after split + + Examples + -------- + + Before split, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_split(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and do fuse: + + .. code-block:: python + + sch = tir.Schedule(before_split, debug_mode=True) + i, j = sch.get_loops(sch.get_block("B")) + sch.split(i, factors=[2, 64]) + print(tvm.script.asscript(sch.mod["main"])) + + After applying split, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_split(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, [128, 128]) + for i0_outer, i0_inner, i1 in tir.grid(2, 64, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, ((i0_outer*64) + i0_inner)) + tir.bind(vj, i1) + tir.reads([A[vi, vj]]) + tir.writes([B[vi, vj]]) + B[vi, vj] = A[vi, vj] * 2.0 + + """ + for i, factor in enumerate(factors): + if factor is None: + factors[i] = -1 + return _ffi_api_schedule.ScheduleSplit(self, loop, factors) # type: ignore # pylint: disable=no-member + ########## Schedule: compute location ########## def compute_inline(self, block: BlockRV) -> None: """Inline a block into its consumer(s). It requires: diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index e885195b3d..96434ab373 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -515,7 +515,6 @@ class IterMapRewriter : public ExprMutator { */ Optional TryFuseIters(IterSumExpr expr) { if (!is_zero(expr->base)) return NullOpt; - if (expr->args.size() == 1) return expr->args[0]; // select the iterators in order std::vector visited(expr->args.size(), false); std::vector flattened_iters, grouped_iters; @@ -1086,6 +1085,22 @@ TVM_REGISTER_GLOBAL("arith.NormalizeIterMapToExpr").set_body_typed([](const Iter return NormalizeIterMapToExpr(expr); }); +Array IterMapSimplify(const Array& indices, const Map& input_iters, + const PrimExpr& input_pred, bool require_bijective) { + Analyzer analyzer; + Array rewrite = + DetectIterMap(indices, input_iters, input_pred, require_bijective, &analyzer); + if (rewrite.empty()) { + return indices; + } else { + Array res; + res.reserve(rewrite.size()); + IterMapToExprNormalizer converter(&analyzer); + for (const auto& expr : rewrite) res.push_back(converter.Convert(expr)); + return res; + } +} + /*! * \brief Divider to divide the bindings into two sets of bindings(outer and inner) * such that binding_i = Y_i * E(Xi) + Xi, where E(X) is the extent of X. diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index a58e4433da..49ecb85b89 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -799,6 +799,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { TVM_TRY_REWRITE_IF(floordiv(x + c1, c2), floordiv(x, c2) + floordiv(c1, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_REWRITE_IF(floordiv(x * c1, x * c2), floordiv(c1, c2), c2.Eval()->value > 0); + TVM_TRY_REWRITE_IF(floordiv(x + y, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0)); TVM_TRY_REWRITE_IF(floordiv(y + x, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0)); @@ -881,7 +883,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); - + TVM_TRY_REWRITE_IF(floormod(x * c1, x * c2), x * floormod(c1, c2), c2.Eval()->value != 0); TVM_TRY_REWRITE(floormod(x * y, y), ZeroWithTypeLike(x)); TVM_TRY_REWRITE(floormod(y * x, y), ZeroWithTypeLike(y)); diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index dd7fee37e2..0d713707a5 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -142,6 +142,12 @@ Array GetLoops(const StmtSRef& block_sref); * \return A list of leaf blocks */ Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref); +/*! + * \brief Get the direct child Schedulable Stmt (Block and For) + * \param stmt the parent stmt. + * \return the list of child stmts + */ +Array GetChildren(const Stmt& stmt); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index d58dece3c6..7584d36a65 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -298,5 +298,35 @@ Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent throw; } +Array GetChildren(const Stmt& stmt) { + /*! \note Nested SeqStmt is not allowed in schedule. */ + Stmt body; + if (const auto* block = stmt.as()) { + body = block->body; + } else if (const auto* loop = stmt.as()) { + body = loop->body; + } else { + LOG(FATAL) << "The Stmt can only be a Block or a For"; + } + if (const auto* seq = body.as()) { + Array ret; + for (const Stmt& child : seq->seq) { + ICHECK(!child->IsInstance()) << "Nested SeqStmt is not allowed in schedule."; + if (child->IsInstance()) { + ret.push_back(child.as()->block); + } else { + ret.push_back(child); + } + } + return ret; + } else { + if (body->IsInstance()) { + return Array{body.as()->block}; + } else { + return Array{body}; + } + } +} + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 0563d39427..a180bd7613 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -258,6 +258,34 @@ Array ConcreteScheduleNode::GetLoops(const BlockRV& block_rv) { } /******** Schedule: loops manipulation ********/ + +LoopRV ConcreteScheduleNode::Fuse(const Array& loop_rvs) { + TVM_TIR_SCHEDULE_BEGIN(); + CHECK(!loop_rvs.empty()) << "ValueError: 'fuse' requires at least 1 loop(s)"; + Array loop_srefs = this->GetSRefs(loop_rvs); + StmtSRef fused_sref = tir::Fuse(state_, loop_srefs); + this->state_->DebugVerify(); + return CreateRV(fused_sref); + TVM_TIR_SCHEDULE_END("fuse", this->error_render_level_); + throw; +} + +Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, const Array& factor_rvs) { + TVM_TIR_SCHEDULE_BEGIN(); + // Prepare for the splitting + StmtSRef loop_sref = this->GetSRef(loop_rv); + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + Array factors; + factors.reserve(factor_rvs.size()); + for (const ExprRV& factor_rv : factor_rvs) { + factors.push_back(this->Get(factor_rv)); + } + Array results = tir::Split(state_, loop_sref, factors); + return CreateRV(results); + TVM_TIR_SCHEDULE_END("split", this->error_render_level_); + throw; +} + /******** Schedule: compute location ********/ void ConcreteScheduleNode::ComputeInline(const BlockRV& block_rv) { diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 8945fb9ee0..250246a01e 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -68,6 +68,8 @@ class ConcreteScheduleNode : public ScheduleNode { inline PrimExpr Get(const ExprRV& expr_rv) const final; inline StmtSRef GetSRef(const BlockRV& block_rv) const final; inline StmtSRef GetSRef(const LoopRV& loop_rv) const final; + inline Array GetSRefs(const Array& rvs) const final; + inline Array GetSRefs(const Array& rvs) const final; void RemoveRV(const BlockRV& block_rv) final { RemoveFromSymbolTable(block_rv); } void RemoveRV(const LoopRV& loop_rv) final { RemoveFromSymbolTable(loop_rv); } void RemoveRV(const ExprRV& expr_rv) final { RemoveFromSymbolTable(expr_rv); } @@ -78,6 +80,8 @@ class ConcreteScheduleNode : public ScheduleNode { BlockRV GetBlock(const String& name, const String& func_name = "main") override; Array GetLoops(const BlockRV& block_rv) override; /******** Schedule: loops manipulation ********/ + LoopRV Fuse(const Array& loop_rvs) override; + Array Split(const LoopRV& loop_rv, const Array& factors) override; /******** Schedule: compute location ********/ void ComputeInline(const BlockRV& block) override; void ReverseComputeInline(const BlockRV& block) override; @@ -143,17 +147,22 @@ inline For ConcreteScheduleNode::Get(const LoopRV& loop_rv) const { } inline PrimExpr ConcreteScheduleNode::Get(const ExprRV& expr_rv) const { - auto it = this->symbol_table_.find(expr_rv); - if (it == this->symbol_table_.end()) { - LOG(FATAL) << "IndexError: Cannot find corresponding ExprRV: " << expr_rv; - } - const ObjectRef& obj = (*it).second; - const auto* expr_node = obj.as(); - if (expr_node == nullptr) { - LOG(FATAL) << "ValueError: ExprRV's corresponding type is invalid: " - << (obj.defined() ? obj->GetTypeKey() : "None"); - } - return GetRef(expr_node); + PrimExpr transformed = Substitute(expr_rv, [this](const Var& var) -> Optional { + auto it = this->symbol_table_.find(var); + if (it == this->symbol_table_.end()) { + LOG(FATAL) << "IndexError: Cannot find corresponding ExprRV: " << var; + } + const ObjectRef& obj = (*it).second; + const auto* int_imm = obj.as(); + if (int_imm == nullptr) { + LOG(FATAL) << "ValueError: ExprRV's corresponding type is invalid: " + << (obj.defined() ? obj->GetTypeKey() : "None"); + } + return Integer(int_imm->value); + }); + PrimExpr simplified = this->analyzer_->Simplify(transformed); + CHECK(is_const_int(transformed)) << "ValueError: The ExprRV does not have a specific value"; + return simplified; } inline StmtSRef ConcreteScheduleNode::GetSRef(const BlockRV& block_rv) const { @@ -198,6 +207,24 @@ inline StmtSRef ConcreteScheduleNode::GetSRef(const LoopRV& loop_rv) const { return GetRef(sref); } +template +inline Array GetSRefsHelper(const ConcreteScheduleNode* sch, const Array& rvs) { + Array result; + result.reserve(rvs.size()); + for (const T& rv : rvs) { + result.push_back(sch->GetSRef(rv)); + } + return result; +} + +inline Array ConcreteScheduleNode::GetSRefs(const Array& rvs) const { + return GetSRefsHelper(this, rvs); +} + +inline Array ConcreteScheduleNode::GetSRefs(const Array& rvs) const { + return GetSRefsHelper(this, rvs); +} + /******** Adding/Removing elements in the symbol table ********/ template diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index ab8299e381..4f36910989 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -25,7 +25,27 @@ namespace tvm { namespace tir { /******** Schedule: loops manipulation ********/ - +/*! + * Split a loop into several consecutive loops. It requires: + * 1) The loop can't have annotation or thread binding. + * 2) The loop must start with 0. + * \param self The state of the schedule + * \param loop_sref The sref to the loop being split + * \param factors The splitting factors + * \return An array of srefs to the loops after splitting + */ +TVM_DLL Array Split(ScheduleState self, const StmtSRef& loop_sref, + const Array& factors); +/*! + * \brief Fuse consecutive loops. It requires: + * 1) The loops can't have annotations or thread bindings. + * 2) The inner loop must be the only child of the outer loop. + * 3) All loops must start with 0. + * \param self The state of the schedule + * \param loop_srefs An array of srefs to the loops to be fused + * \return The sref to the fused loop + */ +TVM_DLL StmtSRef Fuse(ScheduleState self, const Array& loop_srefs); /******** Schedule: compute location ********/ /*! * \brief Inline a block into its consumer(s). It requires: diff --git a/src/tir/schedule/primitive/fuse_split.cc b/src/tir/schedule/primitive/fuse_split.cc new file mode 100644 index 0000000000..02a8774f91 --- /dev/null +++ b/src/tir/schedule/primitive/fuse_split.cc @@ -0,0 +1,483 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" +namespace tvm { +namespace tir { + +/*! \brief Append a new predicate to the each children of type BlockRealize (not recursively) */ +class PredicateUpdater : public StmtMutator { + public: + /*! + * \brief Constructor + * \param predicate The predicate to be apppend to BlockRealizeNode + */ + explicit PredicateUpdater(const PrimExpr& predicate, arith::Analyzer* ana) + : predicate_(predicate) { + if (!ana->CanProve(predicate)) { + add_predicate_ = true; + } + } + + private: + // For each direct child of type BlockRealizeNode, append the predicate + Stmt VisitStmt_(const BlockRealizeNode* realize) final { + // We do not recursively do this + if (add_predicate_) { + ObjectPtr n = CopyOnWrite(realize); + n->predicate = n->predicate && predicate_; + return BlockRealize(n); + } else { + return GetRef(realize); + } + } + + /*! \brief The predicate to be added */ + const PrimExpr& predicate_; + /*! \brief whether to add predicate */ + bool add_predicate_; +}; +/*! \brief Substitute vars and collect the reuse mapping of opaque blocks */ +class IRSubstituteAndCollectOpaqueBlock : public StmtExprMutator { + public: + explicit IRSubstituteAndCollectOpaqueBlock(std::function(const Var&)> vmap, + Map* opaque_blocks) + : vmap_(vmap), opaque_blocks_(opaque_blocks) {} + + private: + PrimExpr VisitExpr_(const VarNode* op) final { + Var var = GetRef(op); + Optional ret = vmap_(var); + if (ret.defined()) { + return ret.value(); + } else { + return std::move(var); + } + } + + Stmt VisitStmt_(const BlockRealizeNode* op) final { + Stmt res = StmtMutator::VisitStmt_(op); + if (op->block->iter_vars.empty()) { + const BlockRealizeNode* realize = res.as(); + opaque_blocks_->Set(op->block, realize->block); + } + return res; + } + + /*! \brief The substitute function */ + std::function(const Var&)> vmap_; + /*! \brief The reuse mapping */ + Map* opaque_blocks_; +}; + +Stmt SubstituteAndCollectOpaqueBlock(Stmt stmt, Map* opaque_blocks, + std::function(const Var&)> vmap) { + return IRSubstituteAndCollectOpaqueBlock(vmap, opaque_blocks)(std::move(stmt)); +} + +/*! \brief Simplify the binding of block realize and update the opaque block reuse mapping*/ +class BlockRealizeRewriter : public StmtExprMutator { + public: + explicit BlockRealizeRewriter( + const std::unordered_map& loop_map, + Map* opaque_blocks) + : opaque_blocks_(opaque_blocks) { + loop_map_.insert(loop_map.begin(), loop_map.end()); + } + + private: + Stmt VisitStmt_(const ForNode* op) final { + loop_map_[op->loop_var] = Range::FromMinExtent(op->min, op->extent); + Stmt res = StmtMutator::VisitStmt_(op); + loop_map_.erase(op->loop_var); + return res; + } + + Stmt VisitStmt_(const BlockRealizeNode* op) final { + // skip opaque block and update mapping + if (op->iter_values.empty()) { + Stmt res = StmtMutator::VisitStmt_(op); + const BlockRealizeNode* realize = res.as(); + for (const std::pair& entry : *opaque_blocks_) { + if (entry.second.same_as(op->block)) { + opaque_blocks_->Set(entry.first, realize->block); + break; + } + } + return res; + } + auto v = arith::IterMapSimplify(op->iter_values, loop_map_, op->predicate, false); + if (v.same_as(op->iter_values)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->iter_values = std::move(v); + return Stmt(n); + } + } + /*! \brief The range of loops */ + std::unordered_map loop_map_; + /*! \brief The reuse mapping */ + Map* opaque_blocks_; +}; + +Stmt SimplifyBindings(const Stmt& stmt, const Array& loops, + Map* opaque_blocks) { + std::unordered_map loop_map; + for (const StmtSRef& sref : loops) { + const auto* loop = sref->StmtAs(); + loop_map[loop->loop_var] = Range::FromMinExtent(loop->min, loop->extent); + } + BlockRealizeRewriter rewriter(loop_map, opaque_blocks); + return rewriter(stmt); +} + +class NotLoopError : public ScheduleError { + public: + explicit NotLoopError(IRModule mod, String type) : mod_(mod), type_(type) {} + + String FastErrorString() const final { + return "ScheduleError: this primitive only operates on a " + "loop"; + } + + String DetailRenderTemplate() const final { + return "this primitive only operates on a loop, but the StmtSref passed in points to" + "type: {0} "; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {type_}; } + + IRModule mod_; + String type_; +}; + +class HasAnnotationError : public ScheduleError { + public: + explicit HasAnnotationError(IRModule mod, For loop) : mod_(mod), loop_(loop) {} + + String FastErrorString() const final { + return "ScheduleError: The primitive can't be applied because the loop has annotation"; + } + + String DetailRenderTemplate() const final { + return "The primitive can't be applied because the loop {0} has annotation"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_}; } + + IRModule mod_; + For loop_; +}; + +class HasThreadBindingError : public ScheduleError { + public: + explicit HasThreadBindingError(IRModule mod, For loop) : mod_(mod), loop_(loop) {} + + String FastErrorString() const final { + return "ScheduleError: The primitive can't be applied because the loop has thread binding"; + } + + String DetailRenderTemplate() const final { + return "The primitive can't be applied because the loop {0} has thread binding"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_}; } + + IRModule mod_; + For loop_; +}; + +class OuterNotInnerParent : public ScheduleError { + public: + explicit OuterNotInnerParent(IRModule mod, For outer, For inner) + : mod_(mod), outer_(outer), inner_(inner) {} + + String FastErrorString() const final { + return "ScheduleError: the outer loop is not the parent of the inner loop"; + } + + String DetailRenderTemplate() const final { + return "The loops can't be fused because the outer loop {0} is not the parent of the inner " + "loop {1}"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {outer_, inner_}; } + + IRModule mod_; + For outer_; + For inner_; +}; + +class NotOnlyChildError : public ScheduleError { + public: + explicit NotOnlyChildError(IRModule mod, For outer, For inner) + : mod_(mod), outer_(outer), inner_(inner) {} + + String FastErrorString() const final { + return "ScheduleError: the inner loop is not the only child of outer loop"; + } + + String DetailRenderTemplate() const final { + return "The loops can't be fused because the inner loop {1} is not the only child of outer " + "loop {0}."; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {outer_, inner_}; } + + IRModule mod_; + For outer_; + For inner_; +}; + +class LoopNotStartWithZeroError : public ScheduleError { + public: + explicit LoopNotStartWithZeroError(IRModule mod, For loop) : mod_(mod), loop_(loop) {} + + String FastErrorString() const final { + return "ScheduleError: the primitive only supports loop starting with 0"; + } + + String DetailRenderTemplate() const final { + return "The loop {0} does not start with 0, which is not supported"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_}; } + + IRModule mod_; + For loop_; +}; + +class NotSingleInferFactorError : public ScheduleError { + public: + explicit NotSingleInferFactorError(IRModule mod) : mod_(mod) {} + + String FastErrorString() const final { + return "ScheduleError: only one factor can be specified as -1 or none"; + } + + String DetailRenderTemplate() const final { + return "Only one factor can be specified as -1 or none"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {}; } + + IRModule mod_; +}; + +class WrongFactorProductError : public ScheduleError { + public: + explicit WrongFactorProductError(IRModule mod, For loop) : mod_(mod), loop_(loop) {} + + String FastErrorString() const final { + return "ScheduleError: The product of factors is not larger than or equal to the extent of " + "loop"; + } + + String DetailRenderTemplate() const final { + return "The product of factors is not larger than or equal to the extent of loop {0}"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_}; } + + IRModule mod_; + For loop_; +}; + +Array Split(ScheduleState self, const StmtSRef& loop_sref, + const Array& factors) { + // Invariance + // - The total repeat number has not changed for each direct child block with updating predicate. + // - The execution order has not changed. (The block executes with the same args and the same + // order with before. + // Step 1. Check correctness + GetScopeRootAndCheckStagePipeline(self, loop_sref); + const auto* loop = loop_sref->StmtAs(); + if (loop == nullptr) { + throw NotLoopError(self->mod, loop_sref->stmt->GetTypeKey()); + } + if (!loop->annotations.empty()) { + throw HasAnnotationError(self->mod, GetRef(loop)); + } + if (loop->thread_binding.defined()) { + throw HasThreadBindingError(self->mod, GetRef(loop)); + } + // Currently, loops starting with 0 is not supported + arith::Analyzer analyzer; + if (!analyzer.CanProve(loop->min == 0)) { + throw LoopNotStartWithZeroError(self->mod, GetRef(loop)); + } + PrimExpr tot_length = 1; + int infer_index = -1; + for (size_t i = 0; i < factors.size(); i++) { + if (!analyzer.CanProve(factors[i] == -1)) { + tot_length *= factors[i]; + } else { + if (infer_index != -1) { + throw NotSingleInferFactorError(self->mod); + } else { + infer_index = i; + } + } + } + // Step 2. infer factors if needed + Array inferred_factors(factors); + if (infer_index != -1) { + inferred_factors.Set(infer_index, + analyzer.Simplify(floordiv(loop->extent + tot_length - 1, tot_length))); + } else { + if (!analyzer.CanProve(tot_length >= loop->extent)) { + throw WrongFactorProductError(self->mod, GetRef(loop)); + } + } + // Step 3. Replace all occurrence of the original loop var with new variables + std::vector new_loop_vars; + new_loop_vars.reserve(inferred_factors.size()); + for (size_t i = 0; i < inferred_factors.size(); i++) { + new_loop_vars.push_back(loop->loop_var.copy_with_suffix("_" + std::to_string(i))); + } + PrimExpr substitute_value = 0; + for (size_t i = 0; i < inferred_factors.size(); i++) { + substitute_value *= inferred_factors[i]; + substitute_value += new_loop_vars[i]; + } + Map opaque_block_reuse; + auto substitute_function = [&](const Var& v) -> Optional { + if (v.same_as(loop->loop_var)) { + return substitute_value; + } else { + return NullOpt; + } + }; + Stmt new_loop_body = + SubstituteAndCollectOpaqueBlock(loop->body, &opaque_block_reuse, substitute_function); + for (size_t i = 0; i < inferred_factors.size(); i++) { + analyzer.Bind(new_loop_vars[i], Range::FromMinExtent(0, inferred_factors[i])); + } + // Step 4. Update predicate to guard the loop + PrimExpr predicate = substitute_value < loop->extent; + new_loop_body = PredicateUpdater(predicate, &analyzer)(new_loop_body); + // Step 5. Generate tnested loops to replace the original loop and simplify the binding + Stmt outer_stmt = new_loop_body; + for (int i = inferred_factors.size() - 1; i >= 0; i--) { + outer_stmt = For(new_loop_vars[i], 0, inferred_factors[i], loop->kind, outer_stmt); + } + + outer_stmt = + Downcast(SimplifyBindings(outer_stmt, GetLoops(loop_sref), &opaque_block_reuse)); + self->Replace(loop_sref, outer_stmt, opaque_block_reuse); + Array result_srefs; + result_srefs.reserve(inferred_factors.size()); + for (size_t i = 0; i < inferred_factors.size(); i++) { + result_srefs.push_back(self->stmt2ref.at(outer_stmt.get())); + const ForNode* outer_loop = outer_stmt.as(); + ICHECK(outer_loop); + outer_stmt = outer_loop->body; + } + return result_srefs; +} + +StmtSRef Fuse(ScheduleState self, const Array& loop_srefs) { + // Invariance + // - The total repeat number has not changed for each direct child block. + // - The execution order has not changed. (The block executes with the same + // args and the same order with before.) + std::vector loops; + loops.reserve(loop_srefs.size()); + StmtSRef outer_sref{nullptr}; + const ForNode* outer_loop = nullptr; + arith::Analyzer analyzer; + // Step 1. check correctness + GetScopeRootAndCheckStagePipeline(self, loop_srefs[0]); + for (const StmtSRef& sref : loop_srefs) { + const auto* loop = sref->StmtAs(); + if (loop == nullptr) { + throw NotLoopError(self->mod, sref->stmt->GetTypeKey()); + } + if (!loop->annotations.empty()) { + throw HasAnnotationError(self->mod, GetRef(loop)); + } + if (loop->thread_binding.defined()) { + throw HasThreadBindingError(self->mod, GetRef(loop)); + } + if (outer_sref.defined()) { + if (sref->parent != outer_sref.get()) { + throw OuterNotInnerParent(self->mod, GetRef(outer_loop), GetRef(loop)); + } + Array outer_children = GetChildren(GetRef(outer_loop)); + if (outer_children.size() != 1 || outer_children[0].get() != loop) { + throw NotOnlyChildError(self->mod, GetRef(outer_loop), GetRef(loop)); + } + } + outer_sref = sref; + outer_loop = loop; + if (!analyzer.CanProve(loop->min == 0)) { + throw LoopNotStartWithZeroError(self->mod, GetRef(loop)); + } + loops.push_back(loop); + } + // Step 2. Create fused loop var and replace the original loop vars + std::string suffix; + for (size_t i = 1; i < loops.size(); i++) { + suffix += "_" + loops[i]->loop_var->name_hint; + } + suffix += "_fused"; + Var fused_var = loops[0]->loop_var.copy_with_suffix(suffix); + Array substitute_value; + substitute_value.resize(loops.size()); + PrimExpr tot = fused_var; + for (int i = loops.size() - 1; i >= 0; i--) { + substitute_value.Set(i, floormod(tot, loops[i]->extent)); + tot = floordiv(tot, loops[i]->extent); + } + Stmt loop_body = loops.back()->body; + Map opaque_block_reuse; + auto substitute_function = [&](const Var& v) -> Optional { + for (size_t i = 0; i < loops.size(); i++) { + if (v.same_as(loops[i]->loop_var)) { + return substitute_value[i]; + } + } + return NullOpt; + }; + Stmt new_loop_body = + SubstituteAndCollectOpaqueBlock(loop_body, &opaque_block_reuse, substitute_function); + // Step 3. Generate a loop to replace the original loops + PrimExpr fused_min = 0; + PrimExpr fused_extent = 1; + for (size_t i = 0; i < loops.size(); i++) { + fused_extent *= loops[i]->extent; + } + fused_extent = analyzer.Simplify(fused_extent); + For fused_loop = For(fused_var, fused_min, fused_extent, loops[0]->kind, new_loop_body); + fused_loop = + Downcast(SimplifyBindings(fused_loop, GetLoops(loop_srefs[0]), &opaque_block_reuse)); + self->Replace(loop_srefs[0], fused_loop, opaque_block_reuse); + return self->stmt2ref.at(fused_loop.get()); +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 115f7936f6..77d17c9dc6 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -123,6 +123,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetBlock") TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetLoops") .set_body_method(&ScheduleNode::GetLoops); /******** (FFI) loops manipulation ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method(&ScheduleNode::Fuse); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method(&ScheduleNode::Split); /******** (FFI) compute location ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeInline") .set_body_method(&ScheduleNode::ComputeInline); diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py b/tests/python/unittest/test_tir_schedule_split_fuse.py new file mode 100644 index 0000000000..56f1a4a3ff --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_split_fuse.py @@ -0,0 +1,469 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import pytest +import tvm +from tvm import tir +from tvm.script import ty + +# pylint: disable=no-member,invalid-name,unused-variable + + +@tvm.script.tir +def elementwise(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_symbolic(a: ty.handle, b: ty.handle, n: ty.int32) -> None: + A = tir.match_buffer(a, (128, 128, n)) + B = tir.match_buffer(b, (128, 128, n)) + for i, j, k in tir.grid(128, 128, n): + with tir.block([128, 128, n], "B") as [vi, vj, vk]: + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_symbolic_fused(a: ty.handle, b: ty.handle, n: ty.int32) -> None: + A = tir.match_buffer(a, (128, 128, n)) + B = tir.match_buffer(b, (128, 128, n)) + for i_j_k_fused in tir.serial(0, (n * 16384)): + with tir.block([128, 128, n], "B") as [vi, vj, vk]: + tir.bind(vi, tir.floordiv(i_j_k_fused, (n * 128))) + tir.bind(vj, tir.floormod(tir.floordiv(i_j_k_fused, n), 128)) + tir.bind(vk, tir.floormod(i_j_k_fused, n)) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_symbolic_split(a: ty.handle, b: ty.handle, n: ty.int32) -> None: + A = tir.match_buffer(a, (128, 128, n)) + B = tir.match_buffer(b, (128, 128, n)) + for i, j, k0, k1 in tir.grid(128, 128, 10, tir.floordiv((n + 9), 10)): + with tir.block([128, 128, n], "B") as [vi, vj, vk]: + tir.where((((k0 * tir.floordiv((n + 9), 10)) + k1) < n)) + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, ((k0 * tir.floordiv((n + 9), 10)) + k1)) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_with_seq(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + C = tir.alloc_buffer((128, 128, 128)) + for i, j in tir.grid(128, 128): + for k in tir.serial(0, 128): + with tir.block([128, 128, 128], "C") as [vi, vj, vk]: + C[vi, vj, vk] = A[vi, vj, vk] * 2.0 + for k in tir.serial(0, 128): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + B[vi, vj, vk] = C[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_with_anno(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + for i, j in tir.grid(128, 128): + for k in tir.serial(0, 128, annotations={"useless_annotation": True}): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_with_thread_binding(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + for i, j in tir.grid(128, 128): + for k in tir.thread_binding(0, 128, thread="threadIdx.x"): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_with_starting_point(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + for i, j in tir.grid(128, 128): + for k in tir.serial(10, 128): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_with_opaque_block(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + for i, j, k in tir.grid(128, 128, 128): + with tir.block([], "opaque"): + tir.reads([A[i, j, k]]) + tir.writes([B[i, j, k]]) + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_fused(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + for fused in tir.serial(0, 2097152): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, tir.floordiv(fused, 16384)) + tir.bind(vj, tir.floormod(tir.floordiv(fused, 128), 128)) + tir.bind(vk, tir.floormod(fused, 128)) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_split_case0(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128, 128]) + B = tir.match_buffer(b, [128, 128, 128]) + for i1, i2, i3, j1, j2, k1, k2 in tir.grid(2, 1, 64, 4, 32, 16, 8): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, ((i1 * 64) + i3)) + tir.bind(vj, ((j1 * 32) + j2)) + tir.bind(vk, ((k1 * 8) + k2)) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_split_case1(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128, 128]) + B = tir.match_buffer(b, [128, 128, 128]) + for i1, i2, i3, j1, j2, j3, k1, k2, k3 in tir.grid(2, 1, 64, 2, 1, 64, 2, 1, 64): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, i1 * 64 + i3) + tir.bind(vj, j1 * 64 + j3) + tir.bind(vk, k1 * 64 + k3) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_split_with_predicate(a: ty.handle, b: ty.handle) -> None: + B = tir.match_buffer(b, [128, 128, 128]) + A = tir.match_buffer(a, [128, 128, 128]) + for i0, i1, i2, j0, j1, k0, k1 in tir.grid(1000, 2, 3, 1, 129, 3, 43): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.where( + ( + ((((((i0 * 2) + i1) * 3) + i2) < 128) and (((j0 * 129) + j1) < 128)) + and (((k0 * 43) + k1) < 128) + ) + ) + tir.bind(vi, (((i0 * 6) + (i1 * 3)) + i2)) + tir.bind(vj, j1) + tir.bind(vk, ((k0 * 43) + k1)) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_fuse_with_opaque_block(a: ty.handle, b: ty.handle) -> None: + B = tir.match_buffer(b, [128, 128, 128]) + A = tir.match_buffer(a, [128, 128, 128]) + for i_j_k_fused in tir.serial(0, 2097152): + with tir.block([], "opaque"): + tir.reads( + [ + A[ + tir.floormod(tir.floordiv(tir.floordiv(i_j_k_fused, 128), 128), 128), + tir.floormod(tir.floordiv(i_j_k_fused, 128), 128), + tir.floormod(i_j_k_fused, 128), + ] + ] + ) + tir.writes( + [ + B[ + tir.floormod(tir.floordiv(tir.floordiv(i_j_k_fused, 128), 128), 128), + tir.floormod(tir.floordiv(i_j_k_fused, 128), 128), + tir.floormod(i_j_k_fused, 128), + ] + ] + ) + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, tir.floordiv(i_j_k_fused, 16384)) + tir.bind(vj, tir.floormod(tir.floordiv(i_j_k_fused, 128), 128)) + tir.bind(vk, tir.floormod(i_j_k_fused, 128)) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_split_with_opaque_block(a: ty.handle, b: ty.handle) -> None: + B = tir.match_buffer(b, [128, 128, 128]) + A = tir.match_buffer(a, [128, 128, 128]) + + for i0, i1, j, k in tir.grid(8, 16, 128, 128): + with tir.block([], "opaque"): + tir.reads([A[i0 * 16 + i1, j, k]]) + tir.writes([B[i0 * 16 + i1, j, k]]) + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, i0 * 16 + i1) + tir.bind(vj, j) + tir.bind(vk, k) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def opaque_access(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [16, 16], "float32") + B = tir.match_buffer(b, [16, 16], "float32") + with tir.block([16, 16], "A") as [vi, vj]: + tir.reads([]) + tir.writes([A[0:16, 0:16]]) + tir.store(A.data, vi * 16 + vj, 1) + with tir.block([16, 16], "B") as [vi, vj]: + tir.reads([]) + tir.writes([B[0:16, 0:16]]) + tir.evaluate(tir.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj, dtype="handle")) + + +@tvm.script.tir +def opaque_access_fused(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [16, 16]) + B = tir.match_buffer(b, [16, 16]) + for i_j_fused in tir.serial(0, 256): + with tir.block([16, 16], "A") as [vi, vj]: + tir.bind(vi, tir.floordiv(i_j_fused, 16)) + tir.bind(vj, tir.floormod(i_j_fused, 16)) + tir.reads([]) + tir.writes([A[0:16, 0:16]]) + tir.store(A.data, ((vi * 16) + vj), 1, 1) + for i_j_fused in tir.serial(0, 256): + with tir.block([16, 16], "B") as [vi, vj]: + tir.bind(vi, tir.floordiv(i_j_fused, 16)) + tir.bind(vj, tir.floormod(i_j_fused, 16)) + tir.reads([]) + tir.writes([B[0:16, 0:16]]) + tir.evaluate( + tir.tvm_fill_fragment(B.data, 16, 16, 16, 0, ((vi * 16) + vj), dtype="handle") + ) + + +@tvm.script.tir +def opaque_access_split(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16)) + B = tir.match_buffer(b, (16, 16)) + for i, j0, j1 in tir.grid(16, 4, 4): + with tir.block([16, 16], "A") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, ((j0 * 4) + j1)) + tir.reads([]) + tir.writes([A[0:16, 0:16]]) + tir.store(A.data, ((vi * 16) + vj), 1, 1) + for i, j0, j1 in tir.grid(16, 4, 4): + with tir.block([16, 16], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, ((j0 * 4) + j1)) + tir.reads([]) + tir.writes([B[0:16, 0:16]]) + tir.evaluate( + tir.tvm_fill_fragment(B.data, 16, 16, 16, 0, ((vi * 16) + vj), dtype="handle") + ) + + +# pylint: enable=no-member,invalid-name,unused-variable + + +def test_fuse(): + sch = tir.Schedule(elementwise, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + sch.fuse(i, j, k) + assert sch.state._get_cached_flags(sch.get_sref(block_b)).stage_pipeline + tvm.ir.assert_structural_equal(elementwise_fused, sch.mod["main"]) + + +def test_split(): + sch = tir.Schedule(elementwise, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + sch.split(i, factors=[2, 1, 64]) + sch.split(j, factors=[4, 32]) + sch.split(k, factors=[16, 8]) + assert sch.state._get_cached_flags(sch.get_sref(block_b)).stage_pipeline + tvm.ir.assert_structural_equal(elementwise_split_case0, sch.mod["main"]) + + +def test_split_with_inferred_factor(): + sch = tir.Schedule(elementwise, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + sch.split(i, factors=[None, 1, 64]) + sch.split(j, factors=[2, None, 64]) + sch.split(k, factors=[2, 1, -1]) + tvm.ir.assert_structural_equal(elementwise_split_case1, sch.mod["main"]) + + +def test_split_with_predicate(): + sch = tir.Schedule(elementwise, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + sch.split(i, factors=[1000, 2, 3]) + sch.split(j, factors=[None, 129]) + sch.split(k, factors=[3, None]) + assert sch.state._get_cached_flags(sch.get_sref(block_b)).stage_pipeline + tvm.ir.assert_structural_equal(elementwise_split_with_predicate, sch.mod["main"]) + + +def test_fuse_fail_not_only_child(): + sch = tir.Schedule(elementwise_with_seq, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.fuse(j, k) + + +def test_fuse_split_fail_with_annotation(): + sch = tir.Schedule(elementwise_with_anno, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.fuse(j, k) + with pytest.raises(tvm.tir.ScheduleError): + sch.split(k, factors=[None, 10]) + + +def test_fuse_split_fail_not_start_with_zero(): + sch = tir.Schedule(elementwise_with_anno, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.fuse(j, k) + with pytest.raises(tvm.tir.ScheduleError): + sch.split(k, factors=[None, 10]) + + +def test_fuse_with_opaque_block(): + sch = tir.Schedule(elementwise_with_opaque_block, debug_mode=True) + block_opaque = sch.get_block("opaque") + i, j, k = sch.get_loops(block_opaque) + sch.fuse(i, j, k) + tvm.ir.assert_structural_equal(elementwise_fuse_with_opaque_block, sch.mod["main"]) + + +def test_fuse_with_opaque_access(): + sch = tir.Schedule(opaque_access, debug_mode=True) + block_a = sch.get_block("A") + i, j = sch.get_loops(block_a) + sch.fuse(i, j) + block_b = sch.get_block("B") + i, j = sch.get_loops(block_b) + sch.fuse(i, j) + tvm.ir.assert_structural_equal(opaque_access_fused, sch.mod["main"]) + + +def test_split_with_opaque_block(): + sch = tir.Schedule(elementwise_with_opaque_block, debug_mode=True) + block_opaque = sch.get_block("opaque") + i, j, k = sch.get_loops(block_opaque) + sch.split(i, factors=[None, 16]) + tvm.ir.assert_structural_equal(elementwise_split_with_opaque_block, sch.mod["main"]) + + +def test_split_with_opaque_access(): + sch = tir.Schedule(opaque_access, debug_mode=True) + block_a = sch.get_block("A") + i, j = sch.get_loops(block_a) + sch.split(j, factors=[None, 4]) + block_b = sch.get_block("B") + i, j = sch.get_loops(block_b) + sch.split(j, factors=[None, 4]) + tvm.ir.assert_structural_equal(opaque_access_split, sch.mod["main"]) + + +def test_fuse_split_fail_with_thread_binding(): + sch = tir.Schedule(elementwise_with_thread_binding, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.fuse(j, k) + with pytest.raises(tvm.tir.ScheduleError): + sch.split(k, factors=[None, 10]) + + +def test_fuse_symbolic(): + sch = tir.Schedule(elementwise_symbolic, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + sch.fuse(i, j, k) + tvm.ir.assert_structural_equal(elementwise_symbolic_fused, sch.mod["main"]) + + +def test_split_symbolic(): + sch = tir.Schedule(elementwise_symbolic, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + sch.split(k, factors=[10, None]) + tvm.ir.assert_structural_equal(elementwise_symbolic_split, sch.mod["main"]) + + +if __name__ == "__main__": + test_fuse() + test_fuse_with_opaque_block() + test_fuse_with_opaque_access() + test_fuse_symbolic() + test_split() + test_split_with_inferred_factor() + test_split_with_opaque_block() + test_split_with_opaque_access() + test_split_with_predicate() + test_split_symbolic() + test_fuse_fail_not_only_child() + test_fuse_split_fail_with_annotation() + test_fuse_split_fail_not_start_with_zero() + test_fuse_split_fail_with_thread_binding()