Skip to content
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: apache/tvm
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: main
Choose a base ref
...
head repository: jinhongyii/tvm
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: main
Choose a head ref
Can’t automatically merge. Don’t worry, you can still create the pull request.
  • 16 commits
  • 11 files changed
  • 3 contributors

Commits on Jul 19, 2021

  1. Fuse&split (#408)

    * first commit
    
    * fix cpplint
    
    * fix
    
    * remove redundant blank
    
    * address comments
    
    * lint
    
    * address comments
    
    * address comments
    
    * address comments
    
    * change fuse
    
    * change split
    
    * polish
    
    * lint
    
    * fix rebase
    
    * fix bug and add tests
    
    * clang format
    
    * address comments
    
    * format
    
    * address comments
    
    * address comments
    
    * add symbolic test
    
    * lint
    
    * address comment
    
    * check stage pipeline
    
    * fix mypy
    
    * check stage_pipeline
    
    * Revert "check stage_pipeline"
    
    This reverts commit a5a7f4fe
    
    * add stage_pipeline_assert
    
    Co-authored-by: jinhongyi <[email protected]>
    jinhongyii and jinhongyi committed Jul 19, 2021

    Verified

    This commit was signed with the committer’s verified signature.
    eason9487 Eason
    Copy the full SHA
    375c8b1 View commit details
  2. address comments

    jinhongyii committed Jul 19, 2021
    Copy the full SHA
    36527c1 View commit details
  3. address comments

    jinhongyii committed Jul 19, 2021
    Copy the full SHA
    0bd76fb View commit details
  4. address comments

    jinhongyii committed Jul 19, 2021
    Copy the full SHA
    9673c00 View commit details
  5. address comments

    jinhongyii committed Jul 19, 2021
    Copy the full SHA
    8b4bee8 View commit details
  6. fix

    jinhongyii committed Jul 19, 2021
    Copy the full SHA
    ce0bd7a View commit details
  7. fix

    jinhongyii committed Jul 19, 2021
    Copy the full SHA
    cf9a729 View commit details
  8. fix

    jinhongyii committed Jul 19, 2021
    Copy the full SHA
    a097b88 View commit details
  9. fix

    jinhongyii committed Jul 19, 2021
    Copy the full SHA
    76e7443 View commit details
  10. fix

    jinhongyii committed Jul 19, 2021
    Copy the full SHA
    87c98e2 View commit details
  11. address comment

    jinhongyii committed Jul 19, 2021
    Copy the full SHA
    ac7afc7 View commit details
  12. fix

    jinhongyii committed Jul 19, 2021
    Copy the full SHA
    5d45b90 View commit details
  13. fix

    jinhongyii committed Jul 19, 2021
    Copy the full SHA
    c6167ca View commit details
  14. address comments

    jinhongyii committed Jul 19, 2021
    Copy the full SHA
    09c6a41 View commit details

Commits on Jul 20, 2021

  1. address comments

    jinhongyii committed Jul 20, 2021
    Copy the full SHA
    c603478 View commit details

Commits on Jul 21, 2021

  1. retrigger ci

    junrushao committed Jul 21, 2021
    Copy the full SHA
    5e442f9 View commit details
12 changes: 12 additions & 0 deletions include/tvm/arith/iter_affine_map.h
Original file line number Diff line number Diff line change
@@ -282,6 +282,18 @@ class IterSumExpr : public IterMapExpr {
Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
const PrimExpr& predicate, bool require_bijective,
arith::Analyzer* analyzer);
/*!
* \brief Use IterVarMap detector to rewrite and simplify the indices
*
* \param indices The indices to detect pattern for.
* \param input_iters Map from variable to iterator's range.
* \param input_pred The predicate constraints on the input iterators
* \param require_bijective A boolean flag that indicates whether the mapping should be bijective.
*
* \return The indices after rewrite
*/
Array<PrimExpr> IterMapSimplify(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
const PrimExpr& input_pred, bool require_bijective);

/*!
* \brief Apply the inverse of the affine transformation to the outputs.
19 changes: 19 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
@@ -196,6 +196,25 @@ class ScheduleNode : public runtime::Object {
*/
virtual Array<LoopRV> GetLoops(const BlockRV& block_rv) = 0;
/******** Schedule: loops manipulation ********/
/*!
* \brief Fuse a list of consecutive loops into one. It requires:
* 1) The loops can't have annotations or thread bindings.
* 2) The (i+1)-th loop must be the only child of the i-th loop.
* 3) All loops must start with 0.
* \param loop_rvs The loops to be fused
* \return The new loop after fusion
*/
virtual LoopRV Fuse(const Array<LoopRV>& loop_rvs) = 0;
/*!
* \brief Split a loop into a list of consecutive loops. It requires:
* 1) The loop can't have annotation or thread binding.
* 2) The loop must start with 0.
* \param loop_rv The loop to be split
* \param factors The tiling factors, and at most one of which is -1, which means that
* factor is inferred.
* \return The new loops after split
*/
virtual Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors) = 0;
/******** Schedule: compute location ********/
/*!
* \brief Inline a block into its consumer(s). It requires:
138 changes: 136 additions & 2 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@
# under the License.
# pylint: disable=unused-import
"""The TensorIR schedule class"""
from typing import List, Optional, Union
from typing import List, Optional, Union, Tuple

from tvm._ffi import register_object as _register_object
from tvm.error import TVMError, register_error
@@ -43,7 +43,10 @@ class BlockRV(Object):
"""A random variable that refers to a block"""


ExprRV = PrimExpr # A random variable that evaluates to an integer
# It is a workaround for mypy: https://github.com/python/mypy/issues/7866#issuecomment-549454370
# This feature is not supported until python 3.10:
# https://docs.python.org/3.10/whatsnew/3.10.html#pep-613-typealias
ExprRV = Union[PrimExpr] # A random variable that evaluates to an integer

RAND_VAR_TYPE = Union[ExprRV, BlockRV, LoopRV] # type: ignore # pylint: disable=invalid-name

@@ -257,6 +260,137 @@ def get_loops(self, block: BlockRV) -> List[LoopRV]:
return _ffi_api_schedule.ScheduleGetLoops(self, block) # type: ignore # pylint: disable=no-member

########## Schedule: loops manipulation ##########
def fuse(self, *loops: List[LoopRV]) -> LoopRV:
"""Fuse a list of consecutive loops into one. It requires:
1) The loops can't have annotations or thread bindings.
2) The (i+1)-th loop must be the only child of the i-th loop.
3) All loops must start with 0.
Parameters
----------
*loops : List[LoopRV]
The loops to be fused
Returns
----------
fused_loop : LoopRV
The new loop after fusion
Examples
--------
Before applying fuse, in TensorIR, the IR is:
.. code-block:: python
@tvm.script.tir
def before_fuse(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
for i, j in tir.grid(128, 128):
with tir.block([128, 128], "B") as [vi, vj]:
B[vi, vj] = A[vi, vj] * 2.0
Create the schedule and do fuse:
.. code-block:: python
sch = tir.Schedule(before_fuse)
i, j = sch.get_loops(sch.get_block("B"))
sch.fuse(i, j)
print(tvm.script.asscript(sch.mod["main"]))
After applying fuse, the IR becomes:
.. code-block:: python
@tvm.script.tir
def after_fuse(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
# the 2 loops are fused into 1
for i_j_fused in tir.serial(0, 16384):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, tir.floordiv(i_j_fused, 128))
tir.bind(vj, tir.floormod(i_j_fused, 128))
B[vi, vj] = A[vi, vj] * 2.0
"""
return _ffi_api_schedule.ScheduleFuse(self, loops) # type: ignore # pylint: disable=no-member

def split(
self,
loop: LoopRV,
factors: List[Union[ExprRV, None]],
) -> List[LoopRV]:
"""Split a loop into a list of consecutive loops. It requires:
1) The loop can't have annotation or thread binding.
2) The loop must start with 0.
Predicates may be added to ensure the total loop numbers keeps unchanged.
In `factors`, at most one of the factors can be None,
which will be automatically inferred.
Parameters
----------
loop : LoopRV
The loop to be split
factors: List[Union[ExprRV, None]]
The splitting factors
Potential inputs are:
- None
- ExprRV
- Nonnegative constant integers
Returns
----------
split_loops : List[LoopRV]
The new loops after split
Examples
--------
Before split, in TensorIR, the IR is:
.. code-block:: python
@tvm.script.tir
def before_split(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
for i, j in tir.grid(128, 128):
with tir.block([128, 128], "B") as [vi, vj]:
B[vi, vj] = A[vi, vj] * 2.0
Create the schedule and do fuse:
.. code-block:: python
sch = tir.Schedule(before_split)
i, j = sch.get_loops(sch.get_block("B"))
sch.split(i, factors=[2, 64])
print(tvm.script.asscript(sch.mod["main"]))
After applying split, the IR becomes:
.. code-block:: python
@tvm.script.tir
def after_split(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
# the original loop is split into 2 loops
for i0, i1, j in tir.grid(2, 64, 128):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, ((i0*64) + i1))
tir.bind(vj, j)
B[vi, vj] = A[vi, vj] * 2.0
"""
# it will be checked later in C++ implementation
# that there is at most one None in `factors`
return _ffi_api_schedule.ScheduleSplit(self, loop, factors) # type: ignore # pylint: disable=no-member

########## Schedule: compute location ##########
def compute_inline(self, block: BlockRV) -> None:
"""Inline a block into its consumer(s). It requires:
15 changes: 15 additions & 0 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
@@ -1085,6 +1085,21 @@ TVM_REGISTER_GLOBAL("arith.NormalizeIterMapToExpr").set_body_typed([](const Iter
return NormalizeIterMapToExpr(expr);
});

Array<PrimExpr> IterMapSimplify(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
const PrimExpr& input_pred, bool require_bijective) {
Analyzer analyzer;
Array<IterSumExpr> rewrite =
DetectIterMap(indices, input_iters, input_pred, require_bijective, &analyzer);
if (rewrite.empty()) {
return indices;
}
Array<PrimExpr> res;
res.reserve(rewrite.size());
IterMapToExprNormalizer converter(&analyzer);
for (const auto& expr : rewrite) res.push_back(converter.Convert(expr));
return res;
}

/*!
* \brief Divider to divide the bindings into two sets of bindings(outer and inner)
* such that binding_i = Y_i * E(Xi) + Xi, where E(X) is the extent of X.
4 changes: 4 additions & 0 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
@@ -799,6 +799,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
TVM_TRY_REWRITE_IF(floordiv(x + c1, c2), floordiv(x, c2) + floordiv(c1, c2),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);

TVM_TRY_REWRITE_IF(floordiv(x * c1, x * c2), floordiv(c1, c2), c2.Eval()->value > 0);

TVM_TRY_REWRITE_IF(floordiv(x + y, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0));

TVM_TRY_REWRITE_IF(floordiv(y + x, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0));
@@ -882,6 +884,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x, c2),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);

TVM_TRY_REWRITE_IF(floormod(x * c1, x * c2), x * floormod(c1, c2), c2.Eval()->value != 0);

TVM_TRY_REWRITE(floormod(x * y, y), ZeroWithTypeLike(x));
TVM_TRY_REWRITE(floormod(y * x, y), ZeroWithTypeLike(y));

87 changes: 87 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
@@ -258,6 +258,93 @@ Array<LoopRV> ConcreteScheduleNode::GetLoops(const BlockRV& block_rv) {
}

/******** Schedule: loops manipulation ********/

LoopRV ConcreteScheduleNode::Fuse(const Array<LoopRV>& loop_rvs) {
CHECK(!loop_rvs.empty()) << "ValueError: 'fuse' requires at least 1 loop(s)";
Array<StmtSRef> loop_srefs = this->GetSRefs(loop_rvs);
StmtSRef result{nullptr};
TVM_TIR_SCHEDULE_BEGIN();
result = tir::Fuse(state_, loop_srefs);
TVM_TIR_SCHEDULE_END("fuse", this->error_render_level_);
this->state_->DebugVerify();
return CreateRV<LoopRV>(result);
}

Array<LoopRV> ConcreteScheduleNode::Split(const LoopRV& loop_rv,
const Array<Optional<ExprRV>>& factor_rvs) {
class NotSingleInferFactorError : public ScheduleError {
public:
explicit NotSingleInferFactorError(IRModule mod) : mod_(mod) {}

String FastErrorString() const final {
return "ScheduleError: only one factor can be specified as -1 or none";
}

String DetailRenderTemplate() const final {
return "Only one factor can be specified as -1 or none";
}

IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {}; }

IRModule mod_;
};

class WrongFactorProductError : public ScheduleError {
public:
explicit WrongFactorProductError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {}

String FastErrorString() const final {
return "ScheduleError: The product of factors is not larger than or equal to the extent of "
"loop";
}

String DetailRenderTemplate() const final {
return "The product of factors is not larger than or equal to the extent of loop {0}";
}

IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {loop_}; }

IRModule mod_;
For loop_;
};
// Prepare for the splitting
StmtSRef loop_sref = this->GetSRef(loop_rv);
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
Array<PrimExpr> factors;
factors.reserve(factor_rvs.size());
int infer_index = -1;
PrimExpr tot_length = 1;
Array<StmtSRef> results;
TVM_TIR_SCHEDULE_BEGIN();
// infer factor if needed and check validity of factors
for (size_t i = 0; i < factor_rvs.size(); i++) {
if (!factor_rvs[i].defined()) {
factors.push_back(Integer(-1));
if (infer_index == -1) {
infer_index = i;
} else {
throw NotSingleInferFactorError(state_->mod);
}
} else {
PrimExpr factor = this->Get(factor_rvs[i].value());
factors.push_back(factor);
tot_length *= factor;
}
}
if (infer_index != -1) {
factors.Set(infer_index,
this->analyzer_->Simplify(floordiv(loop->extent + tot_length - 1, tot_length)));
} else if (!this->analyzer_->CanProve(tot_length >= loop->extent)) {
throw WrongFactorProductError(state_->mod, GetRef<For>(loop));
}
results = tir::Split(state_, loop_sref, factors);
TVM_TIR_SCHEDULE_END("split", this->error_render_level_);
this->state_->DebugVerify();
return CreateRV<LoopRV>(results);
}

/******** Schedule: compute location ********/

void ConcreteScheduleNode::ComputeInline(const BlockRV& block_rv) {
43 changes: 32 additions & 11 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
@@ -68,6 +68,8 @@ class ConcreteScheduleNode : public ScheduleNode {
inline PrimExpr Get(const ExprRV& expr_rv) const final;
inline StmtSRef GetSRef(const BlockRV& block_rv) const final;
inline StmtSRef GetSRef(const LoopRV& loop_rv) const final;
inline Array<StmtSRef> GetSRefs(const Array<BlockRV>& rvs) const;
inline Array<StmtSRef> GetSRefs(const Array<LoopRV>& rvs) const;
void RemoveRV(const BlockRV& block_rv) final { RemoveFromSymbolTable(block_rv); }
void RemoveRV(const LoopRV& loop_rv) final { RemoveFromSymbolTable(loop_rv); }
void RemoveRV(const ExprRV& expr_rv) final { RemoveFromSymbolTable(expr_rv); }
@@ -78,6 +80,8 @@ class ConcreteScheduleNode : public ScheduleNode {
BlockRV GetBlock(const String& name, const String& func_name = "main") override;
Array<LoopRV> GetLoops(const BlockRV& block_rv) override;
/******** Schedule: loops manipulation ********/
LoopRV Fuse(const Array<LoopRV>& loop_rvs) override;
Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors) override;
/******** Schedule: compute location ********/
void ComputeInline(const BlockRV& block) override;
void ReverseComputeInline(const BlockRV& block) override;
@@ -143,17 +147,16 @@ inline For ConcreteScheduleNode::Get(const LoopRV& loop_rv) const {
}

inline PrimExpr ConcreteScheduleNode::Get(const ExprRV& expr_rv) const {
auto it = this->symbol_table_.find(expr_rv);
if (it == this->symbol_table_.end()) {
LOG(FATAL) << "IndexError: Cannot find corresponding ExprRV: " << expr_rv;
}
const ObjectRef& obj = (*it).second;
const auto* expr_node = obj.as<PrimExprNode>();
if (expr_node == nullptr) {
LOG(FATAL) << "ValueError: ExprRV's corresponding type is invalid: "
<< (obj.defined() ? obj->GetTypeKey() : "None");
}
return GetRef<PrimExpr>(expr_node);
PrimExpr transformed = Substitute(expr_rv, [this](const Var& var) -> Optional<PrimExpr> {
auto it = this->symbol_table_.find(var);
if (it == this->symbol_table_.end()) {
LOG(FATAL) << "IndexError: Cannot find corresponding ExprRV: " << var;
}
const ObjectRef& obj = (*it).second;
const auto* int_imm = TVM_TYPE_AS(int_imm, obj, IntImmNode);
return Integer(int_imm->value);
});
return this->analyzer_->Simplify(transformed);
}

inline StmtSRef ConcreteScheduleNode::GetSRef(const BlockRV& block_rv) const {
@@ -198,6 +201,24 @@ inline StmtSRef ConcreteScheduleNode::GetSRef(const LoopRV& loop_rv) const {
return GetRef<StmtSRef>(sref);
}

template <class T>
inline Array<StmtSRef> GetSRefsHelper(const ConcreteScheduleNode* sch, const Array<T>& rvs) {
Array<StmtSRef> result;
result.reserve(rvs.size());
for (const T& rv : rvs) {
result.push_back(sch->GetSRef(rv));
}
return result;
}

inline Array<StmtSRef> ConcreteScheduleNode::GetSRefs(const Array<BlockRV>& rvs) const {
return GetSRefsHelper(this, rvs);
}

inline Array<StmtSRef> ConcreteScheduleNode::GetSRefs(const Array<LoopRV>& rvs) const {
return GetSRefsHelper(this, rvs);
}

/******** Adding/Removing elements in the symbol table ********/

template <class T>
22 changes: 21 additions & 1 deletion src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
@@ -25,7 +25,27 @@ namespace tvm {
namespace tir {

/******** Schedule: loops manipulation ********/

/*!
* Split a loop into a list of consecutive loops. It requires:
* 1) The loop can't have annotation or thread binding.
* 2) The loop must start with 0.
* \param self The state of the schedule
* \param loop_sref The sref to the loop being split
* \param factors The splitting factors
* \return An array of srefs to the loops after splitting
*/
TVM_DLL Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
const Array<PrimExpr>& factors);
/*!
* \brief Fuse a list of consecutive loops into one. It requires:
* 1) The loops can't have annotations or thread bindings.
* 2) The inner loop must be the only child of the outer loop.
* 3) All loops must start with 0.
* \param self The state of the schedule
* \param loop_srefs An array of srefs to the loops to be fused
* \return The sref to the fused loop
*/
TVM_DLL StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs);
/******** Schedule: compute location ********/
/*!
* \brief Inline a block into its consumer(s). It requires:
389 changes: 389 additions & 0 deletions src/tir/schedule/primitive/loop_transformation.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,389 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "../utils.h"

namespace tvm {
namespace tir {

/*! \brief Append a new predicate to the each child of type BlockRealize (not recursively) */
class BlockPredicateAppender : public StmtMutator {
public:
/*!
* \brief Constructor
* \param to_append The predicate to be appended to BlockRealizeNode
*/
explicit BlockPredicateAppender(const PrimExpr& to_append) : to_append_(to_append) {}

private:
// For each direct child of type BlockRealizeNode, append the predicate
Stmt VisitStmt_(const BlockRealizeNode* realize) final {
// We do not recursively do this
ObjectPtr<BlockRealizeNode> n = CopyOnWrite(realize);
n->predicate = n->predicate && to_append_;
return BlockRealize(n);
}

/*! \brief The predicate to be appended */
const PrimExpr& to_append_;
};

/*! \brief Substitute vars and collect the reuse mapping of opaque blocks */
class SubstituteVarAndCollectOpaqueBlock : public StmtExprMutator {
public:
explicit SubstituteVarAndCollectOpaqueBlock(std::function<Optional<PrimExpr>(const Var&)> vmap,
Map<Block, Block>* opaque_blocks)
: vmap_(vmap), opaque_blocks_(opaque_blocks) {}

private:
PrimExpr VisitExpr_(const VarNode* op) final {
Var var = GetRef<Var>(op);
if (Optional<PrimExpr> ret = vmap_(var)) {
return ret.value();
} else {
return std::move(var);
}
}

Stmt VisitStmt_(const BlockRealizeNode* op) final {
BlockRealize realize = Downcast<BlockRealize>(StmtMutator::VisitStmt_(op));
if (realize->block->iter_vars.empty()) {
opaque_blocks_->Set(op->block, realize->block);
}
return std::move(realize);
}

/*! \brief The substitute function */
std::function<Optional<PrimExpr>(const Var&)> vmap_;
/*! \brief The reuse mapping of opaque blocks */
Map<Block, Block>* opaque_blocks_;
};

/*! \brief Simplify the binding of block realize and update the opaque block reuse mapping */
class IterMapSimplifyBlockBinding : public StmtExprMutator {
public:
explicit IterMapSimplifyBlockBinding(MapNode* opaque_blocks, Map<Var, Range> loop_var2extent)
: opaque_blocks_(opaque_blocks), loop_var2extent_(loop_var2extent) {}

static For SimplifyBindings(Stmt stmt, const Array<StmtSRef>& loop_srefs,
MapNode* opaque_blocks) {
Map<Var, Range> loop_var2extent;
for (const StmtSRef& sref : loop_srefs) {
const ForNode* loop = TVM_SREF_TO_FOR(loop, sref);
loop_var2extent.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
}
return Downcast<For>(
IterMapSimplifyBlockBinding(opaque_blocks, std::move(loop_var2extent))(std::move(stmt)));
}

private:
Stmt VisitStmt_(const ForNode* op) final {
loop_var2extent_.Set(op->loop_var, Range::FromMinExtent(op->min, op->extent));
Stmt res = StmtMutator::VisitStmt_(op);
loop_var2extent_.erase(op->loop_var);
return res;
}

Stmt VisitStmt_(const BlockRealizeNode* op) final {
// skip opaque block and update mapping
if (op->iter_values.empty()) {
Block block = op->block;
BlockRealize realize = Downcast<BlockRealize>(StmtMutator::VisitStmt_(op));
for (const std::pair<ObjectRef, ObjectRef>& entry : *opaque_blocks_) {
if (entry.second.same_as(block)) {
opaque_blocks_->at(entry.first) = realize->block;
break;
}
}
return std::move(realize);
}
Array<PrimExpr> v = arith::IterMapSimplify(/*indices=*/op->iter_values,
/*input_iters=*/loop_var2extent_,
/*input_pred=*/op->predicate,
/*require_bijective=*/false);
if (v.same_as(op->iter_values)) {
return GetRef<Stmt>(op);
} else {
ObjectPtr<BlockRealizeNode> n = CopyOnWrite(op);
n->iter_values = std::move(v);
return Stmt(n);
}
}

/*! \brief The reuse mapping */
MapNode* opaque_blocks_;
/*! \brief The range of loops */
Map<Var, Range> loop_var2extent_;
};

class HasAnnotationOrThreadBindingError : public ScheduleError {
public:
explicit HasAnnotationOrThreadBindingError(IRModule mod, For loop)
: mod_(mod), loop_(std::move(loop)) {}

String FastErrorString() const final {
return "ScheduleError: The primitive can't be applied because the loop has annotation or "
"thread binding";
}

String DetailRenderTemplate() const final {
return "The primitive can't be applied because the loop {0} has annotation or thread binding";
}

IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {loop_}; }

IRModule mod_;
For loop_;
};

class OuterNotInnerParent : public ScheduleError {
public:
explicit OuterNotInnerParent(IRModule mod, For outer, For inner)
: mod_(mod), outer_(std::move(outer)), inner_(std::move(inner)) {}

String FastErrorString() const final {
return "ScheduleError: The outer loop is not the parent of the inner loop";
}

String DetailRenderTemplate() const final {
return "The loops can't be fused because the outer loop {0} is not the parent of the inner "
"loop {1}";
}

IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {outer_, inner_}; }

IRModule mod_;
For outer_;
For inner_;
};

class NotOnlyChildError : public ScheduleError {
public:
explicit NotOnlyChildError(IRModule mod, For outer, For inner)
: mod_(mod), outer_(std::move(outer)), inner_(std::move(inner)) {}

String FastErrorString() const final {
return "ScheduleError: The inner loop is not the only child of outer loop";
}

String DetailRenderTemplate() const final {
return "The loops can't be fused because the inner loop {1} is not the only child of outer "
"loop {0}.";
}

IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {outer_, inner_}; }

IRModule mod_;
For outer_;
For inner_;
};

class LoopNotStartWithZeroError : public ScheduleError {
public:
explicit LoopNotStartWithZeroError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {}

String FastErrorString() const final {
return "ScheduleError: The primitive only supports loop starting with 0";
}

String DetailRenderTemplate() const final {
return "The loop {0} does not start with 0, which is not supported";
}

IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {loop_}; }

IRModule mod_;
For loop_;
};

class NotSingleInferFactorError : public ScheduleError {
public:
explicit NotSingleInferFactorError(IRModule mod) : mod_(mod) {}

String FastErrorString() const final {
return "ScheduleError: only one factor can be specified as -1 or none";
}

String DetailRenderTemplate() const final {
return "Only one factor can be specified as -1 or none";
}

IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {}; }

IRModule mod_;
};

class WrongFactorProductError : public ScheduleError {
public:
explicit WrongFactorProductError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {}

String FastErrorString() const final {
return "ScheduleError: The product of factors is not larger than or equal to the extent of "
"loop";
}

String DetailRenderTemplate() const final {
return "The product of factors is not larger than or equal to the extent of loop {0}";
}

IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {loop_}; }

IRModule mod_;
For loop_;
};

Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
const Array<PrimExpr>& factors) {
// Invariance
// - The total repeat number has not changed for each direct child block with updating predicate.
// - The execution order has not changed. (The block executes with the same args and the same
// order with before.
// Step 1. Check correctness
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
if (!loop->annotations.empty() || loop->thread_binding.defined()) {
throw HasAnnotationOrThreadBindingError(self->mod, GetRef<For>(loop));
}
// Currently, loops not starting with 0 are not supported
arith::Analyzer analyzer;
if (!analyzer.CanProve(loop->min == 0)) {
throw LoopNotStartWithZeroError(self->mod, GetRef<For>(loop));
}
// Step 2. Replace all occurrences of the original loop var with new variables
int n = factors.size();
PrimExpr substitute_value = 0;
std::vector<Var> new_loop_vars;
new_loop_vars.reserve(n);
for (int i = 0; i < n; i++) {
const PrimExpr& factor = factors[i];
Var var = loop->loop_var.copy_with_suffix("_" + std::to_string(i));
substitute_value = substitute_value * factor + var;
analyzer.Bind(var, Range::FromMinExtent(0, factor));
new_loop_vars.emplace_back(std::move(var));
}
Map<Block, Block> opaque_block_reuse;
Stmt new_stmt = loop->body;
new_stmt = SubstituteVarAndCollectOpaqueBlock(
[&](const Var& v) -> Optional<PrimExpr> {
if (v.same_as(loop->loop_var)) {
return substitute_value;
} else {
return NullOpt;
}
},
&opaque_block_reuse)(std::move(new_stmt));
// Step 3. Update predicate to guard the loop
PrimExpr predicate = substitute_value < loop->extent;
if (!analyzer.CanProve(predicate)) {
new_stmt = BlockPredicateAppender(/*predicate=*/predicate)(std::move(new_stmt));
}
// Step 4. Generate nested loops to replace the original loop and simplify the binding
for (int i = n - 1; i >= 0; i--) {
new_stmt = For(new_loop_vars[i], 0, factors[i], ForKind::kSerial, new_stmt);
}
new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings(std::move(new_stmt), GetLoops(loop_sref),
opaque_block_reuse.CopyOnWrite());
self->Replace(loop_sref, new_stmt, opaque_block_reuse);
Array<StmtSRef> result_srefs;
result_srefs.reserve(n);
for (int i = 0; i < n; i++) {
result_srefs.push_back(self->stmt2ref.at(new_stmt.get()));
const ForNode* outer_loop = TVM_TYPE_AS(outer_loop, new_stmt, ForNode);
new_stmt = outer_loop->body;
}
return result_srefs;
}

StmtSRef Fuse(ScheduleState self, const Array<StmtSRef>& loop_srefs) {
// Invariance
// - The total repeat number has not changed for each direct child block.
// - The execution order has not changed. (The block executes with the same
// args and the same order with before.)
std::vector<const ForNode*> loops;
loops.reserve(loop_srefs.size());
StmtSRef outer_loop_sref{nullptr};
const ForNode* outer_loop = nullptr;
arith::Analyzer analyzer;
// Step 1. check correctness
for (const StmtSRef& sref : loop_srefs) {
const ForNode* loop = TVM_SREF_TO_FOR(loop, sref);
if (!loop->annotations.empty() || loop->thread_binding.defined()) {
throw HasAnnotationOrThreadBindingError(self->mod, GetRef<For>(loop));
}
if (outer_loop_sref.defined()) {
if (sref->parent != outer_loop_sref.get()) {
throw OuterNotInnerParent(self->mod, GetRef<For>(outer_loop), GetRef<For>(loop));
}
if (!outer_loop->body.same_as(GetRef<For>(loop))) {
throw NotOnlyChildError(self->mod, GetRef<For>(outer_loop), GetRef<For>(loop));
}
}
outer_loop_sref = sref;
outer_loop = loop;
if (!analyzer.CanProve(loop->min == 0)) {
throw LoopNotStartWithZeroError(self->mod, GetRef<For>(loop));
}
loops.push_back(loop);
}
// Step 2. Create fused loop var and replace the original loop vars
std::string suffix;
int n = loops.size();
for (int i = 1; i < n; i++) {
suffix += "_" + loops[i]->loop_var->name_hint;
}
suffix += "_fused";
Var fused_var = loops[0]->loop_var.copy_with_suffix(suffix);
Array<PrimExpr> substitute_value;
substitute_value.resize(loops.size());
PrimExpr tot = fused_var;
for (int i = static_cast<int>(loops.size()) - 1; i >= 0; i--) {
substitute_value.Set(i, floormod(tot, loops[i]->extent));
tot = floordiv(tot, loops[i]->extent);
}
Stmt new_stmt = loops.back()->body;
Map<Block, Block> opaque_block_reuse;
auto f_substitute = [&](const Var& v) -> Optional<PrimExpr> {
for (int i = 0; i < n; i++) {
if (v.same_as(loops[i]->loop_var)) {
return substitute_value[i];
}
}
return NullOpt;
};
new_stmt =
SubstituteVarAndCollectOpaqueBlock(f_substitute, &opaque_block_reuse)(std::move(new_stmt));
// Step 3. Generate a loop to replace the original loops
PrimExpr fused_extent = 1;
for (int i = 0; i < n; i++) {
fused_extent *= loops[i]->extent;
}
fused_extent = analyzer.Simplify(fused_extent);
new_stmt = For(fused_var, 0, fused_extent, ForKind::kSerial, new_stmt);
new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings(
std::move(new_stmt), GetLoops(loop_srefs[0]), opaque_block_reuse.CopyOnWrite());
self->Replace(loop_srefs[0], new_stmt, opaque_block_reuse);
return self->stmt2ref.at(new_stmt.get());
}

} // namespace tir
} // namespace tvm
2 changes: 2 additions & 0 deletions src/tir/schedule/schedule.cc
Original file line number Diff line number Diff line change
@@ -123,6 +123,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetBlock")
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetLoops")
.set_body_method<Schedule>(&ScheduleNode::GetLoops);
/******** (FFI) loops manipulation ********/
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method<Schedule>(&ScheduleNode::Fuse);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method<Schedule>(&ScheduleNode::Split);
/******** (FFI) compute location ********/
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeInline")
.set_body_method<Schedule>(&ScheduleNode::ComputeInline);
453 changes: 453 additions & 0 deletions tests/python/unittest/test_tir_schedule_split_fuse.py

Large diffs are not rendered by default.