Skip to content

Commit

Permalink
[TIR] [Hexagon] Add get_vtcm_allocation_sizes with lowering (#14720)
Browse files Browse the repository at this point in the history
This patch adds an utility function for getting the VTCM sizes allocated
in an IRModule. In order to do that, we've exposed the list of lowering
passes to python and we've refactored the PostprocVerifyVTCMLimit to be
computed for whole module and the same list of lowering passes
  • Loading branch information
quic-sanirudh authored Apr 26, 2023
1 parent 6314b25 commit 486c498
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 39 deletions.
16 changes: 16 additions & 0 deletions include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,22 @@ TVM_DLL bool VerifyMemory(const PrimFunc& func);
*/
TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints);

/**
* @brief Utility function to get the list of lowering passes to be applied to calculate the
* compacted VTCM allocation size
*
* @return returns list of passes
*/
TVM_DLL Array<tvm::transform::Pass> GetVTCMCompactionPasses();

/*!
* \brief Verifies that the VTCM usage for all prim_funcs in the given IRModule
* \param mod The module to be checked
* \param limit The limit to check.
* \return true if the VTCM usage is within the provided limit.
*/
TVM_DLL bool VerifyVTCMLimit(const IRModule& mod, Integer limit);

/*!
* \brief Verifies that the VTCM usage of the given prim_func is within the provided limit.
* \param func The function to be checked.
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/tir/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# pylint: disable=invalid-name
from typing import Dict, List, Union

import tvm
from tvm import Object
from tvm.ir import IRModule
from tvm.tir.expr import Var
Expand Down Expand Up @@ -384,3 +385,15 @@ def find_anchor_block(mod: IRModule) -> Block:
The anchor block if found, None otherwise.
"""
return _ffi_api.find_anchor_block(mod) # type: ignore # pylint: disable=no-member


def get_vtcm_compaction_passes() -> List[tvm.transform.Pass]:
"""Utility function to get the list of lowering passes to be applied to calculate thecompacted
VTCM allocation size
Returns
-------
result : List[tvm.transform.Pass]
returns list of passes
"""
return _ffi_api.get_vtcm_compaction_passes() # type: ignore # pylint: disable=no-member
52 changes: 49 additions & 3 deletions python/tvm/topi/hexagon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
"""Common hexagon specific utilities"""
import math
import struct
from typing import Tuple
from tvm import te
from tvm.tir import IndexMap
from typing import Dict, Tuple, Union

import tvm
from tvm import IRModule, te
from tvm.tir import IndexMap, PrimFunc


def n11c_1024c_2d(n, h, w, c):
Expand Down Expand Up @@ -354,3 +356,47 @@ def within_range(val, dtype):
def saturate(x: te.Tensor, dtype: str):
"""Saturate value for the specified data type"""
return te.max(te.min_value(dtype), te.min(x, te.max_value(dtype)))


def get_vtcm_allocation_sizes(
func_or_mod: Union[PrimFunc, IRModule], compacted=True
) -> Dict[str, int]:
"""Calculate and return the vtcm allocation sizes for all the functions in
the IRModule or just the vtcm size if a single PrimFunc is passed
Parameters
----------
func_or_mod : Union[PrimFunc, IRModule]
PrimFunc or IRModule for which VTCM allocation size is to be calculated
compacted :
Whether to calculate the sizes after applying VTCM lowering passes for
buffer compaction. This helps return the VTCM size that would get
allocated after lowering
Returns
-------
result : Dict[str, int]
A dict with function names as keys and vtcm allocated
inside that function as values
"""
if not isinstance(func_or_mod, (PrimFunc, IRModule)):
raise TypeError(
f"Expected argument to be PrimFunc or IRModule, but received {type(func_or_mod)}"
)
if isinstance(func_or_mod, tvm.tir.PrimFunc):
mod = tvm.IRModule.from_expr(func_or_mod)
else:
mod = func_or_mod
if compacted:
passes = tvm.tir.analysis.get_vtcm_compaction_passes()
mod = tvm.transform.Sequential(list(passes))(mod)

result = {}
all_sizes = tvm.tir.analysis.calculate_allocated_bytes(mod)
for func_name, sizes in all_sizes.items():
if "global.vtcm" in sizes:
result[func_name] = sizes["global.vtcm"]
else:
result[func_name] = 0
return result
44 changes: 8 additions & 36 deletions src/meta_schedule/postproc/verify_vtcm_limit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,48 +36,20 @@ class VerifyVTCMLimitNode : public PostprocNode {
}

bool Verify(const IRModule& mod) const {
for (const auto& kv : mod->functions) {
if (auto prim_func = kv.second.as<tir::PrimFunc>()) {
if (!tir::VerifyVTCMLimit(prim_func.value(), vtcm_capacity)) {
return false;
}
}
if (!tir::VerifyVTCMLimit(mod, vtcm_capacity)) {
return false;
}
return true;
}

bool Apply(const tir::Schedule& sch) final {
IRModule mod = sch->mod();
for (const auto& kv : mod->functions) {
const GlobalVar& g_var = kv.first;
const BaseFunc& base_func = kv.second;
if (const auto* prim_func = base_func.as<tir::PrimFuncNode>()) {
IRModule lowered{nullptr};
try {
auto pass_list = Array<tvm::transform::Pass>();
pass_list.push_back(tir::transform::LowerInitBlock());
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
pass_list.push_back(tir::transform::CompactBufferAllocation());
pass_list.push_back(tir::transform::LowerMatchBuffer());
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
pass_list.push_back(tir::transform::LowerOpaqueBlock());
pass_list.push_back(tir::transform::FlattenBuffer());
pass_list.push_back(tir::transform::Simplify());
pass_list.push_back(tir::transform::VectorizeLoop(true));
pass_list.push_back(tir::transform::StorageRewrite());
transform::PassContext pass_ctx = transform::PassContext::Current();
tir::PrimFunc f = WithAttr(GetRef<tir::PrimFunc>(prim_func), "global_symbol",
runtime::String(g_var->name_hint));
IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(g_var->name_hint), f}}));
lowered = tvm::transform::Sequential(pass_list)(std::move(mod));
} catch (const dmlc::Error& e) {
return false;
}
if (!Verify(lowered)) {
return false;
}
}
IRModule lowered{nullptr};
auto pass_list = tir::GetVTCMCompactionPasses();
transform::PassContext pass_ctx = transform::PassContext::Current();
lowered = tvm::transform::Sequential(pass_list)(std::move(mod));
if (!Verify(lowered)) {
return false;
}
return true;
}
Expand Down
33 changes: 33 additions & 0 deletions src/tir/analysis/calculate_allocated_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/tir/analysis.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <tvm/tir/usmp/utils.h>

#include <algorithm>
Expand Down Expand Up @@ -109,6 +110,18 @@ TVM_REGISTER_GLOBAL("tir.analysis.calculate_allocated_bytes")
}
});

bool VerifyVTCMLimit(const IRModule& mod, Integer limit) {
auto all_sizes = CalculateAllocatedBytes(mod);
for (const auto& kv : all_sizes) {
auto sizes = kv.second;
const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0);
if (limit.IntValue() > 0 && vtcm_allocated.IntValue() > limit.IntValue()) {
return false;
}
}
return true;
}

bool VerifyVTCMLimit(const PrimFunc& func, Integer limit) {
auto sizes = CalculateAllocatedBytes(func)["main"];
const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0);
Expand All @@ -127,6 +140,26 @@ int64_t GetVTCMCapacity(Target target, const transform::PassContext& pass_ctx) {
return pass_ctx->GetConfig<Integer>("tir.vtcm_capacity", Integer(0)).value()->value;
}

Array<tvm::transform::Pass> GetVTCMCompactionPasses() {
auto pass_list = Array<tvm::transform::Pass>();
pass_list.push_back(tir::transform::LowerInitBlock());
pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation());
pass_list.push_back(tir::transform::ConvertBlocksToOpaque());
pass_list.push_back(tir::transform::CompactBufferAllocation());
pass_list.push_back(tir::transform::LowerMatchBuffer());
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
pass_list.push_back(tir::transform::LowerOpaqueBlock());
pass_list.push_back(tir::transform::FlattenBuffer());
pass_list.push_back(tir::transform::Simplify());
pass_list.push_back(tir::transform::VectorizeLoop(true));
pass_list.push_back(tir::transform::StorageRewrite());
return pass_list;
}

TVM_REGISTER_GLOBAL("tir.analysis.get_vtcm_compaction_passes").set_body_typed([]() {
return GetVTCMCompactionPasses();
});

namespace transform {

Pass VerifyVTCMLimit(Optional<Target> default_target) {
Expand Down

0 comments on commit 486c498

Please sign in to comment.