Skip to content

Commit

Permalink
Add a 'rolling_buffer' scheduling primitive (apache#9444)
Browse files Browse the repository at this point in the history
* Add a 'rolling_buffer' scheduling primitive

Co-authored-by: Matthew Barrett <[email protected]>

* Fix lint problems

Change-Id: I5e27a66105fccca84327e41b4c68836ac2515126

* Remove designated initializers

Change-Id: Ic148264239eac7df7d976a6a3e15236935232792

Co-authored-by: Matthew Barrett <[email protected]>
  • Loading branch information
2 people authored and yangulei committed Jan 11, 2022
1 parent 1e5c772 commit d279d80
Show file tree
Hide file tree
Showing 10 changed files with 627 additions and 2 deletions.
7 changes: 7 additions & 0 deletions include/tvm/te/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,11 @@ class Stage : public ObjectRef {
* \return reference to self.
*/
TVM_DLL Stage& double_buffer(); // NOLINT(*)
/*!
* \brief Compute current stage with rolling buffering.
* \return reference to self.
*/
TVM_DLL Stage& rolling_buffer(); // NOLINT(*)
/*!
* \brief whether the stage has been scheduled.
* \return whether the stage has been scheduled.
Expand Down Expand Up @@ -493,6 +498,8 @@ class StageNode : public Object {
bool is_output{false};
/*! \brief Whether apply double buffer optimization to this stage */
bool double_buffer{false};
/*! \brief Whether apply rolling buffer optimization to this stage */
bool rolling_buffer{false};
/*!
* \brief The parent group of the current stage.
* The stage cannot be assigned to stages outside the group.
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1284,6 +1284,8 @@ constexpr const char* double_buffer_scope = "double_buffer_scope";
* \brief Marks region used by double buffer write
*/
constexpr const char* double_buffer_write = "double_buffer_write";
/*! \brief Mark realization for rolling buffer optimization */
constexpr const char* rolling_buffer_scope = "rolling_buffer_scope";
/*! \brief Mark of scan update scope */
constexpr const char* scan_update_scope = "scan_update_scope";
/*! \brief Mark of scan init scope */
Expand Down
8 changes: 8 additions & 0 deletions python/tvm/te/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,14 @@ def double_buffer(self):
"""
_ffi_api.StageDoubleBuffer(self)

def rolling_buffer(self):
"""Compute the current stage via rolling buffering.
This can only be applied to intermediate stage.
This will change the storage cost of the current stage.
"""
_ffi_api.StageRollingBuffer(self)


@tvm._ffi.register_object
class SpecializedCondition(Object):
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,17 @@ def InjectDoubleBuffer():
return _ffi_api.InjectDoubleBuffer() # type: ignore


def InjectRollingBuffer():
"""Inject rolling buffer statements.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.InjectRollingBuffer() # type: ignore


def StorageRewrite():
"""Rewrite storage allocation pattern.
Expand Down
3 changes: 2 additions & 1 deletion src/te/operation/compute_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,8 @@ ComputeLoopNest ComputeLoopNest::Create(const BaseComputeOpNode* self, const Sta
}
ret.init_nest = MakeLoopNest(stage, dom_map, begin_loop, true, skip_iter, &(ret.init_vmap),
debug_keep_trivial_loop);
ret.init_predicates = MakeBoundCheck(stage, dom_map, ret.init_vmap, true, skip_iter);
ret.init_predicates =
MakeBoundCheck(stage, dom_map, ret.init_vmap, !stage->rolling_buffer, skip_iter);
for (auto& e : ret.init_predicates) {
e = likely(e);
}
Expand Down
9 changes: 9 additions & 0 deletions src/te/schedule/schedule_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,13 @@ Stage& Stage::double_buffer() {
return *this;
}

Stage& Stage::rolling_buffer() {
StageNode* self = operator->();
ICHECK(!self->is_output) << "Cannot apply rolling buffer on output";
self->rolling_buffer = true;
return *this;
}

Stage CopyStage(const Stage& s) {
ObjectPtr<StageNode> n = make_object<StageNode>(*s.operator->());
return Stage(n);
Expand Down Expand Up @@ -886,6 +893,8 @@ TVM_REGISTER_GLOBAL("te.StageStorageAlign").set_body_method(&Stage::storage_alig

TVM_REGISTER_GLOBAL("te.StageDoubleBuffer").set_body_method(&Stage::double_buffer);

TVM_REGISTER_GLOBAL("te.StageRollingBuffer").set_body_method(&Stage::rolling_buffer);

TVM_REGISTER_GLOBAL("te.ScheduleNormalize").set_body_method(&Schedule::normalize);

TVM_REGISTER_GLOBAL("te.ScheduleCreateGroup").set_body_method(&Schedule::create_group);
Expand Down
4 changes: 4 additions & 0 deletions src/te/schedule/schedule_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ Stmt MakePipeline(const Stage& s, const std::unordered_map<IterVar, Range>& dom_
pipeline = SeqStmt({producer, consumer});
}

if (s->rolling_buffer) {
pipeline = AttrStmt(s->op, tir::attr::rolling_buffer_scope, Bool(true), pipeline);
}

return s->op->BuildRealize(s, dom_map, pipeline, s->scope);
}

Expand Down
3 changes: 2 additions & 1 deletion src/te/schedule/schedule_postproc_to_primfunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ class TensorToBufferMapper : public StmtExprMutator {
Stmt VisitStmt_(const AttrStmtNode* op) final {
auto ret = StmtExprMutator::VisitStmt_(op);
op = ret.as<AttrStmtNode>();
if (op->attr_key == tir::attr::double_buffer_scope) {
if (op->attr_key == tir::attr::double_buffer_scope ||
op->attr_key == tir::attr::rolling_buffer_scope) {
Stmt body = op->body;
Operation operation = Downcast<Operation>(op->node);
for (int i = operation->num_outputs(); i != 0; --i) {
Expand Down
Loading

0 comments on commit d279d80

Please sign in to comment.