Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[M2a][TensorIR] Reduction Factoring #380

Merged
merged 1 commit into from
Jul 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,22 +96,20 @@ TVM_DLL Array<Var> UndefinedVars(const PrimExpr& expr);
TVM_DLL CallEffectKind SideEffect(const PrimExpr& expr);

/*!
* \brief Whether e expression used any var in variable set.
* \param expr The expression to be checked.
* \param vset_contains The check function to see if var is in the vset.
* \return Whether e uses vset.
* \brief Whether the given Stmt uses any var in the given variable set.
* \param stmt The Stmt to be checked.
* \param vset_contains The check function to see if a var is in the variable set.
* \return Whether `stmt` uses any var in the given variable set.
*/
TVM_DLL bool ExprUseVar(const PrimExpr& expr, std::function<bool(const VarNode*)> vset_contains);
TVM_DLL bool UsesVar(const Stmt& stmt, std::function<bool(const VarNode*)> vset_contains);

/*!
* \brief Whether e expression used var.
* \param expr The expression to be checked.
* \param var The variable.
* \return Whether e uses v.
* \brief Whether the given PrimExpr uses any var in the given variable set.
* \param expr The PrimExpr to be checked.
* \param vset_contains The check function to see if var is in the variable set.
* \return Whether `expr` uses any var in the given variable set.
*/
inline bool ExprUseVar(const PrimExpr& expr, const Var& var) {
return ExprUseVar(expr, [&](const VarNode* node) { return var.get() == node; });
}
TVM_DLL bool UsesVar(const PrimExpr& expr, std::function<bool(const VarNode*)> vset_contains);

/*!
* \brief Verifies whether the IR stmt or Expr is in SSA form.
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/tir/schedule/block_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,14 +255,14 @@ class BlockScopeNode : public Object {
class BlockScope : public ObjectRef {
public:
/*! \brief The constructor creating an empty block scope with on dependency information */
TVM_DLL BlockScope();
TVM_DLL explicit BlockScope();
/*!
* \brief Create the object with the specific leaf blocks, and compute the dependency information
* between the leaf blocks.
* \param child_block_srefs The srefs to the leaf blocks
* \note We assume the leaf blocks are given in pre-DFS order
*/
TVM_DLL BlockScope(const Array<StmtSRef>& child_block_srefs);
TVM_DLL explicit BlockScope(const Array<StmtSRef>& child_block_srefs);

TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockScope, ObjectRef, BlockScopeNode);
};
Expand Down
11 changes: 11 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,17 @@ class ScheduleNode : public runtime::Object {
/******** Schedule: loop binding/annotation ********/
/******** Schedule: cache read/write ********/
/******** Schedule: reduction ********/
/*!
* \brief Factor a reduction block by the specified loop
* \details See python/tvm/tir/schedule/schedule.py
* \param loop_rv The loop outside block we want to do rfactor
* \param factor_axis The position where the new dimension is placed in the new introduced rfactor
* buffer. Suppose the original reduction block writes to buffer `B` with
* ndim(B) dimensions, then `factor_axis` should be in range `[-ndim(B) - 1,
* ndim(B)]`, and the negative index will be normalized to a non-negative one
* \return The rfactor block
*/
virtual BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) = 0;
/******** Schedule: blockize & tensorize ********/
};

Expand Down
18 changes: 18 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,7 @@ class For : public Stmt {
Map<String, ObjectRef> annotations = Map<String, ObjectRef>(), Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ForNode);
};

/*!
Expand Down Expand Up @@ -1361,6 +1362,23 @@ TVM_DLL PrimExpr TypeAnnotation(DataType dtype, Span span = Span());
// overload printing of for type.
TVM_DLL std::ostream& operator<<(std::ostream& os, ForKind kind);

// inline implementations
inline const char* IterVarType2String(ForKind t) {
switch (t) {
case ForKind::kSerial:
return "Serial";
case ForKind::kParallel:
return "Parallel";
case ForKind::kVectorized:
return "Vectorized";
case ForKind::kUnrolled:
return "Unrolled";
case ForKind::kThreadBinding:
return "ThreadBinding";
}
return "Unknown";
}

} // namespace tir
} // namespace tvm
#endif // TVM_TIR_STMT_H_
2 changes: 1 addition & 1 deletion python/tvm/script/special_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def alloc_buffer(
data=None,
strides=None,
elem_offset=None,
scope="",
scope="global",
align=-1,
offset_factor=0,
buffer_type="default",
Expand Down
148 changes: 146 additions & 2 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def before_inline(a: ty.handle, c: ty.handle) -> None:

.. code-block:: python

sch = tir.Schedule(before_inline, debug_mode=True)
sch = tir.Schedule(before_inline)
sch.compute_inline(sch.get_block("B"))
print(tvm.script.asscript(sch.mod["main"]))

Expand Down Expand Up @@ -491,7 +491,7 @@ def before_inline(a: ty.handle, c: ty.handle) -> None:

.. code-block:: python

sch = tir.Schedule(before_inline, debug_mode=True)
sch = tir.Schedule(before_inline)
sch.reverse_compute_inline(sch.get_block("C"))
print(tvm.script.asscript(sch.mod["main"]))

Expand All @@ -512,6 +512,150 @@ def after_inline(a: ty.handle, c: ty.handle) -> None:
########## Schedule: loop binding/annotation ##########
########## Schedule: cache read/write ##########
########## Schedule: reduction ##########
def rfactor(self, loop: LoopRV, factor_axis: int) -> LoopRV:
"""Factorize an associative reduction block by the specified loop.

An associative reduction cannot be parallelized directly,
because it leads to potential race condition during accumulation.
Alternatively, the reduction could be factorized on a loop with the following steps:
- Step 1: evenly slice the reduction into `n` separate chunks, where `n` is the loop extent
- Step 2: compute the chunks separately and write the result into `n` intermediate buffers;
- Step 3: accumulate the `n` separate buffer into the result buffer.
Note that the Step 2 above introduces opportunities for parallelization.

RFactor is a schedule primitive that implements the transformation described above:
Given a block that writes to buffer `B`, it factorizes a loop of extent `n`.

For example, the pesudocode below accumulates `B[i] = sum(A[i, : , : ])`:


.. code-block:: python

for i in range(128): # loop i is a data parallel loop
for j in range(128): # loop j is a reduction loop
for k in range(128): # loop j is a reduction loop
B[i] = B[i] + A[i, j, k]


Suppose RFactor is applied on the innermost loop `k` and `factor_axis = 1`.
RFactor then creates an intermediate buffer and two blocks.

- The intermediate buffer, or "rf-buffer" is a buffer of rank `ndim(B) + 1` and
size `size(B) * n`, whose shape expands from `shape(B)` by adding an axis of `n`
at the position specified by `factor_axis`. For example,

* shape(B) = [1, 2, 3], factor_axis = 0 => shape(B_rf) = [n, 1, 2, 3]
* shape(B) = [1, 2, 3], factor_axis = 1 => shape(B_rf) = [1, n, 2, 3]
* shape(B) = [1, 2, 3], factor_axis = 2 => shape(B_rf) = [1, 2, n, 3]
* shape(B) = [1, 2, 3], factor_axis = 3 => shape(B_rf) = [1, 2, 3, n]

- The rfactor block, or "rf-block", is a block that writes to the `rf-buffer` without
accumulating over the loop `k`, i.e. the loop `k` is converted from a reduction loop
to a data parallel loop. In our example, the rf-block is:


.. code-block:: python

B_rf = np.zeros((128, 128)) # the rf-buffer
for k in range(128): # loop k is converted to a data parallel loop
for i in range(128): # loop i is a data parallel loop (unchanged)
for j in range(128): # loop j is a reduction loop (unchanged)
B_rf[i, k] = B_rf[i, k] + A[i, j, k]


- The write-back block, or `wb-block`, is a block that accumulates the rf-buffer into
the result buffer. All the reduction loops are removed except the loop `k` for accumulation.
In our example, the wb-block is:

.. code-block:: python

for i in range(128): # loop i is a data parallel loop (unchanged)
# loop j is removed because it is a reduction loop
for k in range(128): # loop k is a reduction loop (unchanged)
B[i] = B[i] + B_rf[i, k]

Parameters
----------
loop : LoopRV
The loop outside block for which we want to do rfactor
factor_axis : int
MasterJH5574 marked this conversation as resolved.
Show resolved Hide resolved
The position where the new dimension is placed in the new introduced rfactor buffer

Returns
-------
rf_block : BlockRV
The block which computes partial results over each slices (i.e., the first block
as described in the above illustration)

Examples
--------

Before rfactor, in TensorIR, the IR is:

.. code-block:: python

@tvm.script.tir
def before_rfactor(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128, 128), "float32")
B = tir.match_buffer(b, (128,), "float32")
with tir.block([128, tir.reduce_axis(0, 128),
tir.reduce_axis(0, 128)], "B") as [vii, vi, vj]:
with tir.init():
B[vii] = 0.0
B[vii] = B[vii] + A[vii, vi, vj]

Create the schedule and do rfactor:

.. code-block:: python

sch = tir.Schedule(before_rfactor)
_, _, k = sch.get_loops(sch.get_block("B"))
sch.rfactor(k, 0)
print(tvm.script.asscript(sch.mod["main"]))

After applying rfactor, the IR becomes:

.. code-block:: python

@tvm.script.tir
def after_rfactor(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, [128, 128, 128])
B = tir.match_buffer(b, [128])
B_rf = tir.alloc_buffer([128, 128])
with tir.block([128, 128, tir.reduce_axis(0, 128)], "B_rf") as [vi2, vii, vi]:
with tir.init():
B_rf[vi2, vii] = 0.0
B_rf[vi2, vii] = (B_rf[vi2, vii] + A[vii, vi, vi2])
with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vii_1, vi2_1]:
with tir.init():
B[vii_1] = 0.0
B[vii_1] = (B[vii_1] + B_rf[vi2_1, vii_1])


Note
----

Rfactor requires:
1) `loop` has only one child block, and it is a reduction block;
2) `loop` is a reduction loop, i.e. the loop variable is bound to only reduction variables
in the block binding;
3) `loop` is not parallelized, vectorized, unrolled or bound to any thread axis;
4) The block scope that `loop` is in is a staged-pipeline;
5) The outermost loop outside the reduction block should has the reduction block as its first child block;
6) The outermost reduction loop should have only one child block;
7) An unary extent loop that is not bound to any reduction or data parallel variables in the block binding
should not appear under some reduction loop;
8) The reduction block should write to only one buffer, and its init and body block only is
a simple `BufferStore`, and the pattern is registered as associative reducer.
The pre-defined patterns include: plus, multiplication, min and max;
9) Each of the loops on top of the block cannot be bound to a data parallel and a reduction
block binding at the same time;
10) `factor_axis` should be in range `[-ndim(B) - 1, ndim(B)]`,
where `B` is the buffer that the reduction block writes to.
Negative indexing is normalized according to numpy convention.
"""
return _ffi_api_schedule.ScheduleRFactor(self, loop, factor_axis) # pylint: disable=no-member

########## Schedule: blockize & tensorize ##########


Expand Down
6 changes: 4 additions & 2 deletions src/arith/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1137,8 +1137,10 @@ PrimExpr CanonicalSimplifier::Impl::SimplifyReduceCombiner(const ReduceNode* op)
// and recursively mark the corresponding components
for (size_t i = 0; i < simplified_result.size(); ++i)
if (!used[i]) {
if (ExprUseVar(simplified_result[idx], op->combiner->lhs[i]) ||
ExprUseVar(simplified_result[idx], op->combiner->rhs[i]))
if (UsesVar(simplified_result[idx],
[v = op->combiner->lhs[i].get()](const VarNode* var) { return var == v; }) ||
UsesVar(simplified_result[idx],
[v = op->combiner->rhs[i].get()](const VarNode* var) { return var == v; }))
mark_used(i);
}
};
Expand Down
4 changes: 2 additions & 2 deletions src/arith/detect_linear_equation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class LinearEqDetector : public ExprFunctor<LinearEqEntry(const PrimExpr&, const
}
LinearEqEntry VisitExprDefault_(const Object* op, const PrimExpr& e) final {
if (fail_) return LinearEqEntry();
if (ExprUseVar(e, var_)) {
if (UsesVar(e, [this](const VarNode* var) { return var == var_.get(); })) {
fail_ = true;
return LinearEqEntry();
} else {
Expand Down Expand Up @@ -159,7 +159,7 @@ Array<PrimExpr> DetectLinearEquation(const PrimExpr& e, const Array<Var>& vars)
for (size_t i = vars.size(); i > 1; --i) {
vset.insert(vars[i - 1].get());
// The previous coeff contains the variable
if (ExprUseVar(coeff[i - 2], vset_contains)) {
if (UsesVar(coeff[i - 2], vset_contains)) {
return Array<PrimExpr>();
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/te/autodiff/ad_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,7 @@ std::pair<PrimExpr, PrimExpr> ImplicationNotContainingVars(
return {pair_a.first || pair_b.first, (pair_a.first || pair_b.second) &&
(pair_b.first || pair_a.second) &&
(pair_a.second || pair_b.second)};
} else if (!tir::ExprUseVar(cond, [&vars](const VarNode* var) { return vars.count(var); })) {
} else if (!tir::UsesVar(cond, [&vars](const VarNode* var) { return vars.count(var); })) {
return {cond, const_true()};
} else {
return {const_true(), cond};
Expand Down Expand Up @@ -1014,7 +1014,7 @@ PrimExpr TrySimplifyCompute(const PrimExpr& expr, const PrimExpr& cond,
// Keep only those variables of the new vars which are used in the new_expr
Array<Var> used_res_variables;
for (const Var& var : res->dst->variables) {
if (ExprUseVar(new_expr, var)) {
if (tir::UsesVar(new_expr, [&var](const VarNode* var_) { return var_ == var.get(); })) {
ICHECK(res->dst->ranges.count(var)) << "Range of " << var << " cannot be inferred.";
used_res_variables.push_back(var);
}
Expand Down
2 changes: 1 addition & 1 deletion src/te/operation/compute_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ Stmt TransformUpdate(const Stage& stage, const std::unordered_map<IterVar, Range
auto fbanned = [&](const VarNode* node) { return banned.count(node); };

for (const PrimExpr& pred : n.main_predicates) {
if (tir::ExprUseVar(pred, fbanned)) {
if (tir::UsesVar(pred, fbanned)) {
LOG(FATAL) << "Tensorize update transform failed, the condition " << pred
<< " has a conflict with the reset condition";
}
Expand Down
4 changes: 2 additions & 2 deletions src/te/operation/tensorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,13 @@ void VerifyTensorizeLoopNest(const ComputeOpNode* self, const Stage& stage,
auto fbanned = [&](const VarNode* node) { return banned.count(node); };

for (const PrimExpr& pred : n.main_predicates) {
if (tir::ExprUseVar(pred, fbanned)) {
if (tir::UsesVar(pred, fbanned)) {
LOG(FATAL) << "Tensorize failed, split condition " << pred
<< " relies on var defined inside tensorize scope";
}
}
for (const PrimExpr& pred : n.init_predicates) {
if (tir::ExprUseVar(pred, fbanned)) {
if (tir::UsesVar(pred, fbanned)) {
LOG(FATAL) << "Tensorize failed, split condition " << pred
<< " relies on var defined inside tensorize scope";
}
Expand Down
Loading