Skip to content

Commit

Permalink
[miroNPU] Add support for TFLite concatenate
Browse files Browse the repository at this point in the history
* Add legalization pass and is_valid checks for concatenate
* Add TIR pass for removing concatenates and replacing them with direct
  writes to the final buffer
* Add tests

Co-authored-by: Matthew Barrett <[email protected]>
  • Loading branch information
ekalda and mbaret committed Nov 25, 2021
1 parent 99b9d42 commit 5c3d5c3
Show file tree
Hide file tree
Showing 16 changed files with 563 additions and 30 deletions.
43 changes: 43 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,48 @@ def __call__(self, *args, **kwargs):
pass


class ConcatRewriter(DFPatternCallback):
"""The newer versions of TFLite converters return a concatenate operator that concatenates
tensors with same QNN params (if the QNN params of tensors were initially different,
the converter adds a requantize node), so this rewriter replaces the QNN concatenate with
"normal" concatenate"""

def __init__(self):
super().__init__(require_type=True, rewrite_once=True)
self.pattern = (
wildcard().has_attr({"Composite": ethosu_patterns.ConcatParams.composite_name})
)(None)

def callback(
self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map
) -> tvm.relay.Expr:
# Find the tensors that are inputs to the concat and the scales and zero points
concat_args = list()
for arg in post.args:
if isinstance(arg, tvm.relay.expr.Call):
concat_args.append(arg)

axis = post.op.body.attrs.axis
concat = relay.op.concatenate(relay.Tuple(concat_args), axis=axis)
return concat


@ir.transform.module_pass(opt_level=1)
class LegalizeConcat:
"""This is the pass that wraps ConcatRewriter"""

def transform_module(
self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext
) -> tvm.ir.IRModule:
for global_var, func in mod.functions.items():
func = rewrite(ConcatRewriter(), func)
mod.update_func(global_var, func)
return mod

def __call__(self, *args, **kwargs):
pass


@ir.transform.module_pass(opt_level=1)
class LegalizeEthosU:
"""This is the pass to call graph-rewrites to perform graph transformation
Expand All @@ -915,6 +957,7 @@ def transform_module(
mod = LegalizeMax()(mod)
mod = LegalizeShl()(mod)
mod = LegalizeAbs()(mod)
mod = LegalizeConcat()(mod)
mod = LegalizeReshape()(mod)
mod = LegalizeStridedSlice()(mod)
mod = LegalizeNoOps()(mod)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def get_binary_elementwise_params(
replace_pointer : tvm.tir.Var
The output pointer of the DMA write operation, which is to replace
the binary elementwise output pointer.
is_allocator : bool
Whether this operator allocates its output.
"""
attrs, body = get_op_attrs(stmt)
reversed_operands = attrs["reversed_operands"]
Expand All @@ -83,7 +86,7 @@ def get_binary_elementwise_params(
# Get feature map info
serial_ifm, _ = get_ifm_params(input_pointer, producers)
serial_ifm2, _ = get_ifm_params(input_pointer1, producers)
serial_ofm, replace_pointer = get_ofm_params(output_pointer, consumers)
serial_ofm, replace_pointer, is_allocator = get_ofm_params(output_pointer, consumers, producers)
# Get activation info
serial_activation = SerialActivation(
op=attrs["activation"], clip_min=attrs["clip_min"], clip_max=attrs["clip_max"]
Expand All @@ -100,4 +103,5 @@ def get_binary_elementwise_params(
),
output_pointer,
replace_pointer,
is_allocator,
)
10 changes: 7 additions & 3 deletions python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tvm.relay.expr_functor import ExprMutator
from tvm.driver.build_module import schedule_to_module

from .passes import ReplaceOperators, RemoveZeroStores, EncodeConstants
from .passes import ReplaceOperators, RemoveZeroStores, EncodeConstants, RemoveConcatenates
from .scheduler import schedule


Expand Down Expand Up @@ -76,6 +76,7 @@ def lower_ethosu(sch, args, const_dict, name="main"):
mod = schedule_to_module(sch, args, name)

mod = tvm.tir.transform.Simplify()(mod)
mod = RemoveConcatenates()(mod)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.tir.transform.UnrollLoop()(mod)
mod = tvm.tir.transform.Simplify()(mod)
Expand Down Expand Up @@ -117,19 +118,22 @@ class ExtractConstants(ExprMutator):
def __init__(self):
super().__init__()
self.constants = []
self.const_vars = []

def visit_constant(self, const):
if isinstance(const.checked_type, relay.ty.TensorType):
if const.checked_type.concrete_shape != ():
self.constants.append(const.data.asnumpy())
name = "p" + str(len(self.constants))
return relay.var(type_annotation=const.checked_type, name_hint=name)
var = relay.var(type_annotation=const.checked_type, name_hint=name)
self.const_vars.append(var)
return var

return const

def visit_function(self, fn):
new_body = self.visit(fn.body)
new_params = list(relay.analysis.free_vars(new_body))
new_params = list(fn.params) + self.const_vars
return relay.Function(new_params, new_body)

def extract_constants(self, func):
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/relay/backend/contrib/ethosu/tir/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def get_conv2d_params(stmt, producers, consumers):
replace_pointer : tvm.tir.Var
The output pointer of the DMA write operation, which is to replace
the convolution output pointer.
is_allocator : bool
Whether this operator allocates its output.
"""
attrs, body = get_op_attrs(stmt)
Expand All @@ -61,7 +63,7 @@ def get_conv2d_params(stmt, producers, consumers):
output_pointer = stores[0].buffer_var
# Get feature map info
serial_ifm, serial_padding = get_ifm_params(input_pointer, producers)
serial_ofm, replace_pointer = get_ofm_params(output_pointer, consumers)
serial_ofm, replace_pointer, is_allocator = get_ofm_params(output_pointer, consumers, producers)
# Get kernel info
serial_kernel = SerialKernel(
width=int(rw.extent),
Expand Down Expand Up @@ -104,4 +106,5 @@ def get_conv2d_params(stmt, producers, consumers):
),
output_pointer,
replace_pointer,
is_allocator,
)
5 changes: 4 additions & 1 deletion python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def get_depthwise_conv2d_params(
replace_pointer : tvm.tir.Var
The output pointer of the DMA write operation, which is to replace
the convolution output pointer.
is_allocator : bool
Whether this operator allocates its output.
"""
attrs, body = get_op_attrs(stmt)
Expand All @@ -70,7 +72,7 @@ def get_depthwise_conv2d_params(
output_pointer = stores[0].buffer_var
# Get feature map info
serial_ifm, serial_padding = get_ifm_params(input_pointer, producers)
serial_ofm, replace_pointer = get_ofm_params(output_pointer, consumers)
serial_ofm, replace_pointer, is_allocator = get_ofm_params(output_pointer, consumers, producers)
# Get kernel info
serial_kernel = SerialKernel(
width=int(rw.extent),
Expand Down Expand Up @@ -114,4 +116,5 @@ def get_depthwise_conv2d_params(
),
output_pointer,
replace_pointer,
is_allocator,
)
14 changes: 12 additions & 2 deletions python/tvm/relay/backend/contrib/ethosu/tir/dma.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def get_ifm_params(pointer, producers):
return serial_ifm, serial_padding


def get_ofm_params(pointer, consumers):
def get_ofm_params(pointer, consumers, producers):
"""Get the parameters associated with the DMA capabilities for an OFM.
Parameters
Expand All @@ -282,18 +282,28 @@ def get_ofm_params(pointer, consumers):
consumers : dict of tvm.tir.Var to tvm.tir.AttrStmt
A dictionary to associate pointers with the loop nest
that consumes their values.
producers : dict of tvm.tir.Var to tvm.tir.AttrStmt
A dictionary to associate pointers with the loop nest
that produces their values.
Returns
-------
serial_ifm : SerialFeatureMap
The serializable OFM.
output_pointer : tvm.tir.Var
The pointer that the OFM DMA pipeline produces.
is_allocator : bool
Whether this operator allocates its output.
"""
convert_to_nhcwb16 = consumers[pointer]
out_channels, _, output_pointer = get_convert_to_nhcwb16_params(convert_to_nhcwb16)
write = consumers[output_pointer]
serial_ofm, _, output_pointer = get_write_params(write)
is_allocator = True
if output_pointer not in producers:
is_allocator = False
elif producers[output_pointer] != write:
is_allocator = False
serial_ofm.channels = out_channels
return serial_ofm, output_pointer
return serial_ofm, output_pointer, is_allocator
17 changes: 14 additions & 3 deletions python/tvm/relay/backend/contrib/ethosu/tir/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,14 @@ def _get_feature_map(stmt: tvm.tir.AttrStmt, fm_type: str) -> Tuple[SerialFeatur
assert loops[0].extent == 1
loops = loops[1:]

fm_inner = inner.value if fm_type == "ifm" else inner

stride_vars = [l.loop_var for l in loops]
strides = get_strides(inner.value.index, stride_vars)
strides = get_strides(fm_inner.index, stride_vars)

base_address = get_base_address(inner.value.index)
base_address = get_base_address(fm_inner.index)
data_type = inner.buffer_var.type_annotation.element_type.dtype
pointer = inner.value.buffer_var if fm_type == "ifm" else inner.buffer_var
pointer = fm_inner.buffer_var

serial_feature_map = SerialFeatureMap(
data_type=data_type,
Expand Down Expand Up @@ -116,6 +118,8 @@ def get_identity_params(
replace_pointer : tvm.tir.Var
The output pointer of the DMA write operation, which is to replace
the pooling output pointer.
is_allocator : bool
Whether this operator allocates its output.
"""
attrs, _ = get_op_attrs(stmt)
Expand All @@ -134,6 +138,12 @@ def get_identity_params(

replace_pointer = write_output_pointer

is_allocator = True
if write_output_pointer not in producers:
is_allocator = False
elif producers[write_output_pointer] != write:
is_allocator = False

# TODO: We might want to support stand alone ReLU in the future by adding clip_min and
# clip max attributes to the identity operator
serial_activation = SerialActivation(op=attrs["activation"], clip_min=0, clip_max=0)
Expand All @@ -152,4 +162,5 @@ def get_identity_params(
),
output_pointer,
replace_pointer,
is_allocator,
)
Loading

0 comments on commit 5c3d5c3

Please sign in to comment.