Skip to content

Commit

Permalink
[TIR][Schedule] Support for specific consumer block targeting in cach…
Browse files Browse the repository at this point in the history
…e_write (apache#13510)

Add optional consumer blocks to cache_write.
  • Loading branch information
wrongtest-intellif authored and fzi-peccia committed Mar 27, 2023
1 parent 9a3bec8 commit 3008e78
Show file tree
Hide file tree
Showing 9 changed files with 198 additions and 23 deletions.
4 changes: 3 additions & 1 deletion include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -399,10 +399,12 @@ class ScheduleNode : public runtime::Object {
* \param block_rv The producer of the buffer
* \param write_buffer_index The index of the buffer in block's write region
* \param storage_scope The target storage scope
* \param consumer_blocks An optional list of consumers to read from cache directly.
* \return The cache stage block.
*/
virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) = 0;
const String& storage_scope,
const Array<BlockRV> consumer_blocks = {}) = 0;
/*!
* \brief Create 2 blocks that read&write a buffer region into a read/write cache.
* It requires the the target block both read & write the target buffer.
Expand Down
11 changes: 10 additions & 1 deletion python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,6 +1110,7 @@ def cache_write(
block: Union[BlockRV, str],
write_buffer_index: Union[int, str, Buffer],
storage_scope: str,
consumer_blocks=None,
) -> BlockRV:
"""Create a block that reads a buffer region into a write cache. It requires:
Expand All @@ -1130,6 +1131,9 @@ def cache_write(
storage_scope: str
The target storage scope.
consumer_blocks: Optional[List[Union[BlockRV, str]]]
An optional list of consumers that should read directly from the cache.
If not specified, all consumers will read from the original buffer.
Returns
-------
Expand Down Expand Up @@ -1179,14 +1183,19 @@ def after_cache_write(a: T.handle, b: T.handle) -> None:
B[vi, vj] = B_local[vi, vj]
"""
if consumer_blocks is None:
consumer_blocks = []

# Convert any string block names into Block RVs.
consumer_blocks = [self._normalize_block_arg(b) for b in consumer_blocks]
block = self._normalize_block_arg(block)

if not isinstance(write_buffer_index, int):
_, write_buffer_index, _ = self._normalize_buffer_arg(
block, write_buffer_index, required_buffer_type="write"
)
return _ffi_api.ScheduleCacheWrite( # type: ignore # pylint: disable=no-member
self, block, write_buffer_index, storage_scope
self, block, write_buffer_index, storage_scope, consumer_blocks
)

@type_checked
Expand Down
11 changes: 9 additions & 2 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -552,10 +552,17 @@ BlockRV ConcreteScheduleNode::CacheRead(const BlockRV& block_rv, int read_buffer
}

BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) {
const String& storage_scope,
const Array<BlockRV> consumer_blocks) {
StmtSRef result{nullptr};
// Create a new array of SRefs from the consumer block list.
Array<StmtSRef> consumer_block_refs = {};
for (BlockRV block : consumer_blocks) {
consumer_block_refs.push_back(this->GetSRef(block));
}
TVM_TIR_SCHEDULE_BEGIN();
result = tir::CacheWrite(state_, this->GetSRef(block_rv), write_buffer_index, storage_scope);
result = tir::CacheWrite(state_, this->GetSRef(block_rv), write_buffer_index, storage_scope,
consumer_block_refs);
TVM_TIR_SCHEDULE_END("cache-write", this->error_render_level_);
this->state_->DebugVerify();
return CreateRV<BlockRV>(result);
Expand Down
4 changes: 2 additions & 2 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ class ConcreteScheduleNode : public ScheduleNode {
/******** Schedule: Insert cache stages ********/
BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope,
const Array<BlockRV> consumer_blocks = {}) override;
BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) override;
BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope,
const Array<BlockRV> consumer_blocks = {}) override;
Array<BlockRV> CacheInplace(const BlockRV& block_rv, int read_buffer_index,
const String& storage_scope) override;
Array<BlockRV> CacheIndex(const BlockRV& block_rv, int write_buffer_index) override;
Expand Down
4 changes: 3 additions & 1 deletion src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,12 @@ TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int r
* \param block_sref The producer of the buffer
* \param write_buffer_index The index of the buffer in block's write region
* \param storage_scope The target storage scope
* \param consumer_blocks Array of blocks that consume the cache.
* \return The cache stage block.
*/
TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index,
const String& storage_scope);
const String& storage_scope,
const Array<StmtSRef> consumer_blocks = {});
/*!
*!
* \brief Create 2 blocks that read&write a buffer region into a read/write cache.
Expand Down
72 changes: 61 additions & 11 deletions src/tir/schedule/primitive/cache_read_write.cc
Original file line number Diff line number Diff line change
Expand Up @@ -382,21 +382,34 @@ class CacheLocDetector : public StmtVisitor {
* writer block of the buffer being applied cache_read or cache_write \param scope_sref The sref
* of the scope block of the cached block \param info The cache stage info.
*/
template <bool is_cache_read>
static void Detect(const ScheduleState& self, const StmtSRef& block_sref,
const StmtSRef& scope_sref, CacheStageInfo* info) {
std::vector<StmtSRef> related_blocks;
// If consumer is specified, skip detecting the others
if (info->consumer_blocks.size() > 0) {
for (StmtSRef consumer : info->consumer_blocks) {
related_blocks.emplace_back(consumer);
if (is_cache_read) {
if (info->consumer_blocks.size() > 0) {
for (StmtSRef consumer : info->consumer_blocks) {
related_blocks.emplace_back(consumer);
}
} else {
for (const Dependency& def : self->GetBlockScope(scope_sref)->GetDepsBySrc(block_sref)) {
if (def->kind == DepKind::kRAW) {
related_blocks.push_back(def->dst);
}
}
}
} else {
for (const Dependency& def : self->GetBlockScope(scope_sref)->GetDepsBySrc(block_sref)) {
if (def->kind == DepKind::kRAW) {
if (info->consumer_blocks.count(def->dst)) {
continue;
}
related_blocks.push_back(def->dst);
}
}
}

if (!related_blocks.empty()) {
CacheLocDetector detector(self, block_sref, scope_sref, related_blocks);
detector(GetRef<Stmt>(scope_sref->stmt));
Expand Down Expand Up @@ -739,6 +752,30 @@ class CacheWriteRewriter : public StmtExprMutator {

Stmt VisitStmt_(const BlockNode* block) final {
Block old_stmt = GetRef<Block>(block);

// Check if this block is one of the specified cache consumers.
// update the read buffer to the cache.
for (StmtSRef consumer_sref : info_->consumer_blocks) {
const BlockNode* consumer_node = TVM_SREF_TO_BLOCK(consumer_sref);
Block consumer_block = GetRef<Block>(consumer_node);
if (old_stmt.same_as(consumer_block)) {
Array<BufferRegion> reads =
ReplaceBuffer(block->reads, info_->write_buffer, info_->read_buffer);
Array<MatchBufferRegion> match_buffers =
ReplaceBuffer(block->match_buffers, info_->write_buffer, info_->read_buffer);
if (!reads.same_as(block->reads) || !match_buffers.same_as(block->match_buffers)) {
auto n = CopyOnWrite(block);
n->reads = std::move(reads);
n->match_buffers = std::move(match_buffers);
n->body = VisitStmt(block->body);
Block new_consumer = Block(n);
info_->block_reuse.Set(old_stmt, new_consumer);
return std::move(new_consumer);
}
return std::move(old_stmt);
}
}

// We only mutate the block which generates info->write_buffer
if (block != writer_block_sref_->stmt && block != scope_sref_->stmt && !under_writer_block_) {
return std::move(old_stmt);
Expand Down Expand Up @@ -1160,7 +1197,7 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff
StmtSRef parent_sref = GetRef<StmtSRef>(write_block_sref->parent);

// Detect insert position
CacheLocDetector::Detect(self, write_block_sref, scope_sref, &info);
CacheLocDetector::Detect</*is_cache_read=*/true>(self, write_block_sref, scope_sref, &info);
cache_region = RelaxBufferRegion(self, region, write_block_sref, parent_sref, info.loc_sref);
} else {
// Case 2. The buffer is the input block for the scope.
Expand Down Expand Up @@ -1190,7 +1227,7 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff
}

StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index,
const String& storage_scope) {
const String& storage_scope, const Array<StmtSRef> consumer_blocks) {
/*!
* Check:
* - The index is in the array of block reading region
Expand Down Expand Up @@ -1219,14 +1256,22 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu
// Create the corresponding buffer allocation
info.alloc = info.read_buffer;

// info.consumer_blocks indicates which buffers should consume the cache.
for (auto consumer : consumer_blocks) {
info.consumer_blocks.insert(consumer);
for (auto child : tir::GetChildBlocks(self, consumer)) {
info.consumer_blocks.insert(child);
}
}

// Step 3. Check the only writer block.
ICHECK_EQ(block_sref.get(), GetOnlyWriteBlock(self, scope_sref, write_buffer).get());

// Step 4. Find the producing region and insert position
BufferRegion region = GetBufferRegionFromBuffer(block->writes, write_buffer).value();
StmtSRef parent_sref = GetRef<StmtSRef>(block_sref->parent);
// Detect insert position
CacheLocDetector::Detect(self, block_sref, scope_sref, &info);
CacheLocDetector::Detect</*is_cache_read=*/false>(self, block_sref, scope_sref, &info);
BufferRegion cache_region =
RelaxBufferRegion(self, region, block_sref, parent_sref, info.loc_sref);

Expand Down Expand Up @@ -1468,21 +1513,26 @@ struct CacheWriteTraits : public UnpackedInstTraits<CacheWriteTraits> {
static constexpr bool kIsPure = false;

private:
static constexpr size_t kNumInputs = 1;
static constexpr size_t kNumInputs = 2;
static constexpr size_t kNumAttrs = 2;
static constexpr size_t kNumDecisions = 0;

static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer write_buffer_index,
static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block,
Array<BlockRV> consumer_blocks, Integer write_buffer_index,
String storage_scope) {
return sch->CacheWrite(block, write_buffer_index->value, storage_scope);
return sch->CacheWrite(block, write_buffer_index->value, storage_scope, consumer_blocks);
}

static String UnpackedAsPython(Array<String> outputs, String block, Integer write_buffer_index,
String storage_scope) {
static String UnpackedAsPython(Array<String> outputs, String block, Array<String> consumer_blocks,
Integer write_buffer_index, String storage_scope) {
PythonAPICall py("cache_write");
py.Input("block", block);
py.Input("write_buffer_index", write_buffer_index->value);
py.Input("storage_scope", storage_scope);
// Only write out consumer blocks if provided.
if (!consumer_blocks.empty()) {
py.Input("consumer_blocks", consumer_blocks);
}
py.SingleOutput(outputs);
return py.Str();
}
Expand Down
8 changes: 5 additions & 3 deletions src/tir/schedule/traced_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -296,12 +296,14 @@ BlockRV TracedScheduleNode::CacheRead(const BlockRV& block_rv, int read_buffer_i
}

BlockRV TracedScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) {
BlockRV result = ConcreteScheduleNode::CacheWrite(block_rv, write_buffer_index, storage_scope);
const String& storage_scope,
const Array<BlockRV> consumer_blocks) {
BlockRV result = ConcreteScheduleNode::CacheWrite(block_rv, write_buffer_index, storage_scope,
consumer_blocks);

static const InstructionKind& kind = InstructionKind::Get("CacheWrite");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
/*inputs=*/{block_rv},
/*inputs=*/{block_rv, consumer_blocks},
/*attrs=*/{Integer(write_buffer_index), storage_scope},
/*outputs=*/{result}));
return result;
Expand Down
4 changes: 2 additions & 2 deletions src/tir/schedule/traced_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ class TracedScheduleNode : public ConcreteScheduleNode {
/******** Schedule: Insert cache stages ********/
BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope,
const Array<BlockRV> consumer_blocks = {}) final;
BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) final;
BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index, const String& storage_scope,
const Array<BlockRV> consumer_blocks = {}) final;
Array<BlockRV> CacheInplace(const BlockRV& block_rv, int read_buffer_index,
const String& storage_scope) final;
BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
Expand Down
103 changes: 103 additions & 0 deletions tests/python/unittest/test_tir_schedule_cache_read_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,81 @@ def cache_write_multi_consumer() -> None:
C[vi] = A[vi]


@T.prim_func
def cache_write_multi_consumer_B_consume_cache():
A = T.alloc_buffer([128], dtype="float32")
B = T.alloc_buffer([128], dtype="float32")
C = T.alloc_buffer([128], dtype="float32")
A_global = T.alloc_buffer([128], dtype="float32")
for i in T.serial(8):
for j in T.serial(16):
with T.block("A"):
vi = T.axis.spatial(128, i * 16 + j)
A_global[vi] = 1.0
for j in T.serial(16):
with T.block("B"):
vi = T.axis.spatial(128, i * 16 + j)
B[vi] = A_global[vi] + 1.0
for ax0 in T.serial(128):
with T.block("A_global"):
v0 = T.axis.spatial(128, ax0)
A[v0] = A_global[v0]
for i in T.serial(128):
with T.block("C"):
vi = T.axis.spatial(128, i)
C[vi] = A[vi]


@T.prim_func
def cache_write_multi_consumer_C_consume_cache():
A = T.alloc_buffer([128], dtype="float32")
B = T.alloc_buffer([128], dtype="float32")
C = T.alloc_buffer([128], dtype="float32")
A_global = T.alloc_buffer([128], dtype="float32")
for i in T.serial(8):
for j in T.serial(16):
with T.block("A"):
vi = T.axis.spatial(128, i * 16 + j)
A_global[vi] = T.float32(1)
for ax0 in T.serial(16):
with T.block("A_global"):
v0 = T.axis.spatial(128, i * 16 + ax0)
A[v0] = A_global[v0]
for j in T.serial(16):
with T.block("B"):
vi = T.axis.spatial(128, i * 16 + j)
B[vi] = A[vi] + T.float32(1)
for i in T.serial(128):
with T.block("C"):
vi = T.axis.spatial(128, i)
C[vi] = A_global[vi]


@T.prim_func
def cache_write_multi_consumer_all_consume_cache():
A = T.alloc_buffer([128], dtype="float32")
B = T.alloc_buffer([128], dtype="float32")
C = T.alloc_buffer([128], dtype="float32")
A_global = T.alloc_buffer([128], dtype="float32")
for i in T.serial(8):
for j in T.serial(16):
with T.block("A"):
vi = T.axis.spatial(128, i * 16 + j)
A_global[vi] = T.float32(1)
for j in T.serial(16):
with T.block("B"):
vi = T.axis.spatial(128, i * 16 + j)
B[vi] = A_global[vi] + T.float32(1)
for i in T.serial(128):
with T.block("C"):
vi = T.axis.spatial(128, i)
C[vi] = A_global[vi]
for ax0 in T.serial(128):
with T.block("A_global"):
v0 = T.axis.spatial(128, ax0)
A[v0] = A_global[v0]


@T.prim_func
def continuous_cache_write(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
Expand Down Expand Up @@ -1113,6 +1188,34 @@ def test_cache_write_location(use_block_name):
tvm.ir.assert_structural_equal(cache_write_multi_consumer, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=func_multi_consumer)

# Test that specific consumer block targetting works.
# B read cache buffer and C read original output buffer
sch = tir.Schedule(func_multi_consumer, debug_mask="all")
block_a = "A" if use_block_name else sch.get_block("A")
block_b = "B" if use_block_name else sch.get_block("B")
sch.cache_write(block_a, 0, "global", consumer_blocks=[block_b])
tvm.ir.assert_structural_equal(cache_write_multi_consumer_B_consume_cache, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=func_multi_consumer)

# Test that specific consumer block targetting works.
# B read original output buffer and C read cache buffer
sch = tir.Schedule(func_multi_consumer, debug_mask="all")
block_a = "A" if use_block_name else sch.get_block("A")
block_c = "C" if use_block_name else sch.get_block("C")
sch.cache_write(block_a, 0, "global", consumer_blocks=[block_c])
tvm.ir.assert_structural_equal(cache_write_multi_consumer_C_consume_cache, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=func_multi_consumer)

# Test that specific consumer block targetting works.
# B and C read cache buffer
sch = tir.Schedule(func_multi_consumer, debug_mask="all")
block_a = "A" if use_block_name else sch.get_block("A")
block_b = "B" if use_block_name else sch.get_block("B")
block_c = "C" if use_block_name else sch.get_block("C")
sch.cache_write(block_a, 0, "global", consumer_blocks=[block_b, block_c])
tvm.ir.assert_structural_equal(cache_write_multi_consumer_all_consume_cache, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=func_multi_consumer)


def test_continuous_cache_write(use_block_name):
sch = tir.Schedule(elementwise, debug_mask="all")
Expand Down

0 comments on commit 3008e78

Please sign in to comment.