Skip to content

Commit

Permalink
[MetaSchedule] Introducing MemHammer (#14164)
Browse files Browse the repository at this point in the history
Introducing MemHammer

Co-authored-by: Wuwei Lin <[email protected]>
Co-authored-by: Junru Shao <[email protected]>
Co-authored-by: Siyuan Feng <[email protected]>
Co-authored-by: Ruihang Lai <[email protected]>
Co-authored-by: Bohan Hou <[email protected]>
Co-authored-by: Hongyi Jin <[email protected]>
  • Loading branch information
7 people authored Mar 20, 2023
1 parent 0627684 commit 36b3097
Show file tree
Hide file tree
Showing 21 changed files with 3,687 additions and 0 deletions.
5 changes: 5 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,11 @@ class ScheduleNode : public runtime::Object {
*/
virtual BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type) = 0;
/******** Schedule: Data movement ********/
virtual BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index,
const String& storage_scope) = 0;
virtual BlockRV WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) = 0;
/******** Schedule: Compute location ********/
/*!
* \brief Move a producer block under the specific loop, and regenerate the
Expand Down
31 changes: 31 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1609,6 +1609,37 @@ constexpr const char* meta_schedule_layout_rewrite_preproc = "meta_schedule.layo
*/
constexpr const char* meta_schedule_auto_tensorize_init = "meta_schedule.auto_tensorize_init";

/*!
* \brief Mark that the block need to add predicate for block var bounds during lowering
*/
constexpr const char* require_block_var_bound_predicate = "require_bound_predicate";

/*! \brief Mark that tensor core is enabled in the PrimExpr */
constexpr const char* meta_schedule_tensor_core_enabled = "meta_schedule.tensor_core_enabled";

/*!
* \brief Mark a block as generated by cache_read or cache_write block.
* 0 means cache_read; 1 means cache_write.
* \sa meta_schedule_cache_type_read
* \sa meta_schedule_cache_type_write
*/
constexpr const char* meta_schedule_cache_type = "meta_schedule.cache_type";

/*! \sa meta_schedule_cache_type */
constexpr const int meta_schedule_cache_type_read = 0;

/*! \sa meta_schedule_cache_type */
constexpr const int meta_schedule_cache_type_write = 1;

/*! \brief Mark auto copy for memhammer */
constexpr const char* auto_copy = "auto_copy";

/*! \brief Mark local stage constraint on data copy */
constexpr const char* local_stage = "local_stage";

/*! \brief Mark vectorization length constraint on block */
constexpr const char* vector_bytes = "vector_bytes";

/*!
* \brief Mark that a block is executed by a warp. This implies the extend of threadIdx.x is
* warp size.
Expand Down
6 changes: 6 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,12 @@ TVM_DLL Pass BindParams(const Array<runtime::NDArray>& constants);
*/
TVM_DLL Pass ExtractPrimFuncConstants();

/*!
* \brief Automatically do memory optimizations for auto copy blocks
* \return The pass.
*/
TVM_DLL Pass LowerAutoCopy();

/*!
* \brief Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv())
* \return The pass.
Expand Down
24 changes: 24 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1687,6 +1687,30 @@ def after_reindex(
self, block, buffer_index, buffer_index_type_enum
)

########## Schedule: Data movement ##########

def read_at(
self,
loop: LoopRV,
block: BlockRV,
read_buffer_index: int,
storage_scope: str,
) -> BlockRV:
return _ffi_api.ScheduleReadAt( # type: ignore # pylint: disable=no-member
self, loop, block, read_buffer_index, storage_scope
)

def write_at(
self,
loop: LoopRV,
block: BlockRV,
write_buffer_index: int,
storage_scope: str,
) -> BlockRV:
return _ffi_api.ScheduleWriteAt( # type: ignore # pylint: disable=no-member
self, loop, block, write_buffer_index, storage_scope
)

########## Schedule: Compute location ##########

@type_checked
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,17 @@ def ExtractPrimFuncConstants():
return _ffi_api.ExtractPrimFuncConstants() # type: ignore


def LowerAutoCopy():
"""Automatically do memory optimizations for auto copy blocks
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LowerAutoCopy() # type: ignore


def RenormalizeSplitPattern():
"""Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv())
Expand Down
1 change: 1 addition & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
pass_list.push_back(tir::transform::UnifyThreadBinding());
pass_list.push_back(tir::transform::ManifestSharedMemoryLocalStage());
pass_list.push_back(tir::transform::CompactBufferAllocation());
pass_list.push_back(tir::transform::LowerAutoCopy());
pass_list.push_back(tir::transform::LowerMatchBuffer());
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
pass_list.push_back(tir::transform::LowerOpaqueBlock());
Expand Down
1 change: 1 addition & 0 deletions src/meta_schedule/feature_extractor/per_store_feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ Sequential PassListForPerStoreFeature() {
tir::transform::ConvertBlocksToOpaque(),
tir::transform::UnifyThreadBinding(),
tir::transform::CompactBufferAllocation(),
tir::transform::LowerAutoCopy(),
tir::transform::LowerMatchBuffer(),
tir::transform::Simplify(),
});
Expand Down
1 change: 1 addition & 0 deletions src/meta_schedule/postproc/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ class VerifyGPUCodeNode : public PostprocNode {
pass_list.push_back(tir::transform::UnifyThreadBinding());
pass_list.push_back(tir::transform::ManifestSharedMemoryLocalStage());
pass_list.push_back(tir::transform::CompactBufferAllocation());
pass_list.push_back(tir::transform::LowerAutoCopy());
pass_list.push_back(tir::transform::LowerMatchBuffer());
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
pass_list.push_back(tir::transform::LowerOpaqueBlock());
Expand Down
24 changes: 24 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,30 @@ BlockRV ConcreteScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index,
return CreateRV<BlockRV>(result);
}

/******** Schedule: Data movement ********/

BlockRV ConcreteScheduleNode::ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv,
int read_buffer_index, const String& storage_scope) {
StmtSRef result{nullptr};
TVM_TIR_SCHEDULE_BEGIN();
result = tir::ReadAt(state_, this->GetSRef(loop_rv), this->GetSRef(block_rv), read_buffer_index,
storage_scope);
TVM_TIR_SCHEDULE_END("read-at", this->error_render_level_);
this->state_->DebugVerify();
return CreateRV<BlockRV>(result);
}

BlockRV ConcreteScheduleNode::WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv,
int write_buffer_index, const String& storage_scope) {
StmtSRef result{nullptr};
TVM_TIR_SCHEDULE_BEGIN();
result = tir::WriteAt(state_, this->GetSRef(loop_rv), this->GetSRef(block_rv), write_buffer_index,
storage_scope);
TVM_TIR_SCHEDULE_END("write-at", this->error_render_level_);
this->state_->DebugVerify();
return CreateRV<BlockRV>(result);
}

/******** Schedule: Compute location ********/

void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
Expand Down
5 changes: 5 additions & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ class ConcreteScheduleNode : public ScheduleNode {
int cse_thresh) override;
BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type) override;
/******** Schedule: Data movement ********/
BlockRV ReadAt(const LoopRV& loop_rv, const BlockRV& block_rv, int read_buffer_index,
const String& storage_scope) override;
BlockRV WriteAt(const LoopRV& loop_rv, const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) override;
/******** Schedule: Compute location ********/
void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
int index = -1) override;
Expand Down
9 changes: 9 additions & 0 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,15 @@ TVM_DLL Array<StmtSRef> CacheIndex(ScheduleState self, const StmtSRef& block_sre
*/
TVM_DLL StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
BufferIndexType buffer_index_type);

/******** Schedule: Data movement ********/

TVM_DLL StmtSRef ReadAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref,
int read_buffer_index, const String& storage_scope);

TVM_DLL StmtSRef WriteAt(ScheduleState self, const StmtSRef& loop_sref, const StmtSRef& block_sref,
int write_buffer_index, const String& storage_scope);

/******** Schedule: Compute location ********/
/*!
* \brief Move a producer block under the specific loop, and regenerate the
Expand Down
Loading

0 comments on commit 36b3097

Please sign in to comment.