Skip to content

Commit

Permalink
[TIR][USMP] Augmenting the algo interface with memory pressure (#9649)
Browse files Browse the repository at this point in the history
This commit adds memory pressue to be an arugment to
the USMP algorithm interface as certain iterative algorithms
could use this as a guide determine the termination
criteria.

Change-Id: I3fb5eea3fe5ba43e68c23625d411e557f6dd89a3
  • Loading branch information
manupak authored Dec 7, 2021
1 parent a3e03a3 commit cb132e2
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 45 deletions.
39 changes: 39 additions & 0 deletions include/tvm/tir/usmp/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,45 @@ class BufferInfo : public ObjectRef {
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BufferInfo, ObjectRef, BufferInfoNode);
};

/*!
* \brief This is a composite node that is produced by extract_buffer_info
* analysis pass that contains useful global information that could be useful
* for memory planning algorithms.
*/
struct BufferInfoAnalysisNode : public Object {
/*! \brief The BufferInfo object and its associated TIR statement */
Map<BufferInfo, tir::Stmt> buffer_info_stmts;
/*! \brief This represent maximum amount of memory being used at
* any point of time in the inference. This value is largely the
* best allocation an algorithm could achieve. Due to
* the complexities of conflict graphs, it would not be feasible
* to achieve this value, practically. However, it can be useful
* for iterative algorithms to know this value to define termination
* criteria.*/
Integer memory_pressure;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("buffer_info_stmts", &buffer_info_stmts);
v->Visit("memory_pressure", &memory_pressure);
}

bool SEqualReduce(const BufferInfoAnalysisNode* other, SEqualReducer equal) const {
return equal(buffer_info_stmts, other->buffer_info_stmts) &&
equal(memory_pressure, other->memory_pressure);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(buffer_info_stmts);
hash_reduce(memory_pressure);
}
};

class BufferInfoAnalysis : public ObjectRef {
public:
TVM_DLL BufferInfoAnalysis(Map<BufferInfo, tir::Stmt> buffer_info_stmts, Integer memory_pressure);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(BufferInfoAnalysis, ObjectRef, BufferInfoAnalysisNode);
};

/*!
* \brief The pool allocation produced after the USMP algorithm
*/
Expand Down
14 changes: 8 additions & 6 deletions src/tir/usmp/algo/greedy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -209,22 +209,24 @@ class GreedyConflicts : public GreedyBase {
}
};

Map<BufferInfo, PoolAllocation> GreedyBySize(const Array<BufferInfo>& buffer_info_arr) {
Map<BufferInfo, PoolAllocation> GreedyBySize(const Array<BufferInfo>& buffer_info_arr,
const Integer& memory_pressure) {
return GreedySize().PlanMemory(buffer_info_arr);
}

Map<BufferInfo, PoolAllocation> GreedyByConflicts(const Array<BufferInfo>& buffer_info_arr) {
Map<BufferInfo, PoolAllocation> GreedyByConflicts(const Array<BufferInfo>& buffer_info_arr,
const Integer& memory_pressure) {
return GreedyConflicts().PlanMemory(buffer_info_arr);
}

TVM_REGISTER_GLOBAL("tir.usmp.algo.greedy_by_size")
.set_body_typed([](Array<BufferInfo> buffer_info_arr) {
return GreedyBySize(buffer_info_arr);
.set_body_typed([](Array<BufferInfo> buffer_info_arr, Integer memory_pressure) {
return GreedyBySize(buffer_info_arr, memory_pressure);
});

TVM_REGISTER_GLOBAL("tir.usmp.algo.greedy_by_conflicts")
.set_body_typed([](Array<BufferInfo> buffer_info_arr) {
return GreedyByConflicts(buffer_info_arr);
.set_body_typed([](Array<BufferInfo> buffer_info_arr, Integer memory_pressure) {
return GreedyByConflicts(buffer_info_arr, memory_pressure);
});

} // namespace algo
Expand Down
31 changes: 15 additions & 16 deletions src/tir/usmp/analysis/extract_buffer_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class BufferInfoExtractor : public StmtExprVisitor {
// Pushing a scope info for the initial body of the main function
scope_stack_.push(ScopeInfo());
}
Map<BufferInfo, tir::Stmt> operator()(const PrimFunc& func);
BufferInfoAnalysis operator()(const PrimFunc& func);

private:
void VisitStmt(const Stmt& n) override;
Expand Down Expand Up @@ -400,7 +400,7 @@ void BufferInfoExtractor::VisitExpr_(const CallNode* op) {
StmtExprVisitor::VisitExpr_(op);
}

Map<BufferInfo, tir::Stmt> BufferInfoExtractor::operator()(const PrimFunc& main_func) {
BufferInfoAnalysis BufferInfoExtractor::operator()(const PrimFunc& main_func) {
VisitPrimFunc(main_func, Call());

// Create a vector of liveness events
Expand Down Expand Up @@ -454,33 +454,32 @@ Map<BufferInfo, tir::Stmt> BufferInfoExtractor::operator()(const PrimFunc& main_

// Traverse the liveness events using a open set to track what
// is live while updating the conflicts through out the linear traversal
std::unordered_map<BufferInfo, int, ObjectPtrHash, ObjectPtrEqual> open_set;

int open_set_size = 0;
int max_open_set_size = 0;
std::unordered_set<BufferInfo, ObjectPtrHash, ObjectPtrEqual> open_set;
for (const auto& le_event : le_events_timeline) {
if (le_event.le_type == START) {
for (const auto& kv : open_set) {
BufferInfo open_buffer_info = kv.first;
for (const BufferInfo& open_buffer_info : open_set) {
open_buffer_info->conflicts.push_back(le_event.buffer_info);
if (le_event.buffer_info != open_buffer_info) {
le_event.buffer_info->conflicts.push_back(open_buffer_info);
}
}
if (open_set.find(le_event.buffer_info) == open_set.end()) {
open_set[le_event.buffer_info] = 1;
} else {
open_set[le_event.buffer_info] += 1;
open_set_size += le_event.buffer_info->size_bytes;
if (open_set_size > max_open_set_size) {
max_open_set_size = open_set_size;
}
open_set.insert(le_event.buffer_info);
} else {
if (open_set[le_event.buffer_info] == 1) {
open_set.erase(le_event.buffer_info);
} else {
open_set[le_event.buffer_info] -= 1;
}
open_set_size -= le_event.buffer_info->size_bytes;
open_set.erase(le_event.buffer_info);
}
}
return this->buffer_info_map_;
return BufferInfoAnalysis(this->buffer_info_map_, max_open_set_size);
}

Map<BufferInfo, tir::Stmt> ExtractBufferInfo(const PrimFunc& main_func, const IRModule& mod) {
BufferInfoAnalysis ExtractBufferInfo(const PrimFunc& main_func, const IRModule& mod) {
return BufferInfoExtractor(mod)(main_func);
}

Expand Down
22 changes: 22 additions & 0 deletions src/tir/usmp/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,28 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
<< ",\n alignment=" << node->alignment << ")";
});

BufferInfoAnalysis::BufferInfoAnalysis(Map<BufferInfo, tir::Stmt> buffer_info_stmts,
Integer memory_pressure) {
auto bufinfo_analysis_node = make_object<BufferInfoAnalysisNode>();
bufinfo_analysis_node->buffer_info_stmts = buffer_info_stmts;
bufinfo_analysis_node->memory_pressure = memory_pressure;
data_ = std::move(bufinfo_analysis_node);
}

TVM_REGISTER_NODE_TYPE(BufferInfoAnalysisNode);
TVM_REGISTER_GLOBAL("tir.usmp.BufferInfoAnalysis")
.set_body_typed([](Map<BufferInfo, tir::Stmt> buffer_info_stmts, Integer memory_pressure) {
return BufferInfoAnalysis(buffer_info_stmts, memory_pressure);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<BufferInfoAnalysisNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const BufferInfoAnalysisNode*>(ref.get());
p->stream << "BufferInfoAnalysisNode(\n"
<< "buffer_info_stmts=" << node->buffer_info_stmts
<< ",\n memory_pressure=" << node->memory_pressure << ")";
});

PoolInfo::PoolInfo(String pool_name, Map<Target, String> target_access, Integer size_hint_bytes) {
auto poolinfo_node = make_object<PoolInfoNode>();
poolinfo_node->pool_name = pool_name;
Expand Down
22 changes: 12 additions & 10 deletions tests/python/unittest/test_tir_usmp_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def test_no_pool_error():
with pytest.raises(
tvm.TVMError, match="TVM USMP Error: the space available in the provided pools exceeded"
):
buffer_pool_allocations = fusmp_algo(buffer_info_arr)
buffer_pool_allocations = fusmp_algo(buffer_info_arr, 0)


@pytest.mark.parametrize("algorithm", ["greedy_by_size", "greedy_by_conflicts"])
Expand Down Expand Up @@ -148,7 +148,7 @@ def _test():

buffer_info_arr = [bi_a, bi_b, bi_c]
fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}")
buffer_pool_allocations = fusmp_algo(buffer_info_arr)
buffer_pool_allocations = fusmp_algo(buffer_info_arr, 0)
assert buffer_pool_allocations[bi_a].byte_offset == 20
assert buffer_pool_allocations[bi_b].byte_offset == 10
assert buffer_pool_allocations[bi_c].byte_offset == 0
Expand Down Expand Up @@ -216,7 +216,7 @@ def test_linear(algorithm, workspace_size):

buffer_info_arr = [bi_a, bi_b, bi_c, bi_d, bi_e, bi_f]
fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}")
buffer_pool_allocations = fusmp_algo(buffer_info_arr)
buffer_pool_allocations = fusmp_algo(buffer_info_arr, 0)
_check_max_workspace_size(buffer_pool_allocations, global_workspace_pool, workspace_size)


Expand Down Expand Up @@ -287,7 +287,7 @@ def test_fanout(algorithm, workspace_size):

buffer_info_arr = [bi_a, bi_b, bi_c, bi_d, bi_e, bi_f, bi_g]
fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}")
buffer_pool_allocations = fusmp_algo(buffer_info_arr)
buffer_pool_allocations = fusmp_algo(buffer_info_arr, 0)
_check_max_workspace_size(buffer_pool_allocations, global_workspace_pool, workspace_size)


Expand Down Expand Up @@ -382,12 +382,13 @@ def test_mobilenet_subgraph(algorithm, fast_memory_size, slow_memory_size):
tir_mod, [fast_memory_pool, slow_memory_pool]
)
main_func = tir_mod["run_model"]
buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
buffer_info_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
assert buffer_info_analysis.memory_pressure == 1117718

fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
buffer_info_arr = fcreate_array_bi(buffer_info_map)
buffer_info_arr = fcreate_array_bi(buffer_info_analysis.buffer_info_stmts)
fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}")
buffer_pool_allocations = fusmp_algo(buffer_info_arr)
buffer_pool_allocations = fusmp_algo(buffer_info_arr, buffer_info_analysis.memory_pressure)

buffer_info_map_names = dict()
for buf_info in buffer_info_arr:
Expand Down Expand Up @@ -540,12 +541,13 @@ def test_resnet_subgraph(algorithm, workspace_size):
tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
tir_mod = _assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_workspace_pool])
main_func = tir_mod["tvmgen_default_run_model"]
buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
buffer_info_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
assert buffer_info_analysis.memory_pressure == 7200256

fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
buffer_info_arr = fcreate_array_bi(buffer_info_map)
buffer_info_arr = fcreate_array_bi(buffer_info_analysis.buffer_info_stmts)
fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}")
buffer_pool_allocations = fusmp_algo(buffer_info_arr)
buffer_pool_allocations = fusmp_algo(buffer_info_arr, buffer_info_analysis.memory_pressure)

buffer_info_map_names = dict()
for buf_info in buffer_info_arr:
Expand Down
25 changes: 15 additions & 10 deletions tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,9 @@ def test_linear():
tir_mod = _assign_poolinfos_to_allocates_in_irmodule(
tir_mod, [fast_memory_pool, slow_memory_pool]
)
buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(tir_mod["run_model"], tir_mod)
buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_map)
buffer_info_analysis = tvm.tir.usmp.analysis.extract_buffer_info(tir_mod["run_model"], tir_mod)
assert buffer_info_analysis.memory_pressure == 1117718
buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_analysis.buffer_info_stmts)

# check conflicts
_verify_conflicts("PaddedInput_7", ["sid_9", "sid_8", "Conv2dOutput_7"], buffer_info_map)
Expand Down Expand Up @@ -293,8 +294,9 @@ def test_parallel_serial_mixed_for_loops():
all_serial_tir_mod, [global_ws_pool]
)
main_func = all_serial_tir_mod["run_model"]
buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, all_serial_tir_mod)
buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_map)
buffer_info_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, all_serial_tir_mod)
assert buffer_info_analysis.memory_pressure == 430848
buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_analysis.buffer_info_stmts)

# When all loops are serial all allocates are touched by USMP
assert len(buffer_info_map) == 3
Expand All @@ -309,10 +311,11 @@ def test_parallel_serial_mixed_for_loops():
parallel_serial_mixed_tir_mod, [global_ws_pool]
)
main_func = parallel_serial_mixed_tir_mod["run_model"]
buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(
buffer_info_analysis = tvm.tir.usmp.analysis.extract_buffer_info(
main_func, parallel_serial_mixed_tir_mod
)
buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_map)
assert buffer_info_analysis.memory_pressure == 430848
buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_analysis.buffer_info_stmts)

# USMP will not touch (yet) the allocates inside parallel for loops
assert len(buffer_info_map) == 2
Expand Down Expand Up @@ -656,8 +659,9 @@ def test_inception_structure():
tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
tir_mod = _assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_ws_pool])
main_func = tir_mod["run_model"]
buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_map)
buffer_info_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
assert buffer_info_analysis.memory_pressure == 1117718
buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_analysis.buffer_info_stmts)

# check conflicts
_verify_conflicts(
Expand Down Expand Up @@ -1369,8 +1373,9 @@ def test_multiple_calls_to_same_primfunc():
tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
tir_mod = _assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_ws_pool])
main_func = tir_mod["run_model"]
buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_map)
buffer_info_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
assert buffer_info_analysis.memory_pressure == 11424
buffer_info_map = _replace_stmt_with_buf_var_names(buffer_info_analysis.buffer_info_stmts)

# check conflicts
_verify_conflicts(
Expand Down
6 changes: 3 additions & 3 deletions tests/python/unittest/test_tir_usmp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,10 @@ def test_create_array_buffer_info():
tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
tir_mod = _assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_ws_pool])
main_func = tir_mod["tvmgen_default_run_model"]
buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
buffer_info_array = fcreate_array_bi(buffer_info_map)
buffer_info_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
buffer_info_array = fcreate_array_bi(buffer_info_analysis.buffer_info_stmts)
for buffer_info in buffer_info_array:
assert buffer_info in buffer_info_map.keys()
assert buffer_info in buffer_info_analysis.buffer_info_stmts.keys()


if __name__ == "__main__":
Expand Down

0 comments on commit cb132e2

Please sign in to comment.