From 3ce4fe47ca976695299069630624757075471702 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Thu, 9 Dec 2021 14:45:16 +0000 Subject: [PATCH] [TIR][USMP] adding the pass to convert to pool offsets (#9418) * [TIR][USMP] adding the pass to convert to pool offsets This commit adds a transform pass that consumes the planned pool allocations using memory planning algorithm that convertes them to pool offsets. * adds two test cases for a linear structure with two pools * adds test case with a single pool for residual structures Change-Id: I9d31e854461b5c21df72d1452120d286b96791c0 * [TIR][USMP] adding the pass to convert to pool offsets * Adding a toggle to produce TIR that is TVMScript printable for unit testing * Fixing the unit tests * Ensure deterministic pool variable ordering. Change-Id: I317675df03327b0ebbf4ca074255384e63f07cd6 * [TIR][USMP] adding the pass to convert to pool offsets Fixing the references after changes in the memory planning algorithm. Change-Id: Id7c22356fd5de43d10a2b4fc70e978af2c6d599d * [TIR][USMP] adding the pass to convert to pool offsets * fixing the lint Change-Id: I7ff920b92d14a9919c930a4b35a2169c77a57dd1 * [TIR][USMP] adding the pass to convert to pool offsets * removing unnecessary defitinitions * remove global var map * adding explaination for let bindings to pointer type Change-Id: I31bd1a9f3057ee7f06252263565b0f75c51e6d13 * [TIR][USMP] adding the pass to convert to pool offsets * rebase changes * making imports absolute * fixing typos and removing unnecesary lines Change-Id: I4c94b9955b001513fecb39ca94f81b1ad99c7bfc * [TIR][USMP] adding the pass to convert to pool offsets * fixing typos Change-Id: I42c557fd394aefdf8c2e825c4e88770eb0732f9b --- include/tvm/tir/usmp/utils.h | 48 ++ python/tvm/script/tir/__init__.py | 2 +- python/tvm/script/tir/ty.py | 1 + python/tvm/tir/usmp/__init__.py | 1 + python/tvm/tir/usmp/transform/__init__.py | 20 + python/tvm/tir/usmp/transform/_ffi_api.py | 21 + python/tvm/tir/usmp/transform/transform.py | 46 ++ src/printer/text_printer.h | 7 +- src/tir/ir/stmt.cc | 9 +- .../convert_pool_allocations_to_offsets.cc | 349 ++++++++++++ src/tir/usmp/utils.cc | 39 ++ ...orm_convert_pool_allocations_to_offsets.py | 523 ++++++++++++++++++ 12 files changed, 1061 insertions(+), 5 deletions(-) create mode 100644 python/tvm/tir/usmp/transform/__init__.py create mode 100644 python/tvm/tir/usmp/transform/_ffi_api.py create mode 100644 python/tvm/tir/usmp/transform/transform.py create mode 100644 src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc create mode 100644 tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py diff --git a/include/tvm/tir/usmp/utils.h b/include/tvm/tir/usmp/utils.h index 145c61dd518b..30c8f2ddea49 100644 --- a/include/tvm/tir/usmp/utils.h +++ b/include/tvm/tir/usmp/utils.h @@ -225,6 +225,44 @@ class PoolAllocation : public ObjectRef { TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PoolAllocation, ObjectRef, PoolAllocationNode); }; +/*! + * \brief This object contains information post-allocation for PoolInfo objects + */ +struct AllocatedPoolInfoNode : public Object { + /*! \brief The assigned PoolInfo object */ + PoolInfo pool_info; + /*! \brief The allocated size into this pool */ + Integer allocated_size; + /*! \brief An optional associated pool Var*/ + Optional pool_var; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("pool_info", &pool_info); + v->Visit("allocated_size", &allocated_size); + v->Visit("pool_var", &pool_var); + } + + bool SEqualReduce(const AllocatedPoolInfoNode* other, SEqualReducer equal) const { + return equal(pool_info, other->pool_info) && equal(allocated_size, other->allocated_size) && + equal(pool_var, other->pool_var); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(pool_info); + hash_reduce(allocated_size); + hash_reduce(pool_var); + } + + static constexpr const char* _type_key = "tir.usmp.AllocatedPoolInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(AllocatedPoolInfoNode, Object); +}; + +class AllocatedPoolInfo : public ObjectRef { + public: + TVM_DLL AllocatedPoolInfo(PoolInfo pool_info, Integer allocated_size, Var pool_var = Var()); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(AllocatedPoolInfo, ObjectRef, AllocatedPoolInfoNode); +}; + /*! * \brief Convert the IR-bound BufferInfo map to an array of BufferInfo * @@ -248,6 +286,16 @@ Integer CalculateExtentsSize(const AllocateNode* op); } // namespace usmp } // namespace tir + +namespace attr { +/*! + * \brief This is a BaseFunc attribute to indicate which input var represent + * a PoolInfo Object in the form of a Map. + */ +static constexpr const char* kPoolArgs = "pool_args"; + +} // namespace attr + } // namespace tvm #endif // TVM_TIR_USMP_UTILS_H_ diff --git a/python/tvm/script/tir/__init__.py b/python/tvm/script/tir/__init__.py index 472b3de0e43b..de4045913102 100644 --- a/python/tvm/script/tir/__init__.py +++ b/python/tvm/script/tir/__init__.py @@ -17,7 +17,7 @@ """TVMScript for TIR""" # Type system -from .ty import int8, int16, int32, int64, float16, float32, float64 +from .ty import uint8, int8, int16, int32, int64, float16, float32, float64 from .ty import boolean, handle, Ptr, Tuple, Buffer from .prim_func import prim_func diff --git a/python/tvm/script/tir/ty.py b/python/tvm/script/tir/ty.py index 2808e7a48735..0432692f5f4f 100644 --- a/python/tvm/script/tir/ty.py +++ b/python/tvm/script/tir/ty.py @@ -137,6 +137,7 @@ def __getitem__(self, args): pass # pylint: disable=unnecessary-pass +uint8 = ConcreteType("uint8") int8 = ConcreteType("int8") int16 = ConcreteType("int16") int32 = ConcreteType("int32") diff --git a/python/tvm/tir/usmp/__init__.py b/python/tvm/tir/usmp/__init__.py index 8aa0d4ccfe88..514727d52e2e 100644 --- a/python/tvm/tir/usmp/__init__.py +++ b/python/tvm/tir/usmp/__init__.py @@ -18,4 +18,5 @@ """Namespace for Unified Static Memory Planner""" from . import analysis +from . import transform from .utils import BufferInfo diff --git a/python/tvm/tir/usmp/transform/__init__.py b/python/tvm/tir/usmp/transform/__init__.py new file mode 100644 index 000000000000..1a9d83328f8d --- /dev/null +++ b/python/tvm/tir/usmp/transform/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import, redefined-builtin +"""Namespace for Unified Static Memory Planner""" + +from .transform import convert_pool_allocations_to_offsets diff --git a/python/tvm/tir/usmp/transform/_ffi_api.py b/python/tvm/tir/usmp/transform/_ffi_api.py new file mode 100644 index 000000000000..7973ca5b0da0 --- /dev/null +++ b/python/tvm/tir/usmp/transform/_ffi_api.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.tir.usmp.analysis""" +import tvm._ffi + + +tvm._ffi._init_api("tir.usmp.transform", __name__) diff --git a/python/tvm/tir/usmp/transform/transform.py b/python/tvm/tir/usmp/transform/transform.py new file mode 100644 index 000000000000..f472172cf36f --- /dev/null +++ b/python/tvm/tir/usmp/transform/transform.py @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""USMP Transform Python API for passes""" +# pylint: disable=invalid-name + +from typing import Dict + +import tvm +from tvm.tir import Stmt +from tvm.tir.usmp.utils import PoolAllocation +from . import _ffi_api + + +def convert_pool_allocations_to_offsets( + pool_allocations: Dict[Stmt, PoolAllocation], emit_tvmscript_printable: bool = False +) -> tvm.transform.Pass: + """Convert pool allocations to Load nodes with offsets from pools. + + Parameters + ---------- + pool_allocations : Dict[Stmt, PoolAllocation] + Allocate or AllocateConst node to pool allocation mapping + emit_tvmscript_printable : bool + A toggle to emit TVMScript printable IRModule for unit tests + removing all attributes that should be attached for integration + + Returns + ------- + ret: tvm.transform.Pass + The registered pass that converts the allocations to offsets. + """ + return _ffi_api.ConvertPoolAllocationsToOffsets(pool_allocations, emit_tvmscript_printable) diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index ebd667ae2ac7..97146b84450d 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -449,10 +449,11 @@ class TextPrinter { Doc PrintFinal(const ObjectRef& node) { Doc doc; - if (node->IsInstance()) { + if (node.defined() && node->IsInstance()) { doc << PrintMod(Downcast(node)); - } else if (node->IsInstance() || node->IsInstance() || - node->IsInstance()) { + } else if (node.defined() && + (node->IsInstance() || node->IsInstance() || + node->IsInstance())) { doc << tir_text_printer_.Print(node); } else { doc << relay_text_printer_.PrintFinal(node); diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 0d42c20c2822..078561c447ad 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -35,7 +35,14 @@ namespace tir { LetStmt::LetStmt(Var var, PrimExpr value, Stmt body, Span span) { ICHECK(value.defined()); ICHECK(body.defined()); - ICHECK_EQ(value.dtype(), var.dtype()); + auto vdtype = value.dtype(); + // It is still valid to bind a pointer type + // var to a value that is of type handle. + if (var->type_annotation.as()) { + ICHECK(vdtype.is_handle()); + } else { + ICHECK_EQ(value.dtype(), var.dtype()); + } ObjectPtr node = make_object(); node->var = std::move(var); diff --git a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc new file mode 100644 index 000000000000..5ebf3c557b06 --- /dev/null +++ b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc @@ -0,0 +1,349 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/analysis/usmp/transform/convert_pool_allocations_to_offsets.cc + * \brief This pass would convert the pool allocations to offsets from pools + */ + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace tvm { +namespace tir { +namespace usmp { + +/*! + * \brief The StmtExpr mutator class to replace allocate nodes + * with offsets within memory pools + * + * This mutator class will add Pool variables recursively to every PrimFunc + * starting from the main PrimFunc. For all allocate nodes, that have been + * memory planned, will be mutated into an offset using a Let binding. + */ +class PoolAllocationToOffsetConverter : public StmtExprMutator { + public: + PoolAllocationToOffsetConverter(const IRModule& module, + const Map& pool_allocations, + bool emit_tvmscript_printable = false) + : pool_allocations_(pool_allocations), emit_tvmscript_printable_(emit_tvmscript_printable) { + module_ = module->ShallowCopy(); + for (const auto& kv : pool_allocations) { + // TODO(@manupa-arm): add AllocateConstNode when it is available + ICHECK(kv.first->IsInstance()); + Allocate allocate_node = Downcast(kv.first); + PoolAllocation pool_allocation = kv.second; + PoolInfo pool_info = pool_allocation->pool_info; + int byte_pool_offset = pool_allocation->byte_offset->value; + int required_pool_size_for_allocation = + byte_pool_offset + CalculateExtentsSize(allocate_node.operator->()); + if (all_pools_sizes_.find(pool_info) == all_pools_sizes_.end()) { + all_pools_sizes_[pool_info] = required_pool_size_for_allocation; + } else { + int prev_required_pool_size = all_pools_sizes_[pool_info]; + if (prev_required_pool_size < required_pool_size_for_allocation) { + all_pools_sizes_[pool_info] = required_pool_size_for_allocation; + } + } + } + + for (const auto& kv : all_pools_sizes_) { + PoolInfo pi = kv.first; + int allocated_size = kv.second; + allocated_pool_ordering_.push_back(AllocatedPoolInfo(pi, allocated_size)); + } + std::sort(allocated_pool_ordering_.begin(), allocated_pool_ordering_.end(), + [](const AllocatedPoolInfo& lhs, const AllocatedPoolInfo& rhs) { + if (lhs->pool_info->pool_name < rhs->pool_info->pool_name) { + return true; + } + return false; + }); + } + IRModule operator()(); + + private: + PrimExpr VisitExpr_(const CallNode* op) override; + Stmt VisitStmt_(const AllocateNode* op) override; + PrimExpr VisitExpr_(const LoadNode* op) override; + Stmt VisitStmt_(const StoreNode* op) override; + + /*! \brief This is a structure where the modified function + * signature is kept while body of the function is mutated + */ + struct ScopeInfo { + Array params; + Map pools_to_params; + Array allocated_pool_params; + Map buffer_map; + }; + + /*! \brief The function scope information that are needed + * in the mutation of the function need to be stacked and + * popped when each function is entered/exited in the + * mutation process. + */ + std::stack scope_stack; + /*! \brief Each PrimFunc signature needs to be updated + * with pool variables. This is a helper function to + * capture the updated information to ScopeInfo object. + */ + ScopeInfo UpdateFunctionScopeInfo(const PrimFunc& original_func); + /*! \brief This is a helper to create the PrimFunc with + * pool variables that calls the UpdateFunctionScopeInfo + * inside of it. + */ + PrimFunc CreatePrimFuncWithPoolParams(const PrimFunc& original_primfunc); + /*! \brief This is a helper to append the pool args to + * the callsite of the function. + */ + Array AppendPoolParamsToArgs(const Array& args); + /*! \brief Some arguments that used to be Allocate nodes + * should be replaced by Let nodes in the pass that loads + * the space from a pool variable. + */ + Array ReplaceAllocateArgsWithLetArgs(const Array& args); + + /*! \brief The tir::Var map to PoolInfo objects */ + Map primfunc_args_to_pool_info_map_; + /*! \brief The buffer var map to their allocate nodes */ + Map allocate_var_to_stmt_map_; + /*! \brief The IRModule being constructed/mutated */ + IRModule module_; + /*! \brief The input allocate node to PoolAllocation map */ + Map pool_allocations_; + /*! \brief The set of ordered pools to ensure an unique order of args for functions */ + std::vector allocated_pool_ordering_; + /*! \brief The storage of calculated pool size at init */ + std::unordered_map all_pools_sizes_; + /*! \brief After mutation, each allocate buffer is replaced with tir::Var that is let bounded + * to position from a pool as designated by a PoolAllocation + */ + Map allocate_buf_to_let_var_; + /*! \brief A counter to give references to pools a reproducible unique set of names */ + int pool_var_count_ = 0; + /*! \brief This toggles to remove non tvmscript printable items for IRModule for unit tests */ + bool emit_tvmscript_printable_ = false; + /*! \brief A counter to give references to pools a reproducible unique set of names */ + std::unordered_set visited_primfuncs; +}; + +PoolAllocationToOffsetConverter::ScopeInfo PoolAllocationToOffsetConverter::UpdateFunctionScopeInfo( + const PrimFunc& original_func) { + ScopeInfo si; + si.params = original_func->params; + si.buffer_map = original_func->buffer_map; + Map ret; + for (const AllocatedPoolInfo& allocated_pool_info : allocated_pool_ordering_) { + PoolInfo pool_info = allocated_pool_info->pool_info; + String pool_ref_name = pool_info->pool_name + "_" + std::to_string(pool_var_count_++); + String var_name = pool_ref_name + "_var"; + DataType elem_dtype = DataType::UInt(8); + Var buffer_var(var_name, PointerType(PrimType(elem_dtype), "global")); + Var pool_var; + if (!emit_tvmscript_printable_) { + pool_var = Var(var_name, PointerType(PrimType(elem_dtype), "global")); + } else { + pool_var = Var(var_name, DataType::Handle(8)); + } + si.params.push_back(pool_var); + si.pools_to_params.Set(pool_info, pool_var); + si.allocated_pool_params.push_back(AllocatedPoolInfo( + allocated_pool_info->pool_info, allocated_pool_info->allocated_size, pool_var)); + + int pool_size = all_pools_sizes_[pool_info]; + String buffer_var_name = pool_ref_name + "_buffer_var"; + si.buffer_map.Set(pool_var, Buffer(buffer_var, elem_dtype, {pool_size}, {1}, 1, buffer_var_name, + 16, 1, BufferType::kDefault)); + } + return si; +} + +PrimFunc PoolAllocationToOffsetConverter::CreatePrimFuncWithPoolParams( + const PrimFunc& original_primfunc) { + // Only create the new function if it was not modified with pool params + if (visited_primfuncs.find(original_primfunc) == visited_primfuncs.end()) { + ScopeInfo si = UpdateFunctionScopeInfo(original_primfunc); + this->scope_stack.push(si); + Stmt new_body = this->VisitStmt(original_primfunc->body); + this->scope_stack.pop(); + DictAttrs original_attrs = original_primfunc->attrs; + // We dont need attrs of PrimFunc that might include non printable attrs such as target + // for unit tests where emit_tvmscript_printable_ is to be used. + if (emit_tvmscript_printable_) { + original_attrs = DictAttrs(); + } + PrimFunc ret = + PrimFunc(si.params, new_body, original_primfunc->ret_type, si.buffer_map, original_attrs); + if (!emit_tvmscript_printable_) { + return WithAttr(ret, tvm::attr::kPoolArgs, si.allocated_pool_params); + } + visited_primfuncs.insert(ret); + return ret; + } + return original_primfunc; +} + +Array PoolAllocationToOffsetConverter::AppendPoolParamsToArgs( + const Array& args) { + Array new_args; + for (const auto& arg : args) { + new_args.push_back(VisitExpr(arg)); + } + ScopeInfo top_scope = this->scope_stack.top(); + for (const auto& pools_vars : top_scope.pools_to_params) { + tir::Var pool_var = pools_vars.second; + Buffer buffer_var = top_scope.buffer_map[pool_var]; + new_args.push_back(buffer_var->data); + } + return new_args; +} + +Array PoolAllocationToOffsetConverter::ReplaceAllocateArgsWithLetArgs( + const Array& args) { + Array ret; + for (const PrimExpr& arg : args) { + if (arg->IsInstance() && + allocate_buf_to_let_var_.find(Downcast(arg)) != allocate_buf_to_let_var_.end()) { + ret.push_back(allocate_buf_to_let_var_[Downcast(arg)]); + } else { + ret.push_back(arg); + } + } + return ret; +} + +PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const CallNode* op) { + if (op->op.same_as(builtin::call_extern()) || op->op.same_as(builtin::tvm_call_cpacked())) { + String func_name = Downcast(op->args[0])->value; + Array new_args; + if (module_->ContainGlobalVar(func_name)) { + GlobalVar gv = module_->GetGlobalVar(func_name); + PrimFunc func = Downcast(module_->Lookup(gv)); + PrimFunc prim_func = CreatePrimFuncWithPoolParams(func); + module_->Update(gv, prim_func); + new_args = AppendPoolParamsToArgs(op->args); + new_args = ReplaceAllocateArgsWithLetArgs(new_args); + } else { + new_args = ReplaceAllocateArgsWithLetArgs(op->args); + } + return Call(op->dtype, op->op, new_args); + } + if (op->op->IsInstance()) { + PrimFunc func = Downcast(op->op); + PrimFunc prim_func = CreatePrimFuncWithPoolParams(func); + Array new_args = AppendPoolParamsToArgs(op->args); + new_args = AppendPoolParamsToArgs(new_args); + new_args = ReplaceAllocateArgsWithLetArgs(new_args); + return Call(op->dtype, prim_func, new_args); + } + return StmtExprMutator::VisitExpr_(op); +} + +Stmt PoolAllocationToOffsetConverter::VisitStmt_(const AllocateNode* op) { + if (pool_allocations_.count(GetRef(op))) { + ScopeInfo scope_info = scope_stack.top(); + PoolAllocation pool_allocation = pool_allocations_[GetRef(op)]; + Var param = scope_info.pools_to_params[pool_allocation->pool_info]; + Buffer buffer_var = scope_info.buffer_map[param]; + Load load_node = + Load(DataType::UInt(8), buffer_var->data, pool_allocation->byte_offset, op->condition); + Call address_of_load = Call(DataType::Handle(8), builtin::address_of(), {load_node}); + Var tir_var; + if (!emit_tvmscript_printable_) { + tir_var = Var(op->buffer_var->name_hint + "_let", op->buffer_var->type_annotation); + } else { + tir_var = Var(op->buffer_var->name_hint + "_let", DataType::Handle(8)); + } + allocate_buf_to_let_var_.Set(op->buffer_var, tir_var); + Stmt new_body = VisitStmt(op->body); + allocate_buf_to_let_var_.erase(op->buffer_var); + return LetStmt(tir_var, address_of_load, new_body); + } + return StmtExprMutator::VisitStmt_(op); +} + +Stmt PoolAllocationToOffsetConverter::VisitStmt_(const StoreNode* op) { + if (allocate_buf_to_let_var_.find(op->buffer_var) != allocate_buf_to_let_var_.end()) { + return Store(allocate_buf_to_let_var_[op->buffer_var], VisitExpr(op->value), op->index, + VisitExpr(op->predicate)); + } + return StmtExprMutator::VisitStmt_(op); +} + +PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const LoadNode* op) { + if (allocate_buf_to_let_var_.find(op->buffer_var) != allocate_buf_to_let_var_.end()) { + return Load(op->dtype, allocate_buf_to_let_var_[op->buffer_var], op->index, + VisitExpr(op->predicate)); + } + return StmtExprMutator::VisitExpr_(op); +} + +IRModule PoolAllocationToOffsetConverter::operator()() { + GlobalVar gv = module_->GetGlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix); + PrimFunc main_func = Downcast(module_->Lookup(gv)); + ScopeInfo si = UpdateFunctionScopeInfo(main_func); + this->scope_stack.push(si); + Stmt main_func_body = this->VisitStmt(main_func->body); + this->scope_stack.pop(); + // We dont need attrs of PrimFunc that might include non printable attrs such as target + // for unit tests where emit_tvmscript_printable_ is to be used. + if (!emit_tvmscript_printable_) { + main_func = + PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, main_func->attrs); + main_func = WithAttr(main_func, tvm::attr::kPoolArgs, si.allocated_pool_params); + } else { + main_func = + PrimFunc(si.params, main_func_body, main_func->ret_type, si.buffer_map, DictAttrs()); + } + module_->Update(gv, main_func); + if (!emit_tvmscript_printable_) { + return WithAttr(this->module_, tvm::attr::kPoolArgs, si.allocated_pool_params); + } + return this->module_; +} + +namespace transform { + +tvm::transform::Pass ConvertPoolAllocationsToOffsets( + const Map& pool_allocations, + Bool emit_tvmscript_printable = Bool(false)) { + auto pass_func = [=](IRModule m, tvm::transform::PassContext ctx) { + return Downcast(PoolAllocationToOffsetConverter( + m, pool_allocations, emit_tvmscript_printable->value != 0)()); + }; + return tvm::transform::CreateModulePass(pass_func, 0, "tir.usmp.ConvertPoolAllocationsToOffsets", + {}); +} + +TVM_REGISTER_GLOBAL("tir.usmp.transform.ConvertPoolAllocationsToOffsets") + .set_body_typed(ConvertPoolAllocationsToOffsets); + +} // namespace transform + +} // namespace usmp +} // namespace tir +} // namespace tvm diff --git a/src/tir/usmp/utils.cc b/src/tir/usmp/utils.cc index 7a6a683770b0..14b3d26641a3 100644 --- a/src/tir/usmp/utils.cc +++ b/src/tir/usmp/utils.cc @@ -135,6 +135,30 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ")"; }); +AllocatedPoolInfo::AllocatedPoolInfo(PoolInfo pool_info, Integer allocated_size, Var pool_var) { + auto allocated_poolinfo_node = make_object(); + allocated_poolinfo_node->pool_info = pool_info; + allocated_poolinfo_node->allocated_size = allocated_size; + if (pool_var.defined()) { + allocated_poolinfo_node->pool_var = pool_var; + } + data_ = std::move(allocated_poolinfo_node); +} + +TVM_REGISTER_NODE_TYPE(AllocatedPoolInfoNode); +TVM_REGISTER_GLOBAL("tir.usmp.AllocatedPoolInfo") + .set_body_typed([](PoolInfo pool_info, Integer allocated_size) { + return AllocatedPoolInfo(pool_info, allocated_size); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "AllocatedPoolInfoNode(\n" + << "pool_info=" << node->pool_info << ",\n allocated_size=" << node->allocated_size + << ")"; + }); + Array CreateArrayBufferInfo(const Map& buffer_info_map) { Array ret; for (const auto& kv : buffer_info_map) { @@ -144,6 +168,19 @@ Array CreateArrayBufferInfo(const Map& buffer_info return ret; } +Map AssignStmtPoolAllocations( + const Map& buffer_info_to_stmt, + const Map& buffer_info_to_pool_allocation) { + Map ret; + for (const auto& kv : buffer_info_to_pool_allocation) { + BufferInfo bi = kv.first; + Stmt stmt_ = buffer_info_to_stmt[bi]; + PoolAllocation pa = kv.second; + ret.Set(stmt_, pa); + } + return ret; +} + Integer CalculateExtentsSize(const AllocateNode* op) { size_t element_size_bytes = op->dtype.bytes(); size_t num_elements = 1; @@ -163,6 +200,8 @@ TVM_REGISTER_GLOBAL("tir.usmp.CreateArrayBufferInfo") return (CreateArrayBufferInfo(buffer_info_map)); }); +TVM_REGISTER_GLOBAL("tir.usmp.AssignStmtPoolAllocations").set_body_typed(AssignStmtPoolAllocations); + } // namespace usmp } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py new file mode 100644 index 000000000000..fc615775c160 --- /dev/null +++ b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py @@ -0,0 +1,523 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import sys + +import tvm +from tvm.script import tir as T +from tvm.tir import stmt_functor +from tvm.tir.usmp import utils as usmp_utils +from tvm.target import Target + + +def _get_primfuncs_from_module(module): + primfuncs = list() + for gv, primfunc in module.functions.items(): + primfuncs.append(primfunc) + return primfuncs + + +def assign_poolinfos_to_allocates_in_primfunc(primfunc, pool_infos): + """Helper to assign poolinfos to allocate nodes in a tir.PrimFunc""" + + def set_poolinfos(stmt): + if isinstance(stmt, tvm.tir.Allocate): + return tvm.tir.Allocate( + buffer_var=stmt.buffer_var, + dtype=stmt.dtype, + extents=stmt.extents, + condition=stmt.condition, + body=stmt.body, + annotations={tvm.tir.usmp.utils.CANDIDATE_MEMORY_POOL_ATTR: pool_infos}, + ) + + return primfunc.with_body(stmt_functor.ir_transform(primfunc.body, None, set_poolinfos)) + + +def assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos): + """Helper to assign poolinfos to allocate nodes in a IRModule""" + ret = tvm.IRModule() + for global_var, basefunc in mod.functions.items(): + if isinstance(basefunc, tvm.tir.PrimFunc): + ret[global_var] = assign_poolinfos_to_allocates_in_primfunc(basefunc, pool_infos) + return ret + + +def _assign_targets_to_primfuncs_irmodule(mod, target): + """Helper to assign target for PrimFunc in a IRModule""" + ret = tvm.IRModule() + for global_var, basefunc in mod.functions.items(): + if isinstance(basefunc, tvm.tir.PrimFunc): + ret[global_var] = basefunc.with_attr("target", target) + return ret + + +# fmt: off +@tvm.script.ir_module +class LinearStructure: + @T.prim_func + def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract", "tir.noalias": True}) + placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = T.match_buffer(placeholder_3, [], dtype="int16", elem_offset=0, align=128, offset_factor=1) + T_subtract_1 = T.match_buffer(T_subtract, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + for ax0_ax1_fused_1 in T.serial(0, 224): + for ax2_1, ax3_inner_1 in T.grid(224, 3): + T.store(T_subtract_1.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1), (T.cast(T.load("uint8", placeholder_4.data, (((ax0_ax1_fused_1*672) + (ax2_1*3)) + ax3_inner_1)), "int16") - T.load("int16", placeholder_5.data, 0)), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", "tir.noalias": True}) + placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32", elem_offset=0, align=128, offset_factor=1) + T_cast_21 = T.match_buffer(T_cast_20, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + PaddedInput_7 = T.allocate([157323], "int16", "global") + for i0_i1_fused_7 in T.serial(0, 229): + for i2_7, i3_7 in T.grid(229, 3): + T.store(PaddedInput_7, (((i0_i1_fused_7*687) + (i2_7*3)) + i3_7), T.if_then_else(((((2 <= i0_i1_fused_7) and (i0_i1_fused_7 < 226)) and (2 <= i2_7)) and (i2_7 < 226)), T.load("int16", placeholder_65.data, ((((i0_i1_fused_7*672) + (i2_7*3)) + i3_7) - 1350)), T.int16(0), dtype="int16"), True) + for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): + Conv2dOutput_7 = T.allocate([64], "int32", "global") + for ff_3 in T.serial(0, 64): + T.store(Conv2dOutput_7, ff_3, 0, True) + for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): + T.store(Conv2dOutput_7, ff_3, (T.load("int32", Conv2dOutput_7, ff_3) + (T.cast(T.load("int16", PaddedInput_7, (((((T.floordiv(ax0_ax1_fused_ax2_fused_7, 112)*1374) + (ry_2*687)) + (T.floormod(ax0_ax1_fused_ax2_fused_7, 112)*6)) + (rx_2*3)) + rc_7)), "int32")*T.cast(T.load("int16", placeholder_66.data, ((((ry_2*1344) + (rx_2*192)) + (rc_7*64)) + ff_3)), "int32"))), True) + for ax3_inner_7 in T.serial(0, 64): + T.store(T_cast_21.data, ((ax0_ax1_fused_ax2_fused_7*64) + ax3_inner_7), T.cast(T.max(T.min(T.q_multiply_shift((T.load("int32", Conv2dOutput_7, ax3_inner_7) + T.load("int32", placeholder_67.data, ax3_inner_7)), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) + + @T.prim_func + def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_max_pool2d_cast", "tir.noalias": True}) + placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16", elem_offset=0, align=128, offset_factor=1) + # body + tensor_2 = T.allocate([200704], "uint8", "global") + for ax0_ax1_fused_4 in T.serial(0, 56): + for ax2_4 in T.serial(0, 56): + for ax3_init in T.serial(0, 64): + T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_init), T.uint8(0), True) + for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): + T.store(tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2), T.max(T.load("uint8", tensor_2, (((ax0_ax1_fused_4*3584) + (ax2_4*64)) + ax3_2)), T.if_then_else(((((ax0_ax1_fused_4*2) + T.floordiv(rv0_rv1_fused_1, 3)) < 112) and (((ax2_4*2) + T.floormod(rv0_rv1_fused_1, 3)) < 112)), T.load("uint8", placeholder_29.data, (((((ax0_ax1_fused_4*14336) + (T.floordiv(rv0_rv1_fused_1, 3)*7168)) + (ax2_4*128)) + (T.floormod(rv0_rv1_fused_1, 3)*64)) + ax3_2)), T.uint8(0), dtype="uint8")), True) + for ax0_ax1_fused_5 in T.serial(0, 56): + for ax2_5, ax3_3 in T.grid(56, 64): + T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True) + + @T.prim_func + def run_model(input: T.handle, output: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "run_model", "runner_function": True}) + # body + T.attr("default", "device_id", 0) + T.attr("default", "device_type", 1) + sid_9 = T.allocate([301056], "int8", "global") + sid_8 = T.allocate([802816], "int8", "global") + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8, output, dtype="int32")) +# fmt: on + + +# fmt: off +@tvm.script.ir_module +class LinearStructurePlanned: + @T.prim_func + def run_model(input: T.handle, output: T.handle, fast_memory_0_var: T.handle, slow_memory_1_var: T.handle) -> None: + fast_memory_0_buffer_var = T.match_buffer(fast_memory_0_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) + slow_memory_1_buffer_var = T.match_buffer(slow_memory_1_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) + # body + T.attr("default", "device_id", 0) + T.attr("default", "device_type", 1) + sid_9_let: T.handle = T.address_of(T.load("uint8", slow_memory_1_buffer_var.data, 1117472), dtype="handle") + sid_8_let: T.handle = T.address_of(T.load("uint8", slow_memory_1_buffer_var.data, 0), dtype="handle") + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract", input, T.lookup_param("p0", dtype="handle"), sid_9_let, fast_memory_0_buffer_var.data, slow_memory_1_buffer_var.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast", sid_9_let, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_8_let, fast_memory_0_buffer_var.data, slow_memory_1_buffer_var.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_max_pool2d_cast", sid_8_let, output, fast_memory_0_buffer_var.data, slow_memory_1_buffer_var.data, dtype="int32")) + + @T.prim_func + def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6: T.handle, fast_memory_6_var: T.handle, slow_memory_7_var: T.handle) -> None: + placeholder_29 = T.match_buffer(placeholder_28, [1, 112, 112, 64], dtype="uint8") + T_cast_7 = T.match_buffer(T_cast_6, [1, 56, 56, 64], dtype="int16") + fast_memory_6_buffer_var = T.match_buffer(fast_memory_6_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) + slow_memory_7_buffer_var = T.match_buffer(slow_memory_7_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) + # body + tensor_2_let: T.handle = T.address_of(T.load("uint8", fast_memory_6_buffer_var.data, 0), dtype="handle") + for ax0_ax1_fused_4, ax2_4 in T.grid(56, 56): + for ax3_init in T.serial(0, 64): + T.store(tensor_2_let, ax0_ax1_fused_4 * 3584 + ax2_4 * 64 + ax3_init, T.uint8(0), True) + for rv0_rv1_fused_1, ax3_2 in T.grid(9, 64): + T.store(tensor_2_let, ax0_ax1_fused_4 * 3584 + ax2_4 * 64 + ax3_2, T.max(T.load("uint8", tensor_2_let, ax0_ax1_fused_4 * 3584 + ax2_4 * 64 + ax3_2), T.if_then_else(ax0_ax1_fused_4 * 2 + rv0_rv1_fused_1 // 3 < 112 and ax2_4 * 2 + rv0_rv1_fused_1 % 3 < 112, T.load("uint8", placeholder_29.data, ax0_ax1_fused_4 * 14336 + rv0_rv1_fused_1 // 3 * 7168 + ax2_4 * 128 + rv0_rv1_fused_1 % 3 * 64 + ax3_2), T.uint8(0), dtype="uint8")), True) + for ax0_ax1_fused_5, ax2_5, ax3_3 in T.grid(56, 56, 64): + T.store(T_cast_7.data, ax0_ax1_fused_5 * 3584 + ax2_5 * 64 + ax3_3, T.cast(T.load("uint8", tensor_2_let, ax0_ax1_fused_5 * 3584 + ax2_5 * 64 + ax3_3), "int16"), True) + + @T.prim_func + def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle, fast_memory_2_var: T.handle, slow_memory_3_var: T.handle) -> None: + placeholder_4 = T.match_buffer(placeholder_2, [1, 224, 224, 3], dtype="uint8") + placeholder_5 = T.match_buffer(placeholder_3, [], dtype="int16") + T_subtract_1 = T.match_buffer(T_subtract, [1, 224, 224, 3], dtype="int16") + fast_memory_2_buffer_var = T.match_buffer(fast_memory_2_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) + slow_memory_3_buffer_var = T.match_buffer(slow_memory_3_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) + # body + for ax0_ax1_fused_1, ax2_1, ax3_inner_1 in T.grid(224, 224, 3): + T.store(T_subtract_1.data, ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1, T.cast(T.load("uint8", placeholder_4.data, ax0_ax1_fused_1 * 672 + ax2_1 * 3 + ax3_inner_1), "int16") - T.load("int16", placeholder_5.data, 0), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast(placeholder_62: T.handle, placeholder_63: T.handle, placeholder_64: T.handle, T_cast_20: T.handle, fast_memory_4_var: T.handle, slow_memory_5_var: T.handle) -> None: + placeholder_65 = T.match_buffer(placeholder_62, [1, 224, 224, 3], dtype="int16") + placeholder_66 = T.match_buffer(placeholder_63, [7, 7, 3, 64], dtype="int16") + placeholder_67 = T.match_buffer(placeholder_64, [1, 1, 1, 64], dtype="int32") + T_cast_21 = T.match_buffer(T_cast_20, [1, 112, 112, 64], dtype="uint8") + fast_memory_4_buffer_var = T.match_buffer(fast_memory_4_var, [200704], dtype="uint8", strides=[1], elem_offset=1, align=16) + slow_memory_5_buffer_var = T.match_buffer(slow_memory_5_var, [1418528], dtype="uint8", strides=[1], elem_offset=1, align=16) + # body + PaddedInput_7_let: T.handle = T.address_of(T.load("uint8", slow_memory_5_buffer_var.data, 802816), dtype="handle") + for i0_i1_fused_7, i2_7, i3_7 in T.grid(229, 229, 3): + T.store(PaddedInput_7_let, i0_i1_fused_7 * 687 + i2_7 * 3 + i3_7, T.if_then_else(2 <= i0_i1_fused_7 and i0_i1_fused_7 < 226 and 2 <= i2_7 and i2_7 < 226, T.load("int16", placeholder_65.data, i0_i1_fused_7 * 672 + i2_7 * 3 + i3_7 - 1350), T.int16(0), dtype="int16"), True) + for ax0_ax1_fused_ax2_fused_7 in T.serial(0, 12544): + Conv2dOutput_7_let: T.handle = T.address_of(T.load("uint8", fast_memory_4_buffer_var.data, 0), dtype="handle") + for ff_3 in T.serial(0, 64): + T.store(Conv2dOutput_7_let, ff_3, 0, True) + for ry_2, rx_2, rc_7 in T.grid(7, 7, 3): + T.store(Conv2dOutput_7_let, ff_3, T.load("int32", Conv2dOutput_7_let, ff_3) + T.cast(T.load("int16", PaddedInput_7_let, ax0_ax1_fused_ax2_fused_7 // 112 * 1374 + ry_2 * 687 + ax0_ax1_fused_ax2_fused_7 % 112 * 6 + rx_2 * 3 + rc_7), "int32") * T.cast(T.load("int16", placeholder_66.data, ry_2 * 1344 + rx_2 * 192 + rc_7 * 64 + ff_3), "int32"), True) + for ax3_inner_7 in T.serial(0, 64): + T.store(T_cast_21.data, ax0_ax1_fused_ax2_fused_7 * 64 + ax3_inner_7, T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_7_let, ax3_inner_7) + T.load("int32", placeholder_67.data, ax3_inner_7), 1939887962, 31, -9, dtype="int32"), 255), 0), "uint8"), True) +# fmt: on + + +def test_mobilenet_subgraph(): + target = Target("c") + fast_memory_pool = usmp_utils.PoolInfo( + pool_name="fast_memory", + target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + size_hint_bytes=200704, + ) + slow_memory_pool = usmp_utils.PoolInfo( + pool_name="slow_memory", target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS} + ) + tir_mod = LinearStructure + tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target) + tir_mod = assign_poolinfos_to_allocates_in_irmodule( + tir_mod, [fast_memory_pool, slow_memory_pool] + ) + main_func = tir_mod["run_model"] + buffer_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod) + buffer_info_map = buffer_analysis.buffer_info_stmts + + fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo") + buffer_info_arr = fcreate_array_bi(buffer_info_map) + fusmp_algo_greedy_by_size = tvm.get_global_func("tir.usmp.algo.greedy_by_size") + buffer_pool_allocations = fusmp_algo_greedy_by_size( + buffer_info_arr, buffer_analysis.memory_pressure + ) + fassign_stmt_pool_allocations = tvm.get_global_func("tir.usmp.AssignStmtPoolAllocations") + pool_allocations = fassign_stmt_pool_allocations(buffer_info_map, buffer_pool_allocations) + tir_mod_with_offsets = tvm.tir.usmp.transform.convert_pool_allocations_to_offsets( + pool_allocations, emit_tvmscript_printable=True + )(tir_mod) + + tir_mod_with_offsets_ref = LinearStructurePlanned + tir_mod_with_offsets_ref = tvm.script.from_source( + tir_mod_with_offsets_ref.script(show_meta=False) + ) + # The TIR produced fails on roundtrip TVMScript testing. + # Therefore, indicates the TVMScript produced here and/or the parser + # is lacking functionality. Thus for these tests, uses a string + # version of the TVMScript for each function as a check instead. + for gv, func in tir_mod_with_offsets_ref.functions.items(): + assert str(tir_mod_with_offsets_ref[gv.name_hint].script()) == str( + tir_mod_with_offsets[gv.name_hint].script() + ) + + +# fmt: off +@tvm.script.ir_module +class ResnetStructure: + @T.prim_func + def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", "tir.noalias": True}) + placeholder_2 = T.match_buffer(placeholder, [1, 75, 75, 64], dtype="uint8") + placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") + T_cast_1 = T.match_buffer(T_cast, [1, 75, 75, 64], dtype="int16") + # body + for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): + T.store(T_cast_1.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.load("uint8", placeholder_2.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner), "int32") - 94, 1843157232, 31, 1, dtype="int32") + T.load("int32", placeholder_3.data, ax3_outer * 16 + ax3_inner), 255), 0), "uint8"), "int16"), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", "tir.noalias": True}) + placeholder_13 = T.match_buffer(placeholder_10, [1, 75, 75, 64], dtype="int16") + placeholder_14 = T.match_buffer(placeholder_11, [3, 3, 64, 64], dtype="int16") + placeholder_15 = T.match_buffer(placeholder_12, [1, 1, 1, 64], dtype="int32") + T_cast_5 = T.match_buffer(T_cast_4, [1, 75, 75, 64], dtype="int16") + # body + PaddedInput_1 = T.allocate([379456], "int16", "global") + for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): + T.store(PaddedInput_1, i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1, T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, T.load("int16", placeholder_13.data, i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864), T.int16(0), dtype="int16"), True) + for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625): + Conv2dOutput_1 = T.allocate([64], "int32", "global") + for ff_1 in T.serial(0, 64): + T.store(Conv2dOutput_1, ff_1, 0, True) + for ry, rx, rc_1 in T.grid(3, 3, 64): + T.store(Conv2dOutput_1, ff_1, T.load("int32", Conv2dOutput_1, ff_1) + T.cast(T.load("int16", PaddedInput_1, T.floordiv(ax0_ax1_fused_ax2_fused_1, 75) * 4928 + ry * 4928 + rx * 64 + T.floormod(ax0_ax1_fused_ax2_fused_1, 75) * 64 + rc_1), "int32") * T.cast(T.load("int16", placeholder_14.data, ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1), "int32"), True) + for ax3_inner_2 in T.serial(0, 64): + T.store(T_cast_5.data, ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_1, ax3_inner_2) + T.load("int32", placeholder_15.data, ax3_inner_2), 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", "tir.noalias": True}) + placeholder_19 = T.match_buffer(placeholder_16, [1, 75, 75, 64], dtype="int16") + placeholder_20 = T.match_buffer(placeholder_17, [1, 1, 64, 256], dtype="int16") + placeholder_21 = T.match_buffer(placeholder_18, [1, 1, 1, 256], dtype="int32") + T_add_1 = T.match_buffer(T_add, [1, 75, 75, 256], dtype="int32") + # body + PaddedInput_2 = T.allocate([360000], "int16", "global") + for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64): + T.store(PaddedInput_2, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2, T.load("int16", placeholder_19.data, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2), True) + for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 5625): + Conv2dOutput_2 = T.allocate([64], "int32", "global") + for ax3_outer_1 in T.serial(0, 4): + for ff_2 in T.serial(0, 64): + T.store(Conv2dOutput_2, ff_2, 0, True) + for rc_2 in T.serial(0, 64): + T.store(Conv2dOutput_2, ff_2, T.load("int32", Conv2dOutput_2, ff_2) + T.cast(T.load("int16", PaddedInput_2, ax0_ax1_fused_ax2_fused_2 * 64 + rc_2), "int32") * T.cast(T.load("int16", placeholder_20.data, rc_2 * 256 + ax3_outer_1 * 64 + ff_2), "int32"), True) + for ax3_inner_3 in T.serial(0, 64): + T.store(T_add_1.data, ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3, T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_2, ax3_inner_3) + T.load("int32", placeholder_21.data, ax3_outer_1 * 64 + ax3_inner_3), 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136, True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", "tir.noalias": True}) + placeholder_29 = T.match_buffer(placeholder_22, [1, 75, 75, 64], dtype="int16") + placeholder_27 = T.match_buffer(placeholder_23, [1, 1, 64, 256], dtype="int16") + placeholder_26 = T.match_buffer(placeholder_24, [1, 1, 1, 256], dtype="int32") + placeholder_28 = T.match_buffer(placeholder_25, [1, 75, 75, 256], dtype="int32") + T_cast_7 = T.match_buffer(T_cast_6, [1, 75, 75, 256], dtype="uint8") + # body + PaddedInput_3 = T.allocate([360000], "int16", "global") + for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64): + T.store(PaddedInput_3, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3, T.load("int16", placeholder_29.data, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3), True) + for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 5625): + Conv2dOutput_3 = T.allocate([64], "int32", "global") + for ax3_outer_2 in T.serial(0, 4): + for ff_3 in T.serial(0, 64): + T.store(Conv2dOutput_3, ff_3, 0, True) + for rc_3 in T.serial(0, 64): + T.store(Conv2dOutput_3, ff_3, T.load("int32", Conv2dOutput_3, ff_3) + T.cast(T.load("int16", PaddedInput_3, ax0_ax1_fused_ax2_fused_3 * 64 + rc_3), "int32") * T.cast(T.load("int16", placeholder_27.data, rc_3 * 256 + ax3_outer_2 * 64 + ff_3), "int32"), True) + for ax3_inner_4 in T.serial(0, 64): + T.store(T_cast_7.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4, T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_3, ax3_inner_4) + T.load("int32", placeholder_26.data, ax3_outer_2 * 64 + ax3_inner_4), 1343014664, 31, -8, dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, dtype="int32") + T.load("int32", placeholder_28.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4), 255), 0), "uint8"), True) + + @T.prim_func + def run_model(input: T.handle, output: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "run_model", "runner_function": True}) + # body + T.attr("default", "device_id", 0) + T.attr("default", "device_type", 1) + sid_2 = T.allocate([720000], "int8", "global") + sid_6 = T.allocate([5760000], "int8", "global") + sid_7 = T.allocate([720000], "int8", "global") + sid_8 = T.allocate([720000], "int8", "global") + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", sid_7, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_6, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", sid_2, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_6, output, dtype="int32")) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", "tir.noalias": True}) + placeholder_7 = T.match_buffer(placeholder_4, [1, 75, 75, 64], dtype="int16") + placeholder_8 = T.match_buffer(placeholder_5, [1, 1, 64, 64], dtype="int16") + placeholder_9 = T.match_buffer(placeholder_6, [1, 1, 1, 64], dtype="int32") + T_cast_3 = T.match_buffer(T_cast_2, [1, 75, 75, 64], dtype="int16") + # body + PaddedInput = T.allocate([360000], "int16", "global") + for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): + T.store(PaddedInput, i0_i1_fused * 4800 + i2 * 64 + i3, T.load("int16", placeholder_7.data, i0_i1_fused * 4800 + i2 * 64 + i3), True) + for ax0_ax1_fused_ax2_fused in T.serial(0, 5625): + Conv2dOutput = T.allocate([64], "int32", "global") + for ff in T.serial(0, 64): + T.store(Conv2dOutput, ff, 0, True) + for rc in T.serial(0, 64): + T.store(Conv2dOutput, ff, T.load("int32", Conv2dOutput, ff) + T.cast(T.load("int16", PaddedInput, ax0_ax1_fused_ax2_fused * 64 + rc), "int32") * T.cast(T.load("int16", placeholder_8.data, rc * 64 + ff), "int32"), True) + for ax3_inner_1 in T.serial(0, 64): + T.store(T_cast_3.data, ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput, ax3_inner_1) + T.load("int32", placeholder_9.data, ax3_inner_1), 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True) +# fmt: on + + +# fmt: off +@tvm.script.ir_module +class ResnetStructurePlanned: + @T.prim_func + def tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast(placeholder: T.handle, placeholder_1: T.handle, T_cast: T.handle, global_workspace_1_var: T.handle) -> None: + placeholder_2 = T.match_buffer(placeholder, [1, 75, 75, 64], dtype="uint8") + placeholder_3 = T.match_buffer(placeholder_1, [64], dtype="int32") + T_cast_1 = T.match_buffer(T_cast, [1, 75, 75, 64], dtype="int16") + global_workspace_1_buffer_var = T.match_buffer(global_workspace_1_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + # body + for ax0_ax1_fused, ax2, ax3_outer, ax3_inner in T.grid(75, 75, 4, 16): + T.store(T_cast_1.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.load("uint8", placeholder_2.data, ax0_ax1_fused * 4800 + ax2 * 64 + ax3_outer * 16 + ax3_inner), "int32") - 94, 1843157232, 31, 1, dtype="int32") + T.load("int32", placeholder_3.data, ax3_outer * 16 + ax3_inner), 255), 0), "uint8"), "int16"), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_(placeholder_22: T.handle, placeholder_23: T.handle, placeholder_24: T.handle, placeholder_25: T.handle, T_cast_6: T.handle, global_workspace_5_var: T.handle) -> None: + placeholder_29 = T.match_buffer(placeholder_22, [1, 75, 75, 64], dtype="int16") + placeholder_27 = T.match_buffer(placeholder_23, [1, 1, 64, 256], dtype="int16") + placeholder_26 = T.match_buffer(placeholder_24, [1, 1, 1, 256], dtype="int32") + placeholder_28 = T.match_buffer(placeholder_25, [1, 75, 75, 256], dtype="int32") + T_cast_7 = T.match_buffer(T_cast_6, [1, 75, 75, 256], dtype="uint8") + global_workspace_5_buffer_var = T.match_buffer(global_workspace_5_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + # body + PaddedInput_3_let: T.handle = T.address_of(T.load("uint8", global_workspace_5_buffer_var.data, 6480000), dtype="handle") + for i0_i1_fused_3, i2_3, i3_3 in T.grid(75, 75, 64): + T.store(PaddedInput_3_let, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3, T.load("int16", placeholder_29.data, i0_i1_fused_3 * 4800 + i2_3 * 64 + i3_3), True) + for ax0_ax1_fused_ax2_fused_3 in T.serial(0, 5625): + Conv2dOutput_3_let: T.handle = T.address_of(T.load("uint8", global_workspace_5_buffer_var.data, 7200000), dtype="handle") + for ax3_outer_2 in T.serial(0, 4): + for ff_3 in T.serial(0, 64): + T.store(Conv2dOutput_3_let, ff_3, 0, True) + for rc_3 in T.serial(0, 64): + T.store(Conv2dOutput_3_let, ff_3, T.load("int32", Conv2dOutput_3_let, ff_3) + T.cast(T.load("int16", PaddedInput_3_let, ax0_ax1_fused_ax2_fused_3 * 64 + rc_3), "int32") * T.cast(T.load("int16", placeholder_27.data, rc_3 * 256 + ax3_outer_2 * 64 + ff_3), "int32"), True) + for ax3_inner_4 in T.serial(0, 64): + T.store(T_cast_7.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4, T.cast(T.max(T.min(T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_3_let, ax3_inner_4) + T.load("int32", placeholder_26.data, ax3_outer_2 * 64 + ax3_inner_4), 1343014664, 31, -8, dtype="int32") + 136, 255), 0), "uint8"), "int32") - 136, 1073903788, 31, 1, dtype="int32") + T.load("int32", placeholder_28.data, ax0_ax1_fused_ax2_fused_3 * 256 + ax3_outer_2 * 64 + ax3_inner_4), 255), 0), "uint8"), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_(placeholder_16: T.handle, placeholder_17: T.handle, placeholder_18: T.handle, T_add: T.handle, global_workspace_4_var: T.handle) -> None: + placeholder_19 = T.match_buffer(placeholder_16, [1, 75, 75, 64], dtype="int16") + placeholder_20 = T.match_buffer(placeholder_17, [1, 1, 64, 256], dtype="int16") + placeholder_21 = T.match_buffer(placeholder_18, [1, 1, 1, 256], dtype="int32") + T_add_1 = T.match_buffer(T_add, [1, 75, 75, 256], dtype="int32") + global_workspace_4_buffer_var = T.match_buffer(global_workspace_4_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + # body + PaddedInput_2_let: T.handle = T.address_of(T.load("uint8", global_workspace_4_buffer_var.data, 7200000), dtype="handle") + for i0_i1_fused_2, i2_2, i3_2 in T.grid(75, 75, 64): + T.store(PaddedInput_2_let, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2, T.load("int16", placeholder_19.data, i0_i1_fused_2 * 4800 + i2_2 * 64 + i3_2), True) + for ax0_ax1_fused_ax2_fused_2 in T.serial(0, 5625): + Conv2dOutput_2_let: T.handle = T.address_of(T.load("uint8", global_workspace_4_buffer_var.data, 7920000), dtype="handle") + for ax3_outer_1 in T.serial(0, 4): + for ff_2 in T.serial(0, 64): + T.store(Conv2dOutput_2_let, ff_2, 0, True) + for rc_2 in T.serial(0, 64): + T.store(Conv2dOutput_2_let, ff_2, T.load("int32", Conv2dOutput_2_let, ff_2) + T.cast(T.load("int16", PaddedInput_2_let, ax0_ax1_fused_ax2_fused_2 * 64 + rc_2), "int32") * T.cast(T.load("int16", placeholder_20.data, rc_2 * 256 + ax3_outer_1 * 64 + ff_2), "int32"), True) + for ax3_inner_3 in T.serial(0, 64): + T.store(T_add_1.data, ax0_ax1_fused_ax2_fused_2 * 256 + ax3_outer_1 * 64 + ax3_inner_3, T.q_multiply_shift(T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_2_let, ax3_inner_3) + T.load("int32", placeholder_21.data, ax3_outer_1 * 64 + ax3_inner_3), 1711626602, 31, -8, dtype="int32") + 132, 255), 0), "uint8"), "int32") - 132, 2094289803, 31, -2, dtype="int32") + 136, True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(placeholder_4: T.handle, placeholder_5: T.handle, placeholder_6: T.handle, T_cast_2: T.handle, global_workspace_2_var: T.handle) -> None: + placeholder_7 = T.match_buffer(placeholder_4, [1, 75, 75, 64], dtype="int16") + placeholder_8 = T.match_buffer(placeholder_5, [1, 1, 64, 64], dtype="int16") + placeholder_9 = T.match_buffer(placeholder_6, [1, 1, 1, 64], dtype="int32") + T_cast_3 = T.match_buffer(T_cast_2, [1, 75, 75, 64], dtype="int16") + global_workspace_2_buffer_var = T.match_buffer(global_workspace_2_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + # body + PaddedInput_let: T.handle = T.address_of(T.load("uint8", global_workspace_2_buffer_var.data, 7200000), dtype="handle") + for i0_i1_fused, i2, i3 in T.grid(75, 75, 64): + T.store(PaddedInput_let, i0_i1_fused * 4800 + i2 * 64 + i3, T.load("int16", placeholder_7.data, i0_i1_fused * 4800 + i2 * 64 + i3), True) + for ax0_ax1_fused_ax2_fused in T.serial(0, 5625): + Conv2dOutput_let: T.handle = T.address_of(T.load("uint8", global_workspace_2_buffer_var.data, 7920000), dtype="handle") + for ff in T.serial(0, 64): + T.store(Conv2dOutput_let, ff, 0, True) + for rc in T.serial(0, 64): + T.store(Conv2dOutput_let, ff, T.load("int32", Conv2dOutput_let, ff) + T.cast(T.load("int16", PaddedInput_let, ax0_ax1_fused_ax2_fused * 64 + rc), "int32") * T.cast(T.load("int16", placeholder_8.data, rc * 64 + ff), "int32"), True) + for ax3_inner_1 in T.serial(0, 64): + T.store(T_cast_3.data, ax0_ax1_fused_ax2_fused * 64 + ax3_inner_1, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_let, ax3_inner_1) + T.load("int32", placeholder_9.data, ax3_inner_1), 1843106743, 31, -6, dtype="int32"), 255), 0), "uint8"), "int16"), True) + + @T.prim_func + def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1(placeholder_10: T.handle, placeholder_11: T.handle, placeholder_12: T.handle, T_cast_4: T.handle, global_workspace_3_var: T.handle) -> None: + placeholder_13 = T.match_buffer(placeholder_10, [1, 75, 75, 64], dtype="int16") + placeholder_14 = T.match_buffer(placeholder_11, [3, 3, 64, 64], dtype="int16") + placeholder_15 = T.match_buffer(placeholder_12, [1, 1, 1, 64], dtype="int32") + T_cast_5 = T.match_buffer(T_cast_4, [1, 75, 75, 64], dtype="int16") + global_workspace_3_buffer_var = T.match_buffer(global_workspace_3_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + # body + PaddedInput_1_let: T.handle = T.address_of(T.load("uint8", global_workspace_3_buffer_var.data, 0), dtype="handle") + for i0_i1_fused_1, i2_1, i3_1 in T.grid(77, 77, 64): + T.store(PaddedInput_1_let, i0_i1_fused_1 * 4928 + i2_1 * 64 + i3_1, T.if_then_else(1 <= i0_i1_fused_1 and i0_i1_fused_1 < 76 and 1 <= i2_1 and i2_1 < 76, T.load("int16", placeholder_13.data, i0_i1_fused_1 * 4800 + i2_1 * 64 + i3_1 - 4864), T.int16(0), dtype="int16"), True) + for ax0_ax1_fused_ax2_fused_1 in T.serial(0, 5625): + Conv2dOutput_1_let: T.handle = T.address_of(T.load("uint8", global_workspace_3_buffer_var.data, 7200000), dtype="handle") + for ff_1 in T.serial(0, 64): + T.store(Conv2dOutput_1_let, ff_1, 0, True) + for ry, rx, rc_1 in T.grid(3, 3, 64): + T.store(Conv2dOutput_1_let, ff_1, T.load("int32", Conv2dOutput_1_let, ff_1) + T.cast(T.load("int16", PaddedInput_1_let, ax0_ax1_fused_ax2_fused_1 // 75 * 4928 + ry * 4928 + rx * 64 + ax0_ax1_fused_ax2_fused_1 % 75 * 64 + rc_1), "int32") * T.cast(T.load("int16", placeholder_14.data, ry * 12288 + rx * 4096 + rc_1 * 64 + ff_1), "int32"), True) + for ax3_inner_2 in T.serial(0, 64): + T.store(T_cast_5.data, ax0_ax1_fused_ax2_fused_1 * 64 + ax3_inner_2, T.cast(T.cast(T.max(T.min(T.q_multiply_shift(T.load("int32", Conv2dOutput_1_let, ax3_inner_2) + T.load("int32", placeholder_15.data, ax3_inner_2), 1608879842, 31, -7, dtype="int32"), 255), 0), "uint8"), "int16"), True) + + @T.prim_func + def run_model(input: T.handle, output: T.handle, global_workspace_0_var: T.handle) -> None: + global_workspace_0_buffer_var = T.match_buffer(global_workspace_0_var, [7920256], dtype="uint8", strides=[1], elem_offset=1, align=16) + # body + T.attr("default", "device_id", 0) + T.attr("default", "device_type", 1) + sid_2_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 5760000), dtype="handle") + sid_6_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 0), dtype="handle") + sid_7_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 6480000), dtype="handle") + sid_8_let: T.handle = T.address_of(T.load("uint8", global_workspace_0_buffer_var.data, 6480000), dtype="handle") + T.evaluate(T.call_extern("tvmgen_default_fused_cast_subtract_fixed_point_multiply_add_clip_cast_cast", input, T.lookup_param("p0", dtype="handle"), sid_2_let, global_workspace_0_buffer_var.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast", sid_2_let, T.lookup_param("p3", dtype="handle"), T.lookup_param("p4", dtype="handle"), sid_8_let, global_workspace_0_buffer_var.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast_1", sid_8_let, T.lookup_param("p5", dtype="handle"), T.lookup_param("p6", dtype="handle"), sid_7_let, global_workspace_0_buffer_var.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_15934180698220515269_", sid_7_let, T.lookup_param("p7", dtype="handle"), T.lookup_param("p8", dtype="handle"), sid_6_let, global_workspace_0_buffer_var.data, dtype="int32")) + T.evaluate(T.call_extern("tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_add_clip_cast_cast_subtract_fixed_point_4200876283395191415_", sid_2_let, T.lookup_param("p1", dtype="handle"), T.lookup_param("p2", dtype="handle"), sid_6_let, output, global_workspace_0_buffer_var.data, dtype="int32")) + __tvm_meta__ = None +# fmt: on + + +def test_resnet_subgraph(): + target = Target("c") + global_workspace_pool = usmp_utils.PoolInfo( + pool_name="global_workspace", + target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}, + ) + tir_mod = ResnetStructure + tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target) + tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_workspace_pool]) + main_func = tir_mod["run_model"] + buffer_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod) + buffer_info_map = buffer_analysis.buffer_info_stmts + + fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo") + buffer_info_arr = fcreate_array_bi(buffer_info_map) + fusmp_algo_greedy_by_size = tvm.get_global_func("tir.usmp.algo.greedy_by_size") + buffer_pool_allocations = fusmp_algo_greedy_by_size( + buffer_info_arr, buffer_analysis.memory_pressure + ) + fassign_stmt_pool_allocations = tvm.get_global_func("tir.usmp.AssignStmtPoolAllocations") + pool_allocations = fassign_stmt_pool_allocations(buffer_info_map, buffer_pool_allocations) + tir_mod_with_offsets = tvm.tir.usmp.transform.convert_pool_allocations_to_offsets( + pool_allocations, emit_tvmscript_printable=True + )(tir_mod) + + tir_mod_with_offsets_ref = ResnetStructurePlanned + + # The TIR produced fails on roundtrip TVMScript testing. + # Therefore, indicates the TVMScript produced here and/or the parser + # is lacking functionality. Thus for these tests, uses a string + # version of the TVMScript for each function as a check instead. + for gv, func in tir_mod_with_offsets_ref.functions.items(): + assert str(tir_mod_with_offsets_ref[gv.name_hint].script()) == str( + tir_mod_with_offsets[gv.name_hint].script() + ) + + +if __name__ == "__main__": + pytest.main([__file__] + sys.argv[1:])