Skip to content

Commit

Permalink
RFactor
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterJH5574 committed Jul 23, 2021
1 parent 587f42d commit e3ab7fd
Show file tree
Hide file tree
Showing 27 changed files with 2,639 additions and 134 deletions.
22 changes: 10 additions & 12 deletions include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,22 +96,20 @@ TVM_DLL Array<Var> UndefinedVars(const PrimExpr& expr);
TVM_DLL CallEffectKind SideEffect(const PrimExpr& expr);

/*!
* \brief Whether e expression used any var in variable set.
* \param expr The expression to be checked.
* \param vset_contains The check function to see if var is in the vset.
* \return Whether e uses vset.
* \brief Whether the given Stmt uses any var in the given variable set.
* \param stmt The Stmt to be checked.
* \param vset_contains The check function to see if a var is in the variable set.
* \return Whether `stmt` uses any var in the given variable set.
*/
TVM_DLL bool ExprUseVar(const PrimExpr& expr, std::function<bool(const VarNode*)> vset_contains);
TVM_DLL bool UsesVar(const Stmt& stmt, std::function<bool(const VarNode*)> vset_contains);

/*!
* \brief Whether e expression used var.
* \param expr The expression to be checked.
* \param var The variable.
* \return Whether e uses v.
* \brief Whether the given PrimExpr uses any var in the given variable set.
* \param expr The PrimExpr to be checked.
* \param vset_contains The check function to see if var is in the variable set.
* \return Whether `expr` uses any var in the given variable set.
*/
inline bool ExprUseVar(const PrimExpr& expr, const Var& var) {
return ExprUseVar(expr, [&](const VarNode* node) { return var.get() == node; });
}
TVM_DLL bool UsesVar(const PrimExpr& expr, std::function<bool(const VarNode*)> vset_contains);

/*!
* \brief Verifies whether the IR stmt or Expr is in SSA form.
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/tir/schedule/block_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,14 +255,14 @@ class BlockScopeNode : public Object {
class BlockScope : public ObjectRef {
public:
/*! \brief The constructor creating an empty block scope with on dependency information */
TVM_DLL BlockScope();
TVM_DLL explicit BlockScope();
/*!
* \brief Create the object with the specific leaf blocks, and compute the dependency information
* between the leaf blocks.
* \param child_block_srefs The srefs to the leaf blocks
* \note We assume the leaf blocks are given in pre-DFS order
*/
TVM_DLL BlockScope(const Array<StmtSRef>& child_block_srefs);
TVM_DLL explicit BlockScope(const Array<StmtSRef>& child_block_srefs);

TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockScope, ObjectRef, BlockScopeNode);
};
Expand Down
11 changes: 11 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,17 @@ class ScheduleNode : public runtime::Object {
/******** Schedule: loop binding/annotation ********/
/******** Schedule: cache read/write ********/
/******** Schedule: reduction ********/
/*!
* \brief Factor a reduction block by the specified loop
* \details See python/tvm/tir/schedule/schedule.py
* \param loop_rv The loop outside block we want to do rfactor
* \param factor_axis The position where the new dimension is placed in the new introduced rfactor
* buffer. Suppose the original reduction block writes to buffer `B` with
* ndim(B) dimensions, then `factor_axis` should be in range `[-ndim(B) - 1,
* ndim(B)]`, and the negative index will be normalized to a non-negative one
* \return The rfactor block
*/
virtual BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) = 0;
/******** Schedule: blockize & tensorize ********/
};

Expand Down
18 changes: 18 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,7 @@ class For : public Stmt {
Map<String, ObjectRef> annotations = Map<String, ObjectRef>(), Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ForNode);
};

/*!
Expand Down Expand Up @@ -1356,6 +1357,23 @@ TVM_DLL PrimExpr TypeAnnotation(DataType dtype, Span span = Span());
// overload printing of for type.
TVM_DLL std::ostream& operator<<(std::ostream& os, ForKind kind);

// inline implementations
inline const char* IterVarType2String(ForKind t) {
switch (t) {
case ForKind::kSerial:
return "Serial";
case ForKind::kParallel:
return "Parallel";
case ForKind::kVectorized:
return "Vectorized";
case ForKind::kUnrolled:
return "Unrolled";
case ForKind::kThreadBinding:
return "ThreadBinding";
}
return "Unknown";
}

} // namespace tir
} // namespace tvm
#endif // TVM_TIR_STMT_H_
148 changes: 146 additions & 2 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def before_inline(a: ty.handle, c: ty.handle) -> None:
.. code-block:: python
sch = tir.Schedule(before_inline, debug_mode=True)
sch = tir.Schedule(before_inline)
sch.compute_inline(sch.get_block("B"))
print(tvm.script.asscript(sch.mod["main"]))
Expand Down Expand Up @@ -484,7 +484,7 @@ def before_inline(a: ty.handle, c: ty.handle) -> None:
.. code-block:: python
sch = tir.Schedule(before_inline, debug_mode=True)
sch = tir.Schedule(before_inline)
sch.reverse_compute_inline(sch.get_block("C"))
print(tvm.script.asscript(sch.mod["main"]))
Expand All @@ -505,6 +505,150 @@ def after_inline(a: ty.handle, c: ty.handle) -> None:
########## Schedule: loop binding/annotation ##########
########## Schedule: cache read/write ##########
########## Schedule: reduction ##########
def rfactor(self, loop: LoopRV, factor_axis: int) -> LoopRV:
"""Factorize an associative reduction block by the specified loop.
An associative reduction cannot be parallelized directly,
because it leads to potential race condition during accumulation.
Alternatively, the reduction could be factorized on a loop with the following steps:
- Step 1: evenly slice the reduction into `n` separate chunks, where `n` is the loop extent
- Step 2: compute the chunks separately and write the result into `n` intermediate buffers;
- Step 3: accumulate the `n` separate buffer into the result buffer.
Note that the Step 2 above introduces opportunities for parallelization.
RFactor is a schedule primitive that implements the transformation described above:
Given a block that writes to buffer `B`, it factorizes a loop of extent `n`.
For example, the pesudocode below accumulates `B[i] = sum(A[i, : , : ])`:
.. code-block:: python
for i in range(128): # loop i is a data parallel loop
for j in range(128): # loop j is a reduction loop
for k in range(128): # loop j is a reduction loop
B[i] = B[i] + A[i, j, k]
Suppose RFactor is applied on the innermost loop `k` and `factor_axis = 1`.
RFactor then creates an intermediate buffer and two blocks.
- The intermediate buffer, or "rf-buffer" is a buffer of rank `ndim(B) + 1` and
size `size(B) * n`, whose shape expands from `shape(B)` by adding an axis of `n`
at the position specified by `factor_axis`. For example,
* shape(B) = [1, 2, 3], factor_axis = 0 => shape(B_rf) = [n, 1, 2, 3]
* shape(B) = [1, 2, 3], factor_axis = 1 => shape(B_rf) = [1, n, 2, 3]
* shape(B) = [1, 2, 3], factor_axis = 2 => shape(B_rf) = [1, 2, n, 3]
* shape(B) = [1, 2, 3], factor_axis = 3 => shape(B_rf) = [1, 2, 3, n]
- The rfactor block, or "rf-block", is a block that writes to the `rf-buffer` without
accumulating over the loop `k`, i.e. the loop `k` is converted from a reduction loop
to a data parallel loop. In our example, the rf-block is:
.. code-block:: python
B_rf = np.zeros((128, 128)) # the rf-buffer
for k in range(128): # loop k is converted to a data parallel loop
for i in range(128): # loop i is a data parallel loop (unchanged)
for j in range(128): # loop j is a reduction loop (unchanged)
B_rf[i, k] = B_rf[i, k] + A[i, j, k]
- The write-back block, or `wb-block`, is a block that accumulates the rf-buffer into
the result buffer. All the reduction loops are removed except the loop `k` for accumulation.
In our example, the wb-block is:
.. code-block:: python
for i in range(128): # loop i is a data parallel loop (unchanged)
# loop j is removed because it is a reduction loop
for k in range(128): # loop k is a reduction loop (unchanged)
B[i] = B[i] + B_rf[i, k]
Parameters
----------
loop : LoopRV
The loop outside block for which we want to do rfactor
factor_axis : int
The position where the new dimension is placed in the new introduced rfactor buffer
Returns
-------
rf_block : BlockRV
The block which computes partial results over each slices (i.e., the first block
as described in the above illustration)
Examples
--------
Before rfactor, in TensorIR, the IR is:
.. code-block:: python
@tvm.script.tir
def before_rfactor(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128, 128), "float32")
B = tir.match_buffer(b, (128,), "float32")
with tir.block([128, tir.reduce_axis(0, 128),
tir.reduce_axis(0, 128)], "B") as [vii, vi, vj]:
with tir.init():
B[vii] = 0.0
B[vii] = B[vii] + A[vii, vi, vj]
Create the schedule and do rfactor:
.. code-block:: python
sch = tir.Schedule(before_rfactor)
_, _, k = sch.get_loops(sch.get_block("B"))
sch.rfactor(k, 0)
print(tvm.script.asscript(sch.mod["main"]))
After applying rfactor, the IR becomes:
.. code-block:: python
@tvm.script.tir
def after_rfactor(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, [128, 128, 128])
B = tir.match_buffer(b, [128])
B_rf = tir.alloc_buffer([128, 128])
with tir.block([128, 128, tir.reduce_axis(0, 128)], "B_rf") as [vi2, vii, vi]:
with tir.init():
B_rf[vi2, vii] = 0.0
B_rf[vi2, vii] = (B_rf[vi2, vii] + A[vii, vi, vi2])
with tir.block([128, tir.reduce_axis(0, 128)], "B") as [vii_1, vi2_1]:
with tir.init():
B[vii_1] = 0.0
B[vii_1] = (B[vii_1] + B_rf[vi2_1, vii_1])
Note
----
Rfactor requires:
1) `loop` has only one child block, and it is a reduction block;
2) `loop` is a reduction loop, i.e. the loop variable is bound to only reduction variables
in the block binding;
3) `loop` is not parallelized, vectorized, unrolled or bound to any thread axis;
4) The block scope that `loop` is in is a staged-pipeline;
5) The outermost loop outside the reduction block should has the reduction block as its first child block;
6) The outermost reduction loop should have only one child block;
7) An unary extent loop that is not bound to any reduction or data parallel variables in the block binding
should not appear under some reduction loop;
8) The reduction block should write to only one buffer, and its init and body block only is
a simple `BufferStore`, and the pattern is registered as associative reducer.
The pre-defined patterns include: plus, multiplication, min and max;
9) Each of the loops on top of the block cannot be bound to a data parallel and a reduction
block binding at the same time;
10) `factor_axis` should be in range `[-ndim(B) - 1, ndim(B)]`,
where `B` is the buffer that the reduction block writes to.
Negative indexing is normalized according to numpy convention.
"""
return _ffi_api_schedule.ScheduleRFactor(self, loop, factor_axis) # pylint: disable=no-member

########## Schedule: blockize & tensorize ##########


Expand Down
6 changes: 4 additions & 2 deletions src/arith/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1137,8 +1137,10 @@ PrimExpr CanonicalSimplifier::Impl::SimplifyReduceCombiner(const ReduceNode* op)
// and recursively mark the corresponding components
for (size_t i = 0; i < simplified_result.size(); ++i)
if (!used[i]) {
if (ExprUseVar(simplified_result[idx], op->combiner->lhs[i]) ||
ExprUseVar(simplified_result[idx], op->combiner->rhs[i]))
if (UsesVar(simplified_result[idx],
[v = op->combiner->lhs[i].get()](const VarNode* var) { return var == v; }) ||
UsesVar(simplified_result[idx],
[v = op->combiner->rhs[i].get()](const VarNode* var) { return var == v; }))
mark_used(i);
}
};
Expand Down
4 changes: 2 additions & 2 deletions src/arith/detect_linear_equation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class LinearEqDetector : public ExprFunctor<LinearEqEntry(const PrimExpr&, const
}
LinearEqEntry VisitExprDefault_(const Object* op, const PrimExpr& e) final {
if (fail_) return LinearEqEntry();
if (ExprUseVar(e, var_)) {
if (UsesVar(e, [this](const VarNode* var) { return var == var_.get(); })) {
fail_ = true;
return LinearEqEntry();
} else {
Expand Down Expand Up @@ -159,7 +159,7 @@ Array<PrimExpr> DetectLinearEquation(const PrimExpr& e, const Array<Var>& vars)
for (size_t i = vars.size(); i > 1; --i) {
vset.insert(vars[i - 1].get());
// The previous coeff contains the variable
if (ExprUseVar(coeff[i - 2], vset_contains)) {
if (UsesVar(coeff[i - 2], vset_contains)) {
return Array<PrimExpr>();
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/te/autodiff/ad_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,7 @@ std::pair<PrimExpr, PrimExpr> ImplicationNotContainingVars(
return {pair_a.first || pair_b.first, (pair_a.first || pair_b.second) &&
(pair_b.first || pair_a.second) &&
(pair_a.second || pair_b.second)};
} else if (!tir::ExprUseVar(cond, [&vars](const VarNode* var) { return vars.count(var); })) {
} else if (!tir::UsesVar(cond, [&vars](const VarNode* var) { return vars.count(var); })) {
return {cond, const_true()};
} else {
return {const_true(), cond};
Expand Down Expand Up @@ -1014,7 +1014,7 @@ PrimExpr TrySimplifyCompute(const PrimExpr& expr, const PrimExpr& cond,
// Keep only those variables of the new vars which are used in the new_expr
Array<Var> used_res_variables;
for (const Var& var : res->dst->variables) {
if (ExprUseVar(new_expr, var)) {
if (tir::UsesVar(new_expr, [&var](const VarNode* var_) { return var_ == var.get(); })) {
ICHECK(res->dst->ranges.count(var)) << "Range of " << var << " cannot be inferred.";
used_res_variables.push_back(var);
}
Expand Down
2 changes: 1 addition & 1 deletion src/te/operation/compute_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ Stmt TransformUpdate(const Stage& stage, const std::unordered_map<IterVar, Range
auto fbanned = [&](const VarNode* node) { return banned.count(node); };

for (const PrimExpr& pred : n.main_predicates) {
if (tir::ExprUseVar(pred, fbanned)) {
if (tir::UsesVar(pred, fbanned)) {
LOG(FATAL) << "Tensorize update transform failed, the condition " << pred
<< " has a conflict with the reset condition";
}
Expand Down
4 changes: 2 additions & 2 deletions src/te/operation/tensorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,13 @@ void VerifyTensorizeLoopNest(const ComputeOpNode* self, const Stage& stage,
auto fbanned = [&](const VarNode* node) { return banned.count(node); };

for (const PrimExpr& pred : n.main_predicates) {
if (tir::ExprUseVar(pred, fbanned)) {
if (tir::UsesVar(pred, fbanned)) {
LOG(FATAL) << "Tensorize failed, split condition " << pred
<< " relies on var defined inside tensorize scope";
}
}
for (const PrimExpr& pred : n.init_predicates) {
if (tir::ExprUseVar(pred, fbanned)) {
if (tir::UsesVar(pred, fbanned)) {
LOG(FATAL) << "Tensorize failed, split condition " << pred
<< " relies on var defined inside tensorize scope";
}
Expand Down
30 changes: 23 additions & 7 deletions src/tir/analysis/var_touch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,33 @@
* \brief Implementation of simple passes
*/
#include <tvm/tir/analysis.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>

namespace tvm {
namespace tir {

class VarTouchVisitor : public ExprVisitor {
class VarTouchVisitor : public StmtExprVisitor {
public:
explicit VarTouchVisitor(std::function<bool(const VarNode*)> var_set) : var_set_(var_set) {}
explicit VarTouchVisitor(std::function<bool(const VarNode*)> var_set)
: var_set_(std::move(var_set)) {}

void VisitStmt(const Stmt& stmt) final {
if (use_var_) return;
StmtExprVisitor::VisitStmt(stmt);
}

void VisitExpr(const PrimExpr& e) final {
if (use_var_) return;
ExprVisitor::VisitExpr(e);
StmtExprVisitor::VisitExpr(e);
}

void VisitExpr_(const VarNode* op) final { Handle(op); }

void VisitStmt_(const StoreNode* op) final {
Handle(op->buffer_var.get());
StmtVisitor::VisitStmt_(op);
}

void VisitExpr_(const LoadNode* op) final {
Handle(op->buffer_var.get());
ExprVisitor::VisitExpr_(op);
Expand All @@ -54,9 +64,15 @@ class VarTouchVisitor : public ExprVisitor {
std::function<bool(const VarNode*)> var_set_;
};

bool ExprUseVar(const PrimExpr& e, std::function<bool(const VarNode*)> var_set) {
VarTouchVisitor visitor(var_set);
visitor(e);
bool UsesVar(const Stmt& stmt, std::function<bool(const VarNode*)> var_set) {
VarTouchVisitor visitor(std::move(var_set));
visitor(stmt);
return visitor.use_var_;
}

bool UsesVar(const PrimExpr& expr, std::function<bool(const VarNode*)> var_set) {
VarTouchVisitor visitor(std::move(var_set));
visitor(expr);
return visitor.use_var_;
}

Expand Down
Loading

0 comments on commit e3ab7fd

Please sign in to comment.