Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[USMP] Add performance characteristics to PoolInfo #10005

Merged
merged 2 commits into from
Feb 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 63 additions & 25 deletions include/tvm/tir/usmp/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,70 +44,108 @@ constexpr const char* kUSMPAlgorithmOption = "tir.usmp.algorithm";
namespace tir {
namespace usmp {

/*!
* \brief The string parameter to indicate read and write access to a pool
* This needs to be kept in sync with PoolInfo.READ_WRITE_ACCESS in
* python/tvm/tir/usmp/utils.py
*/
static constexpr const char* kTargetPoolReadWriteAccess = "rw";
/*!
* \brief The string parameter to indicate read only access to a pool
* This needs to be kept in sync with PoolInfo.READ_ONLY_ACCESS in
* python/tvm/tir/usmp/utils.py
*/
static constexpr const char* kTargetPoolReadOnlyAccess = "ro";

/*!
* \brief Describes a pool of memory accessible by one or more targets.
*/
struct PoolInfoNode : public Object {
/*! \brief The name of the memory pool */
String pool_name;
/*! \brief The expected size hint to be used by the allocator.
* The size_hint_bytes is defaulted to kUnrestrictedPoolSizeHint
* The size_hint_bytes is set to kUnrestrictedPoolSizeHint
* to indicate the pool is not size restricted.
*/
Integer size_hint_bytes;
/*! \brief The accessibility from each Target*/
/*! \brief The accessibility from each Target */
Map<Target, String> target_access; // 'rw' or 'ro'
/*! \brief The clock frequency of the memory in Hz */
Integer clock_frequency_hz;
/*! \brief The read bandwidth in bytes/cycle */
Integer read_bandwidth_bytes_per_cycle;
/*! \brief The write bandwidth in bytes/cycle */
Integer write_bandwidth_bytes_per_cycle;
/*! \brief The read latency in cycles */
Integer read_latency_cycles;
/*! \brief The write latency in cycles */
Integer write_latency_cycles;
/*! \brief The burst length in bytes for each Target */
Map<Target, Integer> target_burst_bytes;
/*! \brief Whether pool is internally generated.
* The internal pools will be generated as part of
* the entry point code generation of the executor*/
* the entry point code generation of the executor
*/
bool is_internal = false;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("pool_name", &pool_name);
v->Visit("size_hint_bytes", &size_hint_bytes);
v->Visit("target_access", &target_access);
v->Visit("clock_frequency_hz", &clock_frequency_hz);
v->Visit("read_bandwidth_bytes_per_cycle", &read_bandwidth_bytes_per_cycle);
v->Visit("write_bandwidth_bytes_per_cycle", &write_bandwidth_bytes_per_cycle);
v->Visit("read_latency_cycles", &read_latency_cycles);
v->Visit("write_latency_cycles", &write_latency_cycles);
v->Visit("target_burst_bytes", &target_burst_bytes);
v->Visit("is_internal", &is_internal);
}

bool SEqualReduce(const PoolInfoNode* other, SEqualReducer equal) const {
return equal(pool_name, other->pool_name) && equal(size_hint_bytes, other->size_hint_bytes) &&
equal(target_access, other->target_access) && equal(is_internal, other->is_internal);
equal(target_access, other->target_access) &&
equal(target_access, other->target_access) &&
equal(clock_frequency_hz, other->clock_frequency_hz) &&
equal(read_bandwidth_bytes_per_cycle, other->read_bandwidth_bytes_per_cycle) &&
equal(write_bandwidth_bytes_per_cycle, other->write_bandwidth_bytes_per_cycle) &&
equal(read_latency_cycles, other->read_latency_cycles) &&
equal(write_latency_cycles, other->write_latency_cycles) &&
equal(target_burst_bytes, other->target_burst_bytes) &&
equal(is_internal, other->is_internal);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(pool_name);
hash_reduce(size_hint_bytes);
hash_reduce(target_access);
hash_reduce(clock_frequency_hz);
hash_reduce(read_bandwidth_bytes_per_cycle);
hash_reduce(write_bandwidth_bytes_per_cycle);
hash_reduce(read_latency_cycles);
hash_reduce(write_latency_cycles);
hash_reduce(target_burst_bytes);
hash_reduce(is_internal);
}

static constexpr const char* _type_key = "tir.usmp.PoolInfo";
TVM_DECLARE_FINAL_OBJECT_INFO(PoolInfoNode, Object);
};

/*!
* \brief The PoolSize is unrestricted for the memory planner
*/
static const int kUnrestrictedPoolSizeHint = -1;

class PoolInfo : public ObjectRef {
public:
TVM_DLL PoolInfo(String pool_name, Map<Target, String> target_access,
Integer size_hint_bytes = kUnrestrictedPoolSizeHint,
Bool is_internal = Bool(false));
/*!
* \brief The string parameter to indicate read and write access to a pool
* This needs to be kept in sync with PoolInfo.READ_WRITE_ACCESS in
* python/tvm/tir/usmp/utils.py
*/
static constexpr const char* kTargetPoolReadWriteAccess = "rw";
/*!
* \brief The string parameter to indicate read only access to a pool
* This needs to be kept in sync with PoolInfo.READ_ONLY_ACCESS in
* python/tvm/tir/usmp/utils.py
*/
static constexpr const char* kTargetPoolReadOnlyAccess = "ro";
/*! \brief The PoolSize is unrestricted for the memory planner */
static const int kUnrestrictedPoolSizeHint = -1;
/*! \brief The clock frequency is not known */
static const int kUnknownClockFrequency = -1;
/*! \brief The read bandwidth is not known */
static const int kUnknownReadBandwidth = -1;
/*! \brief The write bandwidth is not known */
static const int kUnknownWriteBandwidth = -1;

TVM_DLL PoolInfo(String pool_name, Map<Target, String> target_access, Integer size_hint_bytes,
Integer clock_frequency_hz, Integer read_bandwidth_bytes_per_cycle,
Integer write_bandwidth_bytes_per_cycle, Integer read_latency_cycles,
Integer write_latency_cycles, Map<Target, Integer> target_burst_bytes,
Bool is_internal);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PoolInfo, ObjectRef, PoolInfoNode);
};

Expand Down
47 changes: 45 additions & 2 deletions python/tvm/tir/usmp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""USMP Utilities and Data Structures"""
# pylint: disable=invalid-name

from typing import Dict, Optional, List
from typing import Dict, Optional, List, Union

from tvm._ffi import register_object
from tvm.runtime import Object
Expand Down Expand Up @@ -52,6 +52,34 @@ class PoolInfo(Object):
The default value would be -1 which means the pool
is not size restricted.

clock_frequency_hz : Optional[int]
The clock frequency that the memory pool runs at in Hz.
If not specified/known, this will default to -1 indicating
it hasn't been defined.

read_bandwidth_bytes_per_cycle : Optional[int]
The read bandwidth of the memory pool in bytes/cycle.
If not specified/known, this will default to -1 indicating
it hasn't been defined.

write_bandwidth_bytes_per_cycle : Optional[int]
The write bandwidth of the memory pool in bytes/cycle.
If not specified/known, this will default to -1 indicating
it hasn't been defined.

read_latency_cycles : Optional[int]
The read latency of the memory pool in cycles.
If not specified/known, this will default to 0.

write_latency_cycles : Optional[int]
The write latency of the memory pool in cycles.
If not specified/known, this will default to 0.

target_burst_bytes : Optional[Union[Dict[Target, int], None]]
The burst length of the memory pool in bytes per target.
If not specified/known for a given target, a burst length
of 1 byte will be assumed.

"""

# The string parameter to indicate read and write access to a pool
Expand All @@ -67,13 +95,28 @@ def __init__(
self,
pool_name: str,
target_access: Dict[Target, str],
size_hint_bytes: Optional[int] = None,
size_hint_bytes: Optional[int] = -1,
clock_frequency_hz: Optional[int] = -1,
read_bandwidth_bytes_per_cycle: Optional[int] = -1,
write_bandwidth_bytes_per_cycle: Optional[int] = -1,
read_latency_cycles: Optional[int] = 0,
write_latency_cycles: Optional[int] = 0,
target_burst_bytes: Optional[Union[Dict[Target, int], None]] = None,
):
if not target_burst_bytes:
target_burst_bytes = dict()

self.__init_handle_by_constructor__(
_ffi_api.PoolInfo, # type: ignore # pylint: disable=no-member
pool_name,
target_access,
size_hint_bytes,
clock_frequency_hz,
read_bandwidth_bytes_per_cycle,
write_bandwidth_bytes_per_cycle,
read_latency_cycles,
write_latency_cycles,
target_burst_bytes,
)


Expand Down
2 changes: 1 addition & 1 deletion src/tir/usmp/algo/greedy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ size_t GreedyBase::round_up_to_byte_alignment(const size_t& non_aligned_byte_off
*/
bool GreedyBase::IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset,
const size_t& size_bytes) {
if (candidate_pool->size_hint_bytes == -1) {
if (candidate_pool->size_hint_bytes == PoolInfo::kUnrestrictedPoolSizeHint) {
// this means pool is not bounded
return true;
}
Expand Down
8 changes: 5 additions & 3 deletions src/tir/usmp/transform/assign_pool_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,11 @@ class PoolInfoAssigner : public StmtExprMutator {
ICHECK(target_host) << "main function does not have a target attr";
Array<usmp::PoolInfo> pool_infos =
module->GetAttr<Array<usmp::PoolInfo>>(tvm::attr::kPoolInfoIRModuleAttr)
.value_or({usmp::PoolInfo("global_workspace",
{{target_host.value(), usmp::kTargetPoolReadWriteAccess}},
usmp::kUnrestrictedPoolSizeHint, Bool(true))});
.value_or({usmp::PoolInfo(
"global_workspace", {{target_host.value(), PoolInfo::kTargetPoolReadWriteAccess}},
PoolInfo::kUnrestrictedPoolSizeHint, PoolInfo::kUnknownClockFrequency,
PoolInfo::kUnknownReadBandwidth, PoolInfo::kUnknownWriteBandwidth, 0, 0,
{{target_host.value(), 1}}, Bool(true))});
for (const usmp::PoolInfo& pool_info : pool_infos) {
for (const auto& kv : pool_info->target_access) {
Target tgt = kv.first;
Expand Down
32 changes: 24 additions & 8 deletions src/tir/usmp/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,31 +93,47 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
});

PoolInfo::PoolInfo(String pool_name, Map<Target, String> target_access, Integer size_hint_bytes,
Integer clock_frequency_hz, Integer read_bandwidth_bytes_per_cycle,
Integer write_bandwidth_bytes_per_cycle, Integer read_latency_cycles,
Integer write_latency_cycles, Map<Target, Integer> target_burst_bytes,
Bool is_internal) {
auto poolinfo_node = make_object<PoolInfoNode>();
poolinfo_node->pool_name = pool_name;
poolinfo_node->size_hint_bytes = size_hint_bytes;
poolinfo_node->target_access = target_access;
poolinfo_node->clock_frequency_hz = clock_frequency_hz;
poolinfo_node->read_bandwidth_bytes_per_cycle = read_bandwidth_bytes_per_cycle;
poolinfo_node->write_bandwidth_bytes_per_cycle = write_bandwidth_bytes_per_cycle;
poolinfo_node->read_latency_cycles = read_latency_cycles;
poolinfo_node->write_latency_cycles = write_latency_cycles;
poolinfo_node->target_burst_bytes = target_burst_bytes;
poolinfo_node->is_internal = is_internal;
data_ = std::move(poolinfo_node);
}

TVM_REGISTER_NODE_TYPE(PoolInfoNode);
TVM_REGISTER_GLOBAL("tir.usmp.PoolInfo")
.set_body_typed([](String pool_name, Map<Target, String> target_access,
Integer size_hint_bytes) {
if (size_hint_bytes.defined()) {
return PoolInfo(pool_name, target_access, size_hint_bytes);
}
return PoolInfo(pool_name, target_access);
.set_body_typed([](String pool_name, Map<Target, String> target_access, Integer size_hint_bytes,
Integer clock_frequency_hz, Integer read_bandwidth_bytes_per_cycle,
Integer write_bandwidth_bytes_per_cycle, Integer read_latency_cycles,
Integer write_latency_cycles, Map<Target, Integer> target_burst_bytes) {
return PoolInfo(pool_name, target_access, size_hint_bytes, clock_frequency_hz,
read_bandwidth_bytes_per_cycle, write_bandwidth_bytes_per_cycle,
read_latency_cycles, write_latency_cycles, target_burst_bytes, Bool(false));
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PoolInfoNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const PoolInfoNode*>(ref.get());
p->stream << "PoolInfoNode(\n"
<< "pool_name=" << node->pool_name << ",\n target_access=" << node->target_access
<< ",\n size_hint_bytes=" << node->size_hint_bytes << ")";
<< " pool_name=" << node->pool_name << ",\n target_access=" << node->target_access
<< ",\n size_hint_bytes=" << node->size_hint_bytes
<< ",\n clock_frequency_hz=" << node->clock_frequency_hz
<< ",\n read_bandwidth_bytes_per_cycle=" << node->read_bandwidth_bytes_per_cycle
<< ",\n write_bandwidth_bytes_per_cycle=" << node->write_bandwidth_bytes_per_cycle
<< ",\n read_latency_cycles=" << node->read_latency_cycles
<< ",\n write_latency_cycles=" << node->write_latency_cycles
<< ",\n target_burst_bytes=" << node->target_burst_bytes << ")";
});

PoolAllocation::PoolAllocation(PoolInfo pool_info, Integer byte_offset) {
Expand Down