Skip to content

Commit

Permalink
[TIR][Schedule] enhance compute_at primitive to choose proper position
Browse files Browse the repository at this point in the history
  • Loading branch information
yincs-intellif committed Aug 16, 2022
1 parent c477c76 commit 8e7382c
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 28 deletions.
4 changes: 2 additions & 2 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -431,8 +431,8 @@ class ScheduleNode : public runtime::Object {
* \param loop_rv The loop where the block to be moved under
* \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1
*/
virtual void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
bool preserve_unit_loops) = 0;
virtual void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
bool to_early_stage = false) = 0;
/*!
* \brief Move a consumer block under the specific loop, and regenerate the
* loops induced by the block so that the buffer region consumed by the consumer block could
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,6 +1261,7 @@ def compute_at(
block: Union[BlockRV, str],
loop: LoopRV,
preserve_unit_loops: bool = False,
to_early_stage: bool = False,
) -> None:
"""Compute-At. Move a producer block under the specific loop, and regenerate the
loops induced by the block so that the buffer region produced by the producer block could
Expand Down Expand Up @@ -1290,6 +1291,9 @@ def compute_at(
preserve_unit_loops: bool
Whether to keep the trivial loops whose extents are 1
to_early_stage: bool
Choose to closed to or away from it's consumer
Examples
--------
Expand Down Expand Up @@ -1347,6 +1351,7 @@ def after_compute_at(a: T.handle, c: T.handle) -> None:
block,
loop,
preserve_unit_loops,
to_early_stage,
)

@type_checked
Expand Down
4 changes: 2 additions & 2 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ BlockRV ConcreteScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index,
/******** Schedule: Compute location ********/

void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
bool preserve_unit_loops) {
bool preserve_unit_loops, bool to_early_stage) {
static StmtSRef inline_mark = StmtSRef::InlineMark();
static StmtSRef root_mark = StmtSRef::RootMark();
StmtSRef loop_sref = this->GetSRef(loop_rv);
Expand All @@ -579,7 +579,7 @@ void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop
TVM_TIR_SCHEDULE_END("compute-at", this->error_render_level_);
} else {
TVM_TIR_SCHEDULE_BEGIN();
tir::ComputeAt(state_, this->GetSRef(block_rv), loop_sref, preserve_unit_loops);
tir::ComputeAt(state_, this->GetSRef(block_rv), loop_sref, preserve_unit_loops, to_early_stage);
TVM_TIR_SCHEDULE_END("compute-at", this->error_render_level_);
}
this->state_->DebugVerify();
Expand Down
3 changes: 2 additions & 1 deletion src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ class ConcreteScheduleNode : public ScheduleNode {
BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type) override;
/******** Schedule: Compute location ********/
void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) override;
void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
bool to_early_stage = false) override;
void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
bool preserve_unit_loops) override;
void ComputeInline(const BlockRV& block) override;
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ TVM_DLL StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buf
* \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1
*/
TVM_DLL void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
bool preserve_unit_loops);
bool preserve_unit_loops, bool to_early_stage = false);
/*!
* \brief Move a consumer block under the specific loop, and regenerate the
* loops induced by the block so that the buffer region consumed by the consumer block could
Expand Down
78 changes: 63 additions & 15 deletions src/tir/schedule/primitive/compute_at.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,24 +120,38 @@ class NotInSameScopeError : public ScheduleError {

/******** Helper Functions/Classes ********/

Stmt GetBlock(Stmt stmt) {
class Finder : public StmtVisitor {
public:
void VisitStmt_(const BlockRealizeNode* realize) final { blk = realize->block; }
Stmt blk;
};
Finder finder;
finder(stmt);
return finder.blk;
}

/*!
* \brief Find a point where the block can be inserted under the loop
* \tparam require_all_producers_visited Requires all producer blocks to be present under the loop
* \tparam require_all_consumers_visited Requires all consumer blocks to be present under the loop
* \param self The schedule state
* \param scope The scope root block BlockScope
* \param subtrees The subtrees under the loop, among which the insertion points are sought
* \param producer_srefs The producer blocks
* \param consumer_srefs The consumer blocks
* \param block2realize A cache that maps a block to its realize
* \param to_early_stage closed to or away from it's consumer
* \return The last position the new block can be inserted onto, and the
* producer-consumer-relationship is still satisfied.
* \throws ScheduleError if there is no such insertion point found
*/
template <bool require_all_producers_visited, bool require_all_consumers_visited>
int FindInsertionPoint(
const ScheduleState& self, const Array<Stmt>& subtrees, const Array<StmtSRef>& producer_srefs,
const Array<StmtSRef>& consumer_srefs,
std::unordered_map<const BlockNode*, const BlockRealizeNode*>* block2realize) {
int FindInsertionPoint(const ScheduleState& self, const BlockScope scope,
const Array<Stmt>& subtrees, const Array<StmtSRef>& producer_srefs,
const Array<StmtSRef>& consumer_srefs,
std::unordered_map<const BlockNode*, const BlockRealizeNode*>* block2realize,
bool to_early_stage) {
ProducerConsumerSplit split =
ProducerConsumerSplit::Find(self, subtrees, producer_srefs, consumer_srefs, block2realize);
// Step 1. Check if all the producers are visited in the subtrees, if required to
Expand All @@ -160,7 +174,37 @@ int FindInsertionPoint(
// The valid indices are: (last_producer_position, first_consumer_position]
ICHECK(split.last_producer_position < split.first_consumer_position);
// Step 4. Return the last valid insertion point
return split.first_consumer_position;
int insert_position = split.first_consumer_position;
if (require_all_consumers_visited && to_early_stage) {
class Finder : public StmtVisitor {
public:
void VisitStmt_(const BlockRealizeNode* realize) final {
const BlockNode* block = realize->block.get();
if (producer_blocks_.count(block)) {
++this->n_producers_visited_;
}
}

std::unordered_set<const StmtNode*> producer_blocks_;
int n_producers_visited_ = 0;
};
// adjust the inserted position by compute at order
for (int i = split.first_consumer_position; i - 1 > split.last_producer_position; --i) {
auto blk = GetBlock(subtrees[i]);
if (!blk.defined()) break;
auto block_sref = self->stmt2ref.at(blk.get());
Array<StmtSRef> block_producer_srefs = GetProducers(block_sref, scope);
Finder finder;
finder.producer_blocks_.reserve(block_producer_srefs.size());
for (const StmtSRef& block_sref_ : block_producer_srefs) {
finder.producer_blocks_.insert(block_sref_->stmt);
}
finder(subtrees[i - 1]);
if (finder.n_producers_visited_ == 0) break;
insert_position = i - 1;
}
}
return insert_position;
}

/*!
Expand Down Expand Up @@ -556,7 +600,8 @@ void CalculateProvidedRequiredRegions(
template <bool is_compute_at>
void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_sref,
const StmtSRef& loop_sref, bool preserve_unit_loops,
arith::Analyzer* analyzer, bool check_only = false) {
arith::Analyzer* analyzer, bool check_only = false,
bool to_early_stage = false) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
// Step 1. Bunch of checks
Expand Down Expand Up @@ -585,10 +630,11 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s
std::unordered_map<const BlockNode*, const BlockRealizeNode*> block2realize;
block2realize.reserve(self->block_info.size());
int insert_position = FindInsertionPoint<!is_compute_at, is_compute_at>(
/*self=*/self,
/*self=*/self, /*scope=*/scope,
/*subtrees=*/AsArray(loop->body),
/*producer_srefs=*/producer_srefs,
/*consumer_srefs=*/consumer_srefs, /*block2realize=*/&block2realize);
/*consumer_srefs=*/consumer_srefs, /*block2realize=*/&block2realize,
/*to_early_stage*/ to_early_stage);
// Step 4. Calculate the region provided by a single execution instance of `block`,
// as well as the region required by dependent blocks under `loop`.
// Here is the definition of `provide` and `require`:
Expand Down Expand Up @@ -626,10 +672,10 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s
}

void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
bool preserve_unit_loops) {
bool preserve_unit_loops, bool to_early_stage) {
arith::Analyzer analyzer;
ComputeAtOrReverseComputeAtImpl<true>(self, block_sref, loop_sref, preserve_unit_loops,
&analyzer);
ComputeAtOrReverseComputeAtImpl<true>(self, block_sref, loop_sref, preserve_unit_loops, &analyzer,
false, to_early_stage);
}

void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
Expand Down Expand Up @@ -671,20 +717,22 @@ struct ComputeAtTraits : public UnpackedInstTraits<ComputeAtTraits> {

private:
static constexpr size_t kNumInputs = 2;
static constexpr size_t kNumAttrs = 1;
static constexpr size_t kNumAttrs = 2;
static constexpr size_t kNumDecisions = 0;

static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, LoopRV loop_rv,
Bool preserve_unit_loops) {
return sch->ComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool());
Bool preserve_unit_loops, Bool to_early_stage) {
return sch->ComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool(),
to_early_stage.operator bool());
}

static String UnpackedAsPython(Array<String> outputs, String block_rv, String loop_rv,
Bool preserve_unit_loops) {
Bool preserve_unit_loops, Bool to_early_stage) {
PythonAPICall py("compute_at");
py.Input("block", block_rv);
py.Input("loop", loop_rv);
py.Input("preserve_unit_loops", preserve_unit_loops.operator bool());
py.Input("to_early_stage", to_early_stage.operator bool());
return py.Str();
}

Expand Down
13 changes: 7 additions & 6 deletions src/tir/schedule/traced_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -320,14 +320,15 @@ BlockRV TracedScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index,
/******** Schedule: Compute location ********/

void TracedScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
bool preserve_unit_loops) {
ConcreteScheduleNode::ComputeAt(block_rv, loop_rv, preserve_unit_loops);
bool preserve_unit_loops, bool to_early_stage) {
ConcreteScheduleNode::ComputeAt(block_rv, loop_rv, preserve_unit_loops, to_early_stage);

static const InstructionKind& kind = InstructionKind::Get("ComputeAt");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
/*inputs=*/{block_rv, loop_rv},
/*attrs=*/{Integer(preserve_unit_loops)},
/*outputs=*/{}));
trace_->Append(
/*inst=*/Instruction(/*kind=*/kind,
/*inputs=*/{block_rv, loop_rv},
/*attrs=*/{Integer(preserve_unit_loops), Integer(to_early_stage)},
/*outputs=*/{}));
}

void TracedScheduleNode::ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
Expand Down
3 changes: 2 additions & 1 deletion src/tir/schedule/traced_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ class TracedScheduleNode : public ConcreteScheduleNode {
BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type) final;
/******** Schedule: Compute location ********/
void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) final;
void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
bool to_early_stage = false) final;
void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
bool preserve_unit_loops) final;
void ComputeInline(const BlockRV& block_rv) final;
Expand Down
89 changes: 89 additions & 0 deletions tests/python/unittest/test_tir_schedule_compute_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -1353,5 +1353,94 @@ def _create_prim_func():
verify_trace_roundtrip(sch=sch, mod=mod)


def test_compute_at_to_early_stage():
@T.prim_func
def multi_producers_conv(
data: T.Buffer[(1, 3, 224, 224), "int8"],
w: T.Buffer[(16, 3, 7, 7), "int8"],
conv: T.Buffer[(1, 16, 112, 112), "int32"],
) -> None:
pad = T.alloc_buffer([1, 3, 230, 230], dtype="int8")
wbuf = T.alloc_buffer([16, 3, 7, 7], dtype="int8")
for i0, i1, i2, i3 in T.grid(1, 3, 230, 230):
with T.block("pad"):
i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(data[i0_1, i1_1, i2_1 - 3, i3_1 - 3])
T.writes(pad[i0_1, i1_1, i2_1, i3_1])
pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(
3 <= i2_1 and i2_1 < 227 and 3 <= i3_1 and i3_1 < 227,
data[i0_1, i1_1, i2_1 - 3, i3_1 - 3],
T.int8(0),
dtype="int8",
)
for i0 in T.serial(1):
for ax0, ax1, ax2, ax3 in T.grid(16, 3, 7, 7):
with T.block("wbuf"):
v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(w[v0, v1, v2, v3])
T.writes(wbuf[v0, v1, v2, v3])
wbuf[v0, v1, v2, v3] = w[v0, v1, v2, v3]
for i1, i2, i3, i4, i5, i6 in T.grid(16, 112, 112, 3, 7, 7):
with T.block("conv"):
nn, ff, yy, xx, rc, ry, rx = T.axis.remap(
"SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]
)
T.reads(pad[nn, rc, yy * 2 + ry, xx * 2 + rx], wbuf[ff, rc, ry, rx])
T.writes(conv[nn, ff, yy, xx])
with T.init():
conv[nn, ff, yy, xx] = 0
conv[nn, ff, yy, xx] = conv[nn, ff, yy, xx] + T.cast(
pad[nn, rc, yy * 2 + ry, xx * 2 + rx], "int32"
) * T.cast(wbuf[ff, rc, ry, rx], "int32")

@T.prim_func
def multi_producers_after_compute_at(
data: T.Buffer[(1, 3, 224, 224), "int8"],
w: T.Buffer[(16, 3, 7, 7), "int8"],
conv: T.Buffer[(1, 16, 112, 112), "int32"],
) -> None:
pad = T.alloc_buffer([1, 3, 230, 230], dtype="int8")
wbuf = T.alloc_buffer([16, 3, 7, 7], dtype="int8")
for i0 in T.serial(1):
for ax0, ax1, ax2 in T.grid(3, 229, 229):
with T.block("pad"):
i0_1 = T.axis.spatial(1, 0)
i1_1 = T.axis.spatial(3, ax0)
i2_1 = T.axis.spatial(230, ax1)
i3_1 = T.axis.spatial(230, ax2)
T.reads(data[i0_1, i1_1, i2_1 - 3, i3_1 - 3])
T.writes(pad[i0_1, i1_1, i2_1, i3_1])
pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(
3 <= i2_1 and i2_1 < 227 and 3 <= i3_1 and i3_1 < 227,
data[i0_1, i1_1, i2_1 - 3, i3_1 - 3],
T.int8(0),
dtype="int8",
)
for ax0, ax1, ax2, ax3 in T.grid(16, 3, 7, 7):
with T.block("wbuf"):
v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(w[v0, v1, v2, v3])
T.writes(wbuf[v0, v1, v2, v3])
wbuf[v0, v1, v2, v3] = w[v0, v1, v2, v3]
for i1, i2, i3, i4, i5, i6 in T.grid(16, 112, 112, 3, 7, 7):
with T.block("conv"):
nn, ff, yy, xx, rc, ry, rx = T.axis.remap(
"SSSSRRR", [i0, i1, i2, i3, i4, i5, i6]
)
T.reads(pad[nn, rc, yy * 2 + ry, xx * 2 + rx], wbuf[ff, rc, ry, rx])
T.writes(conv[nn, ff, yy, xx])
with T.init():
conv[nn, ff, yy, xx] = 0
conv[nn, ff, yy, xx] = conv[nn, ff, yy, xx] + T.cast(
pad[nn, rc, yy * 2 + ry, xx * 2 + rx], "int32"
) * T.cast(wbuf[ff, rc, ry, rx], "int32")

sch = tir.Schedule(multi_producers_conv, debug_mask="all")
block_c = sch.get_block("pad")
axis = sch.get_loops("conv")[0]
sch.compute_at(block_c, axis, to_early_stage=True)
tvm.ir.assert_structural_equal(multi_producers_after_compute_at, sch.mod["main"])


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 8e7382c

Please sign in to comment.