Skip to content

Commit

Permalink
[TIR][USMP] adding the pass to convert to pool offsets
Browse files Browse the repository at this point in the history
* rebase changes
* making imports absolute
* fixing typos and removing unnecesary lines

Change-Id: I4c94b9955b001513fecb39ca94f81b1ad99c7bfc
  • Loading branch information
manupak committed Dec 8, 2021
1 parent cd33589 commit 52a4f12
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 10 deletions.
7 changes: 4 additions & 3 deletions python/tvm/tir/usmp/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@

from typing import Dict

import tvm
from tvm.tir import Stmt
from tvm.tir.usmp.utils import PoolAllocation
from . import _ffi_api
from ....tir import Stmt
from ..utils import PoolAllocation


def convert_pool_allocations_to_offsets(
pool_allocations: Dict[Stmt, PoolAllocation], emit_tvmscript_printable: bool = False
):
) -> tvm.transform.Pass:
"""Convert pool allocations to Load nodes with offsets from pools.
Parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ namespace usmp {
* \brief The StmtExpr mutator class to replace allocate nodes
* with offsets within memory pools
*
* This mutator class with add Pool variables recursively to every PrimFunc
* This mutator class will add Pool variables recursively to every PrimFunc
* starting from the main PrimFunc. For all allocate nodes, that have been
* memory planned, will be mutated into an offset using a Let binding.
*/
Expand Down Expand Up @@ -88,7 +88,6 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator {
private:
PrimExpr VisitExpr_(const CallNode* op) override;
Stmt VisitStmt_(const AllocateNode* op) override;
// PrimExpr VisitExpr_(const VarNode* op) override;
PrimExpr VisitExpr_(const LoadNode* op) override;
Stmt VisitStmt_(const StoreNode* op) override;

Expand Down Expand Up @@ -270,7 +269,6 @@ Stmt PoolAllocationToOffsetConverter::VisitStmt_(const AllocateNode* op) {
PoolAllocation pool_allocation = pool_allocations_[GetRef<Allocate>(op)];
Var param = scope_info.pools_to_params[pool_allocation->pool_info];
Buffer buffer_var = scope_info.buffer_map[param];
ICHECK(pool_allocation->byte_offset < all_pools_sizes_[pool_allocation->pool_info]);
Load load_node =
Load(DataType::UInt(8), buffer_var->data, pool_allocation->byte_offset, op->condition);
Call address_of_load = Call(DataType::Handle(8), builtin::address_of(), {load_node});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,12 +218,15 @@ def test_mobilenet_subgraph():
tir_mod, [fast_memory_pool, slow_memory_pool]
)
main_func = tir_mod["run_model"]
buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
buffer_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
buffer_info_map = buffer_analysis.buffer_info_stmts

fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
buffer_info_arr = fcreate_array_bi(buffer_info_map)
fusmp_algo_greedy_by_size = tvm.get_global_func("tir.usmp.algo.greedy_by_size")
buffer_pool_allocations = fusmp_algo_greedy_by_size(buffer_info_arr)
buffer_pool_allocations = fusmp_algo_greedy_by_size(
buffer_info_arr, buffer_analysis.memory_pressure
)
fassign_stmt_pool_allocations = tvm.get_global_func("tir.usmp.AssignStmtPoolAllocations")
pool_allocations = fassign_stmt_pool_allocations(buffer_info_map, buffer_pool_allocations)
tir_mod_with_offsets = tvm.tir.usmp.transform.convert_pool_allocations_to_offsets(
Expand Down Expand Up @@ -489,12 +492,15 @@ def test_resnet_subgraph():
tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_workspace_pool])
main_func = tir_mod["run_model"]
buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
buffer_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
buffer_info_map = buffer_analysis.buffer_info_stmts

fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
buffer_info_arr = fcreate_array_bi(buffer_info_map)
fusmp_algo_greedy_by_size = tvm.get_global_func("tir.usmp.algo.greedy_by_size")
buffer_pool_allocations = fusmp_algo_greedy_by_size(buffer_info_arr)
buffer_pool_allocations = fusmp_algo_greedy_by_size(
buffer_info_arr, buffer_analysis.memory_pressure
)
fassign_stmt_pool_allocations = tvm.get_global_func("tir.usmp.AssignStmtPoolAllocations")
pool_allocations = fassign_stmt_pool_allocations(buffer_info_map, buffer_pool_allocations)
tir_mod_with_offsets = tvm.tir.usmp.transform.convert_pool_allocations_to_offsets(
Expand Down

0 comments on commit 52a4f12

Please sign in to comment.