Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR][Schedule] Support for specific consumer block targeting in cache_read #12505

Merged
merged 9 commits into from
Aug 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -386,10 +386,12 @@ class ScheduleNode : public runtime::Object {
* \param block_rv The consumer block of the target buffer.
* \param read_buffer_index The index of the buffer in block's read region.
* \param storage_scope The target storage scope.
* \param consumer_blocks An optional list of consumers of the cache to rewrite.
* \return The cache stage block.
*/
virtual BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index,
const String& storage_scope) = 0;
const String& storage_scope,
const Array<BlockRV> consumer_blocks = {}) = 0;
jwfromm marked this conversation as resolved.
Show resolved Hide resolved
/*!
* \brief Create a block that writes a buffer region into a write cache. It requires:
* 1) There is only one block who writes the target buffer.
Expand Down
17 changes: 15 additions & 2 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,7 +1012,11 @@ def after_unroll(a: T.handle, b: T.handle) -> None:

@type_checked
def cache_read(
self, block: Union[BlockRV, str], read_buffer_index: int, storage_scope: str
self,
block: Union[BlockRV, str],
read_buffer_index: int,
storage_scope: str,
consumer_blocks: Optional[List[Union[BlockRV, str]]] = None,
) -> BlockRV:
"""Create a block that reads a buffer region into a read cache. It requires:

Expand All @@ -1031,6 +1035,10 @@ def cache_read(
storage_scope: str
The target storage scope.

consumer_blocks: Optional[List[Union[BlockRV, str]]]
jwfromm marked this conversation as resolved.
Show resolved Hide resolved
An optional list of consumers that should read from the cache. If not specified,
all consumers will use the cache.

Returns
-------
cached_block : BlockRV
Expand Down Expand Up @@ -1079,9 +1087,14 @@ def after_cache_read(a: T.handle, b: T.handle) -> None:
B[vi, vj] = A_local[vi, vj] * 2.0

"""
if consumer_blocks is None:
consumer_blocks = []

# Convert any string block names into Block RVs.
consumer_blocks = [self._normalize_block_arg(b) for b in consumer_blocks]
block = self._normalize_block_arg(block)
return _ffi_api.ScheduleCacheRead( # type: ignore # pylint: disable=no-member
self, block, read_buffer_index, storage_scope
self, block, read_buffer_index, storage_scope, consumer_blocks
)

@type_checked
Expand Down
11 changes: 9 additions & 2 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -535,10 +535,17 @@ void ConcreteScheduleNode::Unroll(const LoopRV& loop_rv) {
/******** Schedule: Insert cache stages ********/

BlockRV ConcreteScheduleNode::CacheRead(const BlockRV& block_rv, int read_buffer_index,
const String& storage_scope) {
const String& storage_scope,
const Array<BlockRV> consumer_blocks) {
StmtSRef result{nullptr};
// Create a new array of SRefs from the consumer block list.
Array<StmtSRef> consumer_block_refs = {};
for (BlockRV block : consumer_blocks) {
consumer_block_refs.push_back(this->GetSRef(block));
}
TVM_TIR_SCHEDULE_BEGIN();
result = tir::CacheRead(state_, this->GetSRef(block_rv), read_buffer_index, storage_scope);
result = tir::CacheRead(state_, this->GetSRef(block_rv), read_buffer_index, storage_scope,
consumer_block_refs);
TVM_TIR_SCHEDULE_END("cache-read", this->error_render_level_);
this->state_->DebugVerify();
return CreateRV<BlockRV>(result);
Expand Down
4 changes: 2 additions & 2 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ class ConcreteScheduleNode : public ScheduleNode {
void Bind(const LoopRV& loop_rv, const String& thread_axis) override;
void Unroll(const LoopRV& loop_rv) override;
/******** Schedule: Insert cache stages ********/
BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index,
const String& storage_scope) override;
BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope,
const Array<BlockRV> consumer_blocks = {}) override;
BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) override;
BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
Expand Down
3 changes: 2 additions & 1 deletion src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,10 +250,11 @@ TVM_DLL void Unroll(ScheduleState self, const StmtSRef& loop_sref);
* \param block_sref The consumer block of the target buffer.
* \param read_buffer_index The index of the buffer in block's read region.
* \param storage_scope The target storage scope.
* \param consumer_blocks Array of blocks that consume the cache.
* \return The cache stage block.
*/
TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index,
const String& storage_scope);
const String& storage_scope, const Array<StmtSRef> consumer_blocks = {});
/*!
* \brief Create a block that writes a buffer region into a write cache. It requires:
* 1) There is only one block that writes the target buffer.
Expand Down
61 changes: 44 additions & 17 deletions src/tir/schedule/primitive/cache_read_write.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ struct CacheStageInfo {
Stmt cache_stage;
/*! \brief The map used for ScheduleStateNode::Replace. */
Map<Block, Block> block_reuse;
/*! \brief A list of blocks that will consume the new cache. */
Array<StmtSRef> consumer_blocks;
};

/*! \brief Return the buffer region realted with the buffer */
Expand Down Expand Up @@ -525,7 +527,20 @@ class CacheReadRewriter : public StmtExprMutator {

Stmt VisitStmt_(const BlockNode* block) final {
Block old_stmt = GetRef<Block>(block);
// We don't mutate the block which generates info->read_buffer
// Check if this block is one of the specified consumers.
// If no consumer blocks are specified, all blocks should be considered consumers.
bool is_consumer = info_->consumer_blocks.empty();
// Otherwise check if this is one of the specified blocks.
for (StmtSRef consumer_sref : info_->consumer_blocks) {
const BlockNode* consumer_node = TVM_SREF_TO_BLOCK(consumer_node, consumer_sref);
Block consumer_block = GetRef<Block>(consumer_node);
if (old_stmt.same_as(consumer_block)) {
is_consumer = true;
}
}
// Keep track of this blocks status. We'll use this when rewriting loads.
current_block_consumes = is_consumer;
// We don't mutate the block which generates info->read_buffer.
if (block != scope_sref_->stmt &&
GetBufferRegionFromBuffer(block->writes, info_->read_buffer).defined()) {
return std::move(old_stmt);
Expand All @@ -547,23 +562,26 @@ class CacheReadRewriter : public StmtExprMutator {
stmt = Block(n);
} else {
// Otherwise, update read regions and match_buffers
Array<BufferRegion> reads =
ReplaceBuffer(block->reads, info_->read_buffer, info_->write_buffer);
Array<MatchBufferRegion> match_buffers =
ReplaceBuffer(block->match_buffers, info_->read_buffer, info_->write_buffer);
if (!reads.same_as(block->reads) || !match_buffers.same_as(block->match_buffers)) {
ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
n->reads = std::move(reads);
n->match_buffers = std::move(match_buffers);
stmt = Block(n);
// Only make this change if the block is one of the specified consumers.
if (is_consumer) {
Array<BufferRegion> reads =
ReplaceBuffer(block->reads, info_->read_buffer, info_->write_buffer);
Array<MatchBufferRegion> match_buffers =
ReplaceBuffer(block->match_buffers, info_->read_buffer, info_->write_buffer);
if (!reads.same_as(block->reads) || !match_buffers.same_as(block->match_buffers)) {
ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
n->reads = std::move(reads);
n->match_buffers = std::move(match_buffers);
stmt = Block(n);
}
}
}
info_->block_reuse.Set(old_stmt, stmt);
return std::move(stmt);
}

PrimExpr VisitExpr_(const BufferLoadNode* load) final {
if (load->buffer.same_as(info_->read_buffer)) {
if (load->buffer.same_as(info_->read_buffer) && current_block_consumes) {
ObjectPtr<BufferLoadNode> n = make_object<BufferLoadNode>(*load);
n->buffer = info_->write_buffer;
return PrimExpr(n);
Expand All @@ -588,6 +606,8 @@ class CacheReadRewriter : public StmtExprMutator {
const StmtSRef& scope_sref_;
/*! \brief The info for inserting cache stage */
CacheStageInfo* info_;
/*! \brief Whether the most recently visited block is a specified consumer. */
bool current_block_consumes;
};

/*! \brief Mutator for CacheWrite */
Expand Down Expand Up @@ -963,7 +983,7 @@ class ReIndexRewriter : public StmtExprMutator {
/******** Implementation ********/

StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index,
const String& storage_scope) {
const String& storage_scope, const Array<StmtSRef> consumer_blocks) {
/*!
* Check:
* - The index is in the array of block reading region
Expand Down Expand Up @@ -992,6 +1012,8 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff
info.write_buffer = WithScope(read_buffer, storage_scope);
// Create the corresponding buffer allocation
info.alloc = info.write_buffer;
// Indicate which buffers should consume the cache.
info.consumer_blocks = consumer_blocks;

// Step 3. Update cache stage info.
BufferRegion cache_region{nullptr};
Expand Down Expand Up @@ -1170,21 +1192,26 @@ struct CacheReadTraits : public UnpackedInstTraits<CacheReadTraits> {
static constexpr bool kIsPure = false;

private:
static constexpr size_t kNumInputs = 1;
static constexpr size_t kNumInputs = 2;
static constexpr size_t kNumAttrs = 2;
static constexpr size_t kNumDecisions = 0;

static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer read_buffer_index,
static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block,
Array<BlockRV> consumer_blocks, Integer read_buffer_index,
String storage_scope) {
return sch->CacheRead(block, read_buffer_index->value, storage_scope);
return sch->CacheRead(block, read_buffer_index->value, storage_scope, consumer_blocks);
}

static String UnpackedAsPython(Array<String> outputs, String block, Integer read_buffer_index,
String storage_scope) {
static String UnpackedAsPython(Array<String> outputs, String block, Array<String> consumer_blocks,
Integer read_buffer_index, String storage_scope) {
PythonAPICall py("cache_read");
py.Input("block", block);
py.Input("read_buffer_index", read_buffer_index->value);
py.Input("storage_scope", storage_scope);
// Only write out consumer blocks if provided.
if (!consumer_blocks.empty()) {
py.Input("consumer_blocks", consumer_blocks);
}
py.SingleOutput(outputs);
return py.Str();
}
Expand Down
3 changes: 3 additions & 0 deletions src/tir/schedule/trace.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ Array<ObjectRef> TranslateInputRVs(const Array<ObjectRef>& inputs,
<< "TypeError: Expect 'tir.Var', but gets: " << dst->GetTypeKey();
return GetRef<Var>(static_cast<const VarNode*>(dst));
}));
} else if (input->IsInstance<ArrayNode>()) {
// Recursively convert elements of the array into a new list of ObjectRefs.
result.push_back(TranslateInputRVs(Downcast<Array<ObjectRef>>(input), rv_map));
} else {
ICHECK(false) << "TypeError: Cannot recognize the type of an input random variable: "
<< input->GetTypeKey();
Expand Down
8 changes: 5 additions & 3 deletions src/tir/schedule/traced_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -282,12 +282,14 @@ void TracedScheduleNode::Unroll(const LoopRV& loop_rv) {

/******** Schedule: Insert cache stages ********/
BlockRV TracedScheduleNode::CacheRead(const BlockRV& block_rv, int read_buffer_index,
const String& storage_scope) {
BlockRV result = ConcreteScheduleNode::CacheRead(block_rv, read_buffer_index, storage_scope);
const String& storage_scope,
const Array<BlockRV> consumer_blocks) {
BlockRV result =
ConcreteScheduleNode::CacheRead(block_rv, read_buffer_index, storage_scope, consumer_blocks);

static const InstructionKind& kind = InstructionKind::Get("CacheRead");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
/*inputs=*/{block_rv},
/*inputs=*/{block_rv, consumer_blocks},
/*attrs=*/{Integer(read_buffer_index), storage_scope},
/*outputs=*/{result}));
return result;
Expand Down
4 changes: 2 additions & 2 deletions src/tir/schedule/traced_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ class TracedScheduleNode : public ConcreteScheduleNode {
void Bind(const LoopRV& loop_rv, const String& thread_axis) final;
void Unroll(const LoopRV& loop_rv) final;
/******** Schedule: Insert cache stages ********/
BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index,
const String& storage_scope) final;
BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index, const String& storage_scope,
const Array<BlockRV> consumer_blocks = {}) final;
BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) final;
BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
Expand Down
42 changes: 42 additions & 0 deletions tests/python/unittest/test_tir_schedule_cache_read_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,32 @@ def cache_read_multi_consumer() -> None:
C[vi] = A_global[vi]


@T.prim_func
def cache_read_multi_consumer_target() -> None:
A = T.alloc_buffer((128))
B = T.alloc_buffer((128))
C = T.alloc_buffer((128))
A_global = T.alloc_buffer((128))
for i in T.grid(8):
for j in T.grid(16):
with T.block("A"):
vi = T.axis.S(128, i * 16 + j)
A[vi] = 1.0
for j in T.grid(16):
with T.block("A"):
vi = T.axis.S(128, i * 16 + j)
A_global[vi] = A[vi]
for j in T.grid(16):
with T.block("B"):
vi = T.axis.S(128, i * 16 + j)
B[vi] = A[vi] + 1.0

for i in T.grid(128):
with T.block("C"):
vi = T.axis.S(128, i)
C[vi] = A_global[vi]


@T.prim_func
def continuous_cache_read(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (128, 128))
Expand Down Expand Up @@ -783,6 +809,22 @@ def test_cache_read_location(use_block_name):
tvm.ir.assert_structural_equal(cache_read_multi_consumer, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=func_multi_consumer)

# Test that specific consumer block targetting works.
sch = tir.Schedule(func_multi_consumer, debug_mask="all")
block_b = "B" if use_block_name else sch.get_block("B")
block_c = "C" if use_block_name else sch.get_block("C")
sch.cache_read(block_b, 0, "global", consumer_blocks=[block_c])
tvm.ir.assert_structural_equal(cache_read_multi_consumer_target, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=func_multi_consumer)

# Also test setting multiple consumers yields same result as unspecified.
sch = tir.Schedule(func_multi_consumer, debug_mask="all")
block_b = "B" if use_block_name else sch.get_block("B")
block_c = "C" if use_block_name else sch.get_block("C")
sch.cache_read(block_b, 0, "global", consumer_blocks=[block_b, block_c])
tvm.ir.assert_structural_equal(cache_read_multi_consumer, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=func_multi_consumer)


def test_continuous_cache_read(use_block_name):
sch = tir.Schedule(elementwise, debug_mask="all")
Expand Down