Skip to content

Commit

Permalink
Port StorageInfo and StaticMemoryPlan data structure (#8297)
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch authored Jun 23, 2021
1 parent 35d71b1 commit 7e7d7fb
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 0 deletions.
19 changes: 19 additions & 0 deletions src/relay/backend/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,25 @@ namespace tvm {
namespace relay {
namespace backend {

TVM_REGISTER_NODE_TYPE(StorageInfoNode);

StorageInfo::StorageInfo(std::vector<int64_t> storage_ids, std::vector<DLDeviceType> device_types,
std::vector<int64_t> storage_sizes_in_bytes) {
auto n = make_object<StorageInfoNode>();
n->storage_ids = std::move(storage_ids);
n->device_types = std::move(device_types);
n->storage_sizes_in_bytes = std::move(storage_sizes_in_bytes);
data_ = std::move(n);
}

TVM_REGISTER_NODE_TYPE(StaticMemoryPlanNode);

StaticMemoryPlan::StaticMemoryPlan(Map<Expr, StorageInfo> expr_to_storage_info) {
auto n = make_object<StaticMemoryPlanNode>();
n->expr_to_storage_info = std::move(expr_to_storage_info);
data_ = std::move(n);
}

int64_t CalculateRelayExprSizeBytes(const Type& expr_type) {
if (expr_type->IsInstance<TupleTypeNode>()) {
auto tuple_type = Downcast<TupleType>(expr_type);
Expand Down
47 changes: 47 additions & 0 deletions src/relay/backend/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,53 @@ namespace tvm {
namespace relay {
namespace backend {

/*!
* \brief The static storage information produced by memory planning.
*/
class StorageInfoNode : public Object {
public:
/*! \brief The set of storage ids where the expression is stored. */
std::vector<int64_t> storage_ids;
/* \brief The type of "virtual devices" these expressions are stored on. */
std::vector<DLDeviceType> device_types;
/* \brief The sizes of each storage element. */
std::vector<int64_t> storage_sizes_in_bytes;

// TODO(@jroesch): expose the fields
void VisitAttrs(AttrVisitor* v) {}

static constexpr const char* _type_key = "relay.StorageInfo";
TVM_DECLARE_FINAL_OBJECT_INFO(StorageInfoNode, Object);
};

/*! \brief The storage information for a single expression. */
class StorageInfo : public ObjectRef {
public:
StorageInfo(std::vector<int64_t> storage_ids, std::vector<DLDeviceType> device_types,
std::vector<int64_t> storage_sizes_in_bytes);
TVM_DEFINE_OBJECT_REF_METHODS(StorageInfo, ObjectRef, StorageInfoNode);
};

/*!
* \brief The result of static memory planning.
*/
class StaticMemoryPlanNode : public Object {
public:
Map<Expr, StorageInfo> expr_to_storage_info;

void VisitAttrs(AttrVisitor* v) { v->Visit("expr_to_storage_info", &expr_to_storage_info); }

static constexpr const char* _type_key = "relay.StaticMemoryPlan";
TVM_DECLARE_FINAL_OBJECT_INFO(StaticMemoryPlanNode, Object);
};

/*! \brief The result of running static memory planning. */
class StaticMemoryPlan : public ObjectRef {
public:
explicit StaticMemoryPlan(Map<Expr, StorageInfo> expr_to_storage_info);
TVM_DEFINE_OBJECT_REF_METHODS(StaticMemoryPlan, ObjectRef, StaticMemoryPlanNode);
};

struct FunctionInfoNode : public Object {
Map<Target, Integer> workspace_sizes;
Map<Target, Integer> io_sizes;
Expand Down

0 comments on commit 7e7d7fb

Please sign in to comment.