From 5c3d5c34b006fb3985530936aa149850b029fbe1 Mon Sep 17 00:00:00 2001 From: Elen Kalda Date: Fri, 24 Sep 2021 11:38:13 +0100 Subject: [PATCH] [miroNPU] Add support for TFLite concatenate * Add legalization pass and is_valid checks for concatenate * Add TIR pass for removing concatenates and replacing them with direct writes to the final buffer * Add tests Co-authored-by: Matthew Barrett --- .../relay/backend/contrib/ethosu/legalize.py | 43 ++++ .../contrib/ethosu/tir/binary_elementwise.py | 6 +- .../backend/contrib/ethosu/tir/compiler.py | 10 +- .../backend/contrib/ethosu/tir/convolution.py | 5 +- .../backend/contrib/ethosu/tir/depthwise.py | 5 +- .../relay/backend/contrib/ethosu/tir/dma.py | 14 +- .../backend/contrib/ethosu/tir/identity.py | 17 +- .../backend/contrib/ethosu/tir/passes.py | 196 ++++++++++++++++-- .../backend/contrib/ethosu/tir/pooling.py | 3 +- .../backend/contrib/ethosu/tir/scheduler.py | 2 + .../backend/contrib/ethosu/tir/transform.py | 1 + .../contrib/ethosu/tir/unary_elementwise.py | 3 +- python/tvm/relay/op/contrib/ethosu.py | 55 ++++- .../contrib/test_ethosu/test_codegen.py | 73 +++++++ .../contrib/test_ethosu/test_legalize.py | 80 +++++++ .../test_ethosu/test_remove_concatenates.py | 80 +++++++ 16 files changed, 563 insertions(+), 30 deletions(-) create mode 100644 tests/python/contrib/test_ethosu/test_remove_concatenates.py diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 8f2dddbf88a6e..6c7a7d8b51e75 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -890,6 +890,48 @@ def __call__(self, *args, **kwargs): pass +class ConcatRewriter(DFPatternCallback): + """The newer versions of TFLite converters return a concatenate operator that concatenates + tensors with same QNN params (if the QNN params of tensors were initially different, + the converter adds a requantize node), so this rewriter replaces the QNN concatenate with + "normal" concatenate""" + + def __init__(self): + super().__init__(require_type=True, rewrite_once=True) + self.pattern = ( + wildcard().has_attr({"Composite": ethosu_patterns.ConcatParams.composite_name}) + )(None) + + def callback( + self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map + ) -> tvm.relay.Expr: + # Find the tensors that are inputs to the concat and the scales and zero points + concat_args = list() + for arg in post.args: + if isinstance(arg, tvm.relay.expr.Call): + concat_args.append(arg) + + axis = post.op.body.attrs.axis + concat = relay.op.concatenate(relay.Tuple(concat_args), axis=axis) + return concat + + +@ir.transform.module_pass(opt_level=1) +class LegalizeConcat: + """This is the pass that wraps ConcatRewriter""" + + def transform_module( + self, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext + ) -> tvm.ir.IRModule: + for global_var, func in mod.functions.items(): + func = rewrite(ConcatRewriter(), func) + mod.update_func(global_var, func) + return mod + + def __call__(self, *args, **kwargs): + pass + + @ir.transform.module_pass(opt_level=1) class LegalizeEthosU: """This is the pass to call graph-rewrites to perform graph transformation @@ -915,6 +957,7 @@ def transform_module( mod = LegalizeMax()(mod) mod = LegalizeShl()(mod) mod = LegalizeAbs()(mod) + mod = LegalizeConcat()(mod) mod = LegalizeReshape()(mod) mod = LegalizeStridedSlice()(mod) mod = LegalizeNoOps()(mod) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py index 31d448e5cd7d9..53b46aeafbf53 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/binary_elementwise.py @@ -68,6 +68,9 @@ def get_binary_elementwise_params( replace_pointer : tvm.tir.Var The output pointer of the DMA write operation, which is to replace the binary elementwise output pointer. + is_allocator : bool + Whether this operator allocates its output. + """ attrs, body = get_op_attrs(stmt) reversed_operands = attrs["reversed_operands"] @@ -83,7 +86,7 @@ def get_binary_elementwise_params( # Get feature map info serial_ifm, _ = get_ifm_params(input_pointer, producers) serial_ifm2, _ = get_ifm_params(input_pointer1, producers) - serial_ofm, replace_pointer = get_ofm_params(output_pointer, consumers) + serial_ofm, replace_pointer, is_allocator = get_ofm_params(output_pointer, consumers, producers) # Get activation info serial_activation = SerialActivation( op=attrs["activation"], clip_min=attrs["clip_min"], clip_max=attrs["clip_max"] @@ -100,4 +103,5 @@ def get_binary_elementwise_params( ), output_pointer, replace_pointer, + is_allocator, ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py index b68a5ad14a6f7..e9fcf4927f1ed 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py @@ -21,7 +21,7 @@ from tvm.relay.expr_functor import ExprMutator from tvm.driver.build_module import schedule_to_module -from .passes import ReplaceOperators, RemoveZeroStores, EncodeConstants +from .passes import ReplaceOperators, RemoveZeroStores, EncodeConstants, RemoveConcatenates from .scheduler import schedule @@ -76,6 +76,7 @@ def lower_ethosu(sch, args, const_dict, name="main"): mod = schedule_to_module(sch, args, name) mod = tvm.tir.transform.Simplify()(mod) + mod = RemoveConcatenates()(mod) mod = tvm.tir.transform.StorageFlatten(64)(mod) mod = tvm.tir.transform.UnrollLoop()(mod) mod = tvm.tir.transform.Simplify()(mod) @@ -117,19 +118,22 @@ class ExtractConstants(ExprMutator): def __init__(self): super().__init__() self.constants = [] + self.const_vars = [] def visit_constant(self, const): if isinstance(const.checked_type, relay.ty.TensorType): if const.checked_type.concrete_shape != (): self.constants.append(const.data.asnumpy()) name = "p" + str(len(self.constants)) - return relay.var(type_annotation=const.checked_type, name_hint=name) + var = relay.var(type_annotation=const.checked_type, name_hint=name) + self.const_vars.append(var) + return var return const def visit_function(self, fn): new_body = self.visit(fn.body) - new_params = list(relay.analysis.free_vars(new_body)) + new_params = list(fn.params) + self.const_vars return relay.Function(new_params, new_body) def extract_constants(self, func): diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py b/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py index 5e8ea002783f7..b783486271fce 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py @@ -46,6 +46,8 @@ def get_conv2d_params(stmt, producers, consumers): replace_pointer : tvm.tir.Var The output pointer of the DMA write operation, which is to replace the convolution output pointer. + is_allocator : bool + Whether this operator allocates its output. """ attrs, body = get_op_attrs(stmt) @@ -61,7 +63,7 @@ def get_conv2d_params(stmt, producers, consumers): output_pointer = stores[0].buffer_var # Get feature map info serial_ifm, serial_padding = get_ifm_params(input_pointer, producers) - serial_ofm, replace_pointer = get_ofm_params(output_pointer, consumers) + serial_ofm, replace_pointer, is_allocator = get_ofm_params(output_pointer, consumers, producers) # Get kernel info serial_kernel = SerialKernel( width=int(rw.extent), @@ -104,4 +106,5 @@ def get_conv2d_params(stmt, producers, consumers): ), output_pointer, replace_pointer, + is_allocator, ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py b/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py index 9db94b7be76f9..b1a4ebd82a880 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/depthwise.py @@ -56,6 +56,8 @@ def get_depthwise_conv2d_params( replace_pointer : tvm.tir.Var The output pointer of the DMA write operation, which is to replace the convolution output pointer. + is_allocator : bool + Whether this operator allocates its output. """ attrs, body = get_op_attrs(stmt) @@ -70,7 +72,7 @@ def get_depthwise_conv2d_params( output_pointer = stores[0].buffer_var # Get feature map info serial_ifm, serial_padding = get_ifm_params(input_pointer, producers) - serial_ofm, replace_pointer = get_ofm_params(output_pointer, consumers) + serial_ofm, replace_pointer, is_allocator = get_ofm_params(output_pointer, consumers, producers) # Get kernel info serial_kernel = SerialKernel( width=int(rw.extent), @@ -114,4 +116,5 @@ def get_depthwise_conv2d_params( ), output_pointer, replace_pointer, + is_allocator, ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/dma.py b/python/tvm/relay/backend/contrib/ethosu/tir/dma.py index 46df20814eb5b..7670c5d2f7b6d 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/dma.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/dma.py @@ -272,7 +272,7 @@ def get_ifm_params(pointer, producers): return serial_ifm, serial_padding -def get_ofm_params(pointer, consumers): +def get_ofm_params(pointer, consumers, producers): """Get the parameters associated with the DMA capabilities for an OFM. Parameters @@ -282,6 +282,9 @@ def get_ofm_params(pointer, consumers): consumers : dict of tvm.tir.Var to tvm.tir.AttrStmt A dictionary to associate pointers with the loop nest that consumes their values. + producers : dict of tvm.tir.Var to tvm.tir.AttrStmt + A dictionary to associate pointers with the loop nest + that produces their values. Returns ------- @@ -289,11 +292,18 @@ def get_ofm_params(pointer, consumers): The serializable OFM. output_pointer : tvm.tir.Var The pointer that the OFM DMA pipeline produces. + is_allocator : bool + Whether this operator allocates its output. """ convert_to_nhcwb16 = consumers[pointer] out_channels, _, output_pointer = get_convert_to_nhcwb16_params(convert_to_nhcwb16) write = consumers[output_pointer] serial_ofm, _, output_pointer = get_write_params(write) + is_allocator = True + if output_pointer not in producers: + is_allocator = False + elif producers[output_pointer] != write: + is_allocator = False serial_ofm.channels = out_channels - return serial_ofm, output_pointer + return serial_ofm, output_pointer, is_allocator diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/identity.py b/python/tvm/relay/backend/contrib/ethosu/tir/identity.py index 7a81a702f0196..e9ce7f0e27f63 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/identity.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/identity.py @@ -57,12 +57,14 @@ def _get_feature_map(stmt: tvm.tir.AttrStmt, fm_type: str) -> Tuple[SerialFeatur assert loops[0].extent == 1 loops = loops[1:] + fm_inner = inner.value if fm_type == "ifm" else inner + stride_vars = [l.loop_var for l in loops] - strides = get_strides(inner.value.index, stride_vars) + strides = get_strides(fm_inner.index, stride_vars) - base_address = get_base_address(inner.value.index) + base_address = get_base_address(fm_inner.index) data_type = inner.buffer_var.type_annotation.element_type.dtype - pointer = inner.value.buffer_var if fm_type == "ifm" else inner.buffer_var + pointer = fm_inner.buffer_var serial_feature_map = SerialFeatureMap( data_type=data_type, @@ -116,6 +118,8 @@ def get_identity_params( replace_pointer : tvm.tir.Var The output pointer of the DMA write operation, which is to replace the pooling output pointer. + is_allocator : bool + Whether this operator allocates its output. """ attrs, _ = get_op_attrs(stmt) @@ -134,6 +138,12 @@ def get_identity_params( replace_pointer = write_output_pointer + is_allocator = True + if write_output_pointer not in producers: + is_allocator = False + elif producers[write_output_pointer] != write: + is_allocator = False + # TODO: We might want to support stand alone ReLU in the future by adding clip_min and # clip max attributes to the identity operator serial_activation = SerialActivation(op=attrs["activation"], clip_min=0, clip_max=0) @@ -152,4 +162,5 @@ def get_identity_params( ), output_pointer, replace_pointer, + is_allocator, ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index cb46ba319edd3..c7309a9fd98a8 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name, unused-argument +# pylint: disable=invalid-name, unused-argument, no-else-return, inconsistent-return-statements """The TIR passes to be run on Arm(R) Ethos(TM)-U NPU TIR Compiler.""" +from collections import namedtuple import numpy as np # type: ignore import tvm @@ -68,6 +69,8 @@ def ReplaceOperators(): replace_output_pointer = {} pointer_to_extents = {} + ReplaceInfo = namedtuple("ReplaceInfo", ["pointer", "reallocate"]) + def _resolve_pointers(stmt): """This pass determines information about the pointers present in the IR. In particular, it associates pointers with both the operations that @@ -116,11 +119,13 @@ def _replace_operator(stmt): if stmt.attr_key == "pragma_op" and op_name in op_map: # Get the parameters for the extern call param_func = op_map[op_name] - info, output_pointer, replace_pointer = param_func( + info, output_pointer, replace_pointer, is_allocator = param_func( stmt, pointer_to_producer, pointer_to_consumer ) if replace_pointer is not None: - replace_output_pointer[output_pointer] = replace_pointer + replace_output_pointer[output_pointer] = ReplaceInfo( + replace_pointer, is_allocator + ) # Make the extern call irb = tvm.tir.ir_builder.create() irb.emit(tvm.tir.call_extern("handle", op_name, *info)) @@ -164,25 +169,17 @@ def _replace_pointers(stmt): if isinstance(stmt, tvm.tir.AttrStmt): # If the attribute references a pointer that needs replacing if stmt.node in replace_output_pointer: - replace_pointer = replace_output_pointer[stmt.node] - # If the pointer doesn't have an extent registered to it, - # this means the pointer is to a Buffer. In this case, we - # just want to delete the memory scope attribute - if replace_pointer not in pointer_to_extents: + replace_pointer, reallocate = replace_output_pointer[stmt.node] + if not reallocate: return stmt.body # Otherwise, rewrite the memory scope attribute with the new pointer - return tvm.tir.AttrStmt( - replace_output_pointer[stmt.node], stmt.attr_key, stmt.value, stmt.body - ) + return tvm.tir.AttrStmt(replace_pointer, stmt.attr_key, stmt.value, stmt.body) if isinstance(stmt, tvm.tir.Allocate): # If the allocate allocates a pointer that needs replacing if stmt.buffer_var in replace_output_pointer: - replace_pointer = replace_output_pointer[stmt.buffer_var] - # If the pointer doesn't have an extent registered to it, - # this means the pointer is to a Buffer. In this case, we - # just want to delete the allocation statement - if replace_pointer not in pointer_to_extents: + replace_pointer, reallocate = replace_output_pointer[stmt.buffer_var] + if not reallocate: return stmt.body # Otherwise, rewrite the allocation statement with the new pointer # and the new extent @@ -488,3 +485,170 @@ def _encode_constants(mod): return new_func, new_const_dict return _encode_constants + + +def RemoveConcatenates(): + """Remove concatenate operators by modifying the input buffers to write directly into + the concatenated buffer with the appropriate offset. + + This pass works in two stages. The first finds every concatenate operation (marked by + pragma_op = ethosu_concatenate) and it performs the following analysis. For each buffer + that is concatenated, the buffer is marked that it is to be replaced with the concat + buffer and the axis along which it is concatenated as well as the offset along that + axis is recorded in 'ReplaceInfo'. Once this analysis is completed, the concatenate + loop nest along with its buffer realization statements are removed. + + In the second stage, the input buffers to the concatenate operators are rewritten + to use the concat buffer directly. This means applying the correct offset to the + concatenation axis where ever the buffer is loaded or stored. Additionally, as the + realization statements for the concat buffers were removed in the first stage, they + are rewritten in place of the input buffer realization with the earliest liveness.""" + + in_concat = [False] # Whether the visitor is currently inside a concatenate operator + concat_buffers = [] # The buffers produced by concatenate operators + buffer_replace_map = {} # A map of buffers to be replaced with the concat buffer + attrs_by_buffer = {} # AttrStmts by the buffer they reference + realizes_by_buffer = {} # BufferRealize statements by the buffer they reference + first_replacements = {} # The first buffers to be replaced by a given concat buffer + + ReplaceInfo = namedtuple("ReplaceInfo", ["buffer", "axis", "offset"]) + + def _get_replace_info(buffer_load, concat_buffer): + axis = 0 + offset = 0 + dmap = dict() + + for i, index in enumerate(buffer_load.indices): + if isinstance(index, tvm.tir.Sub): + axis = i + dmap = {} + + def _visit(stmt): + if isinstance(stmt, tvm.tir.Var): + dmap[stmt] = tvm.arith.IntervalSet(0, 0) + + tvm.tir.stmt_functor.post_order_visit(index, _visit) + offset = abs(int(tvm.arith.Analyzer().int_set(index, dmap).max_value)) + return ReplaceInfo(concat_buffer, axis, offset) + + def _pre_remove(stmt): + if isinstance(stmt, tvm.tir.BufferRealize): + # Record the realize statements by buffer as we need to hoist some of these + realizes_by_buffer[stmt.buffer] = stmt + if isinstance(stmt, tvm.tir.AttrStmt): + if stmt.attr_key == "realize_scope" and isinstance(stmt.node, tvm.tir.Buffer): + # Record the realize_scope attrs by buffer as we need to hoist some of these + attrs_by_buffer[stmt.node] = stmt + if stmt.attr_key == "pragma_op" and stmt.value.value == "ethosu_concatenate": + # Record that we're entering a concatenate loop nest + in_concat[0] = True + if isinstance(stmt, tvm.tir.BufferLoad) and in_concat[0]: + # Any buffer loaded inside a concat is a buffer we intend to replace with this pass. + # The buffer_replace_map keeps track of which buffers need replacing with the + # concat buffer. + replace_info = _get_replace_info(stmt, concat_buffers[-1]) + buffer_replace_map[stmt.buffer] = replace_info + if isinstance(stmt, tvm.tir.BufferStore) and in_concat[0]: + # If we're inside a concat, the BufferStore indicates what the concat buffer is + concat_buffers.append(stmt.buffer) + + def _post_remove(stmt): + if isinstance(stmt, tvm.tir.AttrStmt): + if isinstance(stmt.node, tvm.tir.Buffer) and stmt.node in concat_buffers: + return stmt.body + if stmt.attr_key == "pragma_op" and stmt.value.value == "ethosu_concatenate": + # When we leave a concatenate operator, record it and then remove the loop nest + in_concat[0] = False + return tvm.tir.Evaluate(0) + if isinstance(stmt, tvm.tir.BufferRealize): + if stmt.buffer in concat_buffers: + return stmt.body + return None + + def _pre_replace(stmt): + if isinstance(stmt, (tvm.tir.BufferLoad, tvm.tir.BufferStore)): + # The first buffer referenced that needs replacing with a concat buffer shall + # be the one that the concat buffer realize is hoisted to. + if stmt.buffer in buffer_replace_map: + concat_buffer = buffer_replace_map[stmt.buffer].buffer + if concat_buffer not in first_replacements: + first_replacements[concat_buffer] = stmt.buffer + + def _post_replace(stmt): + if isinstance(stmt, tvm.tir.BufferStore): + if stmt.buffer in buffer_replace_map: + # Replace the original buffer store with a new one into the concat buffer + # and adjust the indices accordingly to account for the offset + replace_info = buffer_replace_map[stmt.buffer] + concat_buffer = replace_info.buffer + new_indices = list(stmt.indices) + new_indices[replace_info.axis] += replace_info.offset + # DODGY STORE NODE + new_store = tvm.tir.BufferStore(concat_buffer, stmt.value, new_indices, stmt.span) + return new_store + if isinstance(stmt, tvm.tir.BufferLoad): + if stmt.buffer in buffer_replace_map: + # Replace the original buffer load with a new one into the concat buffer + # and adjust the indices accordingly to account for the offset + replace_info = buffer_replace_map[stmt.buffer] + concat_buffer = replace_info.buffer + new_indices = list(stmt.indices) + new_indices[replace_info.axis] += replace_info.offset + new_load = tvm.tir.BufferLoad(concat_buffer, new_indices, stmt.span) + return new_load + if isinstance(stmt, tvm.tir.BufferRealize): + if stmt.buffer in buffer_replace_map: + concat_buffer = buffer_replace_map[stmt.buffer].buffer + # If this isn't the first buffer replaced, don't hoist the realize + if first_replacements[concat_buffer] != stmt.buffer: + return stmt.body + # Otherwise, do hoist it + else: + concat_realize = realizes_by_buffer[concat_buffer] + new_realize = tvm.tir.BufferRealize( + concat_realize.buffer, + concat_realize.bounds, + concat_realize.condition, + stmt.body, + stmt.span, + ) + return new_realize + if isinstance(stmt, tvm.tir.AttrStmt): + if isinstance(stmt.node, tvm.tir.Buffer) and stmt.node in buffer_replace_map: + concat_buffer = buffer_replace_map[stmt.node].buffer + # If this isn't the first buffer replaced, don't hoist the attrstmt + if first_replacements[concat_buffer] != stmt.node: + return stmt.body + # Otherwise, do hoist it + else: + concat_attr = attrs_by_buffer[concat_buffer] + new_attr = tvm.tir.AttrStmt( + concat_attr.node, + concat_attr.attr_key, + concat_attr.value, + stmt.body, + stmt.span, + ) + return new_attr + + def _ftransform(f, mod, ctx): + f = f.with_body( + tvm.tir.stmt_functor.ir_transform( + f.body, + _pre_remove, + _post_remove, + ["tir.AttrStmt", "tir.BufferLoad", "tir.BufferStore", "tir.BufferRealize"], + ) + ) + return f.with_body( + tvm.tir.stmt_functor.ir_transform( + f.body, + _pre_replace, + _post_replace, + ["tir.AttrStmt", "tir.BufferLoad", "tir.BufferStore", "tir.BufferRealize"], + ) + ) + + return tvm.tir.transform.prim_func_pass( + _ftransform, opt_level=0, name="tir.ethosu.remove_concatenates" + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py b/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py index 33dcb36fbbb6d..1572664b52cef 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/pooling.py @@ -60,7 +60,7 @@ def get_pooling_params( output_pointer = rw.body.buffer_var # Get feature map info serial_ifm, serial_padding = get_ifm_params(input_pointer, producers) - serial_ofm, replace_pointer = get_ofm_params(output_pointer, consumers) + serial_ofm, replace_pointer, is_allocator = get_ofm_params(output_pointer, consumers, producers) # Get kernel info serial_kernel = SerialKernel( width=int(rw.extent), @@ -88,4 +88,5 @@ def get_pooling_params( ), output_pointer, replace_pointer, + is_allocator, ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py b/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py index 7f892d0c602ae..73d9531781f79 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py @@ -162,6 +162,8 @@ def schedule_pragmas(sch): """ def _add_pragmas(stage, ax): + if stage.op.name == "T_concat": + stage.pragma(ax, "op", "ethosu_concatenate") if "op" in [attr for attr, val in stage.op.attrs.items()]: stage.pragma(ax, "op", stage.op.attrs["op"]) for attr, val in stage.op.attrs.items(): diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/transform.py b/python/tvm/relay/backend/contrib/ethosu/tir/transform.py index f50975c83838e..83970a8fe4c59 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/transform.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/transform.py @@ -58,4 +58,5 @@ def get_copy_params(stmt, producers, consumers): ), write_store.buffer_var, None, + True, ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py b/python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py index 6dc801f2b28cd..bfc7430dba2e0 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/unary_elementwise.py @@ -56,7 +56,7 @@ def get_unary_elementwise_params(stmt, producers, consumers): output_pointer = inner.buffer_var # Get feature map info serial_ifm, _ = get_ifm_params(input_pointer, producers) - serial_ofm, replace_pointer = get_ofm_params(output_pointer, consumers) + serial_ofm, replace_pointer, is_allocator = get_ofm_params(output_pointer, consumers, producers) # Get activation info serial_activation = SerialActivation( op=attrs["activation"], clip_min=attrs["clip_min"], clip_max=attrs["clip_max"] @@ -71,4 +71,5 @@ def get_unary_elementwise_params(stmt, producers, consumers): ), output_pointer, replace_pointer, + is_allocator, ) diff --git a/python/tvm/relay/op/contrib/ethosu.py b/python/tvm/relay/op/contrib/ethosu.py index 73de3329c45f8..7d49c76f7ba0d 100644 --- a/python/tvm/relay/op/contrib/ethosu.py +++ b/python/tvm/relay/op/contrib/ethosu.py @@ -25,7 +25,7 @@ from tvm import relay from tvm.relay.expr import Constant, Call # type: ignore from tvm.relay.op.contrib.register import register_pattern_table # type: ignore -from tvm.relay.dataflow_pattern import wildcard, is_op, is_constant # type: ignore +from tvm.relay.dataflow_pattern import wildcard, is_op, is_constant, is_tuple # type: ignore from tvm.relay.build_module import bind_params_by_name # type: ignore try: @@ -915,6 +915,58 @@ def abs_pattern() -> tvm.relay.dataflow_pattern.DFPattern: return pattern +class ConcatParams: + """ + A class to extract and store parameters of Concatenate in a Ethos-U friendly way + """ + + composite_name = "ethos-u.concat" + + def __init__(self, func_body): + self.concat = func_body + self.input_tensors = [TensorParams(tensor) for tensor in list(func_body.args[0])] + self.input_scales = [s.data.asnumpy() for s in list(func_body.args[1])] + self.input_zero_points = [zp.data.asnumpy() for zp in list(func_body.args[2])] + self.axis = func_body.attrs.axis + + def is_valid(self): + """Checks whether Concatenate has compatible attributes with the hardware""" + if not check_valid_dtypes(self.input_tensors, supported_dtypes=[np.int8]): + return False + # Check that the scales and zero points of input tensors are the same + if not all(self.input_scales == self.input_scales[0]): + return False + if not all(self.input_zero_points == self.input_zero_points[0]): + return False + + input_dim = len(self.input_tensors[0].shape) + for tensor in self.input_tensors: + if len(tensor.shape) != input_dim: + return False + + if self.axis is None: + return False + if self.axis < 0: + return False + if self.axis >= input_dim: + return False + + output_shape = self.concat.checked_type.shape + if len(output_shape) != input_dim: + return False + return True + + +def concat_pattern(): + """Create pattern for concat""" + tensors = is_tuple(None) + scales = is_tuple(None) + zero_points = is_tuple(None) + concat = is_op("qnn.concatenate")(tensors, scales, zero_points, is_constant(), is_constant()) + optional_clip = concat.optional(is_op("clip")) + return optional_clip + + @register_pattern_table("ethos-u") def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]]: return [ @@ -983,6 +1035,7 @@ def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Cal abs_pattern(), lambda pat: AbsParams(pat).is_valid(), ), + (ConcatParams.composite_name, concat_pattern(), lambda pat: ConcatParams(pat).is_valid()), ] diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 92a1ad71deda1..847864cd93bce 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -969,5 +969,78 @@ def create_graph_single(input_tensor_name, input_tensor_shape, input_tensor_dtyp assert '__attribute__((section(".rodata.tvm"), aligned(16))) static int8_t weights' in source +@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +@pytest.mark.parametrize( + "shapes, axis", + [ + ([(2, 3), (4, 3)], 0), + ([(3, 2, 1), (3, 1, 1)], 1), + ([(10,), (13,), (14,)], 0), + ([(1, 5, 2, 1), (1, 5, 7, 1), (1, 5, 3, 1)], 2), + ], +) +def test_tflite_concat_codegen(shapes, axis, accel_type): + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_function(self, shapes, axis): + op = tf.concat(shapes, axis) + return op + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + [tf.TensorSpec(shape, tf.float32) for shape in shapes], axis + ) + + def representative_dataset(): + for _ in range(100): + datas = [np.random.rand(*shape) for shape in shapes] + yield [data.astype(np.float32) for data in datas] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + + return tflite_model + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + relay_module, params = relay.frontend.from_tflite( + tflite_model, + shape_dict={("ifm" + str(i)): shape for i, shape in enumerate(shapes)}, + dtype_dict={("ifm" + str(i)): "int8" for i, _ in enumerate(shapes)}, + ) + + mod = partition_for_ethosu(relay_module, params) + + # Generate reference data + input_data, output_data = infra.generate_ref_data_tflite(tflite_graph) + + compiled_models = infra.build_source( + mod, + input_data, + output_data, + accel_type, + ) + + # Assumes only two runtime.Modules are created -- i.e. single offload module + imported_modules = compiled_models[0].executor_factory.lib.imported_modules + assert len(imported_modules) == 2 + ethosu_module = imported_modules[0] + + # Verify generated C source + get_cs = tvm._ffi.get_global_func("runtime.module.ethos-u.getcs") + cmms = get_cs(ethosu_module) + cmms = bytes.fromhex(cmms) + + infra.print_payload(cmms) + infra.verify_source(compiled_models, accel_type) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_legalize.py b/tests/python/contrib/test_ethosu/test_legalize.py index dbe11cd2d7ad7..a9b2478e45caf 100644 --- a/tests/python/contrib/test_ethosu/test_legalize.py +++ b/tests/python/contrib/test_ethosu/test_legalize.py @@ -1007,5 +1007,85 @@ def verify(ext_func): verify(mod["tvmgen_default_ethos_u_main_0"]) +@pytest.mark.parametrize( + "shapes, axis", + [ + ([(2, 3), (4, 3)], 0), + ([(10, 2, 1), (10, 14, 1)], 1), + ([(10,), (13,), (14,)], 0), + ([(1, 5, 2, 1), (1, 5, 7, 1), (1, 5, 3, 1)], 2), + ], +) +def test_tflite_concat_legalize(shapes, axis): + def create_tflite_graph(): + class Model(tf.Module): + @tf.function + def tf_function(self, shapes, axis): + op = tf.concat(shapes, axis) + return op + + model = Model() + concrete_func = model.tf_function.get_concrete_function( + [tf.TensorSpec(shape, tf.float32) for shape in shapes], axis + ) + + def representative_dataset(): + for _ in range(100): + datas = [np.random.rand(*shape) for shape in shapes] + yield [data.astype(np.float32) for data in datas] + + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) + converter.optimizations = [tf.lite.Optimize.DEFAULT] + converter.representative_dataset = representative_dataset + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + converter.inference_input_type = tf.int8 + converter.inference_output_type = tf.int8 + tflite_model = converter.convert() + + return tflite_model + + def verify(ext_func): + new_concat_axis = np.sum(shape[axis] for shape in shapes) + out_shape = list(shapes[0]) + out_shape[axis] = new_concat_axis + + op = ext_func.body + for i, _ in enumerate(shapes): + assert list(op.args[0][i].checked_type.shape) == list(shapes[i]) + + assert list(op.checked_type.shape) == out_shape + assert op.checked_type.dtype == "int8" + + concat_pattern_table = [ + ( + ethosu.ConcatParams.composite_name, + ethosu.concat_pattern(), + lambda pat: ethosu.ConcatParams(pat).is_valid(), + ) + ] + + tflite_graph = create_tflite_graph() + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) + + relay_module, _ = relay.frontend.from_tflite( + tflite_model, + shape_dict={("ifm" + str(i)): shape for i, shape in enumerate(shapes)}, + dtype_dict={("ifm" + str(i)): "int8" for i, _ in enumerate(shapes)}, + ) + mod = partition_ethosu_by_table(relay_module, concat_pattern_table) + + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + legalize.ConcatRewriter(), mod["tvmgen_default_ethos_u_main_0"] + ) + mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( + legalize.NoOpRewriter(), mod["tvmgen_default_ethos_u_main_0"] + ) + mod["tvmgen_default_ethos_u_main_0"] = relay.transform.InferType()(mod)[ + "tvmgen_default_ethos_u_main_0" + ] + + verify(mod["tvmgen_default_ethos_u_main_0"]) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_remove_concatenates.py b/tests/python/contrib/test_ethosu/test_remove_concatenates.py new file mode 100644 index 0000000000000..8a3ad602fd3b3 --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_remove_concatenates.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.script +from tvm.script import tir +from tvm import relay +from tvm.relay.testing import run_opt_pass +from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir +from .infra import make_ethosu_conv2d + +import pytest + +# fmt: off +@tvm.script.ir_module +class ReferenceModule: + @tir.prim_func + def main(placeholder: tir.handle, placeholder_1: tir.handle, placeholder_2: tir.handle, placeholder_3: tir.handle, placeholder_4: tir.handle, placeholder_5: tir.handle, placeholder_6: tir.handle, placeholder_7: tir.handle, placeholder_8: tir.handle, placeholder_9: tir.handle, T_concat: tir.handle) -> None: + # function attr dict + tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + buffer = tir.match_buffer(placeholder_2, [2992], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = tir.match_buffer(placeholder_4, [2992], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_10 = tir.match_buffer(placeholder_1, [1, 8, 10, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_2 = tir.match_buffer(placeholder_9, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_3 = tir.match_buffer(placeholder_8, [2992], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_4 = tir.match_buffer(placeholder_5, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_5 = tir.match_buffer(placeholder_6, [2992], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + T_concat_1 = tir.match_buffer(T_concat, [1, 8, 32, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + placeholder_11 = tir.match_buffer(placeholder, [1, 8, 12, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_6 = tir.match_buffer(placeholder_7, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_7 = tir.match_buffer(placeholder_3, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + T_concat_2 = tir.allocate([2816], "int8", "global") + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, tir.load("int8", placeholder_10.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 160, 16, 1, "int8", 8, 10, 16, 8, 0, 10, tir.load("int8", T_concat_2, 192), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_1.data, 0), 2992, 12, tir.load("uint8", buffer_4.data, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, tir.load("int8", T_concat_2, 192), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 352, 16, 1, "int8", 8, 10, 16, 8, 0, 10, tir.load("int8", T_concat_1.data, 352), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 512, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_3.data, 0), 2992, 12, tir.load("uint8", buffer_2.data, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 8, 12, 16, 8, 0, 12, tir.load("int8", placeholder_11.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 192, 16, 1, "int8", 8, 12, 16, 8, 0, 12, tir.load("int8", T_concat_2, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer.data, 0), 2992, 12, tir.load("uint8", buffer_7.data, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 8, 22, 16, 8, 0, 22, tir.load("int8", T_concat_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 352, 16, 1, "int8", 8, 22, 16, 8, 0, 22, tir.load("int8", T_concat_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 512, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_5.data, 0), 2992, 12, tir.load("uint8", buffer_6.data, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + __tvm_meta__ = None +# fmt: on + + +def test_concat(): + def _get_func(): + ifm1 = relay.var("ifm1", shape=(1, 8, 12, 16), dtype="int8") + ifm2 = relay.var("ifm2", shape=(1, 8, 10, 16), dtype="int8") + conv1 = make_ethosu_conv2d(ifm1, 16, 16, (3, 3), (1, 1), (1, 1), (1, 1), "NONE", "NHWC") + conv2 = make_ethosu_conv2d(ifm2, 16, 16, (3, 3), (1, 1), (1, 1), (1, 1), "NONE", "NHWC") + conc1 = relay.concatenate((conv1, conv2), axis=2) + conv3 = make_ethosu_conv2d(conc1, 16, 16, (3, 3), (1, 1), (1, 1), (1, 1), "NONE", "NHWC") + conv4 = make_ethosu_conv2d(conv2, 16, 16, (3, 3), (1, 1), (1, 1), (1, 1), "NONE", "NHWC") + conc2 = relay.concatenate((conv3, conv4), axis=2) + func = relay.Function(relay.analysis.free_vars(conc2), conc2) + func = run_opt_pass(func, relay.transform.InferType()) + return func + + func = _get_func() + mod, _ = lower_to_tir(func) + script = mod.script(show_meta=True) + test_mod = tvm.script.from_source(script) + + # script = tvm.script.asscript(mod, True) + reference_mod = ReferenceModule + tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) + + +if __name__ == "__main__": + pytest.main([__file__])