diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 68900e107d7c..d3ecd8a1135b 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -303,6 +303,18 @@ class ScheduleNode : public runtime::Object { * \param ordered_loop_rvs The loops in the new order */ virtual void Reorder(const Array& ordered_loop_rvs) = 0; + /*! + * \brief Create a new unit loop on top of the specific block. + * \param block_rv The block above which the new loop is created + * \return The new loop created + */ + virtual LoopRV AddUnitLoop(const BlockRV& block_rv) = 0; + /*! + * \brief Create a new unit loop on top of the specific loop. + * \param loop_rv The loop above which the new loop is created + * \return The new loop created + */ + virtual LoopRV AddUnitLoop(const LoopRV& loop_rv) = 0; /******** Schedule: Manipulate ForKind ********/ /*! * \brief Parallelize the input loop. It requires: diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 4179088aa534..d225280b655f 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -15,19 +15,19 @@ # specific language governing permissions and limitations # under the License. """The TensorIR schedule class""" -from typing import Callable, Dict, List, Optional, Union, Tuple +from typing import Callable, Dict, List, Optional, Tuple, Union from tvm._ffi import register_object as _register_object from tvm.error import TVMError, register_error from tvm.ir import IRModule, PrimExpr from tvm.runtime import Object, String -from tvm.tir import Block, FloatImm, For, IntImm, PrimFunc, Buffer -from ..function import IndexMap +from tvm.tir import Block, Buffer, FloatImm, For, IntImm, PrimFunc +from ..function import IndexMap from . import _ffi_api +from ._type_checker import type_checked from .state import ScheduleState, StmtSRef, _parse_debug_mask, _parse_mod from .trace import Trace -from ._type_checker import type_checked @register_error @@ -685,6 +685,62 @@ def after_reorder(a: T.handle, b: T.handle) -> None: """ _ffi_api.ScheduleReorder(self, ordered_loops) # type: ignore # pylint: disable=no-member + @type_checked + def add_unit_loop(self, block_or_loop: Union[LoopRV, BlockRV]) -> LoopRV: + """Create a new unit loop on top of the specific block or loop. + + Parameters + ---------- + block_or_loop : Union[LoopRV, BlockRV] + The block above which the new loop is created + + Returns + ------- + new_loop : LoopRV + The new unit loop + + Examples + -------- + + Before add_unit_loop, in TensorIR, the IR is: + + .. code-block:: python + + @T.prim_func + def before_add_unit_loop( + A: T.Buffer[(), "int32"], + B: T.Buffer[(), "int32"], + C: T.Buffer[(), "int32"], + ) -> None: + with T.block("C"): + vi = T.axis.spatial(1, 0) + C[()] = A[()] + B[()] + + Create the schedule and do add-unit-loop: + + .. code-block:: python + + sch = tir.Schedule(before_add_unit_loop) + sch.add_unit_loop(sch.get_block("C")) + print(sch.mod["main"].script()) + + After applying add-unit-loop, the IR becomes: + + .. code-block:: python + + @T.prim_func + def after_add_unit_loop( + A: T.Buffer[(), "int32"], + B: T.Buffer[(), "int32"], + C: T.Buffer[(), "int32"], + ) -> None: + for u in T.serial(1): + with T.block("C"): + vi = T.axis.spatial(1, 0) + C[()] = A[()] + B[()] + """ + return _ffi_api.ScheduleAddUnitLoop(self, block_or_loop) # type: ignore # pylint: disable=no-member + ########## Schedule: Manipulate ForKind ########## @type_checked diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 590a0f002595..051bd4250625 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -453,6 +453,24 @@ void ConcreteScheduleNode::Reorder(const Array& ordered_loop_rvs) { this->state_->DebugVerify(); } +LoopRV ConcreteScheduleNode::AddUnitLoop(const BlockRV& block_rv) { + LoopRV result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = CreateRV(tir::AddUnitLoop(state_, GetSRef(block_rv))); + TVM_TIR_SCHEDULE_END("add-unit-loop", this->error_render_level_); + this->state_->DebugVerify(); + return result; +} + +LoopRV ConcreteScheduleNode::AddUnitLoop(const LoopRV& loop_rv) { + LoopRV result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = CreateRV(tir::AddUnitLoop(state_, GetSRef(loop_rv))); + TVM_TIR_SCHEDULE_END("add-unit-loop", this->error_render_level_); + this->state_->DebugVerify(); + return result; +} + /******** Schedule: Manipulate ForKind ********/ void ConcreteScheduleNode::Parallel(const LoopRV& loop_rv) { diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 70c0265611c3..11d68694a1fe 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -99,6 +99,8 @@ class ConcreteScheduleNode : public ScheduleNode { LoopRV Fuse(const Array& loop_rvs) override; Array Split(const LoopRV& loop_rv, const Array>& factors) override; void Reorder(const Array& ordered_loop_rvs) override; + LoopRV AddUnitLoop(const BlockRV& block_rv) override; + LoopRV AddUnitLoop(const LoopRV& loop_rv) override; /******** Schedule: Manipulate ForKind ********/ void Parallel(const LoopRV& loop_rv) override; void Vectorize(const LoopRV& loop_rv) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index f4dba69c6b15..af0f417e4cf5 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -186,6 +186,16 @@ TVM_DLL StmtSRef Fuse(ScheduleState self, const Array& loop_srefs); */ TVM_DLL void Reorder(ScheduleState self, const Array& ordered_loop_srefs); +/*! + * \brief Create a new unit loop on top of the specific block or loop. + * \param sref The block/loop above which the new thread_binding loop is created + * \param extent The extent of the new thread_binding loop + * \param thread_axis The thread axis of the new thread_binding loop + * \param attrs Extra loop attributes + * \return The new thread_binding loop + */ +TVM_DLL StmtSRef AddUnitLoop(ScheduleState self, StmtSRef sref); + /******** Schedule: Manipulate ForKind ********/ /*! * \brief Parallelize the input loop. It requires: diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index 5315b139f0f6..66e29518ca5e 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -698,6 +698,43 @@ void Reorder(ScheduleState self, const Array& ordered_loop_srefs) { self->Replace(GetRef(top), new_loop, {}); } +StmtSRef AddUnitLoop(ScheduleState self, StmtSRef sref) { + if (sref->stmt->IsInstance()) { + For new_loop(Var("u", DataType::Int(32)), 0, 1, ForKind::kSerial, GetRef(sref->stmt)); + self->Replace(sref, new_loop, {}); + return self->stmt2ref.at(new_loop.get()); + } + class NewLoopCreator : public StmtMutator { + public: + explicit NewLoopCreator(const StmtNode* src_block) : src_block_(src_block) {} + + Stmt VisitStmt_(const BlockRealizeNode* realize) final { + if (realize->block.get() == src_block_) { + new_loop_ = + For(Var("u", DataType::Int(32)), 0, 1, ForKind::kSerial, GetRef(realize)); + return new_loop_; + } + return StmtMutator::VisitStmt_(realize); + } + + const StmtNode* src_block_; + For new_loop_{nullptr}; + }; + + CHECK(sref->parent != nullptr) << "ValueError: Cannot add loops on top of the root block"; + StmtSRef parent_sref = GetRef(sref->parent); + NewLoopCreator creator(sref->stmt); + Stmt new_stmt = creator(GetRef(parent_sref->stmt)); + if (new_stmt->IsInstance()) { + self->Replace(parent_sref, std::move(new_stmt), {}); + } else { + Block old_parent_block = GetRef(parent_sref->StmtAs()); + Block new_parent_block = Downcast(new_stmt); + self->Replace(parent_sref, new_stmt, {{old_parent_block, new_parent_block}}); + } + return self->stmt2ref.at(creator.new_loop_.get()); +} + /******** InstructionKind Registration ********/ struct SplitTraits : public UnpackedInstTraits { @@ -800,9 +837,41 @@ struct ReorderTraits : public UnpackedInstTraits { friend struct ::tvm::tir::UnpackedInstTraits; }; +struct AddUnitLoopTraits : public UnpackedInstTraits { + static constexpr const char* kName = "AddUnitLoop"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static LoopRV UnpackedApplyToSchedule(Schedule sch, ObjectRef rv) { + if (const auto* block = rv.as()) { + return sch->AddUnitLoop(GetRef(block)); + } else if (const auto* loop = rv.as()) { + return sch->AddUnitLoop(GetRef(loop)); + } else { + LOG(FATAL) << "TypeError: AddUnitLoop expects a loop or block"; + throw; + } + } + + static String UnpackedAsPython(Array outputs, String rv) { + PythonAPICall py("add_unit_loop"); + py.Input("block_or_loop", rv); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + TVM_REGISTER_INST_KIND_TRAITS(SplitTraits); TVM_REGISTER_INST_KIND_TRAITS(FuseTraits); TVM_REGISTER_INST_KIND_TRAITS(ReorderTraits); +TVM_REGISTER_INST_KIND_TRAITS(AddUnitLoopTraits); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 3880d0b19eeb..372d94a15025 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -153,6 +153,18 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method(&Sche TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method(&ScheduleNode::Split); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReorder") .set_body_method(&ScheduleNode::Reorder); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAddUnitLoop") + .set_body_typed([](Schedule self, ObjectRef rv) -> LoopRV { + if (const auto* loop_rv = rv.as()) { + return self->AddUnitLoop(GetRef(loop_rv)); + } else if (const auto* block_rv = rv.as()) { + return self->AddUnitLoop(GetRef(block_rv)); + } else { + LOG(FATAL) << "TypeError: Cannot evaluate the random variable of type: " << rv->GetTypeKey() + << ". Its value is: " << rv; + throw; + } + }); /******** (FFI) Manipulate ForKind ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleParallel") .set_body_method(&ScheduleNode::Parallel); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index d2f627edfd11..95a10e26ac2f 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -198,6 +198,28 @@ void TracedScheduleNode::Reorder(const Array& ordered_loop_rvs) { /*outputs=*/{})); } +LoopRV TracedScheduleNode::AddUnitLoop(const BlockRV& block_rv) { + LoopRV result = ConcreteScheduleNode::AddUnitLoop(block_rv); + + static const InstructionKind& kind = InstructionKind::Get("AddUnitLoop"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{}, + /*outputs=*/{result})); + return result; +} + +LoopRV TracedScheduleNode::AddUnitLoop(const LoopRV& loop_rv) { + LoopRV result = ConcreteScheduleNode::AddUnitLoop(loop_rv); + + static const InstructionKind& kind = InstructionKind::Get("AddUnitLoop"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{loop_rv}, + /*attrs=*/{}, + /*outputs=*/{result})); + return result; +} + /******** Schedule: Manipulate ForKind ********/ void TracedScheduleNode::Parallel(const LoopRV& loop_rv) { diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index ba4a4b99cbb2..25bf3d4871ae 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -63,6 +63,8 @@ class TracedScheduleNode : public ConcreteScheduleNode { LoopRV Fuse(const Array& loop_rvs) final; Array Split(const LoopRV& loop_rv, const Array>& factor_rvs) final; void Reorder(const Array& ordered_loop_rvs) final; + LoopRV AddUnitLoop(const BlockRV& block_rv) final; + LoopRV AddUnitLoop(const LoopRV& loop_rv) final; /******** Schedule: Manipulate ForKind ********/ void Parallel(const LoopRV& loop_rv) final; void Vectorize(const LoopRV& loop_rv) final; diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py b/tests/python/unittest/test_tir_schedule_split_fuse.py index 16eef57c4748..d70748bc8a03 100644 --- a/tests/python/unittest/test_tir_schedule_split_fuse.py +++ b/tests/python/unittest/test_tir_schedule_split_fuse.py @@ -524,5 +524,63 @@ def test_fuse_not_affine(): verify_trace_roundtrip(sch=sch, mod=elementwise_not_affine) +def test_add_unit_loop_above_block(): + @T.prim_func + def zero_dim( + A: T.Buffer[(), "int32"], + B: T.Buffer[(), "int32"], + C: T.Buffer[(), "int32"], + ) -> None: + with T.block("C"): + vi = T.axis.spatial(1, 0) + C[()] = A[()] + B[()] + + @T.prim_func + def zero_dim_added( + A: T.Buffer[(), "int32"], + B: T.Buffer[(), "int32"], + C: T.Buffer[(), "int32"], + ) -> None: + for u in range(1): + with T.block("C"): + vi = T.axis.spatial(1, 0) + C[()] = A[()] + B[()] + + sch = tir.Schedule(zero_dim, debug_mask="all") + block = sch.get_block("C") + sch.add_unit_loop(block) + tvm.ir.assert_structural_equal(zero_dim_added, sch.mod["main"]) + + +def test_add_unit_loop_above_loop(): + @T.prim_func + def zero_dim( + A: T.Buffer[(), "int32"], + B: T.Buffer[(), "int32"], + C: T.Buffer[(), "int32"], + ) -> None: + for u in range(1): + with T.block("C"): + vi = T.axis.spatial(1, 0) + C[()] = A[()] + B[()] + + @T.prim_func + def zero_dim_added( + A: T.Buffer[(), "int32"], + B: T.Buffer[(), "int32"], + C: T.Buffer[(), "int32"], + ) -> None: + for u1, u2 in T.grid(1, 1): + with T.block("C"): + vi = T.axis.spatial(1, 0) + C[()] = A[()] + B[()] + + sch = tir.Schedule(zero_dim, debug_mask="all") + block = sch.get_block("C") + (loop,) = sch.get_loops(block) + sch.add_unit_loop(loop) + tvm.ir.assert_structural_equal(zero_dim_added, sch.mod["main"]) + + if __name__ == "__main__": tvm.testing.main()