diff --git a/src/tir/usmp/algo/greedy_by_size.cc b/src/tir/usmp/algo/greedy_by_size.cc index c657ad41a82e4..7aafe8c0974fb 100644 --- a/src/tir/usmp/algo/greedy_by_size.cc +++ b/src/tir/usmp/algo/greedy_by_size.cc @@ -34,13 +34,13 @@ namespace tir { namespace usmp { namespace algo { -size_t round_up_to_byte_alignment(const size_t& non_aligned_byte_offset, - const int& byte_alignment) { +static size_t round_up_to_byte_alignment(const size_t& non_aligned_byte_offset, + const int& byte_alignment) { return ((non_aligned_byte_offset + byte_alignment - 1) / byte_alignment) * byte_alignment; } -bool IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset, - const size_t& size_bytes) { +static bool IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset, + const size_t& size_bytes) { if (candidate_pool->size_hint_bytes == -1) { // this means pool is not bounded return true; @@ -53,7 +53,7 @@ bool IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset, return false; } -PoolInfo SelectPlacementPool( +static PoolInfo SelectPlacementPool( const Array& pool_candidates, const std::unordered_map& pool_offsets) { for (const auto& pool_info : pool_candidates) { diff --git a/tests/python/unittest/test_tir_usmp_algo_greedy_by_size.py b/tests/python/unittest/test_tir_usmp_algo_greedy_by_size.py index a24a79ba85efa..6bd7832f533b1 100644 --- a/tests/python/unittest/test_tir_usmp_algo_greedy_by_size.py +++ b/tests/python/unittest/test_tir_usmp_algo_greedy_by_size.py @@ -77,6 +77,15 @@ def assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos): return ret +def _assign_targets_to_primfuncs_irmodule(mod, target): + """helper to assign target for PrimFunc in a IRModule""" + ret = tvm.IRModule() + for global_var, basefunc in mod.functions.items(): + if isinstance(basefunc, tvm.tir.PrimFunc): + ret[global_var] = basefunc.with_attr("target", target) + return ret + + def _check_max_workspace_size(buffer_pool_allocations, pool_info, size): max_workspace_size = 0 for buffer_info, pool_allocation in buffer_pool_allocations.items(): @@ -143,7 +152,7 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True) @T.prim_func - def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: + def run_model(input: T.handle, output: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True}) # body @@ -159,19 +168,21 @@ def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: def test_linear(): + target = Target("c") fast_memory_pool = usmp_utils.PoolInfo( pool_name="fast_memory", - target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, size_hint_bytes=200704, ) slow_memory_pool = usmp_utils.PoolInfo( - pool_name="slow_memory", target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS} + pool_name="slow_memory", target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS} ) tir_mod = LinearStructure + tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target) tir_mod = assign_poolinfos_to_allocates_in_irmodule( tir_mod, [fast_memory_pool, slow_memory_pool] ) - main_func = tir_mod["tvmgen_default_run_model"] + main_func = tir_mod["run_model"] buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod) fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo") @@ -184,13 +195,15 @@ def test_linear(): buffer_info_map_names[buf_info.name_hint] = buf_info # check conflicts - _verify_conflicts("sid_8", ["Conv2dOutput_7", "tensor_2"], buffer_info_map_names) - _verify_conflicts("Conv2dOutput_7", ["PaddedInput_7", "sid_8"], buffer_info_map_names) - _verify_conflicts("PaddedInput_7", ["sid_9", "Conv2dOutput_7"], buffer_info_map_names) + _verify_conflicts("PaddedInput_7", ["sid_9", "sid_8", "Conv2dOutput_7"], buffer_info_map_names) _verify_conflicts("tensor_2", ["sid_8"], buffer_info_map_names) _verify_conflicts("sid_9", ["PaddedInput_7"], buffer_info_map_names) + _verify_conflicts( + "sid_8", ["PaddedInput_7", "Conv2dOutput_7", "tensor_2"], buffer_info_map_names + ) + _verify_conflicts("Conv2dOutput_7", ["sid_8", "PaddedInput_7"], buffer_info_map_names) - _check_max_workspace_size(buffer_pool_allocations, slow_memory_pool, 802816) + _check_max_workspace_size(buffer_pool_allocations, slow_memory_pool, 1418528) _check_max_workspace_size(buffer_pool_allocations, fast_memory_pool, 200704) @@ -316,11 +329,13 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place def test_fanout(): + target = Target("c") global_workspace_pool = usmp_utils.PoolInfo( pool_name="global_workspace", - target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, ) tir_mod = ResnetStructure + tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target) tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_workspace_pool]) main_func = tir_mod["tvmgen_default_run_model"] buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod) @@ -336,35 +351,108 @@ def test_fanout(): # check conflicts _verify_conflicts( - "sid_6", - ["Conv2dOutput_2", "sid_2", "PaddedInput_3", "Conv2dOutput_3"], + "Conv2dOutput_1", + [ + "PaddedInput_1", + "sid_7", + ], buffer_info_map_names, ) - _verify_conflicts("PaddedInput_1", ["sid_8", "sid_2", "Conv2dOutput_1"], buffer_info_map_names) - _verify_conflicts("PaddedInput_2", ["sid_7", "sid_2", "Conv2dOutput_2"], buffer_info_map_names) - _verify_conflicts("sid_8", ["Conv2dOutput", "sid_2", "PaddedInput_1"], buffer_info_map_names) _verify_conflicts( - "sid_2", + "sid_8", [ "PaddedInput", "Conv2dOutput", - "sid_8", "PaddedInput_1", - "Conv2dOutput_1", + ], + buffer_info_map_names, + ) + _verify_conflicts( + "PaddedInput_2", + [ "sid_7", + "sid_6", + "Conv2dOutput_2", + ], + buffer_info_map_names, + ) + _verify_conflicts( + "sid_2", + [ + "PaddedInput", + "PaddedInput_3", + ], + buffer_info_map_names, + ) + _verify_conflicts( + "Conv2dOutput", + [ + "sid_8", + "PaddedInput", + ], + buffer_info_map_names, + ) + _verify_conflicts( + "sid_7", + [ + "Conv2dOutput_1", + "PaddedInput_1", + "PaddedInput_2", + ], + buffer_info_map_names, + ) + _verify_conflicts( + "sid_6", + [ "PaddedInput_2", "Conv2dOutput_2", + "Conv2dOutput_3", + "PaddedInput_3", + ], + buffer_info_map_names, + ) + _verify_conflicts( + "PaddedInput_3", + [ + "sid_2", + "Conv2dOutput_3", "sid_6", + ], + buffer_info_map_names, + ) + _verify_conflicts( + "Conv2dOutput_3", + [ "PaddedInput_3", + "sid_6", + ], + buffer_info_map_names, + ) + _verify_conflicts( + "PaddedInput", + [ + "sid_2", + "sid_8", + "Conv2dOutput", + ], + buffer_info_map_names, + ) + _verify_conflicts( + "Conv2dOutput_2", + [ + "sid_6", + "PaddedInput_2", ], buffer_info_map_names, ) - _verify_conflicts("PaddedInput", ["sid_2", "Conv2dOutput"], buffer_info_map_names) - _verify_conflicts("sid_7", ["Conv2dOutput_1", "sid_2", "PaddedInput_2"], buffer_info_map_names) - _verify_conflicts("PaddedInput_3", ["sid_6", "sid_2", "Conv2dOutput_3"], buffer_info_map_names) - _verify_conflicts("Conv2dOutput_3", ["PaddedInput_3", "sid_6"], buffer_info_map_names) - _verify_conflicts("Conv2dOutput", ["PaddedInput", "sid_2", "sid_8"], buffer_info_map_names) - _verify_conflicts("Conv2dOutput_1", ["PaddedInput_1", "sid_2", "sid_7"], buffer_info_map_names) - _verify_conflicts("Conv2dOutput_2", ["PaddedInput_2", "sid_2", "sid_6"], buffer_info_map_names) - - _check_max_workspace_size(buffer_pool_allocations, global_workspace_pool, 7920256) + _verify_conflicts( + "PaddedInput_1", + [ + "sid_8", + "Conv2dOutput_1", + "sid_7", + ], + buffer_info_map_names, + ) + + _check_max_workspace_size(buffer_pool_allocations, global_workspace_pool, 7200000)