From ecb3c3feae7c4b7bbba5082585297612d5d9889c Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne <manupa.karunaratne@arm.com> Date: Thu, 8 Jul 2021 13:44:25 +0100 Subject: [PATCH 1/8] [TIR][USMP] greedy_by_size usmp algo * Implementation of greedy by size memory planning algorithm * Added a test case of linear sequence of operators with two pools * Added a test case with residual structures Change-Id: I03b41292eab85ddb43710356c23dd123beb24462 --- src/tir/usmp/algo/greedy_by_size.cc | 128 ++++++ .../test_tir_usmp_algo_greedy_by_size.py | 370 ++++++++++++++++++ 2 files changed, 498 insertions(+) create mode 100644 src/tir/usmp/algo/greedy_by_size.cc create mode 100644 tests/python/unittest/test_tir_usmp_algo_greedy_by_size.py diff --git a/src/tir/usmp/algo/greedy_by_size.cc b/src/tir/usmp/algo/greedy_by_size.cc new file mode 100644 index 000000000000..c657ad41a82e --- /dev/null +++ b/src/tir/usmp/algo/greedy_by_size.cc @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/analysis/usmp/algo/greedy_by_size.cc + * \brief Implement greedy by size memory planning algorithm + */ + +#include <tvm/arith/analyzer.h> +#include <tvm/runtime/device_api.h> +#include <tvm/tir/builtin.h> +#include <tvm/tir/function.h> +#include <tvm/tir/stmt_functor.h> +#include <tvm/tir/usmp/utils.h> + +namespace tvm { +namespace tir { +namespace usmp { +namespace algo { + +size_t round_up_to_byte_alignment(const size_t& non_aligned_byte_offset, + const int& byte_alignment) { + return ((non_aligned_byte_offset + byte_alignment - 1) / byte_alignment) * byte_alignment; +} + +bool IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset, + const size_t& size_bytes) { + if (candidate_pool->size_hint_bytes == -1) { + // this means pool is not bounded + return true; + } + auto pool_size = static_cast<size_t>(candidate_pool->size_hint_bytes->value); + auto max_address = next_offset + size_bytes; + if (max_address <= pool_size) { + return true; + } + return false; +} + +PoolInfo SelectPlacementPool( + const Array<PoolInfo>& pool_candidates, + const std::unordered_map<PoolInfo, size_t, ObjectPtrHash, ObjectPtrEqual>& pool_offsets) { + for (const auto& pool_info : pool_candidates) { + if (pool_offsets.count(pool_info)) { + return pool_info; + } + } + ICHECK(false) << "TVM USMP Internal Error: no candidate have been selected!"; + return PoolInfo(); +} + +Map<BufferInfo, PoolAllocation> GreedyBySize(const Array<BufferInfo>& buffer_info_arr) { + std::vector<BufferInfo> buffer_info_vec; + Map<BufferInfo, PoolAllocation> pool_allocations; + for (const auto& buffer_info : buffer_info_arr) { + buffer_info_vec.push_back(std::move(buffer_info)); + } + std::sort(buffer_info_vec.begin(), buffer_info_vec.end(), + [](const BufferInfo& a, const BufferInfo& b) { + if (a->size_bytes->value == b->size_bytes->value) { + if (a->conflicts.size() == b->conflicts.size()) { + auto a_name_hash = std::hash<std::string>{}(a->name_hint->data); + auto b_name_hash = std::hash<std::string>{}(b->name_hint->data); + return a_name_hash > b_name_hash; + } else { + return a->conflicts.size() > b->conflicts.size(); + } + } + return a->size_bytes > b->size_bytes; + }); + + for (const auto& buf_info : buffer_info_vec) { + std::unordered_map<PoolInfo, size_t, ObjectPtrHash, ObjectPtrEqual> pool_offset_candidates; + for (const auto& pool_info : buf_info->pool_candidates) { + if (IsValidPlacement(pool_info, 0, buf_info->size_bytes->value)) { + pool_offset_candidates[pool_info] = 0; + } + } + + for (const auto& conflict_buf_info_obj : buf_info->conflicts) { + auto conflict_buf_info = Downcast<BufferInfo>(conflict_buf_info_obj); + size_t next_offset = 0; + if (pool_allocations.count(conflict_buf_info)) { + auto pool_allocation = pool_allocations[conflict_buf_info]; + next_offset = pool_allocation->byte_offset + conflict_buf_info->size_bytes; + next_offset = round_up_to_byte_alignment(next_offset, conflict_buf_info->alignment->value); + if (IsValidPlacement(pool_allocation->pool_info, next_offset, + buf_info->size_bytes->value)) { + if (next_offset > pool_offset_candidates[pool_allocation->pool_info]) { + pool_offset_candidates[pool_allocation->pool_info] = next_offset; + } + } else { + pool_offset_candidates.erase(pool_allocation->pool_info); + } + } + } + auto selected_pool = SelectPlacementPool(buf_info->pool_candidates, pool_offset_candidates); + pool_allocations.Set( + buf_info, PoolAllocation(selected_pool, Integer(pool_offset_candidates[selected_pool]))); + } + return pool_allocations; +} + +TVM_REGISTER_GLOBAL("tir.usmp.algo.greedy_by_size") + .set_body_typed([](Array<BufferInfo> buffer_info_arr) { + return GreedyBySize(buffer_info_arr); + }); + +} // namespace algo +} // namespace usmp +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_usmp_algo_greedy_by_size.py b/tests/python/unittest/test_tir_usmp_algo_greedy_by_size.py new file mode 100644 index 000000000000..a24a79ba85ef --- /dev/null +++ b/tests/python/unittest/test_tir_usmp_algo_greedy_by_size.py @@ -0,0 +1,370 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +import tvm +from tvm import tir, script +from tvm.script import tir as T +from tvm.tir import stmt_functor +from tvm.tir.usmp import utils as usmp_utils +from tvm.target import Target + + +def _replace_stmt_with_buf_var_names(buffer_info_map): + """helper to replace tir.allocates with buffer names""" + new_buffer_info_map = dict() + for k, v in buffer_info_map.items(): + new_buffer_info_map[v.buffer_var.name] = k + return new_buffer_info_map + + +def _verify_conflicts(main_buf_name, conflicting_buf_names, buffer_info_map): + """helper to check expected liveness conflicts""" + buf_info = buffer_info_map[main_buf_name] + for conflict in buf_info.conflicts: + assert conflict.name_hint in conflicting_buf_names + + +def _get_allocates(primfunc): + """helper to extract all allocate nodes by name""" + allocates = dict() + + def get_allocate(stmt): + if isinstance(stmt, tvm.tir.Allocate): + allocates[str(stmt.buffer_var.name)] = stmt + + stmt_functor.post_order_visit(primfunc.body, get_allocate) + return allocates + + +def assign_poolinfos_to_allocates_in_primfunc(primfunc, pool_infos): + """helper to assing poolinfos to allocate nodes in a tir.PrimFunc""" + + def set_poolinfos(stmt): + if isinstance(stmt, tvm.tir.Allocate): + return tvm.tir.Allocate( + buffer_var=stmt.buffer_var, + dtype=stmt.dtype, + extents=stmt.extents, + condition=stmt.condition, + body=stmt.body, + annotations={tvm.tir.usmp.utils.CANDIDATE_MEMORY_POOL_ATTR: pool_infos}, + ) + + return primfunc.with_body(stmt_functor.ir_transform(primfunc.body, None, set_poolinfos)) + + +def assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos): + """helper to assing poolinfos to allocate nodes in a IRModule""" + ret = tvm.IRModule() + for global_var, basefunc in mod.functions.items(): + if isinstance(basefunc, tvm.tir.PrimFunc): + ret[global_var] = assign_poolinfos_to_allocates_in_primfunc(basefunc, pool_infos) + return ret + + +def _check_max_workspace_size(buffer_pool_allocations, pool_info, size): + max_workspace_size = 0 + for buffer_info, pool_allocation in buffer_pool_allocations.items(): + if pool_allocation.pool_info == pool_info: + size_candidate = pool_allocation.byte_offset + buffer_info.size_bytes + if size_candidate > max_workspace_size: + max_workspace_size = size_candidate + assert max_workspace_size == size + + +# fmt: off +@tvm.script.ir_module +class LinearStructure: + @T.prim_func + def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) + placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = T.match_buffer(placeholder_3, [], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T_subtract_1 = T.match_buffer(T_subtract, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + for ax0_ax1_fused_1 in T.serial(0, 224): + for ax2_1, ax3_inner_1 in T.grid(224, 3): + T.store(T_subtract_1.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1), (T.cast(T.load("uint8", placeholder_4.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)), "int16") - T.load("int16", placeholder_5.data, 0)), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True}) + placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_21 = T.match_buffer(T_cast_20, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_7 = T.allocate([157323], "int16", "global") + for i0_i1_fused_7 in T.serial(0, 229): + for i2_7, i3_7 in T.grid(229, 3): + T.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), T.load("int16", placeholder_65.data, ((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)), T.int16(0), dtype="int16"), True) + for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): + Conv2dOutput_7 = T.allocate([64], "int32", "global") + for ff_3 in T.serial(0, 64): + T.store(Conv2dOutput_7, ff_3, 0, True) + for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): + T.store(Conv2dOutput_7, ff_3, (T.load("int32", Conv2dOutput_7, ff_3) + (T.cast(T.load("int16", PaddedInput_7, (((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)), "int32")*T.cast(T.load("int16", placeholder_66.data, ((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)), "int32"))), True) + for ax3_inner_7 in T.serial(0, 64): + T.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_7, ax3_inner_7) + T.load("int32", placeholder_67.data, ax3_inner_7)), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) + + @T.prim_func + def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) + placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + tensor_2 = T.allocate([200704], "uint8", "global") + for ax0_ax1_fused_4 in T.serial(0, 56): + for ax2_4 in T.serial(0, 56): + for ax3_init in T.serial(0, 64): + T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True) + for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): + T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(T.load("uint8", tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)), T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), T.load("uint8", placeholder_29.data, (((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)), T.uint8(0), dtype="uint8")), True) + for ax0_ax1_fused_5 in T.serial(0, 56): + for ax2_5, ax3_3 in T.grid(56, 64): + T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True) + + @T.prim_func + def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True}) + # body + T.attr("default", "device_id", 0) + T.attr("default", "device_type", 1) + sid_9 = T.allocate([301056], "int8", "global") + sid_8 = T.allocate([802816], "int8", "global") + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8, output, dtype="int32")) + __tvm_meta__ = None +# fmt: on + + +def test_linear(): + fast_memory_pool = usmp_utils.PoolInfo( + pool_name="fast_memory", + target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + size_hint_bytes=200704, + ) + slow_memory_pool = usmp_utils.PoolInfo( + pool_name="slow_memory", target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS} + ) + tir_mod = LinearStructure + tir_mod = assign_poolinfos_to_allocates_in_irmodule( + tir_mod, [fast_memory_pool, slow_memory_pool] + ) + main_func = tir_mod["tvmgen_default_run_model"] + buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod) + + fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo") + buffer_info_arr = fcreate_array_bi(buffer_info_map) + fusmp_algo_greedy_by_size = tvm.get_global_func("tir.usmp.algo.greedy_by_size") + buffer_pool_allocations = fusmp_algo_greedy_by_size(buffer_info_arr) + + buffer_info_map_names = dict() + for buf_info in buffer_info_arr: + buffer_info_map_names[buf_info.name_hint] = buf_info + + # check conflicts + _verify_conflicts("sid_8", ["Conv2dOutput_7", "tensor_2"], buffer_info_map_names) + _verify_conflicts("Conv2dOutput_7", ["PaddedInput_7", "sid_8"], buffer_info_map_names) + _verify_conflicts("PaddedInput_7", ["sid_9", "Conv2dOutput_7"], buffer_info_map_names) + _verify_conflicts("tensor_2", ["sid_8"], buffer_info_map_names) + _verify_conflicts("sid_9", ["PaddedInput_7"], buffer_info_map_names) + + _check_max_workspace_size(buffer_pool_allocations, slow_memory_pool, 802816) + _check_max_workspace_size(buffer_pool_allocations, fast_memory_pool, 200704) + + +# fmt: off +@tvm.script.ir_module +class ResnetStructure: + @T.prim_func + def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", "tir.noalias": True}) + placeholder_2 = T.match_buffer(placeholder, [1, 75, 75, 64], dtype="uint8") + placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") + T_cast_1 = T.match_buffer(T_cast, [1, 75, 75, 64], dtype="int16") + # body + for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): + T.store(T_cast_1.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.load("uint8", placeholder_2.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner), "int32") - 94, 1843157232, 31, 1, dtype="int32") + T.load("int32", placeholder_3.data, ax3_outer * 16 + ax3_inner), 255), 0), "uint8"), "int16"), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", "tir.noalias": True}) + placeholder_13 = T.match_buffer(placeholder_10, [1, 75, 75, 64], dtype="int16") + placeholder_14 = T.match_buffer(placeholder_11, [3, 3, 64, 64], dtype="int16") + placeholder_15 = T.match_buffer(placeholder_12, [1, 1, 1, 64], dtype="int32") + T_cast_5 = T.match_buffer(T_cast_4, [1, 75, 75, 64], dtype="int16") + # body + PaddedInput_1 = T.allocate([379456], "int16", "global") + for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): + T.store(PaddedInput_1, i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1, T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, T.load("int16", placeholder_13.data, i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864), T.int16(0), dtype="int16"), True) + for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625): + Conv2dOutput_1 = T.allocate([64], "int32", "global") + for ff_1 in T.serial(0, 64): + T.store(Conv2dOutput_1, ff_1, 0, True) + for ry, rx, rc_1 in T.grid(3, 3, 64): + T.store(Conv2dOutput_1, ff_1, T.load("int32", Conv2dOutput_1, ff_1) + T.cast(T.load("int16", PaddedInput_1, T.floordiv(ax0_ax1_fused_ax2_fused_1, 75) * 4928 + ry * 4928 + rx * 64 + T.floormod(ax0_ax1_fused_ax2_fused_1, 75) * 64 + rc_1), "int32") * T.cast(T.load("int16", placeholder_14.data, ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1), "int32"), True) + for ax3_inner_2 in T.serial(0, 64): + T.store(T_cast_5.data, ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_1, ax3_inner_2) + T.load("int32", placeholder_15.data, ax3_inner_2), 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", "tir.noalias": True}) + placeholder_19 = T.match_buffer(placeholder_16, [1, 75, 75, 64], dtype="int16") + placeholder_20 = T.match_buffer(placeholder_17, [1, 1, 64, 256], dtype="int16") + placeholder_21 = T.match_buffer(placeholder_18, [1, 1, 1, 256], dtype="int32") + T_add_1 = T.match_buffer(T_add, [1, 75, 75, 256], dtype="int32") + # body + PaddedInput_2 = T.allocate([360000], "int16", "global") + for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64): + T.store(PaddedInput_2, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2, T.load("int16", placeholder_19.data, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2), True) + for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 5625): + Conv2dOutput_2 = T.allocate([64], "int32", "global") + for ax3_outer_1 in T.serial(0, 4): + for ff_2 in T.serial(0, 64): + T.store(Conv2dOutput_2, ff_2, 0, True) + for rc_2 in T.serial(0, 64): + T.store(Conv2dOutput_2, ff_2, T.load("int32", Conv2dOutput_2, ff_2) + T.cast(T.load("int16", PaddedInput_2, ax0_ax1_fused_ax2_fused_2 * 64 + rc_2), "int32") * T.cast(T.load("int16", placeholder_20.data, rc_2 * 256 + ax3_outer_1 * 64 + ff_2), "int32"), True) + for ax3_inner_3 in T.serial(0, 64): + T.store(T_add_1.data, ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3, T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_2, ax3_inner_3) + T.load("int32", placeholder_21.data, ax3_outer_1 * 64 + ax3_inner_3), 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136, True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", "tir.noalias": True}) + placeholder_29 = T.match_buffer(placeholder_22, [1, 75, 75, 64], dtype="int16") + placeholder_27 = T.match_buffer(placeholder_23, [1, 1, 64, 256], dtype="int16") + placeholder_26 = T.match_buffer(placeholder_24, [1, 1, 1, 256], dtype="int32") + placeholder_28 = T.match_buffer(placeholder_25, [1, 75, 75, 256], dtype="int32") + T_cast_7 = T.match_buffer(T_cast_6, [1, 75, 75, 256], dtype="uint8") + # body + PaddedInput_3 = T.allocate([360000], "int16", "global") + for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64): + T.store(PaddedInput_3, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3, T.load("int16", placeholder_29.data, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3), True) + for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 5625): + Conv2dOutput_3 = T.allocate([64], "int32", "global") + for ax3_outer_2 in T.serial(0, 4): + for ff_3 in T.serial(0, 64): + T.store(Conv2dOutput_3, ff_3, 0, True) + for rc_3 in T.serial(0, 64): + T.store(Conv2dOutput_3, ff_3, T.load("int32", Conv2dOutput_3, ff_3) + T.cast(T.load("int16", PaddedInput_3, ax0_ax1_fused_ax2_fused_3 * 64 + rc_3), "int32") * T.cast(T.load("int16", placeholder_27.data, rc_3 * 256 + ax3_outer_2 * 64 + ff_3), "int32"), True) + for ax3_inner_4 in T.serial(0, 64): + T.store(T_cast_7.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4, T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_3, ax3_inner_4) + T.load("int32", placeholder_26.data, ax3_outer_2 * 64 + ax3_inner_4), 1343014664, 31, -8, dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, dtype="int32") + T.load("int32", placeholder_28.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4), 255), 0), "uint8"), True) + + @T.prim_func + def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True}) + # body + T.attr("default", "device_id", 0) + T.attr("default", "device_type", 1) + sid_2 = T.allocate([720000], "int8", "global") + sid_6 = T.allocate([5760000], "int8", "global") + sid_7 = T.allocate([720000], "int8", "global") + sid_8 = T.allocate([720000], "int8", "global") + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", sid_7, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_6, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", sid_2, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_6, output, dtype="int32")) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", "tir.noalias": True}) + placeholder_7 = T.match_buffer(placeholder_4, [1, 75, 75, 64], dtype="int16") + placeholder_8 = T.match_buffer(placeholder_5, [1, 1, 64, 64], dtype="int16") + placeholder_9 = T.match_buffer(placeholder_6, [1, 1, 1, 64], dtype="int32") + T_cast_3 = T.match_buffer(T_cast_2, [1, 75, 75, 64], dtype="int16") + # body + PaddedInput = T.allocate([360000], "int16", "global") + for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): + T.store(PaddedInput, i0_i1_fused * 4800 + i2 * 64 + i3, T.load("int16", placeholder_7.data, i0_i1_fused * 4800 + i2 * 64 + i3), True) + for ax0_ax1_fused_ax2_fused in T.serial(0, 5625): + Conv2dOutput = T.allocate([64], "int32", "global") + for ff in T.serial(0, 64): + T.store(Conv2dOutput, ff, 0, True) + for rc in T.serial(0, 64): + T.store(Conv2dOutput, ff, T.load("int32", Conv2dOutput, ff) + T.cast(T.load("int16", PaddedInput, ax0_ax1_fused_ax2_fused * 64 + rc), "int32") * T.cast(T.load("int16", placeholder_8.data, rc * 64 + ff), "int32"), True) + for ax3_inner_1 in T.serial(0, 64): + T.store(T_cast_3.data, ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput, ax3_inner_1) + T.load("int32", placeholder_9.data, ax3_inner_1), 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True) + __tvm_meta__ = None +# fmt: on + + +def test_fanout(): + global_workspace_pool = usmp_utils.PoolInfo( + pool_name="global_workspace", + target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + ) + tir_mod = ResnetStructure + tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_workspace_pool]) + main_func = tir_mod["tvmgen_default_run_model"] + buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod) + + fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo") + buffer_info_arr = fcreate_array_bi(buffer_info_map) + fusmp_algo_greedy_by_size = tvm.get_global_func("tir.usmp.algo.greedy_by_size") + buffer_pool_allocations = fusmp_algo_greedy_by_size(buffer_info_arr) + + buffer_info_map_names = dict() + for buf_info in buffer_info_arr: + buffer_info_map_names[buf_info.name_hint] = buf_info + + # check conflicts + _verify_conflicts( + "sid_6", + ["Conv2dOutput_2", "sid_2", "PaddedInput_3", "Conv2dOutput_3"], + buffer_info_map_names, + ) + _verify_conflicts("PaddedInput_1", ["sid_8", "sid_2", "Conv2dOutput_1"], buffer_info_map_names) + _verify_conflicts("PaddedInput_2", ["sid_7", "sid_2", "Conv2dOutput_2"], buffer_info_map_names) + _verify_conflicts("sid_8", ["Conv2dOutput", "sid_2", "PaddedInput_1"], buffer_info_map_names) + _verify_conflicts( + "sid_2", + [ + "PaddedInput", + "Conv2dOutput", + "sid_8", + "PaddedInput_1", + "Conv2dOutput_1", + "sid_7", + "PaddedInput_2", + "Conv2dOutput_2", + "sid_6", + "PaddedInput_3", + ], + buffer_info_map_names, + ) + _verify_conflicts("PaddedInput", ["sid_2", "Conv2dOutput"], buffer_info_map_names) + _verify_conflicts("sid_7", ["Conv2dOutput_1", "sid_2", "PaddedInput_2"], buffer_info_map_names) + _verify_conflicts("PaddedInput_3", ["sid_6", "sid_2", "Conv2dOutput_3"], buffer_info_map_names) + _verify_conflicts("Conv2dOutput_3", ["PaddedInput_3", "sid_6"], buffer_info_map_names) + _verify_conflicts("Conv2dOutput", ["PaddedInput", "sid_2", "sid_8"], buffer_info_map_names) + _verify_conflicts("Conv2dOutput_1", ["PaddedInput_1", "sid_2", "sid_7"], buffer_info_map_names) + _verify_conflicts("Conv2dOutput_2", ["PaddedInput_2", "sid_2", "sid_6"], buffer_info_map_names) + + _check_max_workspace_size(buffer_pool_allocations, global_workspace_pool, 7920256) From e0ca1d91f14e80c2ad3042e0f642d7fb81fed75a Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne <manupa.karunaratne@arm.com> Date: Mon, 22 Nov 2021 14:25:41 +0000 Subject: [PATCH 2/8] [TIR][USMP] greedy_by_size usmp algo * Adding targets to the PrimFuncs in the tests Change-Id: Ic91947e23cbcc4fc0020eb62f0ed9df26cf919f9 --- src/tir/usmp/algo/greedy_by_size.cc | 10 +- .../test_tir_usmp_algo_greedy_by_size.py | 140 ++++++++++++++---- 2 files changed, 119 insertions(+), 31 deletions(-) diff --git a/src/tir/usmp/algo/greedy_by_size.cc b/src/tir/usmp/algo/greedy_by_size.cc index c657ad41a82e..7aafe8c0974f 100644 --- a/src/tir/usmp/algo/greedy_by_size.cc +++ b/src/tir/usmp/algo/greedy_by_size.cc @@ -34,13 +34,13 @@ namespace tir { namespace usmp { namespace algo { -size_t round_up_to_byte_alignment(const size_t& non_aligned_byte_offset, - const int& byte_alignment) { +static size_t round_up_to_byte_alignment(const size_t& non_aligned_byte_offset, + const int& byte_alignment) { return ((non_aligned_byte_offset + byte_alignment - 1) / byte_alignment) * byte_alignment; } -bool IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset, - const size_t& size_bytes) { +static bool IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset, + const size_t& size_bytes) { if (candidate_pool->size_hint_bytes == -1) { // this means pool is not bounded return true; @@ -53,7 +53,7 @@ bool IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset, return false; } -PoolInfo SelectPlacementPool( +static PoolInfo SelectPlacementPool( const Array<PoolInfo>& pool_candidates, const std::unordered_map<PoolInfo, size_t, ObjectPtrHash, ObjectPtrEqual>& pool_offsets) { for (const auto& pool_info : pool_candidates) { diff --git a/tests/python/unittest/test_tir_usmp_algo_greedy_by_size.py b/tests/python/unittest/test_tir_usmp_algo_greedy_by_size.py index a24a79ba85ef..6bd7832f533b 100644 --- a/tests/python/unittest/test_tir_usmp_algo_greedy_by_size.py +++ b/tests/python/unittest/test_tir_usmp_algo_greedy_by_size.py @@ -77,6 +77,15 @@ def assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos): return ret +def _assign_targets_to_primfuncs_irmodule(mod, target): + """helper to assign target for PrimFunc in a IRModule""" + ret = tvm.IRModule() + for global_var, basefunc in mod.functions.items(): + if isinstance(basefunc, tvm.tir.PrimFunc): + ret[global_var] = basefunc.with_attr("target", target) + return ret + + def _check_max_workspace_size(buffer_pool_allocations, pool_info, size): max_workspace_size = 0 for buffer_info, pool_allocation in buffer_pool_allocations.items(): @@ -143,7 +152,7 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True) @T.prim_func - def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: + def run_model(input: T.handle, output: T.handle) -> None: # function attr dict T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True}) # body @@ -159,19 +168,21 @@ def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: def test_linear(): + target = Target("c") fast_memory_pool = usmp_utils.PoolInfo( pool_name="fast_memory", - target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, size_hint_bytes=200704, ) slow_memory_pool = usmp_utils.PoolInfo( - pool_name="slow_memory", target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS} + pool_name="slow_memory", target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS} ) tir_mod = LinearStructure + tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target) tir_mod = assign_poolinfos_to_allocates_in_irmodule( tir_mod, [fast_memory_pool, slow_memory_pool] ) - main_func = tir_mod["tvmgen_default_run_model"] + main_func = tir_mod["run_model"] buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod) fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo") @@ -184,13 +195,15 @@ def test_linear(): buffer_info_map_names[buf_info.name_hint] = buf_info # check conflicts - _verify_conflicts("sid_8", ["Conv2dOutput_7", "tensor_2"], buffer_info_map_names) - _verify_conflicts("Conv2dOutput_7", ["PaddedInput_7", "sid_8"], buffer_info_map_names) - _verify_conflicts("PaddedInput_7", ["sid_9", "Conv2dOutput_7"], buffer_info_map_names) + _verify_conflicts("PaddedInput_7", ["sid_9", "sid_8", "Conv2dOutput_7"], buffer_info_map_names) _verify_conflicts("tensor_2", ["sid_8"], buffer_info_map_names) _verify_conflicts("sid_9", ["PaddedInput_7"], buffer_info_map_names) + _verify_conflicts( + "sid_8", ["PaddedInput_7", "Conv2dOutput_7", "tensor_2"], buffer_info_map_names + ) + _verify_conflicts("Conv2dOutput_7", ["sid_8", "PaddedInput_7"], buffer_info_map_names) - _check_max_workspace_size(buffer_pool_allocations, slow_memory_pool, 802816) + _check_max_workspace_size(buffer_pool_allocations, slow_memory_pool, 1418528) _check_max_workspace_size(buffer_pool_allocations, fast_memory_pool, 200704) @@ -316,11 +329,13 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place def test_fanout(): + target = Target("c") global_workspace_pool = usmp_utils.PoolInfo( pool_name="global_workspace", - target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, ) tir_mod = ResnetStructure + tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target) tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_workspace_pool]) main_func = tir_mod["tvmgen_default_run_model"] buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod) @@ -336,35 +351,108 @@ def test_fanout(): # check conflicts _verify_conflicts( - "sid_6", - ["Conv2dOutput_2", "sid_2", "PaddedInput_3", "Conv2dOutput_3"], + "Conv2dOutput_1", + [ + "PaddedInput_1", + "sid_7", + ], buffer_info_map_names, ) - _verify_conflicts("PaddedInput_1", ["sid_8", "sid_2", "Conv2dOutput_1"], buffer_info_map_names) - _verify_conflicts("PaddedInput_2", ["sid_7", "sid_2", "Conv2dOutput_2"], buffer_info_map_names) - _verify_conflicts("sid_8", ["Conv2dOutput", "sid_2", "PaddedInput_1"], buffer_info_map_names) _verify_conflicts( - "sid_2", + "sid_8", [ "PaddedInput", "Conv2dOutput", - "sid_8", "PaddedInput_1", - "Conv2dOutput_1", + ], + buffer_info_map_names, + ) + _verify_conflicts( + "PaddedInput_2", + [ "sid_7", + "sid_6", + "Conv2dOutput_2", + ], + buffer_info_map_names, + ) + _verify_conflicts( + "sid_2", + [ + "PaddedInput", + "PaddedInput_3", + ], + buffer_info_map_names, + ) + _verify_conflicts( + "Conv2dOutput", + [ + "sid_8", + "PaddedInput", + ], + buffer_info_map_names, + ) + _verify_conflicts( + "sid_7", + [ + "Conv2dOutput_1", + "PaddedInput_1", + "PaddedInput_2", + ], + buffer_info_map_names, + ) + _verify_conflicts( + "sid_6", + [ "PaddedInput_2", "Conv2dOutput_2", + "Conv2dOutput_3", + "PaddedInput_3", + ], + buffer_info_map_names, + ) + _verify_conflicts( + "PaddedInput_3", + [ + "sid_2", + "Conv2dOutput_3", "sid_6", + ], + buffer_info_map_names, + ) + _verify_conflicts( + "Conv2dOutput_3", + [ "PaddedInput_3", + "sid_6", + ], + buffer_info_map_names, + ) + _verify_conflicts( + "PaddedInput", + [ + "sid_2", + "sid_8", + "Conv2dOutput", + ], + buffer_info_map_names, + ) + _verify_conflicts( + "Conv2dOutput_2", + [ + "sid_6", + "PaddedInput_2", ], buffer_info_map_names, ) - _verify_conflicts("PaddedInput", ["sid_2", "Conv2dOutput"], buffer_info_map_names) - _verify_conflicts("sid_7", ["Conv2dOutput_1", "sid_2", "PaddedInput_2"], buffer_info_map_names) - _verify_conflicts("PaddedInput_3", ["sid_6", "sid_2", "Conv2dOutput_3"], buffer_info_map_names) - _verify_conflicts("Conv2dOutput_3", ["PaddedInput_3", "sid_6"], buffer_info_map_names) - _verify_conflicts("Conv2dOutput", ["PaddedInput", "sid_2", "sid_8"], buffer_info_map_names) - _verify_conflicts("Conv2dOutput_1", ["PaddedInput_1", "sid_2", "sid_7"], buffer_info_map_names) - _verify_conflicts("Conv2dOutput_2", ["PaddedInput_2", "sid_2", "sid_6"], buffer_info_map_names) - - _check_max_workspace_size(buffer_pool_allocations, global_workspace_pool, 7920256) + _verify_conflicts( + "PaddedInput_1", + [ + "sid_8", + "Conv2dOutput_1", + "sid_7", + ], + buffer_info_map_names, + ) + + _check_max_workspace_size(buffer_pool_allocations, global_workspace_pool, 7200000) From 26fffb0fbfac1558bef905dee1817ba2e5eeabdf Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne <manupa.karunaratne@arm.com> Date: Wed, 24 Nov 2021 18:35:09 +0000 Subject: [PATCH 3/8] [TIR][USMP] Greedy algorithms for USMP This commit implements greedy algorithms for USMP based on size and number of liveness conflicts. * This includes a slight fix for buffer info extraction where non-linear network buffers owned by the main function should not show sporadic liveness. Change-Id: I957d543e75b3b0bcf5fc1fbc7870705c875c7a03 --- src/tir/usmp/algo/greedy.cc | 235 +++++++ src/tir/usmp/algo/greedy_by_size.cc | 128 ---- src/tir/usmp/analysis/extract_buffer_info.cc | 134 ++-- ...reedy_by_size.py => test_tir_usmp_algo.py} | 116 ++-- ...st_tir_usmp_analysis_extract_bufferinfo.py | 574 +++++++++++------- 5 files changed, 746 insertions(+), 441 deletions(-) create mode 100644 src/tir/usmp/algo/greedy.cc delete mode 100644 src/tir/usmp/algo/greedy_by_size.cc rename tests/python/unittest/{test_tir_usmp_algo_greedy_by_size.py => test_tir_usmp_algo.py} (94%) diff --git a/src/tir/usmp/algo/greedy.cc b/src/tir/usmp/algo/greedy.cc new file mode 100644 index 000000000000..f0b1581cd616 --- /dev/null +++ b/src/tir/usmp/algo/greedy.cc @@ -0,0 +1,235 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/analysis/usmp/algo/greedy_by_size.cc + * \brief This source contains greedy algorithms for planning + * memory for USMP. There are two algorithms present here : + * 1) greedy_by_size and 2) greedy_by_conflicts. + * + * greedy_by_size : this algorithm prioritizes placing the + * largest size buffer to the given pools. The BufferInfo objects + * are sorted based on the size and placed on each pool adhering + * to size_hint constraint. + * + * greedy_by_conflicts : this algorithm prioritizes placing the + * the most liveness conflicted buffer to the given pools. The + * BufferInfo objects are sorted based on the number of conflicts + * and placed on each pool adhering to size_hint constraint. + */ + +#include <tvm/arith/analyzer.h> +#include <tvm/runtime/device_api.h> +#include <tvm/tir/builtin.h> +#include <tvm/tir/function.h> +#include <tvm/tir/stmt_functor.h> +#include <tvm/tir/usmp/utils.h> + +namespace tvm { +namespace tir { +namespace usmp { +namespace algo { + +/*! + * \brief This is the base class for Greedy Algorithms where the sorting + * is specialized in the extended classes based on the greedy criteria. + */ +class GreedyBase { + public: + GreedyBase() {} + /*! + * \brief This function should be implemented by the extended classes to sort the BufferInfo + * objects based on a criteria and then calling PostSortAllocation. + */ + virtual Map<BufferInfo, PoolAllocation> PlanMemory(const Array<BufferInfo>& buffer_info_arr) = 0; + + protected: + /*! + * \brief Rounds up the offset to satisfy the alignement requirement + */ + size_t round_up_to_byte_alignment(const size_t& non_aligned_byte_offset, + const int& byte_alignment) { + return ((non_aligned_byte_offset + byte_alignment - 1) / byte_alignment) * byte_alignment; + } + + /*! + * \brief A helper function check whether a offset is valid given the constraints + */ + bool IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset, + const size_t& size_bytes) { + if (candidate_pool->size_hint_bytes == -1) { + // this means pool is not bounded + return true; + } + auto pool_size = static_cast<size_t>(candidate_pool->size_hint_bytes->value); + auto max_address = next_offset + size_bytes; + if (max_address <= pool_size) { + return true; + } + return false; + } + + /*! + * \brief Selects a pool for placement in the given set of ordered pool candidates + */ + PoolInfo SelectPlacementPool( + const Array<PoolInfo>& pool_candidates, + const std::unordered_map<PoolInfo, size_t, ObjectPtrHash, ObjectPtrEqual>& pool_offsets) { + // Here the pool candidates are ordered when it is consumed by the algorithm. + // This could be from order the user has specified. However, schedulers are + // welcome to change the order for performance reasons. + for (const auto& pool_info : pool_candidates) { + if (pool_offsets.count(pool_info)) { + return pool_info; + } + } + ICHECK(false) << "TVM USMP Internal Error: no candidate have been selected!"; + return PoolInfo(); + } + + /*! + * \brief This is the base allocation function that works on sorted BufferInfo objects based + * on the greedy heuristic. The sorting algorithm has to be called before calling this. + */ + Map<BufferInfo, PoolAllocation> PostSortAllocation( + const std::vector<BufferInfo>& buffer_info_vec) { + Map<BufferInfo, PoolAllocation> pool_allocations; + for (const auto& buf_info : buffer_info_vec) { + std::unordered_map<PoolInfo, size_t, ObjectPtrHash, ObjectPtrEqual> pool_offset_candidates; + for (const auto& pool_info : buf_info->pool_candidates) { + // Mark pool candidates that satisfy the size constraints. + if (IsValidPlacement(pool_info, 0, buf_info->size_bytes->value)) { + pool_offset_candidates[pool_info] = 0; + } + } + + for (const auto& conflict_buf_info_obj : buf_info->conflicts) { + auto conflict_buf_info = Downcast<BufferInfo>(conflict_buf_info_obj); + size_t next_offset = 0; + // We only look at already allocated BufferInfo in-terms of conflicts. + if (pool_allocations.count(conflict_buf_info)) { + auto pool_allocation = pool_allocations[conflict_buf_info]; + next_offset = pool_allocation->byte_offset + conflict_buf_info->size_bytes; + next_offset = + round_up_to_byte_alignment(next_offset, conflict_buf_info->alignment->value); + // Checks whether the next offset in the same pool as the conflicting BufferInfo is valid. + if (IsValidPlacement(pool_allocation->pool_info, next_offset, + buf_info->size_bytes->value)) { + // There could be multiple conflicting BufferInfo in the same pool. + // Thus, we need to make sure we pick the largest offset of them all. + if (next_offset > pool_offset_candidates[pool_allocation->pool_info]) { + pool_offset_candidates[pool_allocation->pool_info] = next_offset; + } + } else { + pool_offset_candidates.erase(pool_allocation->pool_info); + } + } + } + auto selected_pool = SelectPlacementPool(buf_info->pool_candidates, pool_offset_candidates); + pool_allocations.Set( + buf_info, PoolAllocation(selected_pool, Integer(pool_offset_candidates[selected_pool]))); + } + return pool_allocations; + } +}; + +/*! + * \brief This class implements Greedy by the size of BufferInfo + * greedy algorithm. Please refer to main documentation of the file + * for more details. + */ +class GreedySize : public GreedyBase { + public: + GreedySize() {} + Map<BufferInfo, PoolAllocation> PlanMemory(const Array<BufferInfo>& buffer_info_arr) { + std::vector<BufferInfo> buffer_info_vec; + Map<BufferInfo, PoolAllocation> pool_allocations; + for (const auto& buffer_info : buffer_info_arr) { + buffer_info_vec.push_back(std::move(buffer_info)); + } + std::sort(buffer_info_vec.begin(), buffer_info_vec.end(), + [](const BufferInfo& a, const BufferInfo& b) { + if (a->size_bytes->value == b->size_bytes->value) { + if (a->conflicts.size() == b->conflicts.size()) { + auto a_name_hash = std::hash<std::string>{}(a->name_hint->data); + auto b_name_hash = std::hash<std::string>{}(b->name_hint->data); + return a_name_hash > b_name_hash; + } else { + return a->conflicts.size() > b->conflicts.size(); + } + } + return a->size_bytes > b->size_bytes; + }); + return PostSortAllocation(buffer_info_vec); + } +}; + +/*! + * \brief This class implements Greedy by the number of conflicts of + * BufferInfo greedy algorithm. Please refer to main documentation + * of the file for more details. + */ +class GreedyConflicts : public GreedyBase { + public: + GreedyConflicts() {} + Map<BufferInfo, PoolAllocation> PlanMemory(const Array<BufferInfo>& buffer_info_arr) { + std::vector<BufferInfo> buffer_info_vec; + Map<BufferInfo, PoolAllocation> pool_allocations; + for (const auto& buffer_info : buffer_info_arr) { + buffer_info_vec.push_back(std::move(buffer_info)); + } + std::sort(buffer_info_vec.begin(), buffer_info_vec.end(), + [](const BufferInfo& a, const BufferInfo& b) { + if (a->conflicts.size() == b->conflicts.size()) { + if (a->size_bytes->value == b->size_bytes->value) { + auto a_name_hash = std::hash<std::string>{}(a->name_hint->data); + auto b_name_hash = std::hash<std::string>{}(b->name_hint->data); + return a_name_hash > b_name_hash; + } else { + return a->size_bytes->value > b->size_bytes->value; + } + } + return a->conflicts.size() > b->conflicts.size(); + }); + return PostSortAllocation(buffer_info_vec); + } +}; + +Map<BufferInfo, PoolAllocation> GreedyBySize(const Array<BufferInfo>& buffer_info_arr) { + return GreedySize().PlanMemory(buffer_info_arr); +} + +Map<BufferInfo, PoolAllocation> GreedyByConflicts(const Array<BufferInfo>& buffer_info_arr) { + return GreedyConflicts().PlanMemory(buffer_info_arr); +} + +TVM_REGISTER_GLOBAL("tir.usmp.algo.greedy_by_size") + .set_body_typed([](Array<BufferInfo> buffer_info_arr) { + return GreedyBySize(buffer_info_arr); + }); + +TVM_REGISTER_GLOBAL("tir.usmp.algo.greedy_by_conflicts") + .set_body_typed([](Array<BufferInfo> buffer_info_arr) { + return GreedyByConflicts(buffer_info_arr); + }); + +} // namespace algo +} // namespace usmp +} // namespace tir +} // namespace tvm diff --git a/src/tir/usmp/algo/greedy_by_size.cc b/src/tir/usmp/algo/greedy_by_size.cc deleted file mode 100644 index 7aafe8c0974f..000000000000 --- a/src/tir/usmp/algo/greedy_by_size.cc +++ /dev/null @@ -1,128 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tir/analysis/usmp/algo/greedy_by_size.cc - * \brief Implement greedy by size memory planning algorithm - */ - -#include <tvm/arith/analyzer.h> -#include <tvm/runtime/device_api.h> -#include <tvm/tir/builtin.h> -#include <tvm/tir/function.h> -#include <tvm/tir/stmt_functor.h> -#include <tvm/tir/usmp/utils.h> - -namespace tvm { -namespace tir { -namespace usmp { -namespace algo { - -static size_t round_up_to_byte_alignment(const size_t& non_aligned_byte_offset, - const int& byte_alignment) { - return ((non_aligned_byte_offset + byte_alignment - 1) / byte_alignment) * byte_alignment; -} - -static bool IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset, - const size_t& size_bytes) { - if (candidate_pool->size_hint_bytes == -1) { - // this means pool is not bounded - return true; - } - auto pool_size = static_cast<size_t>(candidate_pool->size_hint_bytes->value); - auto max_address = next_offset + size_bytes; - if (max_address <= pool_size) { - return true; - } - return false; -} - -static PoolInfo SelectPlacementPool( - const Array<PoolInfo>& pool_candidates, - const std::unordered_map<PoolInfo, size_t, ObjectPtrHash, ObjectPtrEqual>& pool_offsets) { - for (const auto& pool_info : pool_candidates) { - if (pool_offsets.count(pool_info)) { - return pool_info; - } - } - ICHECK(false) << "TVM USMP Internal Error: no candidate have been selected!"; - return PoolInfo(); -} - -Map<BufferInfo, PoolAllocation> GreedyBySize(const Array<BufferInfo>& buffer_info_arr) { - std::vector<BufferInfo> buffer_info_vec; - Map<BufferInfo, PoolAllocation> pool_allocations; - for (const auto& buffer_info : buffer_info_arr) { - buffer_info_vec.push_back(std::move(buffer_info)); - } - std::sort(buffer_info_vec.begin(), buffer_info_vec.end(), - [](const BufferInfo& a, const BufferInfo& b) { - if (a->size_bytes->value == b->size_bytes->value) { - if (a->conflicts.size() == b->conflicts.size()) { - auto a_name_hash = std::hash<std::string>{}(a->name_hint->data); - auto b_name_hash = std::hash<std::string>{}(b->name_hint->data); - return a_name_hash > b_name_hash; - } else { - return a->conflicts.size() > b->conflicts.size(); - } - } - return a->size_bytes > b->size_bytes; - }); - - for (const auto& buf_info : buffer_info_vec) { - std::unordered_map<PoolInfo, size_t, ObjectPtrHash, ObjectPtrEqual> pool_offset_candidates; - for (const auto& pool_info : buf_info->pool_candidates) { - if (IsValidPlacement(pool_info, 0, buf_info->size_bytes->value)) { - pool_offset_candidates[pool_info] = 0; - } - } - - for (const auto& conflict_buf_info_obj : buf_info->conflicts) { - auto conflict_buf_info = Downcast<BufferInfo>(conflict_buf_info_obj); - size_t next_offset = 0; - if (pool_allocations.count(conflict_buf_info)) { - auto pool_allocation = pool_allocations[conflict_buf_info]; - next_offset = pool_allocation->byte_offset + conflict_buf_info->size_bytes; - next_offset = round_up_to_byte_alignment(next_offset, conflict_buf_info->alignment->value); - if (IsValidPlacement(pool_allocation->pool_info, next_offset, - buf_info->size_bytes->value)) { - if (next_offset > pool_offset_candidates[pool_allocation->pool_info]) { - pool_offset_candidates[pool_allocation->pool_info] = next_offset; - } - } else { - pool_offset_candidates.erase(pool_allocation->pool_info); - } - } - } - auto selected_pool = SelectPlacementPool(buf_info->pool_candidates, pool_offset_candidates); - pool_allocations.Set( - buf_info, PoolAllocation(selected_pool, Integer(pool_offset_candidates[selected_pool]))); - } - return pool_allocations; -} - -TVM_REGISTER_GLOBAL("tir.usmp.algo.greedy_by_size") - .set_body_typed([](Array<BufferInfo> buffer_info_arr) { - return GreedyBySize(buffer_info_arr); - }); - -} // namespace algo -} // namespace usmp -} // namespace tir -} // namespace tvm diff --git a/src/tir/usmp/analysis/extract_buffer_info.cc b/src/tir/usmp/analysis/extract_buffer_info.cc index c25578fd9779..609481ffeae9 100644 --- a/src/tir/usmp/analysis/extract_buffer_info.cc +++ b/src/tir/usmp/analysis/extract_buffer_info.cc @@ -26,6 +26,7 @@ * conflicts between other tir.allocate nodes. */ #include <tvm/arith/analyzer.h> +#include <tvm/relay/executor.h> #include <tvm/runtime/device_api.h> #include <tvm/tir/builtin.h> #include <tvm/tir/function.h> @@ -99,11 +100,21 @@ class BufferInfoExtractor : public StmtExprVisitor { */ std::unordered_map<Call, Map<tir::Stmt, Integer>, ObjectPtrHash, ObjectPtrEqual> buffer_info_end_stmt_idx_; + + /*! + * \brief This structure contains information regarding a Allocate node. + */ + struct AllocateInfo { + tir::Stmt Allocate; + PrimFunc prim_func; + Call call; + }; + /*! * \brief Maintains the mapping of buffer variable to their allocate nodes to ensure * that only one BufferInfo object is created. */ - Map<tir::Var, tir::Stmt> allocate_var_to_stmt_map_; + std::unordered_map<tir::Var, AllocateInfo, ObjectPtrHash, ObjectPtrEqual> allocate_infos; /*! * \brief Indicates a count of stmts visited so far to use as a metric of liveness */ @@ -203,29 +214,41 @@ void BufferInfoExtractor::RecordAllocateNodeInfo(const AllocateNode* op) { auto size_bytes = CalculateExtentsSize(op); // We only statically memory plan only allocates with known // compile time sizes. - if (size_bytes.defined() && - allocate_var_to_stmt_map_.find(op->buffer_var) == allocate_var_to_stmt_map_.end()) { - // By default, the core compiler is assumed to attach the a default pool to each allocate. - ICHECK(op->annotations.count(kPoolCandidatesAllocateAttr)) - << "Every statically sized allocate node needs an pool candidate attribute"; - auto pool_candidates = Downcast<Array<PoolInfo>>(op->annotations[kPoolCandidatesAllocateAttr]); - - // TODO(@manupa-arm): improve the error when the responsible component for attaching a single - // pool is added - ICHECK(pool_candidates.size() > 0) - << "The core compiler should at least attach a single PoolInfo. If there were no " - "user-given arguments for memory pools, the default behaviour is a single size " - "un-restricted pool is assigned"; - PrimFunc func = scope_stack_.top().func; - Optional<Target> tgt = func->GetAttr<Target>(tvm::attr::kTarget); - ICHECK(tgt) << "There should not be any PrimFuncs without a target attached by now"; - auto workspace_alignment = - tgt.value()->GetAttr<Integer>("workspace-byte-alignment").value_or(16); - auto buffer_info = BufferInfo(GetUniqueBufferName(op->buffer_var->name_hint), size_bytes, - pool_candidates, workspace_alignment); - auto allocate = GetRef<Allocate>(op); - allocate_var_to_stmt_map_.Set(op->buffer_var, allocate); - buffer_info_map_.Set(buffer_info, allocate); + if (size_bytes.defined()) { + if (allocate_infos.find(op->buffer_var) == allocate_infos.end()) { + // By default, the core compiler is assumed to attach the a default pool to each allocate. + ICHECK(op->annotations.count(kPoolCandidatesAllocateAttr)) + << "Every statically sized allocate node needs an pool candidate attribute"; + auto pool_candidates = + Downcast<Array<PoolInfo>>(op->annotations[kPoolCandidatesAllocateAttr]); + + // TODO(@manupa-arm): improve the error when the responsible component for attaching a single + // pool is added + ICHECK(pool_candidates.size() > 0) + << "The core compiler should at least attach a single PoolInfo. If there were no " + "user-given arguments for memory pools, the default behaviour is a single size " + "un-restricted pool is assigned"; + PrimFunc func = scope_stack_.top().func; + Optional<tvm::relay::Executor> executor_config = + module_->GetAttr<tvm::relay::Executor>(tvm::attr::kExecutor); + Integer workspace_alignment = 16; + if (executor_config) { + // ICHECK(tgt) << "There should not be any PrimFuncs without a target attached by now"; + workspace_alignment = + executor_config.value()->GetAttr<Integer>("workspace-byte-alignment").value_or(16); + } + auto buffer_info = BufferInfo(GetUniqueBufferName(op->buffer_var->name_hint), size_bytes, + pool_candidates, workspace_alignment); + auto allocate = GetRef<Allocate>(op); + allocate_infos[op->buffer_var] = + AllocateInfo{allocate, scope_stack_.top().func, scope_stack_.top().call}; + buffer_info_map_.Set(buffer_info, allocate); + } else { + // Update the allocate info with the latest call + AllocateInfo ai = allocate_infos[op->buffer_var]; + ai.call = scope_stack_.top().call; + allocate_infos[op->buffer_var] = ai; + } } } @@ -257,17 +280,25 @@ void BufferInfoExtractor::VisitStmt_(const ForNode* op) { si.initial_stmt_of_the_nested_loops = Integer(current_stmt_idx_); } Call current_call = scope_stack_.top().call; + PrimFunc current_primfunc = scope_stack_.top().func; scope_stack_.push(si); StmtExprVisitor::VisitStmt_(op); // Extending the liveness to beginning of for-loop next and end of the current for-loop for (const Allocate& allocate : scope_stack_.top().allocate_nodes) { + AllocateInfo ai = allocate_infos[allocate->buffer_var]; + Call update_call = current_call; + // If the allocate does not belong to current prim func + // We need to update the call to which the allocate belong to + if (ai.prim_func != current_primfunc) { + update_call = ai.call; + } if (scope_stack_.top().initial_stmt_of_the_nested_loops->value < - buffer_info_start_stmt_idx_[current_call][allocate]) { - buffer_info_start_stmt_idx_[current_call].Set( + buffer_info_start_stmt_idx_[update_call][allocate]) { + buffer_info_start_stmt_idx_[update_call].Set( allocate, scope_stack_.top().initial_stmt_of_the_nested_loops->value); } - if (current_stmt_idx_ > buffer_info_end_stmt_idx_[current_call][allocate]) { - buffer_info_end_stmt_idx_[current_call].Set(allocate, current_stmt_idx_); + if (current_stmt_idx_ > buffer_info_end_stmt_idx_[update_call][allocate]) { + buffer_info_end_stmt_idx_[update_call].Set(allocate, current_stmt_idx_); } } scope_stack_.pop(); @@ -286,12 +317,21 @@ void BufferInfoExtractor::VisitStmt_(const StoreNode* op) { void BufferInfoExtractor::VisitExpr_(const VarNode* op) { auto var = GetRef<Var>(op); Call current_call = scope_stack_.top().call; - if (allocate_var_to_stmt_map_.count(var)) { - auto allocate = allocate_var_to_stmt_map_[var]; - if (buffer_info_start_stmt_idx_[current_call].count(allocate) == 0) { - buffer_info_start_stmt_idx_[current_call].Set(allocate, current_stmt_idx_); + PrimFunc current_primfunc = scope_stack_.top().func; + if (allocate_infos.count(var)) { + auto allocate = allocate_infos[var].Allocate; + auto allocate_primfunc = allocate_infos[var].prim_func; + Call update_call = current_call; + if (allocate_primfunc != current_primfunc) { + // If the allocate node does not belong to the current primfunc. + // It's access should be reported to the call to PrimFunc that + // Allocate belong to. + update_call = allocate_infos[var].call; } - buffer_info_end_stmt_idx_[current_call].Set(allocate, current_stmt_idx_); + if (buffer_info_start_stmt_idx_[update_call].count(allocate) == 0) { + buffer_info_start_stmt_idx_[update_call].Set(allocate, current_stmt_idx_); + } + buffer_info_end_stmt_idx_[update_call].Set(allocate, current_stmt_idx_); ScopeInfo& currect_scope_info = scope_stack_.top(); if (currect_scope_info.for_loop.defined()) { @@ -320,13 +360,13 @@ void BufferInfoExtractor::UpdateAliases(const Array<PrimExpr>& args, const PrimF // to the original allocate if (arg->IsInstance<LoadNode>()) { auto load = Downcast<Load>(arg); - if (allocate_var_to_stmt_map_.count(load->buffer_var)) { - allocate_var_to_stmt_map_.Set(param_buf, allocate_var_to_stmt_map_[load->buffer_var]); + if (allocate_infos.count(load->buffer_var)) { + allocate_infos[param_buf] = allocate_infos[load->buffer_var]; } } else if (arg->IsInstance<VarNode>()) { auto var = Downcast<Var>(arg); - if (allocate_var_to_stmt_map_.count(var)) { - allocate_var_to_stmt_map_.Set(param_buf, allocate_var_to_stmt_map_[var]); + if (allocate_infos.count(var)) { + allocate_infos[param_buf] = allocate_infos[var]; } } } @@ -415,18 +455,30 @@ Map<BufferInfo, tir::Stmt> BufferInfoExtractor::operator()(const PrimFunc& main_ // Traverse the liveness events using a open set to track what // is live while updating the conflicts through out the linear traversal - std::unordered_set<BufferInfo, ObjectPtrHash, ObjectPtrEqual> open_set; + // std::unordered_set<BufferInfo, ObjectPtrHash, ObjectPtrEqual> open_set; + std::unordered_map<BufferInfo, int, ObjectPtrHash, ObjectPtrEqual> open_set; for (const auto& le_event : le_events_timeline) { if (le_event.le_type == START) { - for (const auto& open_buffer_info : open_set) { + for (const auto& kv : open_set) { + BufferInfo open_buffer_info = kv.first; open_buffer_info->conflicts.push_back(le_event.buffer_info); if (le_event.buffer_info != open_buffer_info) { le_event.buffer_info->conflicts.push_back(open_buffer_info); } } - open_set.insert(le_event.buffer_info); + // open_set.insert(le_event.buffer_info); + if (open_set.find(le_event.buffer_info) == open_set.end()) { + open_set[le_event.buffer_info] = 1; + } else { + open_set[le_event.buffer_info] += 1; + } } else { - open_set.erase(le_event.buffer_info); + if (open_set[le_event.buffer_info] == 1) { + open_set.erase(le_event.buffer_info); + } else { + open_set[le_event.buffer_info] -= 1; + } + // open_set.erase(le_event.buffer_info); } } return this->buffer_info_map_; diff --git a/tests/python/unittest/test_tir_usmp_algo_greedy_by_size.py b/tests/python/unittest/test_tir_usmp_algo.py similarity index 94% rename from tests/python/unittest/test_tir_usmp_algo_greedy_by_size.py rename to tests/python/unittest/test_tir_usmp_algo.py index 6bd7832f533b..69d52009fefc 100644 --- a/tests/python/unittest/test_tir_usmp_algo_greedy_by_size.py +++ b/tests/python/unittest/test_tir_usmp_algo.py @@ -167,7 +167,22 @@ def run_model(input: T.handle, output: T.handle) -> None: # fmt: on -def test_linear(): +def print_conflicts(buffer_info_map): + """_verify_conflicts("sid_8", ["Conv2dOutput_7", "tensor_2"], buffer_info_map)""" + + for buffer_info_name, buf_info in buffer_info_map.items(): + conflict_str = "[" + for conflict in buf_info.conflicts: + conflict_str += f'"{conflict.name_hint}", ' + conflict_str += "]" + print(f'_verify_conflicts("{buffer_info_name}", {conflict_str}, buffer_info_map_names)') + + +@pytest.mark.parametrize( + ["algorithm", "fast_memory_size", "slow_memory_size"], + [("greedy_by_size", 200704, 1418528), ("greedy_by_conflicts", 200704, 1418528)], +) +def test_linear(algorithm, fast_memory_size, slow_memory_size): target = Target("c") fast_memory_pool = usmp_utils.PoolInfo( pool_name="fast_memory", @@ -187,7 +202,7 @@ def test_linear(): fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo") buffer_info_arr = fcreate_array_bi(buffer_info_map) - fusmp_algo_greedy_by_size = tvm.get_global_func("tir.usmp.algo.greedy_by_size") + fusmp_algo_greedy_by_size = tvm.get_global_func(f"tir.usmp.algo.{algorithm}") buffer_pool_allocations = fusmp_algo_greedy_by_size(buffer_info_arr) buffer_info_map_names = dict() @@ -203,8 +218,8 @@ def test_linear(): ) _verify_conflicts("Conv2dOutput_7", ["sid_8", "PaddedInput_7"], buffer_info_map_names) - _check_max_workspace_size(buffer_pool_allocations, slow_memory_pool, 1418528) - _check_max_workspace_size(buffer_pool_allocations, fast_memory_pool, 200704) + _check_max_workspace_size(buffer_pool_allocations, slow_memory_pool, slow_memory_size) + _check_max_workspace_size(buffer_pool_allocations, fast_memory_pool, fast_memory_size) # fmt: off @@ -328,7 +343,10 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place # fmt: on -def test_fanout(): +@pytest.mark.parametrize( + ["algorithm", "workspace_size"], [("greedy_by_size", 7920256), ("greedy_by_conflicts", 7200256)] +) +def test_fanout(algorithm, workspace_size): target = Target("c") global_workspace_pool = usmp_utils.PoolInfo( pool_name="global_workspace", @@ -342,7 +360,7 @@ def test_fanout(): fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo") buffer_info_arr = fcreate_array_bi(buffer_info_map) - fusmp_algo_greedy_by_size = tvm.get_global_func("tir.usmp.algo.greedy_by_size") + fusmp_algo_greedy_by_size = tvm.get_global_func(f"tir.usmp.algo.{algorithm}") buffer_pool_allocations = fusmp_algo_greedy_by_size(buffer_info_arr) buffer_info_map_names = dict() @@ -351,36 +369,31 @@ def test_fanout(): # check conflicts _verify_conflicts( - "Conv2dOutput_1", + "sid_7", [ "PaddedInput_1", - "sid_7", - ], - buffer_info_map_names, - ) - _verify_conflicts( - "sid_8", - [ - "PaddedInput", - "Conv2dOutput", - "PaddedInput_1", + "sid_2", + "Conv2dOutput_1", + "PaddedInput_2", ], buffer_info_map_names, ) _verify_conflicts( - "PaddedInput_2", + "Conv2dOutput_3", [ - "sid_7", + "PaddedInput_3", "sid_6", - "Conv2dOutput_2", ], buffer_info_map_names, ) _verify_conflicts( - "sid_2", + "sid_6", [ - "PaddedInput", + "Conv2dOutput_2", + "PaddedInput_2", + "sid_2", "PaddedInput_3", + "Conv2dOutput_3", ], buffer_info_map_names, ) @@ -388,43 +401,45 @@ def test_fanout(): "Conv2dOutput", [ "sid_8", + "sid_2", "PaddedInput", ], buffer_info_map_names, ) _verify_conflicts( - "sid_7", + "PaddedInput_3", [ - "Conv2dOutput_1", - "PaddedInput_1", - "PaddedInput_2", + "sid_6", + "sid_2", + "Conv2dOutput_3", ], buffer_info_map_names, ) _verify_conflicts( - "sid_6", + "Conv2dOutput_2", [ "PaddedInput_2", - "Conv2dOutput_2", - "Conv2dOutput_3", - "PaddedInput_3", + "sid_2", + "sid_6", ], buffer_info_map_names, ) _verify_conflicts( - "PaddedInput_3", + "PaddedInput_1", [ + "sid_8", "sid_2", - "Conv2dOutput_3", - "sid_6", + "sid_7", + "Conv2dOutput_1", ], buffer_info_map_names, ) _verify_conflicts( - "Conv2dOutput_3", + "Conv2dOutput_1", [ - "PaddedInput_3", - "sid_6", + "sid_7", + "PaddedInput_1", + "sid_2", ], buffer_info_map_names, ) @@ -438,21 +453,40 @@ def test_fanout(): buffer_info_map_names, ) _verify_conflicts( - "Conv2dOutput_2", + "sid_8", [ - "sid_6", - "PaddedInput_2", + "PaddedInput", + "sid_2", + "Conv2dOutput", + "PaddedInput_1", ], buffer_info_map_names, ) _verify_conflicts( - "PaddedInput_1", + "sid_2", [ + "PaddedInput", "sid_8", + "Conv2dOutput", + "PaddedInput_1", + "sid_7", "Conv2dOutput_1", + "PaddedInput_2", + "Conv2dOutput_2", + "sid_6", + "PaddedInput_3", + ], + buffer_info_map_names, + ) + _verify_conflicts( + "PaddedInput_2", + [ "sid_7", + "sid_2", + "Conv2dOutput_2", + "sid_6", ], buffer_info_map_names, ) - _check_max_workspace_size(buffer_pool_allocations, global_workspace_pool, 7200000) + _check_max_workspace_size(buffer_pool_allocations, global_workspace_pool, workspace_size) diff --git a/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py b/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py index fa645f1379ff..abaa0cd2431e 100644 --- a/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py +++ b/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py @@ -661,149 +661,261 @@ def test_inception_structure(): # check conflicts _verify_conflicts( - "PaddedInput_8", + "sid_3", + [ + "sid_4", + "PaddedInput_2", + "sid_2", + "Conv2dOutput_2", + "PaddedInput_1", + "Conv2dOutput_1", + "sid_20", + "PaddedInput_6", + "Conv2dOutput_6", + "sid_19", + "PaddedInput_4", + ], + buffer_info_map, + ) + _verify_conflicts( + "Conv2dOutput", [ "sid_6", - "Conv2dOutput_8", - "sid_5", + "PaddedInput", ], buffer_info_map, ) _verify_conflicts( - "sid_26", + "Conv2dOutput_7", + [ + "PaddedInput_7", + "sid_8", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_4", [ + "sid_5", + "sid_3", + "PaddedInput_2", + "sid_2", + "Conv2dOutput_2", + "PaddedInput_1", + "Conv2dOutput_1", + "sid_20", + "PaddedInput_6", + "Conv2dOutput_6", + "sid_19", "PaddedInput_4", "Conv2dOutput_4", + "sid_26", "PaddedInput_5", + "Conv2dOutput_5", + "sid_25", + "tensor_3", ], buffer_info_map, ) _verify_conflicts( - "Conv2dOutput", + "sid_2", [ - "sid_6", - "PaddedInput", + "PaddedInput_2", + "sid_3", + "sid_4", + "Conv2dOutput_2", + "PaddedInput_1", + "Conv2dOutput_1", + "sid_20", + "PaddedInput_6", + "Conv2dOutput_6", + "sid_19", + "PaddedInput_4", + "Conv2dOutput_4", + "sid_26", + "PaddedInput_5", + "Conv2dOutput_5", + "sid_25", + "tensor_3", + "sid_32", + "PaddedInput_3", + "Conv2dOutput_3", + "sid_31", ], buffer_info_map, ) _verify_conflicts( - "sid_4", + "sid_19", [ - "sid_5", + "Conv2dOutput_6", + "sid_2", + "PaddedInput_6", "sid_3", + "sid_4", + "PaddedInput_4", + "Conv2dOutput_4", + "sid_26", + "PaddedInput_5", + "Conv2dOutput_5", + "sid_25", "tensor_3", + "sid_32", + "PaddedInput_3", + "Conv2dOutput_3", + "sid_31", ], buffer_info_map, ) _verify_conflicts( - "tensor_2", + "PaddedInput_2", [ - "sid_8", - "sid_7", + "sid_3", + "sid_4", + "sid_2", + "Conv2dOutput_2", ], buffer_info_map, ) _verify_conflicts( - "Conv2dOutput_7", + "Conv2dOutput_6", [ - "sid_8", - "PaddedInput_7", + "sid_2", + "PaddedInput_6", + "sid_3", + "sid_4", + "sid_19", ], buffer_info_map, ) _verify_conflicts( - "Conv2dOutput_1", + "sid_9", [ - "sid_20", - "PaddedInput_1", + "PaddedInput_7", ], buffer_info_map, ) _verify_conflicts( - "Conv2dOutput_4", + "sid_7", [ - "sid_26", - "PaddedInput_4", + "tensor_2", + "PaddedInput", ], buffer_info_map, ) _verify_conflicts( - "Conv2dOutput_2", + "PaddedInput_4", [ - "PaddedInput_2", "sid_2", + "sid_19", + "sid_3", + "sid_4", + "Conv2dOutput_4", + "sid_26", ], buffer_info_map, ) _verify_conflicts( "PaddedInput_3", [ + "sid_2", "sid_32", - "sid_31", + "sid_25", + "sid_19", "Conv2dOutput_3", + "sid_31", ], buffer_info_map, ) _verify_conflicts( - "sid_3", + "sid_5", [ + "PaddedInput_8", + "Conv2dOutput_8", "sid_4", - "PaddedInput_2", - "PaddedInput_1", - "PaddedInput_4", ], buffer_info_map, ) _verify_conflicts( - "Conv2dOutput_6", + "sid_31", [ - "PaddedInput_6", + "Conv2dOutput_3", + "PaddedInput_3", + "sid_2", + "sid_25", "sid_19", ], buffer_info_map, ) _verify_conflicts( - "Conv2dOutput_5", + "PaddedInput", [ - "PaddedInput_5", + "sid_7", + "sid_6", + "Conv2dOutput", + ], + buffer_info_map, + ) + _verify_conflicts( + "Conv2dOutput_2", + [ + "sid_2", + "PaddedInput_2", + "sid_3", + "sid_4", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_32", + [ + "tensor_3", + "sid_2", "sid_25", + "sid_19", + "PaddedInput_3", ], buffer_info_map, ) _verify_conflicts( - "PaddedInput_7", + "tensor_2", [ - "sid_9", "sid_8", - "Conv2dOutput_7", + "sid_7", ], buffer_info_map, ) _verify_conflicts( - "sid_7", + "sid_26", [ - "tensor_2", - "PaddedInput", + "Conv2dOutput_4", + "PaddedInput_4", + "sid_2", + "sid_19", + "sid_4", + "PaddedInput_5", ], buffer_info_map, ) _verify_conflicts( - "sid_31", + "Conv2dOutput_3", [ "PaddedInput_3", - "Conv2dOutput_3", - "sid_25", "sid_2", + "sid_25", "sid_19", + "sid_31", ], buffer_info_map, ) _verify_conflicts( - "sid_5", + "PaddedInput_6", [ - "Conv2dOutput_8", - "PaddedInput_8", + "sid_2", + "sid_3", + "sid_20", "sid_4", + "Conv2dOutput_6", + "sid_19", ], buffer_info_map, ) @@ -817,146 +929,132 @@ def test_inception_structure(): buffer_info_map, ) _verify_conflicts( - "sid_20", + "PaddedInput_8", [ - "PaddedInput_1", - "Conv2dOutput_1", - "PaddedInput_6", + "sid_6", + "sid_5", + "Conv2dOutput_8", ], buffer_info_map, ) _verify_conflicts( - "Conv2dOutput_8", + "Conv2dOutput_5", [ - "PaddedInput_8", - "sid_5", + "PaddedInput_5", + "sid_2", + "sid_19", + "sid_4", + "sid_25", ], buffer_info_map, ) _verify_conflicts( - "PaddedInput_1", + "Conv2dOutput_1", [ + "PaddedInput_1", + "sid_2", "sid_3", + "sid_4", "sid_20", - "Conv2dOutput_1", ], buffer_info_map, ) _verify_conflicts( - "Conv2dOutput_3", + "tensor_3", [ - "sid_31", - "PaddedInput_3", + "sid_2", + "sid_25", + "sid_19", + "sid_4", + "sid_32", ], buffer_info_map, ) _verify_conflicts( - "PaddedInput", + "sid_8", [ - "sid_7", - "sid_6", - "Conv2dOutput", + "Conv2dOutput_7", + "PaddedInput_7", + "tensor_2", ], buffer_info_map, ) _verify_conflicts( - "PaddedInput_2", + "sid_20", [ - "sid_3", - "Conv2dOutput_2", + "Conv2dOutput_1", + "PaddedInput_1", "sid_2", - ], - buffer_info_map, - ) - _verify_conflicts( - "sid_19", - [ - "Conv2dOutput_6", + "sid_3", + "sid_4", "PaddedInput_6", - "sid_31", - "sid_2", - "sid_25", ], buffer_info_map, ) _verify_conflicts( - "PaddedInput_4", + "Conv2dOutput_8", [ - "sid_3", - "sid_26", - "Conv2dOutput_4", + "sid_5", + "PaddedInput_8", ], buffer_info_map, ) _verify_conflicts( - "PaddedInput_5", + "PaddedInput_1", [ - "sid_26", - "Conv2dOutput_5", - "sid_25", + "sid_2", + "sid_3", + "sid_4", + "Conv2dOutput_1", + "sid_20", ], buffer_info_map, ) _verify_conflicts( - "PaddedInput_6", + "Conv2dOutput_4", [ - "sid_20", - "Conv2dOutput_6", + "PaddedInput_4", + "sid_2", "sid_19", + "sid_4", + "sid_26", ], buffer_info_map, ) _verify_conflicts( "sid_25", [ - "Conv2dOutput_5", "PaddedInput_5", - "sid_31", + "Conv2dOutput_5", "sid_2", "sid_19", - ], - buffer_info_map, - ) - _verify_conflicts( - "tensor_3", - [ "sid_4", - "sid_32", - ], - buffer_info_map, - ) - _verify_conflicts( - "sid_32", - [ "tensor_3", + "sid_32", "PaddedInput_3", + "Conv2dOutput_3", + "sid_31", ], buffer_info_map, ) _verify_conflicts( - "sid_9", - [ - "PaddedInput_7", - ], - buffer_info_map, - ) - _verify_conflicts( - "sid_2", + "PaddedInput_7", [ - "Conv2dOutput_2", - "PaddedInput_2", - "sid_31", - "sid_25", - "sid_19", + "sid_9", + "Conv2dOutput_7", + "sid_8", ], buffer_info_map, ) _verify_conflicts( - "sid_8", + "PaddedInput_5", [ - "PaddedInput_7", - "Conv2dOutput_7", - "tensor_2", + "sid_2", + "sid_19", + "sid_26", + "sid_4", + "Conv2dOutput_5", + "sid_25", ], buffer_info_map, ) @@ -1276,56 +1374,70 @@ def test_multiple_calls_to_same_primfunc(): # check conflicts _verify_conflicts( - "sid_18", + "sid_6", [ - "sid_19", - "sid_2", - "T_softmax_exp2", - "T_softmax_maxelem2", - "T_softmax_expsum2", - "T_softmax_norm2", + "sid_7", + "sid_12", + "compute", + "compute_global", + "sid_11", + "sid_10", + "T_softmax_exp", + "T_softmax_maxelem", + "sid_5", + "T_softmax_norm", + "T_softmax_expsum", ], buffer_info_map, ) _verify_conflicts( - "sid_3", + "T_softmax_exp", [ - "data_pad", - "conv2d_NCHWc_global", - "sid_2", + "sid_10", + "sid_6", + "T_softmax_maxelem", + "sid_5", + "T_softmax_norm", + "T_softmax_expsum", ], buffer_info_map, ) _verify_conflicts( - "T_softmax_norm", + "T_softmax_expsum2", [ - "T_softmax_expsum", - "T_softmax_exp", - "sid_5", - "sid_6", - "T_softmax_maxelem", - "sid_10", + "T_softmax_exp2", + "T_softmax_norm2", + "sid_18", + "T_softmax_maxelem2", + "sid_2", ], buffer_info_map, ) _verify_conflicts( - "T_softmax_norm2", + "compute", [ - "T_softmax_expsum2", - "T_softmax_maxelem2", - "T_softmax_exp2", - "sid_18", + "sid_12", + "sid_6", + "compute_global", + "sid_11", + "sid_19", + "sid_20", "sid_2", + "compute_global", ], buffer_info_map, ) _verify_conflicts( - "sid_11", + "compute_global", [ "compute", "sid_12", - "compute_global", - "sid_10", + "sid_6", + "sid_11", + "compute", + "sid_19", + "sid_20", + "sid_2", ], buffer_info_map, ) @@ -1334,75 +1446,93 @@ def test_multiple_calls_to_same_primfunc(): [ "sid_11", "sid_6", + "T_softmax_exp", + "T_softmax_maxelem", "sid_5", "T_softmax_norm", "T_softmax_expsum", - "T_softmax_maxelem", - "T_softmax_exp", ], buffer_info_map, ) _verify_conflicts( - "sid_5", + "sid_2", [ - "T_softmax_norm", - "T_softmax_expsum", - "T_softmax_exp", - "sid_6", - "T_softmax_maxelem", - "sid_10", - "sid_4", + "sid_3", + "sid_5", "sid_20", + "sid_19", + "compute", + "compute_global", + "sid_18", + "T_softmax_norm2", + "T_softmax_exp2", + "T_softmax_maxelem2", + "T_softmax_expsum2", ], buffer_info_map, ) _verify_conflicts( - "T_softmax_expsum", + "sid_5", [ - "T_softmax_exp", - "T_softmax_norm", - "sid_5", - "sid_6", "T_softmax_maxelem", "sid_10", + "T_softmax_exp", + "sid_6", + "T_softmax_norm", + "T_softmax_expsum", + "sid_4", + "data_pad", + "sid_3", + "conv2d_NCHWc_global", + "sid_2", + "sid_20", ], buffer_info_map, ) _verify_conflicts( - "sid_8", + "T_softmax_norm2", [ - "data_pad", + "sid_18", + "sid_2", + "T_softmax_exp2", + "T_softmax_maxelem2", + "T_softmax_expsum2", ], buffer_info_map, ) _verify_conflicts( - "T_softmax_expsum2", + "sid_20", [ - "T_softmax_maxelem2", - "T_softmax_exp2", - "sid_18", "sid_2", - "T_softmax_norm2", + "sid_5", + "sid_19", + "compute", + "compute_global", ], buffer_info_map, ) _verify_conflicts( - "T_softmax_maxelem2", + "T_softmax_expsum", [ - "T_softmax_exp2", - "sid_18", - "sid_2", - "T_softmax_expsum2", - "T_softmax_norm2", + "sid_5", + "T_softmax_norm", + "T_softmax_maxelem", + "sid_10", + "T_softmax_exp", + "sid_6", ], buffer_info_map, ) _verify_conflicts( - "sid_12", + "data_pad", [ - "sid_11", - "compute", - "compute_global", + "sid_8", + "conv2d_NCHWc_global", + "sid_7", + "sid_4", + "sid_5", + "sid_3", + "conv2d_NCHWc_global", ], buffer_info_map, ) @@ -1410,6 +1540,7 @@ def test_multiple_calls_to_same_primfunc(): "sid_19", [ "sid_20", + "sid_2", "compute", "compute_global", "sid_18", @@ -1423,17 +1554,19 @@ def test_multiple_calls_to_same_primfunc(): "sid_7", "sid_3", "data_pad", + "sid_5", ], buffer_info_map, ) _verify_conflicts( - "T_softmax_exp2", + "sid_18", [ - "sid_18", + "sid_19", "sid_2", + "T_softmax_norm2", + "T_softmax_exp2", "T_softmax_maxelem2", "T_softmax_expsum2", - "T_softmax_norm2", ], buffer_info_map, ) @@ -1447,105 +1580,84 @@ def test_multiple_calls_to_same_primfunc(): buffer_info_map, ) _verify_conflicts( - "data_pad", + "T_softmax_exp2", [ - "sid_8", - "conv2d_NCHWc_global", - "sid_7", - "sid_4", - "sid_3", - "conv2d_NCHWc_global", + "T_softmax_norm2", + "sid_18", + "sid_2", + "T_softmax_maxelem2", + "T_softmax_expsum2", ], buffer_info_map, ) _verify_conflicts( - "sid_20", + "sid_4", [ "sid_5", - "sid_19", - "compute", - "compute_global", + "data_pad", ], buffer_info_map, ) _verify_conflicts( - "sid_4", + "T_softmax_maxelem", [ + "sid_10", + "T_softmax_exp", + "sid_6", "sid_5", - "data_pad", + "T_softmax_norm", + "T_softmax_expsum", ], buffer_info_map, ) _verify_conflicts( - "T_softmax_exp", + "T_softmax_maxelem2", [ - "T_softmax_expsum", - "T_softmax_norm", - "sid_5", - "sid_6", - "T_softmax_maxelem", - "sid_10", + "T_softmax_exp2", + "T_softmax_norm2", + "sid_18", + "sid_2", + "T_softmax_expsum2", ], buffer_info_map, ) _verify_conflicts( - "compute_global", + "sid_11", [ - "sid_12", - "sid_11", - "compute", "compute", - "sid_20", - "sid_19", + "sid_12", + "compute_global", + "sid_6", + "sid_10", ], buffer_info_map, ) _verify_conflicts( - "compute", + "sid_12", [ - "sid_11", - "sid_12", - "compute_global", - "sid_20", - "sid_19", + "sid_6", + "compute", "compute_global", + "sid_11", ], buffer_info_map, ) _verify_conflicts( - "sid_6", + "T_softmax_norm", [ - "sid_7", "sid_5", - "T_softmax_norm", - "T_softmax_expsum", - "T_softmax_exp", "T_softmax_maxelem", "sid_10", - ], - buffer_info_map, - ) - _verify_conflicts( - "T_softmax_maxelem", - [ + "T_softmax_exp", "sid_6", - "sid_5", - "T_softmax_norm", "T_softmax_expsum", - "T_softmax_exp", - "sid_10", ], buffer_info_map, ) _verify_conflicts( - "sid_2", + "sid_8", [ - "sid_3", - "sid_18", - "T_softmax_exp2", - "T_softmax_maxelem2", - "T_softmax_expsum2", - "T_softmax_norm2", + "data_pad", ], buffer_info_map, ) From 1d5783214bce107d4f4abfd52483463fe6bc5c98 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne <manupa.karunaratne@arm.com> Date: Mon, 29 Nov 2021 19:10:26 +0000 Subject: [PATCH 4/8] Removing unimplemented Python-APi for USMP There was some remanants of unimplemented python APIs related BufferInfo. They are removed now. Change-Id: I4de1d817eb34187bc20da2ac2b1cb0da5b372833 --- python/tvm/tir/usmp/utils.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/python/tvm/tir/usmp/utils.py b/python/tvm/tir/usmp/utils.py index 0445775869e8..188d4c57810b 100644 --- a/python/tvm/tir/usmp/utils.py +++ b/python/tvm/tir/usmp/utils.py @@ -114,18 +114,6 @@ def __init__( alignment, ) - def set_pool_candidates(self, pool_candidates: list): - """Sets the pool candidate names""" - _ffi_api.BufferInfoSetPoolCandidates(self, pool_candidates) - - def set_pool_offsets(self, pool_name: str, pool_offset: int): - """Sets the pool offset by name""" - _ffi_api.BufferInfoSetPoolOffset(self, pool_name, pool_offset) - - def set_conflicts(self, conflicts: list): - """Sets the the conflicting array of buffer info objects""" - _ffi_api.BufferInfoSetConflicts(self, conflicts) - @register_object("tir.usmp.PoolAllocation") class PoolAllocation(Object): From 78e099b6e35cb54cbf9eafd97096e136689d1a43 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne <manupa.karunaratne@arm.com> Date: Mon, 29 Nov 2021 19:15:32 +0000 Subject: [PATCH 5/8] Removing commented code from extract buffer info pass Change-Id: Ia2d24753398bb388918aecba3b0191100d5100a6 --- src/tir/usmp/analysis/extract_buffer_info.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/tir/usmp/analysis/extract_buffer_info.cc b/src/tir/usmp/analysis/extract_buffer_info.cc index 609481ffeae9..f82db06685f4 100644 --- a/src/tir/usmp/analysis/extract_buffer_info.cc +++ b/src/tir/usmp/analysis/extract_buffer_info.cc @@ -233,7 +233,6 @@ void BufferInfoExtractor::RecordAllocateNodeInfo(const AllocateNode* op) { module_->GetAttr<tvm::relay::Executor>(tvm::attr::kExecutor); Integer workspace_alignment = 16; if (executor_config) { - // ICHECK(tgt) << "There should not be any PrimFuncs without a target attached by now"; workspace_alignment = executor_config.value()->GetAttr<Integer>("workspace-byte-alignment").value_or(16); } From 7c93c60c9a383ff74934eff63ae79a9d02fe8357 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne <manupa.karunaratne@arm.com> Date: Tue, 30 Nov 2021 19:55:31 +0000 Subject: [PATCH 6/8] [TIR][USMP] Greedy algorithms for USMP This commits removes commented out lines ,few trivial cleanups and few BufferInfo based tests to check the algorithm. Change-Id: I1a12b6a424370e9e4c4a55563dde0ad698b07ea3 --- python/tvm/tir/usmp/utils.py | 4 + src/tir/usmp/algo/greedy.cc | 10 +- src/tir/usmp/analysis/extract_buffer_info.cc | 3 - tests/python/unittest/test_tir_usmp_algo.py | 232 +++++++++++++++++-- 4 files changed, 217 insertions(+), 32 deletions(-) diff --git a/python/tvm/tir/usmp/utils.py b/python/tvm/tir/usmp/utils.py index 188d4c57810b..470765174acb 100644 --- a/python/tvm/tir/usmp/utils.py +++ b/python/tvm/tir/usmp/utils.py @@ -114,6 +114,10 @@ def __init__( alignment, ) + def set_conflicts(self, conflicts: list): + """Sets the the conflicting array of buffer info objects""" + _ffi_api.BufferInfoSetConflicts(self, conflicts) + @register_object("tir.usmp.PoolAllocation") class PoolAllocation(Object): diff --git a/src/tir/usmp/algo/greedy.cc b/src/tir/usmp/algo/greedy.cc index f0b1581cd616..78afe0333c99 100644 --- a/src/tir/usmp/algo/greedy.cc +++ b/src/tir/usmp/algo/greedy.cc @@ -18,7 +18,7 @@ */ /*! - * \file tir/analysis/usmp/algo/greedy_by_size.cc + * \file tir/analysis/usmp/algo/greedy.cc * \brief This source contains greedy algorithms for planning * memory for USMP. There are two algorithms present here : * 1) greedy_by_size and 2) greedy_by_conflicts. @@ -89,17 +89,17 @@ class GreedyBase { * \brief Selects a pool for placement in the given set of ordered pool candidates */ PoolInfo SelectPlacementPool( - const Array<PoolInfo>& pool_candidates, + const BufferInfo& buf_info, const std::unordered_map<PoolInfo, size_t, ObjectPtrHash, ObjectPtrEqual>& pool_offsets) { // Here the pool candidates are ordered when it is consumed by the algorithm. // This could be from order the user has specified. However, schedulers are // welcome to change the order for performance reasons. - for (const auto& pool_info : pool_candidates) { + for (const auto& pool_info : buf_info->pool_candidates) { if (pool_offsets.count(pool_info)) { return pool_info; } } - ICHECK(false) << "TVM USMP Internal Error: no candidate have been selected!"; + CHECK(false) << "TVM USMP Error: no candidate have been selected for " << buf_info; return PoolInfo(); } @@ -141,7 +141,7 @@ class GreedyBase { } } } - auto selected_pool = SelectPlacementPool(buf_info->pool_candidates, pool_offset_candidates); + auto selected_pool = SelectPlacementPool(buf_info, pool_offset_candidates); pool_allocations.Set( buf_info, PoolAllocation(selected_pool, Integer(pool_offset_candidates[selected_pool]))); } diff --git a/src/tir/usmp/analysis/extract_buffer_info.cc b/src/tir/usmp/analysis/extract_buffer_info.cc index f82db06685f4..3fea7211672f 100644 --- a/src/tir/usmp/analysis/extract_buffer_info.cc +++ b/src/tir/usmp/analysis/extract_buffer_info.cc @@ -454,7 +454,6 @@ Map<BufferInfo, tir::Stmt> BufferInfoExtractor::operator()(const PrimFunc& main_ // Traverse the liveness events using a open set to track what // is live while updating the conflicts through out the linear traversal - // std::unordered_set<BufferInfo, ObjectPtrHash, ObjectPtrEqual> open_set; std::unordered_map<BufferInfo, int, ObjectPtrHash, ObjectPtrEqual> open_set; for (const auto& le_event : le_events_timeline) { if (le_event.le_type == START) { @@ -465,7 +464,6 @@ Map<BufferInfo, tir::Stmt> BufferInfoExtractor::operator()(const PrimFunc& main_ le_event.buffer_info->conflicts.push_back(open_buffer_info); } } - // open_set.insert(le_event.buffer_info); if (open_set.find(le_event.buffer_info) == open_set.end()) { open_set[le_event.buffer_info] = 1; } else { @@ -477,7 +475,6 @@ Map<BufferInfo, tir::Stmt> BufferInfoExtractor::operator()(const PrimFunc& main_ } else { open_set[le_event.buffer_info] -= 1; } - // open_set.erase(le_event.buffer_info); } } return this->buffer_info_map_; diff --git a/tests/python/unittest/test_tir_usmp_algo.py b/tests/python/unittest/test_tir_usmp_algo.py index 69d52009fefc..4efb88affd16 100644 --- a/tests/python/unittest/test_tir_usmp_algo.py +++ b/tests/python/unittest/test_tir_usmp_algo.py @@ -51,7 +51,7 @@ def get_allocate(stmt): return allocates -def assign_poolinfos_to_allocates_in_primfunc(primfunc, pool_infos): +def _assign_poolinfos_to_allocates_in_primfunc(primfunc, pool_infos): """helper to assing poolinfos to allocate nodes in a tir.PrimFunc""" def set_poolinfos(stmt): @@ -68,12 +68,12 @@ def set_poolinfos(stmt): return primfunc.with_body(stmt_functor.ir_transform(primfunc.body, None, set_poolinfos)) -def assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos): +def _assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos): """helper to assing poolinfos to allocate nodes in a IRModule""" ret = tvm.IRModule() for global_var, basefunc in mod.functions.items(): if isinstance(basefunc, tvm.tir.PrimFunc): - ret[global_var] = assign_poolinfos_to_allocates_in_primfunc(basefunc, pool_infos) + ret[global_var] = _assign_poolinfos_to_allocates_in_primfunc(basefunc, pool_infos) return ret @@ -96,9 +96,204 @@ def _check_max_workspace_size(buffer_pool_allocations, pool_info, size): assert max_workspace_size == size +def test_no_pool_error(): + target = Target("c") + tiny_workspace_pool = usmp_utils.PoolInfo( + pool_name="tiny_workspace", + target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + size_hint_bytes=10, + ) + bi_a = usmp_utils.BufferInfo( + name_hint="bi_a", size_bytes=10, pool_candidates=[tiny_workspace_pool] + ) + bi_b = usmp_utils.BufferInfo( + name_hint="bi_b", size_bytes=10, pool_candidates=[tiny_workspace_pool] + ) + bi_c = usmp_utils.BufferInfo( + name_hint="bi_c", size_bytes=10, pool_candidates=[tiny_workspace_pool] + ) + bi_a.set_conflicts([bi_b]) + bi_b.set_conflicts([bi_c]) + bi_c.set_conflicts([bi_a]) + buffer_info_arr = [bi_a, bi_b, bi_c] + fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.greedy_by_size") + with pytest.raises( + tvm.TVMError, match="TVM USMP Error: no candidate have been selected for BufferInfoNode" + ): + buffer_pool_allocations = fusmp_algo(buffer_info_arr) + + +@pytest.mark.parametrize("algorithm", ["greedy_by_size", "greedy_by_conflicts"]) +def test_name_based_ordering(algorithm): + """ This checks when the size and conlicts are same a stable result is generated""" + + def _test(): + target = Target("c") + global_workspace_pool = usmp_utils.PoolInfo( + pool_name="global_workspace", + target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + ) + bi_a = usmp_utils.BufferInfo( + name_hint="bi_a", size_bytes=10, pool_candidates=[global_workspace_pool] + ) + bi_b = usmp_utils.BufferInfo( + name_hint="bi_b", size_bytes=10, pool_candidates=[global_workspace_pool] + ) + bi_c = usmp_utils.BufferInfo( + name_hint="bi_c", size_bytes=10, pool_candidates=[global_workspace_pool] + ) + bi_a.set_conflicts([bi_b]) + bi_b.set_conflicts([bi_c]) + bi_c.set_conflicts([bi_a]) + + buffer_info_arr = [bi_a, bi_b, bi_c] + fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}") + buffer_pool_allocations = fusmp_algo(buffer_info_arr) + assert buffer_pool_allocations[bi_a].byte_offset == 0 + assert buffer_pool_allocations[bi_b].byte_offset == 20 + assert buffer_pool_allocations[bi_c].byte_offset == 10 + + # This is tested for several times to check stability + for x in range(0, 10): + _test() + + +@pytest.mark.parametrize( + ["algorithm", "workspace_size"], + [("greedy_by_size", 140), ("greedy_by_conflicts", 140)], +) +def test_linear(algorithm, workspace_size): + """ + The test case here represent BufferInfo objects + that could get generated for a linear sequence + such as : + (Op A) + | + bi_a + | + (Op B) + | + bi_b + | + . + . + . + (Op F) + | + bi_f + """ + target = Target("c") + global_workspace_pool = usmp_utils.PoolInfo( + pool_name="global_workspace", + target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + ) + bi_a = usmp_utils.BufferInfo( + name_hint="bi_a", size_bytes=10, pool_candidates=[global_workspace_pool] + ) + bi_b = usmp_utils.BufferInfo( + name_hint="bi_b", size_bytes=20, pool_candidates=[global_workspace_pool] + ) + bi_c = usmp_utils.BufferInfo( + name_hint="bi_c", size_bytes=100, pool_candidates=[global_workspace_pool] + ) + bi_d = usmp_utils.BufferInfo( + name_hint="bi_d", size_bytes=40, pool_candidates=[global_workspace_pool] + ) + bi_e = usmp_utils.BufferInfo( + name_hint="bi_e", size_bytes=50, pool_candidates=[global_workspace_pool] + ) + bi_f = usmp_utils.BufferInfo( + name_hint="bi_f", size_bytes=50, pool_candidates=[global_workspace_pool] + ) + + # Creating conflicts for a linear graph + bi_a.set_conflicts([bi_b]) + bi_b.set_conflicts([bi_a, bi_c]) + bi_c.set_conflicts([bi_b, bi_d]) + bi_d.set_conflicts([bi_c, bi_e]) + bi_e.set_conflicts([bi_d, bi_f]) + bi_f.set_conflicts([bi_e]) + + buffer_info_arr = [bi_a, bi_b, bi_c, bi_d, bi_e, bi_f] + fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}") + buffer_pool_allocations = fusmp_algo(buffer_info_arr) + _check_max_workspace_size(buffer_pool_allocations, global_workspace_pool, workspace_size) + + +@pytest.mark.parametrize( + ["algorithm", "workspace_size"], + [("greedy_by_size", 190), ("greedy_by_conflicts", 320)], +) +def test_fanout(algorithm, workspace_size): + """ + The test case here represent BufferInfo objects + that could get generated for a fanout topology + such as : + (Op A) + | + bi_a --------- + | | + (Op B) (Op C) + | | + bi_b bi_c + | | + (Op D) (Op E) + | | + bi_d bi_e + | | + (Op F) ------ + | + bi_f + | + (Op G) + | + bi_g + """ + target = Target("c") + global_workspace_pool = usmp_utils.PoolInfo( + pool_name="global_workspace", + target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + ) + bi_a = usmp_utils.BufferInfo( + name_hint="bi_a", size_bytes=10, pool_candidates=[global_workspace_pool] + ) + bi_b = usmp_utils.BufferInfo( + name_hint="bi_b", size_bytes=20, pool_candidates=[global_workspace_pool] + ) + bi_c = usmp_utils.BufferInfo( + name_hint="bi_c", size_bytes=100, pool_candidates=[global_workspace_pool] + ) + bi_d = usmp_utils.BufferInfo( + name_hint="bi_d", size_bytes=40, pool_candidates=[global_workspace_pool] + ) + bi_e = usmp_utils.BufferInfo( + name_hint="bi_e", size_bytes=50, pool_candidates=[global_workspace_pool] + ) + bi_f = usmp_utils.BufferInfo( + name_hint="bi_f", size_bytes=60, pool_candidates=[global_workspace_pool] + ) + bi_g = usmp_utils.BufferInfo( + name_hint="bi_g", size_bytes=70, pool_candidates=[global_workspace_pool] + ) + + # Creating conflicts for a linear graph + bi_a.set_conflicts([bi_b, bi_c]) + bi_b.set_conflicts([bi_a, bi_c, bi_e]) + bi_c.set_conflicts([bi_e, bi_a, bi_b, bi_d]) + bi_d.set_conflicts([bi_b, bi_f, bi_c, bi_e]) + bi_e.set_conflicts([bi_c, bi_f, bi_b, bi_d]) + bi_f.set_conflicts([bi_d, bi_e, bi_f]) + bi_g.set_conflicts([bi_f]) + + buffer_info_arr = [bi_a, bi_b, bi_c, bi_d, bi_e, bi_f, bi_g] + fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}") + buffer_pool_allocations = fusmp_algo(buffer_info_arr) + _check_max_workspace_size(buffer_pool_allocations, global_workspace_pool, workspace_size) + + # fmt: off @tvm.script.ir_module -class LinearStructure: +class MobilenetStructure: @T.prim_func def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None: # function attr dict @@ -167,22 +362,11 @@ def run_model(input: T.handle, output: T.handle) -> None: # fmt: on -def print_conflicts(buffer_info_map): - """_verify_conflicts("sid_8", ["Conv2dOutput_7", "tensor_2"], buffer_info_map)""" - - for buffer_info_name, buf_info in buffer_info_map.items(): - conflict_str = "[" - for conflict in buf_info.conflicts: - conflict_str += f'"{conflict.name_hint}", ' - conflict_str += "]" - print(f'_verify_conflicts("{buffer_info_name}", {conflict_str}, buffer_info_map_names)') - - @pytest.mark.parametrize( ["algorithm", "fast_memory_size", "slow_memory_size"], [("greedy_by_size", 200704, 1418528), ("greedy_by_conflicts", 200704, 1418528)], ) -def test_linear(algorithm, fast_memory_size, slow_memory_size): +def test_mobilenet_subgraph(algorithm, fast_memory_size, slow_memory_size): target = Target("c") fast_memory_pool = usmp_utils.PoolInfo( pool_name="fast_memory", @@ -192,9 +376,9 @@ def test_linear(algorithm, fast_memory_size, slow_memory_size): slow_memory_pool = usmp_utils.PoolInfo( pool_name="slow_memory", target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS} ) - tir_mod = LinearStructure + tir_mod = MobilenetStructure tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target) - tir_mod = assign_poolinfos_to_allocates_in_irmodule( + tir_mod = _assign_poolinfos_to_allocates_in_irmodule( tir_mod, [fast_memory_pool, slow_memory_pool] ) main_func = tir_mod["run_model"] @@ -202,8 +386,8 @@ def test_linear(algorithm, fast_memory_size, slow_memory_size): fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo") buffer_info_arr = fcreate_array_bi(buffer_info_map) - fusmp_algo_greedy_by_size = tvm.get_global_func(f"tir.usmp.algo.{algorithm}") - buffer_pool_allocations = fusmp_algo_greedy_by_size(buffer_info_arr) + fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}") + buffer_pool_allocations = fusmp_algo(buffer_info_arr) buffer_info_map_names = dict() for buf_info in buffer_info_arr: @@ -346,7 +530,7 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place @pytest.mark.parametrize( ["algorithm", "workspace_size"], [("greedy_by_size", 7920256), ("greedy_by_conflicts", 7200256)] ) -def test_fanout(algorithm, workspace_size): +def test_resnet_subgraph(algorithm, workspace_size): target = Target("c") global_workspace_pool = usmp_utils.PoolInfo( pool_name="global_workspace", @@ -354,14 +538,14 @@ def test_fanout(algorithm, workspace_size): ) tir_mod = ResnetStructure tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target) - tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_workspace_pool]) + tir_mod = _assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_workspace_pool]) main_func = tir_mod["tvmgen_default_run_model"] buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod) fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo") buffer_info_arr = fcreate_array_bi(buffer_info_map) - fusmp_algo_greedy_by_size = tvm.get_global_func(f"tir.usmp.algo.{algorithm}") - buffer_pool_allocations = fusmp_algo_greedy_by_size(buffer_info_arr) + fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}") + buffer_pool_allocations = fusmp_algo(buffer_info_arr) buffer_info_map_names = dict() for buf_info in buffer_info_arr: From 6101e61017619ee52f9d582de101a491d7172d5e Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne <manupa.karunaratne@arm.com> Date: Wed, 1 Dec 2021 01:54:06 +0000 Subject: [PATCH 7/8] [TIR][USMP] Greedy algorithms for USMP Changed sorting criteria use alphabetic ordering as opposed to hashes of string as it seemed different accross different platforms. Change-Id: Ia7938d1b0d1374924c3ec7287526ccf374c54eb7 --- src/tir/usmp/algo/greedy.cc | 8 ++------ tests/python/unittest/test_tir_usmp_algo.py | 6 +++--- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/src/tir/usmp/algo/greedy.cc b/src/tir/usmp/algo/greedy.cc index 78afe0333c99..8d65e0a10ef0 100644 --- a/src/tir/usmp/algo/greedy.cc +++ b/src/tir/usmp/algo/greedy.cc @@ -167,9 +167,7 @@ class GreedySize : public GreedyBase { [](const BufferInfo& a, const BufferInfo& b) { if (a->size_bytes->value == b->size_bytes->value) { if (a->conflicts.size() == b->conflicts.size()) { - auto a_name_hash = std::hash<std::string>{}(a->name_hint->data); - auto b_name_hash = std::hash<std::string>{}(b->name_hint->data); - return a_name_hash > b_name_hash; + return std::string(a->name_hint->data) > std::string(b->name_hint->data); } else { return a->conflicts.size() > b->conflicts.size(); } @@ -198,9 +196,7 @@ class GreedyConflicts : public GreedyBase { [](const BufferInfo& a, const BufferInfo& b) { if (a->conflicts.size() == b->conflicts.size()) { if (a->size_bytes->value == b->size_bytes->value) { - auto a_name_hash = std::hash<std::string>{}(a->name_hint->data); - auto b_name_hash = std::hash<std::string>{}(b->name_hint->data); - return a_name_hash > b_name_hash; + return std::string(a->name_hint->data) > std::string(b->name_hint->data); } else { return a->size_bytes->value > b->size_bytes->value; } diff --git a/tests/python/unittest/test_tir_usmp_algo.py b/tests/python/unittest/test_tir_usmp_algo.py index 4efb88affd16..32cd30fb5bed 100644 --- a/tests/python/unittest/test_tir_usmp_algo.py +++ b/tests/python/unittest/test_tir_usmp_algo.py @@ -149,9 +149,9 @@ def _test(): buffer_info_arr = [bi_a, bi_b, bi_c] fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}") buffer_pool_allocations = fusmp_algo(buffer_info_arr) - assert buffer_pool_allocations[bi_a].byte_offset == 0 - assert buffer_pool_allocations[bi_b].byte_offset == 20 - assert buffer_pool_allocations[bi_c].byte_offset == 10 + assert buffer_pool_allocations[bi_a].byte_offset == 20 + assert buffer_pool_allocations[bi_b].byte_offset == 10 + assert buffer_pool_allocations[bi_c].byte_offset == 0 # This is tested for several times to check stability for x in range(0, 10): From 87b0ac9cda13170fd634c09492a0c3e3fbf0f35e Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne <manupa.karunaratne@arm.com> Date: Wed, 1 Dec 2021 15:08:39 +0000 Subject: [PATCH 8/8] [TIR][USMP] Greedy algorithms for USMP Improving the error message Change-Id: Ib59efb172fe10b70f88a24f4874a7891e8a9cde7 --- src/tir/usmp/algo/greedy.cc | 4 +++- tests/python/unittest/test_tir_usmp_algo.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/tir/usmp/algo/greedy.cc b/src/tir/usmp/algo/greedy.cc index 8d65e0a10ef0..b98d828ae745 100644 --- a/src/tir/usmp/algo/greedy.cc +++ b/src/tir/usmp/algo/greedy.cc @@ -99,7 +99,9 @@ class GreedyBase { return pool_info; } } - CHECK(false) << "TVM USMP Error: no candidate have been selected for " << buf_info; + CHECK(false) << "TVM USMP Error: the space available in the provided pools exceeded when " + "trying to allocate the buffer : " + << buf_info << "\n. Please increase the size_hints for memory pools."; return PoolInfo(); } diff --git a/tests/python/unittest/test_tir_usmp_algo.py b/tests/python/unittest/test_tir_usmp_algo.py index 32cd30fb5bed..61a70ad062d1 100644 --- a/tests/python/unittest/test_tir_usmp_algo.py +++ b/tests/python/unittest/test_tir_usmp_algo.py @@ -118,7 +118,7 @@ def test_no_pool_error(): buffer_info_arr = [bi_a, bi_b, bi_c] fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.greedy_by_size") with pytest.raises( - tvm.TVMError, match="TVM USMP Error: no candidate have been selected for BufferInfoNode" + tvm.TVMError, match="TVM USMP Error: the space available in the provided pools exceeded" ): buffer_pool_allocations = fusmp_algo(buffer_info_arr)