Skip to content

Commit

Permalink
[TIR][USMP] greedy_by_size usmp algo
Browse files Browse the repository at this point in the history
* Adding targets to the PrimFuncs in the tests

Change-Id: Ic91947e23cbcc4fc0020eb62f0ed9df26cf919f9
  • Loading branch information
manupak committed Nov 22, 2021
1 parent e7d4fee commit 313e1ba
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 31 deletions.
10 changes: 5 additions & 5 deletions src/tir/usmp/algo/greedy_by_size.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ namespace tir {
namespace usmp {
namespace algo {

size_t round_up_to_byte_alignment(const size_t& non_aligned_byte_offset,
const int& byte_alignment) {
static size_t round_up_to_byte_alignment(const size_t& non_aligned_byte_offset,
const int& byte_alignment) {
return ((non_aligned_byte_offset + byte_alignment - 1) / byte_alignment) * byte_alignment;
}

bool IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset,
const size_t& size_bytes) {
static bool IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset,
const size_t& size_bytes) {
if (candidate_pool->size_hint_bytes == -1) {
// this means pool is not bounded
return true;
Expand All @@ -53,7 +53,7 @@ bool IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset,
return false;
}

PoolInfo SelectPlacementPool(
static PoolInfo SelectPlacementPool(
const Array<PoolInfo>& pool_candidates,
const std::unordered_map<PoolInfo, size_t, ObjectPtrHash, ObjectPtrEqual>& pool_offsets) {
for (const auto& pool_info : pool_candidates) {
Expand Down
140 changes: 114 additions & 26 deletions tests/python/unittest/test_tir_usmp_algo_greedy_by_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ def assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos):
return ret


def _assign_targets_to_primfuncs_irmodule(mod, target):
"""helper to assign target for PrimFunc in a IRModule"""
ret = tvm.IRModule()
for global_var, basefunc in mod.functions.items():
if isinstance(basefunc, tvm.tir.PrimFunc):
ret[global_var] = basefunc.with_attr("target", target)
return ret


def _check_max_workspace_size(buffer_pool_allocations, pool_info, size):
max_workspace_size = 0
for buffer_info, pool_allocation in buffer_pool_allocations.items():
Expand Down Expand Up @@ -143,7 +152,7 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6:
T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True)

@T.prim_func
def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None:
def run_model(input: T.handle, output: T.handle) -> None:
# function attr dict
T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True})
# body
Expand All @@ -159,19 +168,21 @@ def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None:


def test_linear():
target = Target("c")
fast_memory_pool = usmp_utils.PoolInfo(
pool_name="fast_memory",
target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS},
target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS},
size_hint_bytes=200704,
)
slow_memory_pool = usmp_utils.PoolInfo(
pool_name="slow_memory", target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS}
pool_name="slow_memory", target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}
)
tir_mod = LinearStructure
tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
tir_mod = assign_poolinfos_to_allocates_in_irmodule(
tir_mod, [fast_memory_pool, slow_memory_pool]
)
main_func = tir_mod["tvmgen_default_run_model"]
main_func = tir_mod["run_model"]
buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)

fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
Expand All @@ -184,13 +195,15 @@ def test_linear():
buffer_info_map_names[buf_info.name_hint] = buf_info

# check conflicts
_verify_conflicts("sid_8", ["Conv2dOutput_7", "tensor_2"], buffer_info_map_names)
_verify_conflicts("Conv2dOutput_7", ["PaddedInput_7", "sid_8"], buffer_info_map_names)
_verify_conflicts("PaddedInput_7", ["sid_9", "Conv2dOutput_7"], buffer_info_map_names)
_verify_conflicts("PaddedInput_7", ["sid_9", "sid_8", "Conv2dOutput_7"], buffer_info_map_names)
_verify_conflicts("tensor_2", ["sid_8"], buffer_info_map_names)
_verify_conflicts("sid_9", ["PaddedInput_7"], buffer_info_map_names)
_verify_conflicts(
"sid_8", ["PaddedInput_7", "Conv2dOutput_7", "tensor_2"], buffer_info_map_names
)
_verify_conflicts("Conv2dOutput_7", ["sid_8", "PaddedInput_7"], buffer_info_map_names)

_check_max_workspace_size(buffer_pool_allocations, slow_memory_pool, 802816)
_check_max_workspace_size(buffer_pool_allocations, slow_memory_pool, 1418528)
_check_max_workspace_size(buffer_pool_allocations, fast_memory_pool, 200704)


Expand Down Expand Up @@ -316,11 +329,13 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place


def test_fanout():
target = Target("c")
global_workspace_pool = usmp_utils.PoolInfo(
pool_name="global_workspace",
target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS},
target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS},
)
tir_mod = ResnetStructure
tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_workspace_pool])
main_func = tir_mod["tvmgen_default_run_model"]
buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
Expand All @@ -336,35 +351,108 @@ def test_fanout():

# check conflicts
_verify_conflicts(
"sid_6",
["Conv2dOutput_2", "sid_2", "PaddedInput_3", "Conv2dOutput_3"],
"Conv2dOutput_1",
[
"PaddedInput_1",
"sid_7",
],
buffer_info_map_names,
)
_verify_conflicts("PaddedInput_1", ["sid_8", "sid_2", "Conv2dOutput_1"], buffer_info_map_names)
_verify_conflicts("PaddedInput_2", ["sid_7", "sid_2", "Conv2dOutput_2"], buffer_info_map_names)
_verify_conflicts("sid_8", ["Conv2dOutput", "sid_2", "PaddedInput_1"], buffer_info_map_names)
_verify_conflicts(
"sid_2",
"sid_8",
[
"PaddedInput",
"Conv2dOutput",
"sid_8",
"PaddedInput_1",
"Conv2dOutput_1",
],
buffer_info_map_names,
)
_verify_conflicts(
"PaddedInput_2",
[
"sid_7",
"sid_6",
"Conv2dOutput_2",
],
buffer_info_map_names,
)
_verify_conflicts(
"sid_2",
[
"PaddedInput",
"PaddedInput_3",
],
buffer_info_map_names,
)
_verify_conflicts(
"Conv2dOutput",
[
"sid_8",
"PaddedInput",
],
buffer_info_map_names,
)
_verify_conflicts(
"sid_7",
[
"Conv2dOutput_1",
"PaddedInput_1",
"PaddedInput_2",
],
buffer_info_map_names,
)
_verify_conflicts(
"sid_6",
[
"PaddedInput_2",
"Conv2dOutput_2",
"Conv2dOutput_3",
"PaddedInput_3",
],
buffer_info_map_names,
)
_verify_conflicts(
"PaddedInput_3",
[
"sid_2",
"Conv2dOutput_3",
"sid_6",
],
buffer_info_map_names,
)
_verify_conflicts(
"Conv2dOutput_3",
[
"PaddedInput_3",
"sid_6",
],
buffer_info_map_names,
)
_verify_conflicts(
"PaddedInput",
[
"sid_2",
"sid_8",
"Conv2dOutput",
],
buffer_info_map_names,
)
_verify_conflicts(
"Conv2dOutput_2",
[
"sid_6",
"PaddedInput_2",
],
buffer_info_map_names,
)
_verify_conflicts("PaddedInput", ["sid_2", "Conv2dOutput"], buffer_info_map_names)
_verify_conflicts("sid_7", ["Conv2dOutput_1", "sid_2", "PaddedInput_2"], buffer_info_map_names)
_verify_conflicts("PaddedInput_3", ["sid_6", "sid_2", "Conv2dOutput_3"], buffer_info_map_names)
_verify_conflicts("Conv2dOutput_3", ["PaddedInput_3", "sid_6"], buffer_info_map_names)
_verify_conflicts("Conv2dOutput", ["PaddedInput", "sid_2", "sid_8"], buffer_info_map_names)
_verify_conflicts("Conv2dOutput_1", ["PaddedInput_1", "sid_2", "sid_7"], buffer_info_map_names)
_verify_conflicts("Conv2dOutput_2", ["PaddedInput_2", "sid_2", "sid_6"], buffer_info_map_names)

_check_max_workspace_size(buffer_pool_allocations, global_workspace_pool, 7920256)
_verify_conflicts(
"PaddedInput_1",
[
"sid_8",
"Conv2dOutput_1",
"sid_7",
],
buffer_info_map_names,
)

_check_max_workspace_size(buffer_pool_allocations, global_workspace_pool, 7200000)

0 comments on commit 313e1ba

Please sign in to comment.