Skip to content

Commit

Permalink
[TIR][USMP] adding the pass to convert to pool offsets
Browse files Browse the repository at this point in the history
This commit adds a transform pass that consumes
the planned pool allocations using memory planning algorithm
that convertes them to pool offsets.

* adds two test cases for a linear structure with two pools
* adds test case with a single pool for residual structures

Change-Id: I9d31e854461b5c21df72d1452120d286b96791c0
  • Loading branch information
manupak committed Nov 1, 2021
1 parent ceb9abe commit 93a81d8
Show file tree
Hide file tree
Showing 12 changed files with 921 additions and 18 deletions.
7 changes: 7 additions & 0 deletions include/tvm/tir/usmp/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,13 @@ Array<BufferInfo> CreateArrayBufferInfo(const Map<Stmt, BufferInfo>& buffer_info
*/
static constexpr const char* kPoolCandidatesAllocateAttr = "candidate_memory_pools";

/*!
* \brief Calculate the size of the extents in bytes
*
* \param op the allocate node
*/
Integer CalculateExtentsSize(const AllocateNode* op);

} // namespace usmp
} // namespace tir
} // namespace tvm
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/script/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""TVMScript for TIR"""

# Type system
from .ty import int8, int16, int32, int64, float16, float32, float64
from .ty import uint8, int8, int16, int32, int64, float16, float32, float64
from .ty import boolean, handle, Ptr, Tuple

from .prim_func import prim_func
1 change: 1 addition & 0 deletions python/tvm/script/tir/ty.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __getitem__(self, vtypes):
return ConcreteType(tvm.ir.TupleType([vtype.evaluate() for vtype in vtypes]))


uint8 = ConcreteType("uint8")
int8 = ConcreteType("int8")
int16 = ConcreteType("int16")
int32 = ConcreteType("int32")
Expand Down
1 change: 1 addition & 0 deletions python/tvm/tir/usmp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@
"""Namespace for Unified Static Memory Planner"""

from . import analysis
from . import transform
from .utils import BufferInfo
20 changes: 20 additions & 0 deletions python/tvm/tir/usmp/transform/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-import, redefined-builtin
"""Namespace for Unified Static Memory Planner"""

from .transform import convert_pool_allocations_to_offsets
21 changes: 21 additions & 0 deletions python/tvm/tir/usmp/transform/_ffi_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI APIs for tvm.tir.usmp.analysis"""
import tvm._ffi


tvm._ffi._init_api("tir.usmp.transform", __name__)
40 changes: 40 additions & 0 deletions python/tvm/tir/usmp/transform/transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""USMP Transform Python API for passes"""
# pylint: disable=invalid-name

from typing import Dict

from . import _ffi_api
from ....tir import Stmt
from ..utils import PoolAllocation


def convert_pool_allocations_to_offsets(pool_allocations: Dict[Stmt, PoolAllocation]):
"""Convert pool allocations to Load nodes with offsets from pools.
Parameters
----------
pool_allocations : Dict[Stmt, PoolAllocation]
Allocate or AllocateConst node to pool allocation mapping
Returns
-------
ret: tvm.transform.Pass
The registered pass that converts the allocations to offsets.
"""
return _ffi_api.ConvertPoolAllocationsToOffsets(pool_allocations)
7 changes: 4 additions & 3 deletions src/printer/text_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -417,10 +417,11 @@ class TextPrinter {

Doc PrintFinal(const ObjectRef& node) {
Doc doc;
if (node->IsInstance<IRModuleNode>()) {
if (node.defined() && node->IsInstance<IRModuleNode>()) {
doc << PrintMod(Downcast<IRModule>(node));
} else if (node->IsInstance<tir::PrimFuncNode>() || node->IsInstance<PrimExprNode>() ||
node->IsInstance<tir::StmtNode>()) {
} else if (node.defined() &&
(node->IsInstance<tir::PrimFuncNode>() || node->IsInstance<PrimExprNode>() ||
node->IsInstance<tir::StmtNode>())) {
doc << tir_text_printer_.Print(node);
} else {
doc << relay_text_printer_.PrintFinal(node);
Expand Down
14 changes: 0 additions & 14 deletions src/tir/usmp/analysis/extract_buffer_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,20 +92,6 @@ void BufferInfoExtractor::VisitStmt(const Stmt& n) {
StmtExprVisitor::VisitStmt(n);
}

static Integer CalculateExtentsSize(const AllocateNode* op) {
size_t element_size_bytes = op->dtype.bytes();
size_t num_elements = 1;
for (const auto& ext : op->extents) {
if (ext->IsInstance<IntImmNode>()) {
num_elements *= Downcast<IntImm>(ext)->value;
} else {
// We can't statically calculate workspace for dynamic shapes
return Integer();
}
}
return Integer(num_elements * element_size_bytes);
}

void BufferInfoExtractor::VisitStmt_(const AllocateNode* op) {
const auto& currect_scope_info = scope_stack_.top();
const auto& type = Downcast<PointerType>(op->buffer_var->type_annotation);
Expand Down
Loading

0 comments on commit 93a81d8

Please sign in to comment.