Skip to content

Commit

Permalink
[TIR][USMP] adding the pass to convert to pool offsets
Browse files Browse the repository at this point in the history
* removing unnecessary defitinitions
* remove global var map
* adding explaination for let bindings to pointer type

Change-Id: I31bd1a9f3057ee7f06252263565b0f75c51e6d13
  • Loading branch information
manupak committed Dec 6, 2021
1 parent 44594ac commit 977608e
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 18 deletions.
6 changes: 1 addition & 5 deletions include/tvm/tir/usmp/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,11 +254,7 @@ namespace attr {
* a PoolInfo Object in the form of a Map<Var, PoolInfo>.
*/
static constexpr const char* kPoolArgs = "pool_args";
/*!
* \brief This is a BaseFunc attribute to indicate which input var represent
* a PoolInfo Object in the form of a Map<Var, PoolInfo>.
*/
static constexpr const char* kPoolInfoIRModuleAttr = "pool_infos";

} // namespace attr

} // namespace tvm
Expand Down
2 changes: 2 additions & 0 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ LetStmt::LetStmt(Var var, PrimExpr value, Stmt body, Span span) {
ICHECK(value.defined());
ICHECK(body.defined());
auto vdtype = value.dtype();
// It is still valid to bind a pointer type
// var to a value that is of type handle.
if (var->type_annotation.as<PointerTypeNode>()) {
ICHECK(vdtype.is_handle());
} else {
Expand Down
27 changes: 14 additions & 13 deletions src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,21 @@ namespace tvm {
namespace tir {
namespace usmp {

/*!
* \brief The StmtExpr mutator class to replace allocate nodes
* with offsets within memory pools
*
* This mutator class with add Pool variables recursively to every PrimFunc
* starting from the main PrimFunc. For all allocate nodes, that have been
* memory planned, will be mutated into an offset using a Let binding.
*/
class PoolAllocationToOffsetConverter : public StmtExprMutator {
public:
explicit PoolAllocationToOffsetConverter(const IRModule& module,
const Map<tir::Stmt, PoolAllocation>& pool_allocations,
bool emit_tvmscript_printable = false)
PoolAllocationToOffsetConverter(const IRModule& module,
const Map<tir::Stmt, PoolAllocation>& pool_allocations,
bool emit_tvmscript_printable = false)
: pool_allocations_(pool_allocations), emit_tvmscript_printable_(emit_tvmscript_printable) {
module_ = module->ShallowCopy();
for (const auto& gv_func : module_->functions) {
function_global_vars_.Set(gv_func.first->name_hint, gv_func.first);
}
for (const auto& kv : pool_allocations) {
// TODO(@manupa-arm): add AllocateConstNode when it is available
ICHECK(kv.first->IsInstance<AllocateNode>());
Expand Down Expand Up @@ -135,10 +140,6 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator {
std::vector<AllocatedPoolInfo> allocated_pool_ordering_;
/*! \brief The storage of calculated pool size at init */
std::unordered_map<PoolInfo, int, ObjectPtrHash, ObjectPtrEqual> all_pools_sizes_;
/*! \brief The AoT codegen uses extern_calls due to some functions not being exposed in the TIR
* IRModule This maps maintains the map of which to each function
*/
Map<String, GlobalVar> function_global_vars_;
/*! \brief After mutation, each allocate buffer is replaced with tir::Var that is let bounded
* to position from a pool as designated by a PoolAllocation
*/
Expand Down Expand Up @@ -240,8 +241,8 @@ PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const CallNode* op) {
if (op->op.same_as(builtin::call_extern()) || op->op.same_as(builtin::tvm_call_cpacked())) {
String func_name = Downcast<StringImm>(op->args[0])->value;
Array<PrimExpr> new_args;
if (function_global_vars_.find(func_name) != function_global_vars_.end()) {
GlobalVar gv = function_global_vars_.at(func_name);
if (module_->ContainGlobalVar(func_name)) {
GlobalVar gv = module_->GetGlobalVar(func_name);
PrimFunc func = Downcast<PrimFunc>(module_->Lookup(gv));
PrimFunc prim_func = CreatePrimFuncWithPoolParams(func);
module_->Update(gv, prim_func);
Expand Down Expand Up @@ -304,7 +305,7 @@ PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const LoadNode* op) {
}

IRModule PoolAllocationToOffsetConverter::operator()() {
GlobalVar gv = function_global_vars_.at(::tvm::runtime::symbol::tvm_run_func_suffix);
GlobalVar gv = module_->GetGlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix);
PrimFunc main_func = Downcast<PrimFunc>(module_->Lookup(gv));
ScopeInfo si = UpdateFunctionScopeInfo(main_func);
this->scope_stack.push(si);
Expand Down

0 comments on commit 977608e

Please sign in to comment.