diff --git a/python/tvm/tir/usmp/utils.py b/python/tvm/tir/usmp/utils.py index 0445775869e8..470765174acb 100644 --- a/python/tvm/tir/usmp/utils.py +++ b/python/tvm/tir/usmp/utils.py @@ -114,14 +114,6 @@ def __init__( alignment, ) - def set_pool_candidates(self, pool_candidates: list): - """Sets the pool candidate names""" - _ffi_api.BufferInfoSetPoolCandidates(self, pool_candidates) - - def set_pool_offsets(self, pool_name: str, pool_offset: int): - """Sets the pool offset by name""" - _ffi_api.BufferInfoSetPoolOffset(self, pool_name, pool_offset) - def set_conflicts(self, conflicts: list): """Sets the the conflicting array of buffer info objects""" _ffi_api.BufferInfoSetConflicts(self, conflicts) diff --git a/src/tir/usmp/algo/greedy.cc b/src/tir/usmp/algo/greedy.cc new file mode 100644 index 000000000000..b98d828ae745 --- /dev/null +++ b/src/tir/usmp/algo/greedy.cc @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/analysis/usmp/algo/greedy.cc + * \brief This source contains greedy algorithms for planning + * memory for USMP. There are two algorithms present here : + * 1) greedy_by_size and 2) greedy_by_conflicts. + * + * greedy_by_size : this algorithm prioritizes placing the + * largest size buffer to the given pools. The BufferInfo objects + * are sorted based on the size and placed on each pool adhering + * to size_hint constraint. + * + * greedy_by_conflicts : this algorithm prioritizes placing the + * the most liveness conflicted buffer to the given pools. The + * BufferInfo objects are sorted based on the number of conflicts + * and placed on each pool adhering to size_hint constraint. + */ + +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace tir { +namespace usmp { +namespace algo { + +/*! + * \brief This is the base class for Greedy Algorithms where the sorting + * is specialized in the extended classes based on the greedy criteria. + */ +class GreedyBase { + public: + GreedyBase() {} + /*! + * \brief This function should be implemented by the extended classes to sort the BufferInfo + * objects based on a criteria and then calling PostSortAllocation. + */ + virtual Map PlanMemory(const Array& buffer_info_arr) = 0; + + protected: + /*! + * \brief Rounds up the offset to satisfy the alignement requirement + */ + size_t round_up_to_byte_alignment(const size_t& non_aligned_byte_offset, + const int& byte_alignment) { + return ((non_aligned_byte_offset + byte_alignment - 1) / byte_alignment) * byte_alignment; + } + + /*! + * \brief A helper function check whether a offset is valid given the constraints + */ + bool IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset, + const size_t& size_bytes) { + if (candidate_pool->size_hint_bytes == -1) { + // this means pool is not bounded + return true; + } + auto pool_size = static_cast(candidate_pool->size_hint_bytes->value); + auto max_address = next_offset + size_bytes; + if (max_address <= pool_size) { + return true; + } + return false; + } + + /*! + * \brief Selects a pool for placement in the given set of ordered pool candidates + */ + PoolInfo SelectPlacementPool( + const BufferInfo& buf_info, + const std::unordered_map& pool_offsets) { + // Here the pool candidates are ordered when it is consumed by the algorithm. + // This could be from order the user has specified. However, schedulers are + // welcome to change the order for performance reasons. + for (const auto& pool_info : buf_info->pool_candidates) { + if (pool_offsets.count(pool_info)) { + return pool_info; + } + } + CHECK(false) << "TVM USMP Error: the space available in the provided pools exceeded when " + "trying to allocate the buffer : " + << buf_info << "\n. Please increase the size_hints for memory pools."; + return PoolInfo(); + } + + /*! + * \brief This is the base allocation function that works on sorted BufferInfo objects based + * on the greedy heuristic. The sorting algorithm has to be called before calling this. + */ + Map PostSortAllocation( + const std::vector& buffer_info_vec) { + Map pool_allocations; + for (const auto& buf_info : buffer_info_vec) { + std::unordered_map pool_offset_candidates; + for (const auto& pool_info : buf_info->pool_candidates) { + // Mark pool candidates that satisfy the size constraints. + if (IsValidPlacement(pool_info, 0, buf_info->size_bytes->value)) { + pool_offset_candidates[pool_info] = 0; + } + } + + for (const auto& conflict_buf_info_obj : buf_info->conflicts) { + auto conflict_buf_info = Downcast(conflict_buf_info_obj); + size_t next_offset = 0; + // We only look at already allocated BufferInfo in-terms of conflicts. + if (pool_allocations.count(conflict_buf_info)) { + auto pool_allocation = pool_allocations[conflict_buf_info]; + next_offset = pool_allocation->byte_offset + conflict_buf_info->size_bytes; + next_offset = + round_up_to_byte_alignment(next_offset, conflict_buf_info->alignment->value); + // Checks whether the next offset in the same pool as the conflicting BufferInfo is valid. + if (IsValidPlacement(pool_allocation->pool_info, next_offset, + buf_info->size_bytes->value)) { + // There could be multiple conflicting BufferInfo in the same pool. + // Thus, we need to make sure we pick the largest offset of them all. + if (next_offset > pool_offset_candidates[pool_allocation->pool_info]) { + pool_offset_candidates[pool_allocation->pool_info] = next_offset; + } + } else { + pool_offset_candidates.erase(pool_allocation->pool_info); + } + } + } + auto selected_pool = SelectPlacementPool(buf_info, pool_offset_candidates); + pool_allocations.Set( + buf_info, PoolAllocation(selected_pool, Integer(pool_offset_candidates[selected_pool]))); + } + return pool_allocations; + } +}; + +/*! + * \brief This class implements Greedy by the size of BufferInfo + * greedy algorithm. Please refer to main documentation of the file + * for more details. + */ +class GreedySize : public GreedyBase { + public: + GreedySize() {} + Map PlanMemory(const Array& buffer_info_arr) { + std::vector buffer_info_vec; + Map pool_allocations; + for (const auto& buffer_info : buffer_info_arr) { + buffer_info_vec.push_back(std::move(buffer_info)); + } + std::sort(buffer_info_vec.begin(), buffer_info_vec.end(), + [](const BufferInfo& a, const BufferInfo& b) { + if (a->size_bytes->value == b->size_bytes->value) { + if (a->conflicts.size() == b->conflicts.size()) { + return std::string(a->name_hint->data) > std::string(b->name_hint->data); + } else { + return a->conflicts.size() > b->conflicts.size(); + } + } + return a->size_bytes > b->size_bytes; + }); + return PostSortAllocation(buffer_info_vec); + } +}; + +/*! + * \brief This class implements Greedy by the number of conflicts of + * BufferInfo greedy algorithm. Please refer to main documentation + * of the file for more details. + */ +class GreedyConflicts : public GreedyBase { + public: + GreedyConflicts() {} + Map PlanMemory(const Array& buffer_info_arr) { + std::vector buffer_info_vec; + Map pool_allocations; + for (const auto& buffer_info : buffer_info_arr) { + buffer_info_vec.push_back(std::move(buffer_info)); + } + std::sort(buffer_info_vec.begin(), buffer_info_vec.end(), + [](const BufferInfo& a, const BufferInfo& b) { + if (a->conflicts.size() == b->conflicts.size()) { + if (a->size_bytes->value == b->size_bytes->value) { + return std::string(a->name_hint->data) > std::string(b->name_hint->data); + } else { + return a->size_bytes->value > b->size_bytes->value; + } + } + return a->conflicts.size() > b->conflicts.size(); + }); + return PostSortAllocation(buffer_info_vec); + } +}; + +Map GreedyBySize(const Array& buffer_info_arr) { + return GreedySize().PlanMemory(buffer_info_arr); +} + +Map GreedyByConflicts(const Array& buffer_info_arr) { + return GreedyConflicts().PlanMemory(buffer_info_arr); +} + +TVM_REGISTER_GLOBAL("tir.usmp.algo.greedy_by_size") + .set_body_typed([](Array buffer_info_arr) { + return GreedyBySize(buffer_info_arr); + }); + +TVM_REGISTER_GLOBAL("tir.usmp.algo.greedy_by_conflicts") + .set_body_typed([](Array buffer_info_arr) { + return GreedyByConflicts(buffer_info_arr); + }); + +} // namespace algo +} // namespace usmp +} // namespace tir +} // namespace tvm diff --git a/src/tir/usmp/analysis/extract_buffer_info.cc b/src/tir/usmp/analysis/extract_buffer_info.cc index c25578fd9779..3fea7211672f 100644 --- a/src/tir/usmp/analysis/extract_buffer_info.cc +++ b/src/tir/usmp/analysis/extract_buffer_info.cc @@ -26,6 +26,7 @@ * conflicts between other tir.allocate nodes. */ #include +#include #include #include #include @@ -99,11 +100,21 @@ class BufferInfoExtractor : public StmtExprVisitor { */ std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_info_end_stmt_idx_; + + /*! + * \brief This structure contains information regarding a Allocate node. + */ + struct AllocateInfo { + tir::Stmt Allocate; + PrimFunc prim_func; + Call call; + }; + /*! * \brief Maintains the mapping of buffer variable to their allocate nodes to ensure * that only one BufferInfo object is created. */ - Map allocate_var_to_stmt_map_; + std::unordered_map allocate_infos; /*! * \brief Indicates a count of stmts visited so far to use as a metric of liveness */ @@ -203,29 +214,40 @@ void BufferInfoExtractor::RecordAllocateNodeInfo(const AllocateNode* op) { auto size_bytes = CalculateExtentsSize(op); // We only statically memory plan only allocates with known // compile time sizes. - if (size_bytes.defined() && - allocate_var_to_stmt_map_.find(op->buffer_var) == allocate_var_to_stmt_map_.end()) { - // By default, the core compiler is assumed to attach the a default pool to each allocate. - ICHECK(op->annotations.count(kPoolCandidatesAllocateAttr)) - << "Every statically sized allocate node needs an pool candidate attribute"; - auto pool_candidates = Downcast>(op->annotations[kPoolCandidatesAllocateAttr]); - - // TODO(@manupa-arm): improve the error when the responsible component for attaching a single - // pool is added - ICHECK(pool_candidates.size() > 0) - << "The core compiler should at least attach a single PoolInfo. If there were no " - "user-given arguments for memory pools, the default behaviour is a single size " - "un-restricted pool is assigned"; - PrimFunc func = scope_stack_.top().func; - Optional tgt = func->GetAttr(tvm::attr::kTarget); - ICHECK(tgt) << "There should not be any PrimFuncs without a target attached by now"; - auto workspace_alignment = - tgt.value()->GetAttr("workspace-byte-alignment").value_or(16); - auto buffer_info = BufferInfo(GetUniqueBufferName(op->buffer_var->name_hint), size_bytes, - pool_candidates, workspace_alignment); - auto allocate = GetRef(op); - allocate_var_to_stmt_map_.Set(op->buffer_var, allocate); - buffer_info_map_.Set(buffer_info, allocate); + if (size_bytes.defined()) { + if (allocate_infos.find(op->buffer_var) == allocate_infos.end()) { + // By default, the core compiler is assumed to attach the a default pool to each allocate. + ICHECK(op->annotations.count(kPoolCandidatesAllocateAttr)) + << "Every statically sized allocate node needs an pool candidate attribute"; + auto pool_candidates = + Downcast>(op->annotations[kPoolCandidatesAllocateAttr]); + + // TODO(@manupa-arm): improve the error when the responsible component for attaching a single + // pool is added + ICHECK(pool_candidates.size() > 0) + << "The core compiler should at least attach a single PoolInfo. If there were no " + "user-given arguments for memory pools, the default behaviour is a single size " + "un-restricted pool is assigned"; + PrimFunc func = scope_stack_.top().func; + Optional executor_config = + module_->GetAttr(tvm::attr::kExecutor); + Integer workspace_alignment = 16; + if (executor_config) { + workspace_alignment = + executor_config.value()->GetAttr("workspace-byte-alignment").value_or(16); + } + auto buffer_info = BufferInfo(GetUniqueBufferName(op->buffer_var->name_hint), size_bytes, + pool_candidates, workspace_alignment); + auto allocate = GetRef(op); + allocate_infos[op->buffer_var] = + AllocateInfo{allocate, scope_stack_.top().func, scope_stack_.top().call}; + buffer_info_map_.Set(buffer_info, allocate); + } else { + // Update the allocate info with the latest call + AllocateInfo ai = allocate_infos[op->buffer_var]; + ai.call = scope_stack_.top().call; + allocate_infos[op->buffer_var] = ai; + } } } @@ -257,17 +279,25 @@ void BufferInfoExtractor::VisitStmt_(const ForNode* op) { si.initial_stmt_of_the_nested_loops = Integer(current_stmt_idx_); } Call current_call = scope_stack_.top().call; + PrimFunc current_primfunc = scope_stack_.top().func; scope_stack_.push(si); StmtExprVisitor::VisitStmt_(op); // Extending the liveness to beginning of for-loop next and end of the current for-loop for (const Allocate& allocate : scope_stack_.top().allocate_nodes) { + AllocateInfo ai = allocate_infos[allocate->buffer_var]; + Call update_call = current_call; + // If the allocate does not belong to current prim func + // We need to update the call to which the allocate belong to + if (ai.prim_func != current_primfunc) { + update_call = ai.call; + } if (scope_stack_.top().initial_stmt_of_the_nested_loops->value < - buffer_info_start_stmt_idx_[current_call][allocate]) { - buffer_info_start_stmt_idx_[current_call].Set( + buffer_info_start_stmt_idx_[update_call][allocate]) { + buffer_info_start_stmt_idx_[update_call].Set( allocate, scope_stack_.top().initial_stmt_of_the_nested_loops->value); } - if (current_stmt_idx_ > buffer_info_end_stmt_idx_[current_call][allocate]) { - buffer_info_end_stmt_idx_[current_call].Set(allocate, current_stmt_idx_); + if (current_stmt_idx_ > buffer_info_end_stmt_idx_[update_call][allocate]) { + buffer_info_end_stmt_idx_[update_call].Set(allocate, current_stmt_idx_); } } scope_stack_.pop(); @@ -286,12 +316,21 @@ void BufferInfoExtractor::VisitStmt_(const StoreNode* op) { void BufferInfoExtractor::VisitExpr_(const VarNode* op) { auto var = GetRef(op); Call current_call = scope_stack_.top().call; - if (allocate_var_to_stmt_map_.count(var)) { - auto allocate = allocate_var_to_stmt_map_[var]; - if (buffer_info_start_stmt_idx_[current_call].count(allocate) == 0) { - buffer_info_start_stmt_idx_[current_call].Set(allocate, current_stmt_idx_); + PrimFunc current_primfunc = scope_stack_.top().func; + if (allocate_infos.count(var)) { + auto allocate = allocate_infos[var].Allocate; + auto allocate_primfunc = allocate_infos[var].prim_func; + Call update_call = current_call; + if (allocate_primfunc != current_primfunc) { + // If the allocate node does not belong to the current primfunc. + // It's access should be reported to the call to PrimFunc that + // Allocate belong to. + update_call = allocate_infos[var].call; } - buffer_info_end_stmt_idx_[current_call].Set(allocate, current_stmt_idx_); + if (buffer_info_start_stmt_idx_[update_call].count(allocate) == 0) { + buffer_info_start_stmt_idx_[update_call].Set(allocate, current_stmt_idx_); + } + buffer_info_end_stmt_idx_[update_call].Set(allocate, current_stmt_idx_); ScopeInfo& currect_scope_info = scope_stack_.top(); if (currect_scope_info.for_loop.defined()) { @@ -320,13 +359,13 @@ void BufferInfoExtractor::UpdateAliases(const Array& args, const PrimF // to the original allocate if (arg->IsInstance()) { auto load = Downcast(arg); - if (allocate_var_to_stmt_map_.count(load->buffer_var)) { - allocate_var_to_stmt_map_.Set(param_buf, allocate_var_to_stmt_map_[load->buffer_var]); + if (allocate_infos.count(load->buffer_var)) { + allocate_infos[param_buf] = allocate_infos[load->buffer_var]; } } else if (arg->IsInstance()) { auto var = Downcast(arg); - if (allocate_var_to_stmt_map_.count(var)) { - allocate_var_to_stmt_map_.Set(param_buf, allocate_var_to_stmt_map_[var]); + if (allocate_infos.count(var)) { + allocate_infos[param_buf] = allocate_infos[var]; } } } @@ -415,18 +454,27 @@ Map BufferInfoExtractor::operator()(const PrimFunc& main_ // Traverse the liveness events using a open set to track what // is live while updating the conflicts through out the linear traversal - std::unordered_set open_set; + std::unordered_map open_set; for (const auto& le_event : le_events_timeline) { if (le_event.le_type == START) { - for (const auto& open_buffer_info : open_set) { + for (const auto& kv : open_set) { + BufferInfo open_buffer_info = kv.first; open_buffer_info->conflicts.push_back(le_event.buffer_info); if (le_event.buffer_info != open_buffer_info) { le_event.buffer_info->conflicts.push_back(open_buffer_info); } } - open_set.insert(le_event.buffer_info); + if (open_set.find(le_event.buffer_info) == open_set.end()) { + open_set[le_event.buffer_info] = 1; + } else { + open_set[le_event.buffer_info] += 1; + } } else { - open_set.erase(le_event.buffer_info); + if (open_set[le_event.buffer_info] == 1) { + open_set.erase(le_event.buffer_info); + } else { + open_set[le_event.buffer_info] -= 1; + } } } return this->buffer_info_map_; diff --git a/tests/python/unittest/test_tir_usmp_algo.py b/tests/python/unittest/test_tir_usmp_algo.py new file mode 100644 index 000000000000..61a70ad062d1 --- /dev/null +++ b/tests/python/unittest/test_tir_usmp_algo.py @@ -0,0 +1,676 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +import tvm +from tvm import tir, script +from tvm.script import tir as T +from tvm.tir import stmt_functor +from tvm.tir.usmp import utils as usmp_utils +from tvm.target import Target + + +def _replace_stmt_with_buf_var_names(buffer_info_map): + """helper to replace tir.allocates with buffer names""" + new_buffer_info_map = dict() + for k, v in buffer_info_map.items(): + new_buffer_info_map[v.buffer_var.name] = k + return new_buffer_info_map + + +def _verify_conflicts(main_buf_name, conflicting_buf_names, buffer_info_map): + """helper to check expected liveness conflicts""" + buf_info = buffer_info_map[main_buf_name] + for conflict in buf_info.conflicts: + assert conflict.name_hint in conflicting_buf_names + + +def _get_allocates(primfunc): + """helper to extract all allocate nodes by name""" + allocates = dict() + + def get_allocate(stmt): + if isinstance(stmt, tvm.tir.Allocate): + allocates[str(stmt.buffer_var.name)] = stmt + + stmt_functor.post_order_visit(primfunc.body, get_allocate) + return allocates + + +def _assign_poolinfos_to_allocates_in_primfunc(primfunc, pool_infos): + """helper to assing poolinfos to allocate nodes in a tir.PrimFunc""" + + def set_poolinfos(stmt): + if isinstance(stmt, tvm.tir.Allocate): + return tvm.tir.Allocate( + buffer_var=stmt.buffer_var, + dtype=stmt.dtype, + extents=stmt.extents, + condition=stmt.condition, + body=stmt.body, + annotations={tvm.tir.usmp.utils.CANDIDATE_MEMORY_POOL_ATTR: pool_infos}, + ) + + return primfunc.with_body(stmt_functor.ir_transform(primfunc.body, None, set_poolinfos)) + + +def _assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos): + """helper to assing poolinfos to allocate nodes in a IRModule""" + ret = tvm.IRModule() + for global_var, basefunc in mod.functions.items(): + if isinstance(basefunc, tvm.tir.PrimFunc): + ret[global_var] = _assign_poolinfos_to_allocates_in_primfunc(basefunc, pool_infos) + return ret + + +def _assign_targets_to_primfuncs_irmodule(mod, target): + """helper to assign target for PrimFunc in a IRModule""" + ret = tvm.IRModule() + for global_var, basefunc in mod.functions.items(): + if isinstance(basefunc, tvm.tir.PrimFunc): + ret[global_var] = basefunc.with_attr("target", target) + return ret + + +def _check_max_workspace_size(buffer_pool_allocations, pool_info, size): + max_workspace_size = 0 + for buffer_info, pool_allocation in buffer_pool_allocations.items(): + if pool_allocation.pool_info == pool_info: + size_candidate = pool_allocation.byte_offset + buffer_info.size_bytes + if size_candidate > max_workspace_size: + max_workspace_size = size_candidate + assert max_workspace_size == size + + +def test_no_pool_error(): + target = Target("c") + tiny_workspace_pool = usmp_utils.PoolInfo( + pool_name="tiny_workspace", + target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + size_hint_bytes=10, + ) + bi_a = usmp_utils.BufferInfo( + name_hint="bi_a", size_bytes=10, pool_candidates=[tiny_workspace_pool] + ) + bi_b = usmp_utils.BufferInfo( + name_hint="bi_b", size_bytes=10, pool_candidates=[tiny_workspace_pool] + ) + bi_c = usmp_utils.BufferInfo( + name_hint="bi_c", size_bytes=10, pool_candidates=[tiny_workspace_pool] + ) + bi_a.set_conflicts([bi_b]) + bi_b.set_conflicts([bi_c]) + bi_c.set_conflicts([bi_a]) + buffer_info_arr = [bi_a, bi_b, bi_c] + fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.greedy_by_size") + with pytest.raises( + tvm.TVMError, match="TVM USMP Error: the space available in the provided pools exceeded" + ): + buffer_pool_allocations = fusmp_algo(buffer_info_arr) + + +@pytest.mark.parametrize("algorithm", ["greedy_by_size", "greedy_by_conflicts"]) +def test_name_based_ordering(algorithm): + """ This checks when the size and conlicts are same a stable result is generated""" + + def _test(): + target = Target("c") + global_workspace_pool = usmp_utils.PoolInfo( + pool_name="global_workspace", + target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + ) + bi_a = usmp_utils.BufferInfo( + name_hint="bi_a", size_bytes=10, pool_candidates=[global_workspace_pool] + ) + bi_b = usmp_utils.BufferInfo( + name_hint="bi_b", size_bytes=10, pool_candidates=[global_workspace_pool] + ) + bi_c = usmp_utils.BufferInfo( + name_hint="bi_c", size_bytes=10, pool_candidates=[global_workspace_pool] + ) + bi_a.set_conflicts([bi_b]) + bi_b.set_conflicts([bi_c]) + bi_c.set_conflicts([bi_a]) + + buffer_info_arr = [bi_a, bi_b, bi_c] + fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}") + buffer_pool_allocations = fusmp_algo(buffer_info_arr) + assert buffer_pool_allocations[bi_a].byte_offset == 20 + assert buffer_pool_allocations[bi_b].byte_offset == 10 + assert buffer_pool_allocations[bi_c].byte_offset == 0 + + # This is tested for several times to check stability + for x in range(0, 10): + _test() + + +@pytest.mark.parametrize( + ["algorithm", "workspace_size"], + [("greedy_by_size", 140), ("greedy_by_conflicts", 140)], +) +def test_linear(algorithm, workspace_size): + """ + The test case here represent BufferInfo objects + that could get generated for a linear sequence + such as : + (Op A) + | + bi_a + | + (Op B) + | + bi_b + | + . + . + . + (Op F) + | + bi_f + """ + target = Target("c") + global_workspace_pool = usmp_utils.PoolInfo( + pool_name="global_workspace", + target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + ) + bi_a = usmp_utils.BufferInfo( + name_hint="bi_a", size_bytes=10, pool_candidates=[global_workspace_pool] + ) + bi_b = usmp_utils.BufferInfo( + name_hint="bi_b", size_bytes=20, pool_candidates=[global_workspace_pool] + ) + bi_c = usmp_utils.BufferInfo( + name_hint="bi_c", size_bytes=100, pool_candidates=[global_workspace_pool] + ) + bi_d = usmp_utils.BufferInfo( + name_hint="bi_d", size_bytes=40, pool_candidates=[global_workspace_pool] + ) + bi_e = usmp_utils.BufferInfo( + name_hint="bi_e", size_bytes=50, pool_candidates=[global_workspace_pool] + ) + bi_f = usmp_utils.BufferInfo( + name_hint="bi_f", size_bytes=50, pool_candidates=[global_workspace_pool] + ) + + # Creating conflicts for a linear graph + bi_a.set_conflicts([bi_b]) + bi_b.set_conflicts([bi_a, bi_c]) + bi_c.set_conflicts([bi_b, bi_d]) + bi_d.set_conflicts([bi_c, bi_e]) + bi_e.set_conflicts([bi_d, bi_f]) + bi_f.set_conflicts([bi_e]) + + buffer_info_arr = [bi_a, bi_b, bi_c, bi_d, bi_e, bi_f] + fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}") + buffer_pool_allocations = fusmp_algo(buffer_info_arr) + _check_max_workspace_size(buffer_pool_allocations, global_workspace_pool, workspace_size) + + +@pytest.mark.parametrize( + ["algorithm", "workspace_size"], + [("greedy_by_size", 190), ("greedy_by_conflicts", 320)], +) +def test_fanout(algorithm, workspace_size): + """ + The test case here represent BufferInfo objects + that could get generated for a fanout topology + such as : + (Op A) + | + bi_a --------- + | | + (Op B) (Op C) + | | + bi_b bi_c + | | + (Op D) (Op E) + | | + bi_d bi_e + | | + (Op F) ------ + | + bi_f + | + (Op G) + | + bi_g + """ + target = Target("c") + global_workspace_pool = usmp_utils.PoolInfo( + pool_name="global_workspace", + target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + ) + bi_a = usmp_utils.BufferInfo( + name_hint="bi_a", size_bytes=10, pool_candidates=[global_workspace_pool] + ) + bi_b = usmp_utils.BufferInfo( + name_hint="bi_b", size_bytes=20, pool_candidates=[global_workspace_pool] + ) + bi_c = usmp_utils.BufferInfo( + name_hint="bi_c", size_bytes=100, pool_candidates=[global_workspace_pool] + ) + bi_d = usmp_utils.BufferInfo( + name_hint="bi_d", size_bytes=40, pool_candidates=[global_workspace_pool] + ) + bi_e = usmp_utils.BufferInfo( + name_hint="bi_e", size_bytes=50, pool_candidates=[global_workspace_pool] + ) + bi_f = usmp_utils.BufferInfo( + name_hint="bi_f", size_bytes=60, pool_candidates=[global_workspace_pool] + ) + bi_g = usmp_utils.BufferInfo( + name_hint="bi_g", size_bytes=70, pool_candidates=[global_workspace_pool] + ) + + # Creating conflicts for a linear graph + bi_a.set_conflicts([bi_b, bi_c]) + bi_b.set_conflicts([bi_a, bi_c, bi_e]) + bi_c.set_conflicts([bi_e, bi_a, bi_b, bi_d]) + bi_d.set_conflicts([bi_b, bi_f, bi_c, bi_e]) + bi_e.set_conflicts([bi_c, bi_f, bi_b, bi_d]) + bi_f.set_conflicts([bi_d, bi_e, bi_f]) + bi_g.set_conflicts([bi_f]) + + buffer_info_arr = [bi_a, bi_b, bi_c, bi_d, bi_e, bi_f, bi_g] + fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}") + buffer_pool_allocations = fusmp_algo(buffer_info_arr) + _check_max_workspace_size(buffer_pool_allocations, global_workspace_pool, workspace_size) + + +# fmt: off +@tvm.script.ir_module +class MobilenetStructure: + @T.prim_func + def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) + placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = T.match_buffer(placeholder_3, [], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T_subtract_1 = T.match_buffer(T_subtract, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + for ax0_ax1_fused_1 in T.serial(0, 224): + for ax2_1, ax3_inner_1 in T.grid(224, 3): + T.store(T_subtract_1.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1), (T.cast(T.load("uint8", placeholder_4.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)), "int16") - T.load("int16", placeholder_5.data, 0)), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True}) + placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_21 = T.match_buffer(T_cast_20, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_7 = T.allocate([157323], "int16", "global") + for i0_i1_fused_7 in T.serial(0, 229): + for i2_7, i3_7 in T.grid(229, 3): + T.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), T.load("int16", placeholder_65.data, ((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)), T.int16(0), dtype="int16"), True) + for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): + Conv2dOutput_7 = T.allocate([64], "int32", "global") + for ff_3 in T.serial(0, 64): + T.store(Conv2dOutput_7, ff_3, 0, True) + for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): + T.store(Conv2dOutput_7, ff_3, (T.load("int32", Conv2dOutput_7, ff_3) + (T.cast(T.load("int16", PaddedInput_7, (((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)), "int32")*T.cast(T.load("int16", placeholder_66.data, ((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)), "int32"))), True) + for ax3_inner_7 in T.serial(0, 64): + T.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_7, ax3_inner_7) + T.load("int32", placeholder_67.data, ax3_inner_7)), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) + + @T.prim_func + def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) + placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + tensor_2 = T.allocate([200704], "uint8", "global") + for ax0_ax1_fused_4 in T.serial(0, 56): + for ax2_4 in T.serial(0, 56): + for ax3_init in T.serial(0, 64): + T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True) + for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): + T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(T.load("uint8", tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)), T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), T.load("uint8", placeholder_29.data, (((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)), T.uint8(0), dtype="uint8")), True) + for ax0_ax1_fused_5 in T.serial(0, 56): + for ax2_5, ax3_3 in T.grid(56, 64): + T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True) + + @T.prim_func + def run_model(input: T.handle, output: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True}) + # body + T.attr("default", "device_id", 0) + T.attr("default", "device_type", 1) + sid_9 = T.allocate([301056], "int8", "global") + sid_8 = T.allocate([802816], "int8", "global") + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8, output, dtype="int32")) + __tvm_meta__ = None +# fmt: on + + +@pytest.mark.parametrize( + ["algorithm", "fast_memory_size", "slow_memory_size"], + [("greedy_by_size", 200704, 1418528), ("greedy_by_conflicts", 200704, 1418528)], +) +def test_mobilenet_subgraph(algorithm, fast_memory_size, slow_memory_size): + target = Target("c") + fast_memory_pool = usmp_utils.PoolInfo( + pool_name="fast_memory", + target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + size_hint_bytes=200704, + ) + slow_memory_pool = usmp_utils.PoolInfo( + pool_name="slow_memory", target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS} + ) + tir_mod = MobilenetStructure + tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target) + tir_mod = _assign_poolinfos_to_allocates_in_irmodule( + tir_mod, [fast_memory_pool, slow_memory_pool] + ) + main_func = tir_mod["run_model"] + buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod) + + fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo") + buffer_info_arr = fcreate_array_bi(buffer_info_map) + fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}") + buffer_pool_allocations = fusmp_algo(buffer_info_arr) + + buffer_info_map_names = dict() + for buf_info in buffer_info_arr: + buffer_info_map_names[buf_info.name_hint] = buf_info + + # check conflicts + _verify_conflicts("PaddedInput_7", ["sid_9", "sid_8", "Conv2dOutput_7"], buffer_info_map_names) + _verify_conflicts("tensor_2", ["sid_8"], buffer_info_map_names) + _verify_conflicts("sid_9", ["PaddedInput_7"], buffer_info_map_names) + _verify_conflicts( + "sid_8", ["PaddedInput_7", "Conv2dOutput_7", "tensor_2"], buffer_info_map_names + ) + _verify_conflicts("Conv2dOutput_7", ["sid_8", "PaddedInput_7"], buffer_info_map_names) + + _check_max_workspace_size(buffer_pool_allocations, slow_memory_pool, slow_memory_size) + _check_max_workspace_size(buffer_pool_allocations, fast_memory_pool, fast_memory_size) + + +# fmt: off +@tvm.script.ir_module +class ResnetStructure: + @T.prim_func + def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", "tir.noalias": True}) + placeholder_2 = T.match_buffer(placeholder, [1, 75, 75, 64], dtype="uint8") + placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") + T_cast_1 = T.match_buffer(T_cast, [1, 75, 75, 64], dtype="int16") + # body + for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): + T.store(T_cast_1.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.load("uint8", placeholder_2.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner), "int32") - 94, 1843157232, 31, 1, dtype="int32") + T.load("int32", placeholder_3.data, ax3_outer * 16 + ax3_inner), 255), 0), "uint8"), "int16"), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", "tir.noalias": True}) + placeholder_13 = T.match_buffer(placeholder_10, [1, 75, 75, 64], dtype="int16") + placeholder_14 = T.match_buffer(placeholder_11, [3, 3, 64, 64], dtype="int16") + placeholder_15 = T.match_buffer(placeholder_12, [1, 1, 1, 64], dtype="int32") + T_cast_5 = T.match_buffer(T_cast_4, [1, 75, 75, 64], dtype="int16") + # body + PaddedInput_1 = T.allocate([379456], "int16", "global") + for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): + T.store(PaddedInput_1, i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1, T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, T.load("int16", placeholder_13.data, i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864), T.int16(0), dtype="int16"), True) + for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625): + Conv2dOutput_1 = T.allocate([64], "int32", "global") + for ff_1 in T.serial(0, 64): + T.store(Conv2dOutput_1, ff_1, 0, True) + for ry, rx, rc_1 in T.grid(3, 3, 64): + T.store(Conv2dOutput_1, ff_1, T.load("int32", Conv2dOutput_1, ff_1) + T.cast(T.load("int16", PaddedInput_1, T.floordiv(ax0_ax1_fused_ax2_fused_1, 75) * 4928 + ry * 4928 + rx * 64 + T.floormod(ax0_ax1_fused_ax2_fused_1, 75) * 64 + rc_1), "int32") * T.cast(T.load("int16", placeholder_14.data, ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1), "int32"), True) + for ax3_inner_2 in T.serial(0, 64): + T.store(T_cast_5.data, ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_1, ax3_inner_2) + T.load("int32", placeholder_15.data, ax3_inner_2), 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", "tir.noalias": True}) + placeholder_19 = T.match_buffer(placeholder_16, [1, 75, 75, 64], dtype="int16") + placeholder_20 = T.match_buffer(placeholder_17, [1, 1, 64, 256], dtype="int16") + placeholder_21 = T.match_buffer(placeholder_18, [1, 1, 1, 256], dtype="int32") + T_add_1 = T.match_buffer(T_add, [1, 75, 75, 256], dtype="int32") + # body + PaddedInput_2 = T.allocate([360000], "int16", "global") + for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64): + T.store(PaddedInput_2, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2, T.load("int16", placeholder_19.data, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2), True) + for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 5625): + Conv2dOutput_2 = T.allocate([64], "int32", "global") + for ax3_outer_1 in T.serial(0, 4): + for ff_2 in T.serial(0, 64): + T.store(Conv2dOutput_2, ff_2, 0, True) + for rc_2 in T.serial(0, 64): + T.store(Conv2dOutput_2, ff_2, T.load("int32", Conv2dOutput_2, ff_2) + T.cast(T.load("int16", PaddedInput_2, ax0_ax1_fused_ax2_fused_2 * 64 + rc_2), "int32") * T.cast(T.load("int16", placeholder_20.data, rc_2 * 256 + ax3_outer_1 * 64 + ff_2), "int32"), True) + for ax3_inner_3 in T.serial(0, 64): + T.store(T_add_1.data, ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3, T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_2, ax3_inner_3) + T.load("int32", placeholder_21.data, ax3_outer_1 * 64 + ax3_inner_3), 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136, True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", "tir.noalias": True}) + placeholder_29 = T.match_buffer(placeholder_22, [1, 75, 75, 64], dtype="int16") + placeholder_27 = T.match_buffer(placeholder_23, [1, 1, 64, 256], dtype="int16") + placeholder_26 = T.match_buffer(placeholder_24, [1, 1, 1, 256], dtype="int32") + placeholder_28 = T.match_buffer(placeholder_25, [1, 75, 75, 256], dtype="int32") + T_cast_7 = T.match_buffer(T_cast_6, [1, 75, 75, 256], dtype="uint8") + # body + PaddedInput_3 = T.allocate([360000], "int16", "global") + for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64): + T.store(PaddedInput_3, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3, T.load("int16", placeholder_29.data, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3), True) + for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 5625): + Conv2dOutput_3 = T.allocate([64], "int32", "global") + for ax3_outer_2 in T.serial(0, 4): + for ff_3 in T.serial(0, 64): + T.store(Conv2dOutput_3, ff_3, 0, True) + for rc_3 in T.serial(0, 64): + T.store(Conv2dOutput_3, ff_3, T.load("int32", Conv2dOutput_3, ff_3) + T.cast(T.load("int16", PaddedInput_3, ax0_ax1_fused_ax2_fused_3 * 64 + rc_3), "int32") * T.cast(T.load("int16", placeholder_27.data, rc_3 * 256 + ax3_outer_2 * 64 + ff_3), "int32"), True) + for ax3_inner_4 in T.serial(0, 64): + T.store(T_cast_7.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4, T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_3, ax3_inner_4) + T.load("int32", placeholder_26.data, ax3_outer_2 * 64 + ax3_inner_4), 1343014664, 31, -8, dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, dtype="int32") + T.load("int32", placeholder_28.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4), 255), 0), "uint8"), True) + + @T.prim_func + def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True}) + # body + T.attr("default", "device_id", 0) + T.attr("default", "device_type", 1) + sid_2 = T.allocate([720000], "int8", "global") + sid_6 = T.allocate([5760000], "int8", "global") + sid_7 = T.allocate([720000], "int8", "global") + sid_8 = T.allocate([720000], "int8", "global") + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", sid_7, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_6, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", sid_2, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_6, output, dtype="int32")) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", "tir.noalias": True}) + placeholder_7 = T.match_buffer(placeholder_4, [1, 75, 75, 64], dtype="int16") + placeholder_8 = T.match_buffer(placeholder_5, [1, 1, 64, 64], dtype="int16") + placeholder_9 = T.match_buffer(placeholder_6, [1, 1, 1, 64], dtype="int32") + T_cast_3 = T.match_buffer(T_cast_2, [1, 75, 75, 64], dtype="int16") + # body + PaddedInput = T.allocate([360000], "int16", "global") + for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): + T.store(PaddedInput, i0_i1_fused * 4800 + i2 * 64 + i3, T.load("int16", placeholder_7.data, i0_i1_fused * 4800 + i2 * 64 + i3), True) + for ax0_ax1_fused_ax2_fused in T.serial(0, 5625): + Conv2dOutput = T.allocate([64], "int32", "global") + for ff in T.serial(0, 64): + T.store(Conv2dOutput, ff, 0, True) + for rc in T.serial(0, 64): + T.store(Conv2dOutput, ff, T.load("int32", Conv2dOutput, ff) + T.cast(T.load("int16", PaddedInput, ax0_ax1_fused_ax2_fused * 64 + rc), "int32") * T.cast(T.load("int16", placeholder_8.data, rc * 64 + ff), "int32"), True) + for ax3_inner_1 in T.serial(0, 64): + T.store(T_cast_3.data, ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput, ax3_inner_1) + T.load("int32", placeholder_9.data, ax3_inner_1), 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True) + __tvm_meta__ = None +# fmt: on + + +@pytest.mark.parametrize( + ["algorithm", "workspace_size"], [("greedy_by_size", 7920256), ("greedy_by_conflicts", 7200256)] +) +def test_resnet_subgraph(algorithm, workspace_size): + target = Target("c") + global_workspace_pool = usmp_utils.PoolInfo( + pool_name="global_workspace", + target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + ) + tir_mod = ResnetStructure + tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target) + tir_mod = _assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_workspace_pool]) + main_func = tir_mod["tvmgen_default_run_model"] + buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod) + + fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo") + buffer_info_arr = fcreate_array_bi(buffer_info_map) + fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}") + buffer_pool_allocations = fusmp_algo(buffer_info_arr) + + buffer_info_map_names = dict() + for buf_info in buffer_info_arr: + buffer_info_map_names[buf_info.name_hint] = buf_info + + # check conflicts + _verify_conflicts( + "sid_7", + [ + "PaddedInput_1", + "sid_2", + "Conv2dOutput_1", + "PaddedInput_2", + ], + buffer_info_map_names, + ) + _verify_conflicts( + "Conv2dOutput_3", + [ + "PaddedInput_3", + "sid_6", + ], + buffer_info_map_names, + ) + _verify_conflicts( + "sid_6", + [ + "Conv2dOutput_2", + "PaddedInput_2", + "sid_2", + "PaddedInput_3", + "Conv2dOutput_3", + ], + buffer_info_map_names, + ) + _verify_conflicts( + "Conv2dOutput", + [ + "sid_8", + "sid_2", + "PaddedInput", + ], + buffer_info_map_names, + ) + _verify_conflicts( + "PaddedInput_3", + [ + "sid_6", + "sid_2", + "Conv2dOutput_3", + ], + buffer_info_map_names, + ) + _verify_conflicts( + "Conv2dOutput_2", + [ + "PaddedInput_2", + "sid_2", + "sid_6", + ], + buffer_info_map_names, + ) + _verify_conflicts( + "PaddedInput_1", + [ + "sid_8", + "sid_2", + "sid_7", + "Conv2dOutput_1", + ], + buffer_info_map_names, + ) + _verify_conflicts( + "Conv2dOutput_1", + [ + "sid_7", + "PaddedInput_1", + "sid_2", + ], + buffer_info_map_names, + ) + _verify_conflicts( + "PaddedInput", + [ + "sid_2", + "sid_8", + "Conv2dOutput", + ], + buffer_info_map_names, + ) + _verify_conflicts( + "sid_8", + [ + "PaddedInput", + "sid_2", + "Conv2dOutput", + "PaddedInput_1", + ], + buffer_info_map_names, + ) + _verify_conflicts( + "sid_2", + [ + "PaddedInput", + "sid_8", + "Conv2dOutput", + "PaddedInput_1", + "sid_7", + "Conv2dOutput_1", + "PaddedInput_2", + "Conv2dOutput_2", + "sid_6", + "PaddedInput_3", + ], + buffer_info_map_names, + ) + _verify_conflicts( + "PaddedInput_2", + [ + "sid_7", + "sid_2", + "Conv2dOutput_2", + "sid_6", + ], + buffer_info_map_names, + ) + + _check_max_workspace_size(buffer_pool_allocations, global_workspace_pool, workspace_size) diff --git a/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py b/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py index fa645f1379ff..abaa0cd2431e 100644 --- a/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py +++ b/tests/python/unittest/test_tir_usmp_analysis_extract_bufferinfo.py @@ -661,149 +661,261 @@ def test_inception_structure(): # check conflicts _verify_conflicts( - "PaddedInput_8", + "sid_3", + [ + "sid_4", + "PaddedInput_2", + "sid_2", + "Conv2dOutput_2", + "PaddedInput_1", + "Conv2dOutput_1", + "sid_20", + "PaddedInput_6", + "Conv2dOutput_6", + "sid_19", + "PaddedInput_4", + ], + buffer_info_map, + ) + _verify_conflicts( + "Conv2dOutput", [ "sid_6", - "Conv2dOutput_8", - "sid_5", + "PaddedInput", ], buffer_info_map, ) _verify_conflicts( - "sid_26", + "Conv2dOutput_7", + [ + "PaddedInput_7", + "sid_8", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_4", [ + "sid_5", + "sid_3", + "PaddedInput_2", + "sid_2", + "Conv2dOutput_2", + "PaddedInput_1", + "Conv2dOutput_1", + "sid_20", + "PaddedInput_6", + "Conv2dOutput_6", + "sid_19", "PaddedInput_4", "Conv2dOutput_4", + "sid_26", "PaddedInput_5", + "Conv2dOutput_5", + "sid_25", + "tensor_3", ], buffer_info_map, ) _verify_conflicts( - "Conv2dOutput", + "sid_2", [ - "sid_6", - "PaddedInput", + "PaddedInput_2", + "sid_3", + "sid_4", + "Conv2dOutput_2", + "PaddedInput_1", + "Conv2dOutput_1", + "sid_20", + "PaddedInput_6", + "Conv2dOutput_6", + "sid_19", + "PaddedInput_4", + "Conv2dOutput_4", + "sid_26", + "PaddedInput_5", + "Conv2dOutput_5", + "sid_25", + "tensor_3", + "sid_32", + "PaddedInput_3", + "Conv2dOutput_3", + "sid_31", ], buffer_info_map, ) _verify_conflicts( - "sid_4", + "sid_19", [ - "sid_5", + "Conv2dOutput_6", + "sid_2", + "PaddedInput_6", "sid_3", + "sid_4", + "PaddedInput_4", + "Conv2dOutput_4", + "sid_26", + "PaddedInput_5", + "Conv2dOutput_5", + "sid_25", "tensor_3", + "sid_32", + "PaddedInput_3", + "Conv2dOutput_3", + "sid_31", ], buffer_info_map, ) _verify_conflicts( - "tensor_2", + "PaddedInput_2", [ - "sid_8", - "sid_7", + "sid_3", + "sid_4", + "sid_2", + "Conv2dOutput_2", ], buffer_info_map, ) _verify_conflicts( - "Conv2dOutput_7", + "Conv2dOutput_6", [ - "sid_8", - "PaddedInput_7", + "sid_2", + "PaddedInput_6", + "sid_3", + "sid_4", + "sid_19", ], buffer_info_map, ) _verify_conflicts( - "Conv2dOutput_1", + "sid_9", [ - "sid_20", - "PaddedInput_1", + "PaddedInput_7", ], buffer_info_map, ) _verify_conflicts( - "Conv2dOutput_4", + "sid_7", [ - "sid_26", - "PaddedInput_4", + "tensor_2", + "PaddedInput", ], buffer_info_map, ) _verify_conflicts( - "Conv2dOutput_2", + "PaddedInput_4", [ - "PaddedInput_2", "sid_2", + "sid_19", + "sid_3", + "sid_4", + "Conv2dOutput_4", + "sid_26", ], buffer_info_map, ) _verify_conflicts( "PaddedInput_3", [ + "sid_2", "sid_32", - "sid_31", + "sid_25", + "sid_19", "Conv2dOutput_3", + "sid_31", ], buffer_info_map, ) _verify_conflicts( - "sid_3", + "sid_5", [ + "PaddedInput_8", + "Conv2dOutput_8", "sid_4", - "PaddedInput_2", - "PaddedInput_1", - "PaddedInput_4", ], buffer_info_map, ) _verify_conflicts( - "Conv2dOutput_6", + "sid_31", [ - "PaddedInput_6", + "Conv2dOutput_3", + "PaddedInput_3", + "sid_2", + "sid_25", "sid_19", ], buffer_info_map, ) _verify_conflicts( - "Conv2dOutput_5", + "PaddedInput", [ - "PaddedInput_5", + "sid_7", + "sid_6", + "Conv2dOutput", + ], + buffer_info_map, + ) + _verify_conflicts( + "Conv2dOutput_2", + [ + "sid_2", + "PaddedInput_2", + "sid_3", + "sid_4", + ], + buffer_info_map, + ) + _verify_conflicts( + "sid_32", + [ + "tensor_3", + "sid_2", "sid_25", + "sid_19", + "PaddedInput_3", ], buffer_info_map, ) _verify_conflicts( - "PaddedInput_7", + "tensor_2", [ - "sid_9", "sid_8", - "Conv2dOutput_7", + "sid_7", ], buffer_info_map, ) _verify_conflicts( - "sid_7", + "sid_26", [ - "tensor_2", - "PaddedInput", + "Conv2dOutput_4", + "PaddedInput_4", + "sid_2", + "sid_19", + "sid_4", + "PaddedInput_5", ], buffer_info_map, ) _verify_conflicts( - "sid_31", + "Conv2dOutput_3", [ "PaddedInput_3", - "Conv2dOutput_3", - "sid_25", "sid_2", + "sid_25", "sid_19", + "sid_31", ], buffer_info_map, ) _verify_conflicts( - "sid_5", + "PaddedInput_6", [ - "Conv2dOutput_8", - "PaddedInput_8", + "sid_2", + "sid_3", + "sid_20", "sid_4", + "Conv2dOutput_6", + "sid_19", ], buffer_info_map, ) @@ -817,146 +929,132 @@ def test_inception_structure(): buffer_info_map, ) _verify_conflicts( - "sid_20", + "PaddedInput_8", [ - "PaddedInput_1", - "Conv2dOutput_1", - "PaddedInput_6", + "sid_6", + "sid_5", + "Conv2dOutput_8", ], buffer_info_map, ) _verify_conflicts( - "Conv2dOutput_8", + "Conv2dOutput_5", [ - "PaddedInput_8", - "sid_5", + "PaddedInput_5", + "sid_2", + "sid_19", + "sid_4", + "sid_25", ], buffer_info_map, ) _verify_conflicts( - "PaddedInput_1", + "Conv2dOutput_1", [ + "PaddedInput_1", + "sid_2", "sid_3", + "sid_4", "sid_20", - "Conv2dOutput_1", ], buffer_info_map, ) _verify_conflicts( - "Conv2dOutput_3", + "tensor_3", [ - "sid_31", - "PaddedInput_3", + "sid_2", + "sid_25", + "sid_19", + "sid_4", + "sid_32", ], buffer_info_map, ) _verify_conflicts( - "PaddedInput", + "sid_8", [ - "sid_7", - "sid_6", - "Conv2dOutput", + "Conv2dOutput_7", + "PaddedInput_7", + "tensor_2", ], buffer_info_map, ) _verify_conflicts( - "PaddedInput_2", + "sid_20", [ - "sid_3", - "Conv2dOutput_2", + "Conv2dOutput_1", + "PaddedInput_1", "sid_2", - ], - buffer_info_map, - ) - _verify_conflicts( - "sid_19", - [ - "Conv2dOutput_6", + "sid_3", + "sid_4", "PaddedInput_6", - "sid_31", - "sid_2", - "sid_25", ], buffer_info_map, ) _verify_conflicts( - "PaddedInput_4", + "Conv2dOutput_8", [ - "sid_3", - "sid_26", - "Conv2dOutput_4", + "sid_5", + "PaddedInput_8", ], buffer_info_map, ) _verify_conflicts( - "PaddedInput_5", + "PaddedInput_1", [ - "sid_26", - "Conv2dOutput_5", - "sid_25", + "sid_2", + "sid_3", + "sid_4", + "Conv2dOutput_1", + "sid_20", ], buffer_info_map, ) _verify_conflicts( - "PaddedInput_6", + "Conv2dOutput_4", [ - "sid_20", - "Conv2dOutput_6", + "PaddedInput_4", + "sid_2", "sid_19", + "sid_4", + "sid_26", ], buffer_info_map, ) _verify_conflicts( "sid_25", [ - "Conv2dOutput_5", "PaddedInput_5", - "sid_31", + "Conv2dOutput_5", "sid_2", "sid_19", - ], - buffer_info_map, - ) - _verify_conflicts( - "tensor_3", - [ "sid_4", - "sid_32", - ], - buffer_info_map, - ) - _verify_conflicts( - "sid_32", - [ "tensor_3", + "sid_32", "PaddedInput_3", + "Conv2dOutput_3", + "sid_31", ], buffer_info_map, ) _verify_conflicts( - "sid_9", - [ - "PaddedInput_7", - ], - buffer_info_map, - ) - _verify_conflicts( - "sid_2", + "PaddedInput_7", [ - "Conv2dOutput_2", - "PaddedInput_2", - "sid_31", - "sid_25", - "sid_19", + "sid_9", + "Conv2dOutput_7", + "sid_8", ], buffer_info_map, ) _verify_conflicts( - "sid_8", + "PaddedInput_5", [ - "PaddedInput_7", - "Conv2dOutput_7", - "tensor_2", + "sid_2", + "sid_19", + "sid_26", + "sid_4", + "Conv2dOutput_5", + "sid_25", ], buffer_info_map, ) @@ -1276,56 +1374,70 @@ def test_multiple_calls_to_same_primfunc(): # check conflicts _verify_conflicts( - "sid_18", + "sid_6", [ - "sid_19", - "sid_2", - "T_softmax_exp2", - "T_softmax_maxelem2", - "T_softmax_expsum2", - "T_softmax_norm2", + "sid_7", + "sid_12", + "compute", + "compute_global", + "sid_11", + "sid_10", + "T_softmax_exp", + "T_softmax_maxelem", + "sid_5", + "T_softmax_norm", + "T_softmax_expsum", ], buffer_info_map, ) _verify_conflicts( - "sid_3", + "T_softmax_exp", [ - "data_pad", - "conv2d_NCHWc_global", - "sid_2", + "sid_10", + "sid_6", + "T_softmax_maxelem", + "sid_5", + "T_softmax_norm", + "T_softmax_expsum", ], buffer_info_map, ) _verify_conflicts( - "T_softmax_norm", + "T_softmax_expsum2", [ - "T_softmax_expsum", - "T_softmax_exp", - "sid_5", - "sid_6", - "T_softmax_maxelem", - "sid_10", + "T_softmax_exp2", + "T_softmax_norm2", + "sid_18", + "T_softmax_maxelem2", + "sid_2", ], buffer_info_map, ) _verify_conflicts( - "T_softmax_norm2", + "compute", [ - "T_softmax_expsum2", - "T_softmax_maxelem2", - "T_softmax_exp2", - "sid_18", + "sid_12", + "sid_6", + "compute_global", + "sid_11", + "sid_19", + "sid_20", "sid_2", + "compute_global", ], buffer_info_map, ) _verify_conflicts( - "sid_11", + "compute_global", [ "compute", "sid_12", - "compute_global", - "sid_10", + "sid_6", + "sid_11", + "compute", + "sid_19", + "sid_20", + "sid_2", ], buffer_info_map, ) @@ -1334,75 +1446,93 @@ def test_multiple_calls_to_same_primfunc(): [ "sid_11", "sid_6", + "T_softmax_exp", + "T_softmax_maxelem", "sid_5", "T_softmax_norm", "T_softmax_expsum", - "T_softmax_maxelem", - "T_softmax_exp", ], buffer_info_map, ) _verify_conflicts( - "sid_5", + "sid_2", [ - "T_softmax_norm", - "T_softmax_expsum", - "T_softmax_exp", - "sid_6", - "T_softmax_maxelem", - "sid_10", - "sid_4", + "sid_3", + "sid_5", "sid_20", + "sid_19", + "compute", + "compute_global", + "sid_18", + "T_softmax_norm2", + "T_softmax_exp2", + "T_softmax_maxelem2", + "T_softmax_expsum2", ], buffer_info_map, ) _verify_conflicts( - "T_softmax_expsum", + "sid_5", [ - "T_softmax_exp", - "T_softmax_norm", - "sid_5", - "sid_6", "T_softmax_maxelem", "sid_10", + "T_softmax_exp", + "sid_6", + "T_softmax_norm", + "T_softmax_expsum", + "sid_4", + "data_pad", + "sid_3", + "conv2d_NCHWc_global", + "sid_2", + "sid_20", ], buffer_info_map, ) _verify_conflicts( - "sid_8", + "T_softmax_norm2", [ - "data_pad", + "sid_18", + "sid_2", + "T_softmax_exp2", + "T_softmax_maxelem2", + "T_softmax_expsum2", ], buffer_info_map, ) _verify_conflicts( - "T_softmax_expsum2", + "sid_20", [ - "T_softmax_maxelem2", - "T_softmax_exp2", - "sid_18", "sid_2", - "T_softmax_norm2", + "sid_5", + "sid_19", + "compute", + "compute_global", ], buffer_info_map, ) _verify_conflicts( - "T_softmax_maxelem2", + "T_softmax_expsum", [ - "T_softmax_exp2", - "sid_18", - "sid_2", - "T_softmax_expsum2", - "T_softmax_norm2", + "sid_5", + "T_softmax_norm", + "T_softmax_maxelem", + "sid_10", + "T_softmax_exp", + "sid_6", ], buffer_info_map, ) _verify_conflicts( - "sid_12", + "data_pad", [ - "sid_11", - "compute", - "compute_global", + "sid_8", + "conv2d_NCHWc_global", + "sid_7", + "sid_4", + "sid_5", + "sid_3", + "conv2d_NCHWc_global", ], buffer_info_map, ) @@ -1410,6 +1540,7 @@ def test_multiple_calls_to_same_primfunc(): "sid_19", [ "sid_20", + "sid_2", "compute", "compute_global", "sid_18", @@ -1423,17 +1554,19 @@ def test_multiple_calls_to_same_primfunc(): "sid_7", "sid_3", "data_pad", + "sid_5", ], buffer_info_map, ) _verify_conflicts( - "T_softmax_exp2", + "sid_18", [ - "sid_18", + "sid_19", "sid_2", + "T_softmax_norm2", + "T_softmax_exp2", "T_softmax_maxelem2", "T_softmax_expsum2", - "T_softmax_norm2", ], buffer_info_map, ) @@ -1447,105 +1580,84 @@ def test_multiple_calls_to_same_primfunc(): buffer_info_map, ) _verify_conflicts( - "data_pad", + "T_softmax_exp2", [ - "sid_8", - "conv2d_NCHWc_global", - "sid_7", - "sid_4", - "sid_3", - "conv2d_NCHWc_global", + "T_softmax_norm2", + "sid_18", + "sid_2", + "T_softmax_maxelem2", + "T_softmax_expsum2", ], buffer_info_map, ) _verify_conflicts( - "sid_20", + "sid_4", [ "sid_5", - "sid_19", - "compute", - "compute_global", + "data_pad", ], buffer_info_map, ) _verify_conflicts( - "sid_4", + "T_softmax_maxelem", [ + "sid_10", + "T_softmax_exp", + "sid_6", "sid_5", - "data_pad", + "T_softmax_norm", + "T_softmax_expsum", ], buffer_info_map, ) _verify_conflicts( - "T_softmax_exp", + "T_softmax_maxelem2", [ - "T_softmax_expsum", - "T_softmax_norm", - "sid_5", - "sid_6", - "T_softmax_maxelem", - "sid_10", + "T_softmax_exp2", + "T_softmax_norm2", + "sid_18", + "sid_2", + "T_softmax_expsum2", ], buffer_info_map, ) _verify_conflicts( - "compute_global", + "sid_11", [ - "sid_12", - "sid_11", - "compute", "compute", - "sid_20", - "sid_19", + "sid_12", + "compute_global", + "sid_6", + "sid_10", ], buffer_info_map, ) _verify_conflicts( - "compute", + "sid_12", [ - "sid_11", - "sid_12", - "compute_global", - "sid_20", - "sid_19", + "sid_6", + "compute", "compute_global", + "sid_11", ], buffer_info_map, ) _verify_conflicts( - "sid_6", + "T_softmax_norm", [ - "sid_7", "sid_5", - "T_softmax_norm", - "T_softmax_expsum", - "T_softmax_exp", "T_softmax_maxelem", "sid_10", - ], - buffer_info_map, - ) - _verify_conflicts( - "T_softmax_maxelem", - [ + "T_softmax_exp", "sid_6", - "sid_5", - "T_softmax_norm", "T_softmax_expsum", - "T_softmax_exp", - "sid_10", ], buffer_info_map, ) _verify_conflicts( - "sid_2", + "sid_8", [ - "sid_3", - "sid_18", - "T_softmax_exp2", - "T_softmax_maxelem2", - "T_softmax_expsum2", - "T_softmax_norm2", + "data_pad", ], buffer_info_map, )