Skip to content

Commit

Permalink
[TensorIR][M2a] Reduction Factoring (RFactor) (#8544)
Browse files Browse the repository at this point in the history
Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Siyuan Feng <[email protected]>
Co-authored-by: Bohan Hou <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
Co-authored-by: Wuwei Lin <[email protected]>
  • Loading branch information
6 people authored Jul 31, 2021
1 parent c8a892b commit 5012462
Show file tree
Hide file tree
Showing 29 changed files with 2,656 additions and 165 deletions.
22 changes: 10 additions & 12 deletions include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,22 +96,20 @@ TVM_DLL Array<Var> UndefinedVars(const PrimExpr& expr);
TVM_DLL CallEffectKind SideEffect(const PrimExpr& expr);

/*!
* \brief Whether e expression used any var in variable set.
* \param expr The expression to be checked.
* \param vset_contains The check function to see if var is in the vset.
* \return Whether e uses vset.
* \brief Whether the given Stmt uses any var in the given variable set.
* \param stmt The Stmt to be checked.
* \param vset_contains The check function to see if a var is in the variable set.
* \return Whether `stmt` uses any var in the given variable set.
*/
TVM_DLL bool ExprUseVar(const PrimExpr& expr, std::function<bool(const VarNode*)> vset_contains);
TVM_DLL bool UsesVar(const Stmt& stmt, std::function<bool(const VarNode*)> vset_contains);

/*!
* \brief Whether e expression used var.
* \param expr The expression to be checked.
* \param var The variable.
* \return Whether e uses v.
* \brief Whether the given PrimExpr uses any var in the given variable set.
* \param expr The PrimExpr to be checked.
* \param vset_contains The check function to see if var is in the variable set.
* \return Whether `expr` uses any var in the given variable set.
*/
inline bool ExprUseVar(const PrimExpr& expr, const Var& var) {
return ExprUseVar(expr, [&](const VarNode* node) { return var.get() == node; });
}
TVM_DLL bool UsesVar(const PrimExpr& expr, std::function<bool(const VarNode*)> vset_contains);

/*!
* \brief Verifies whether the IR stmt or Expr is in SSA form.
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/tir/schedule/block_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ class BlockScope : public ObjectRef {
* \param child_block_srefs The srefs to the leaf blocks
* \note We assume the leaf blocks are given in pre-DFS order
*/
TVM_DLL BlockScope(const Array<StmtSRef>& child_block_srefs);
TVM_DLL explicit BlockScope(const Array<StmtSRef>& child_block_srefs);

TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockScope, ObjectRef, BlockScopeNode);
};
Expand Down
18 changes: 18 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,24 @@ class ScheduleNode : public runtime::Object {
/******** Schedule: loop binding/annotation ********/
/******** Schedule: cache read/write ********/
/******** Schedule: reduction ********/
/*!
* \brief Factorize an associative reduction block by the specified loop.
* \details An associative reduction cannot be parallelized directly,
* because it leads to potential race condition during accumulation.
* Alternatively, the reduction could be factorized on a loop with the following steps:
* - Step 1: evenly slice the reduction into `n` separate chunks, where `n` is the loop extent
* - Step 2: compute the chunks separately and write the result into `n` intermediate buffers;
* - Step 3: accumulate the `n` separate buffer into the result buffer.
* Note that the Step 2 above introduces opportunities for parallelization.
* RFactor is a schedule primitive that implements the transformation described above.
* \param loop_rv The loop outside block we want to do rfactor
* \param factor_axis The position where the new dimension is placed in the new introduced rfactor
* buffer. Suppose the original reduction block writes to buffer `B` with
* ndim(B) dimensions, then `factor_axis` should be in range `[-ndim(B) - 1,
* ndim(B)]`, and the negative index will be normalized to a non-negative one
* \return The rfactor block
*/
virtual BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) = 0;
/******** Schedule: blockize & tensorize ********/
};

Expand Down
19 changes: 19 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,7 @@ class For : public Stmt {
Map<String, ObjectRef> annotations = Map<String, ObjectRef>(), Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ForNode);
};

/*!
Expand Down Expand Up @@ -1359,6 +1360,24 @@ TVM_DLL PrimExpr TypeAnnotation(DataType dtype, Span span = Span());
// overload printing of for type.
TVM_DLL std::ostream& operator<<(std::ostream& os, ForKind kind);

// inline implementations
inline const char* ForKind2String(ForKind t) {
switch (t) {
case ForKind::kSerial:
return "serial";
case ForKind::kParallel:
return "parallel";
case ForKind::kVectorized:
return "vectorized";
case ForKind::kUnrolled:
return "unroll";
case ForKind::kThreadBinding:
return "thread_binding";
}
LOG(FATAL) << "Unknown ForKind" << t;
return "Unknown";
}

} // namespace tir
} // namespace tvm
#endif // TVM_TIR_STMT_H_
2 changes: 1 addition & 1 deletion python/tvm/script/special_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def alloc_buffer(
data=None,
strides=None,
elem_offset=None,
scope="",
scope="global",
align=-1,
offset_factor=0,
buffer_type="default",
Expand Down
147 changes: 145 additions & 2 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def before_inline(a: ty.handle, c: ty.handle) -> None:
.. code-block:: python
sch = tir.Schedule(before_inline, debug_mode=True)
sch = tir.Schedule(before_inline)
sch.compute_inline(sch.get_block("B"))
print(tvm.script.asscript(sch.mod["main"]))
Expand Down Expand Up @@ -491,7 +491,7 @@ def before_inline(a: ty.handle, c: ty.handle) -> None:
.. code-block:: python
sch = tir.Schedule(before_inline, debug_mode=True)
sch = tir.Schedule(before_inline)
sch.reverse_compute_inline(sch.get_block("C"))
print(tvm.script.asscript(sch.mod["main"]))
Expand All @@ -512,6 +512,149 @@ def after_inline(a: ty.handle, c: ty.handle) -> None:
########## Schedule: loop binding/annotation ##########
########## Schedule: cache read/write ##########
########## Schedule: reduction ##########
def rfactor(self, loop: LoopRV, factor_axis: int) -> LoopRV:
"""Factorize an associative reduction block by the specified loop.
An associative reduction cannot be parallelized directly,
because it leads to potential race condition during accumulation.
Alternatively, the reduction could be factorized on a loop with the following steps:
- Step 1: evenly slice the reduction into `n` separate chunks, where `n` is the loop extent
- Step 2: compute the chunks separately and write the result into `n` intermediate buffers;
- Step 3: accumulate the `n` separate buffer into the result buffer.
Note that the Step 2 above introduces opportunities for parallelization.
RFactor is a schedule primitive that implements the transformation described above:
Given a block that writes to buffer `B`, it factorizes a loop of extent `n`.
For example, the pesudocode below accumulates `B[i] = sum(A[i, : , : ])`:
.. code-block:: python
for i in range(128): # loop i is a data parallel loop
for j in range(128): # loop j is a reduction loop
for k in range(128): # loop k is a reduction loop
B[i] = B[i] + A[i, j, k]
Suppose RFactor is applied on the innermost loop `k` and `factor_axis = 1`.
RFactor then creates an intermediate buffer and two blocks.
1. The intermediate buffer, or "rf-buffer" is a buffer of rank `ndim(B) + 1` and
size `size(B) * n`, whose shape expands from `shape(B)` by adding an axis of `n`
at the position specified by `factor_axis`. For example,
* shape(B) = [1, 2, 3], factor_axis = 0 => shape(B_rf) = [n, 1, 2, 3]
* shape(B) = [1, 2, 3], factor_axis = 1 => shape(B_rf) = [1, n, 2, 3]
* shape(B) = [1, 2, 3], factor_axis = 2 => shape(B_rf) = [1, 2, n, 3]
* shape(B) = [1, 2, 3], factor_axis = 3 => shape(B_rf) = [1, 2, 3, n]
2. The rfactor block, or "rf-block", is a block that writes to the `rf-buffer` without
accumulating over the loop `k`, i.e. the loop `k` is converted from a reduction loop
to a data parallel loop. In our example, the rf-block is:
.. code-block:: python
B_rf = np.zeros((128, 128)) # the rf-buffer
for k in range(128): # loop k is converted to a data parallel loop
for i in range(128): # loop i is a data parallel loop (unchanged)
for j in range(128): # loop j is a reduction loop (unchanged)
B_rf[i, k] = B_rf[i, k] + A[i, j, k]
3. The write-back block, or `wb-block`, is a block that accumulates the rf-buffer into
the result buffer. All the reduction loops are removed except the loop `k` for accumulation.
In our example, the wb-block is:
.. code-block:: python
for i in range(128): # loop i is a data parallel loop (unchanged)
# loop j is removed because it is a reduction loop
for k in range(128): # loop k is a reduction loop (unchanged)
B[i] = B[i] + B_rf[i, k]
Parameters
----------
loop : LoopRV
The loop outside block for which we want to do rfactor
factor_axis : int
The position where the new dimension is placed in the new introduced rfactor buffer
Returns
-------
rf_block : BlockRV
The block which computes partial results over each slices (i.e., the first block
as described in the above illustration)
Examples
--------
Before rfactor, in TensorIR, the IR is:
.. code-block:: python
@tvm.script.tir
def before_rfactor(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128, 128))
B = tir.match_buffer(b, (128,))
with tir.block([128, tir.reduce_axis(0, 128),
tir.reduce_axis(0, 128)], "B") as [vii, vi, vj]:
with tir.init():
B[vii] = 0.0
B[vii] = B[vii] + A[vii, vi, vj]
Create the schedule and do rfactor:
.. code-block:: python
sch = tir.Schedule(before_rfactor)
_, _, k = sch.get_loops(sch.get_block("B"))
sch.rfactor(k, 0)
print(tvm.script.asscript(sch.mod["main"]))
After applying rfactor, the IR becomes:
.. code-block:: python
@tvm.script.tir
def after_rfactor(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, [128, 128, 128])
B = tir.match_buffer(b, [128])
B_rf = tir.alloc_buffer([128, 128])
with tir.block([128, 128, tir.reduce_axis(0, 128)], "B_rf") as [vi2, vii, vi]:
with tir.init():
B_rf[vi2, vii] = 0.0
B_rf[vi2, vii] = (B_rf[vi2, vii] + A[vii, vi, vi2])
with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vii_1, vi2_1]:
with tir.init():
B[vii_1] = 0.0
B[vii_1] = (B[vii_1] + B_rf[vi2_1, vii_1])
Note
----
Rfactor requires:
1) `loop` has only one child block, and it is a reduction block;
2) `loop` is a reduction loop, i.e. the loop variable is bound to only reduction variables
in the block binding;
3) `loop` is not parallelized, vectorized, unrolled or bound to any thread axis;
4) The block scope that `loop` is in is a staged-pipeline;
5) The outermost loop outside the reduction block should has the reduction block as its
first child block;
6) The outermost reduction loop should have only one child block;
7) An unary extent loop that is not bound to any reduction or data parallel variables in
the block binding should not appear under some reduction loop;
8) The reduction block should write to only one buffer, and its init and body are both
simple `BufferStore`s, and the pattern is registered as an associative reducer.
The pre-defined patterns include: plus, multiplication, min and max;
9) Each of the loops on top of the block cannot be bound to a data parallel and a
reduction block binding at the same time;
10) `factor_axis` should be in range `[-ndim(B) - 1, ndim(B)]`,
where `B` is the buffer that the reduction block writes to.
Negative indexing is normalized according to numpy convention.
"""
return _ffi_api_schedule.ScheduleRFactor(self, loop, factor_axis) # type: ignore # pylint: disable=no-member

########## Schedule: blockize & tensorize ##########


Expand Down
6 changes: 4 additions & 2 deletions src/arith/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1137,8 +1137,10 @@ PrimExpr CanonicalSimplifier::Impl::SimplifyReduceCombiner(const ReduceNode* op)
// and recursively mark the corresponding components
for (size_t i = 0; i < simplified_result.size(); ++i)
if (!used[i]) {
if (ExprUseVar(simplified_result[idx], op->combiner->lhs[i]) ||
ExprUseVar(simplified_result[idx], op->combiner->rhs[i]))
if (UsesVar(simplified_result[idx],
[v = op->combiner->lhs[i].get()](const VarNode* var) { return var == v; }) ||
UsesVar(simplified_result[idx],
[v = op->combiner->rhs[i].get()](const VarNode* var) { return var == v; }))
mark_used(i);
}
};
Expand Down
4 changes: 2 additions & 2 deletions src/arith/detect_linear_equation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class LinearEqDetector : public ExprFunctor<LinearEqEntry(const PrimExpr&, const
}
LinearEqEntry VisitExprDefault_(const Object* op, const PrimExpr& e) final {
if (fail_) return LinearEqEntry();
if (ExprUseVar(e, var_)) {
if (UsesVar(e, [this](const VarNode* var) { return var == var_.get(); })) {
fail_ = true;
return LinearEqEntry();
} else {
Expand Down Expand Up @@ -159,7 +159,7 @@ Array<PrimExpr> DetectLinearEquation(const PrimExpr& e, const Array<Var>& vars)
for (size_t i = vars.size(); i > 1; --i) {
vset.insert(vars[i - 1].get());
// The previous coeff contains the variable
if (ExprUseVar(coeff[i - 2], vset_contains)) {
if (UsesVar(coeff[i - 2], vset_contains)) {
return Array<PrimExpr>();
}
}
Expand Down
17 changes: 0 additions & 17 deletions src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -487,23 +487,6 @@ Doc TIRTextPrinter::VisitStmt_(const EvaluateNode* op) {
return doc;
}

inline const char* ForKind2String(ForKind t) {
switch (t) {
case ForKind::kSerial:
return "serial";
case ForKind::kParallel:
return "parallel";
case ForKind::kVectorized:
return "vectorized";
case ForKind::kUnrolled:
return "unroll";
case ForKind::kThreadBinding:
return "thread_binding";
}
LOG(FATAL) << "Unknown ForKind";
return "Unknown";
}

Doc TIRTextPrinter::VisitStmt_(const ForNode* op) {
Doc doc;
doc << "for (" << Print(op->loop_var) << ", " << Print(op->min) << ", "
Expand Down
17 changes: 0 additions & 17 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -704,23 +704,6 @@ Doc TVMScriptPrinter::VisitStmt_(const EvaluateNode* op) {
return doc;
}

inline const char* ForKind2String(ForKind t) {
switch (t) {
case ForKind::kSerial:
return "serial";
case ForKind::kParallel:
return "parallel";
case ForKind::kVectorized:
return "vectorized";
case ForKind::kUnrolled:
return "unroll";
case ForKind::kThreadBinding:
return "thread_binding";
}
LOG(FATAL) << "Unknown ForKind";
return "Unknown";
}

Doc TVMScriptPrinter::VisitStmt_(const ForNode* op) {
Doc doc;
var_not_in_headers.insert(op->loop_var.get());
Expand Down
4 changes: 2 additions & 2 deletions src/te/autodiff/ad_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,7 @@ std::pair<PrimExpr, PrimExpr> ImplicationNotContainingVars(
return {pair_a.first || pair_b.first, (pair_a.first || pair_b.second) &&
(pair_b.first || pair_a.second) &&
(pair_a.second || pair_b.second)};
} else if (!tir::ExprUseVar(cond, [&vars](const VarNode* var) { return vars.count(var); })) {
} else if (!tir::UsesVar(cond, [&vars](const VarNode* var) { return vars.count(var); })) {
return {cond, const_true()};
} else {
return {const_true(), cond};
Expand Down Expand Up @@ -1014,7 +1014,7 @@ PrimExpr TrySimplifyCompute(const PrimExpr& expr, const PrimExpr& cond,
// Keep only those variables of the new vars which are used in the new_expr
Array<Var> used_res_variables;
for (const Var& var : res->dst->variables) {
if (ExprUseVar(new_expr, var)) {
if (tir::UsesVar(new_expr, [&var](const VarNode* var_) { return var_ == var.get(); })) {
ICHECK(res->dst->ranges.count(var)) << "Range of " << var << " cannot be inferred.";
used_res_variables.push_back(var);
}
Expand Down
2 changes: 1 addition & 1 deletion src/te/operation/compute_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ Stmt TransformUpdate(const Stage& stage, const std::unordered_map<IterVar, Range
auto fbanned = [&](const VarNode* node) { return banned.count(node); };

for (const PrimExpr& pred : n.main_predicates) {
if (tir::ExprUseVar(pred, fbanned)) {
if (tir::UsesVar(pred, fbanned)) {
LOG(FATAL) << "Tensorize update transform failed, the condition " << pred
<< " has a conflict with the reset condition";
}
Expand Down
4 changes: 2 additions & 2 deletions src/te/operation/tensorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,13 @@ void VerifyTensorizeLoopNest(const ComputeOpNode* self, const Stage& stage,
auto fbanned = [&](const VarNode* node) { return banned.count(node); };

for (const PrimExpr& pred : n.main_predicates) {
if (tir::ExprUseVar(pred, fbanned)) {
if (tir::UsesVar(pred, fbanned)) {
LOG(FATAL) << "Tensorize failed, split condition " << pred
<< " relies on var defined inside tensorize scope";
}
}
for (const PrimExpr& pred : n.init_predicates) {
if (tir::ExprUseVar(pred, fbanned)) {
if (tir::UsesVar(pred, fbanned)) {
LOG(FATAL) << "Tensorize failed, split condition " << pred
<< " relies on var defined inside tensorize scope";
}
Expand Down
Loading

0 comments on commit 5012462

Please sign in to comment.