From 5f302f3af39174d930609e4427de94c970de9441 Mon Sep 17 00:00:00 2001 From: Elen Kalda Date: Mon, 31 Jan 2022 09:34:05 +0000 Subject: [PATCH] [microNPU] Add support for conv2d running on two cores on U65 The 512 mac variant has two cores that processes the weights in parallel, so we need to split the weights and biases into two and encode them separately. Change-Id: I53791f614288ac4df181b9462fc632d35b934a86 --- .../relay/backend/contrib/ethosu/legalize.py | 4 +- .../backend/contrib/ethosu/tir/convolution.py | 71 +++++++++-- .../backend/contrib/ethosu/tir/passes.py | 115 ++++++++++++++---- .../relay/backend/contrib/ethosu/tir/spec.py | 4 + .../relay/backend/contrib/ethosu/tir/utils.py | 18 --- .../contrib/ethosu/tir_to_cs_translator.py | 22 +++- .../relay/backend/contrib/ethosu/vela_api.py | 23 ++-- .../contrib/test_ethosu/test_codegen.py | 6 +- .../test_ethosu/test_encode_constants.py | 26 ++-- .../test_ethosu/test_remove_concatenates.py | 8 +- .../test_ethosu/test_replace_conv2d.py | 94 ++++++++------ .../contrib/test_ethosu/test_replace_copy.py | 6 +- .../contrib/test_ethosu/test_scheduler.py | 4 +- .../test_ethosu/test_tir_to_cs_translator.py | 34 +++--- 14 files changed, 293 insertions(+), 142 deletions(-) diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 64c6cefb8b58c..619cf3418501e 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -1151,7 +1151,7 @@ def callback( if axis == [1, 2] and params.keepdims: weight_scale = 1 - weight_values = np.ones([out_channels, filter_height, filter_width, in_channels]) + weight_values = np.ones([out_channels, filter_height, filter_width, 1]) scale_bias = vela_api.pack_biases( biases=np.zeros(ifm_shape[-1]), ifm_scale=params.ifm.q_params.scale_f32, @@ -1216,7 +1216,7 @@ def callback( ) else: weight_scale = 1 / (filter_height * filter_width) - weight_values = np.ones([out_channels, filter_height, filter_width, in_channels]) + weight_values = np.ones([out_channels, filter_height, filter_width, 1]) bias = -1 * int(params.ifm.q_params.zero_point) * filter_height * filter_width scale_bias = vela_api.pack_biases( diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py b/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py index 50c27cc016890..dcd0208e5540d 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py @@ -16,8 +16,10 @@ # under the License. # pylint: disable=invalid-name, unused-argument """Extract parameters from the convolution operators in TIR.""" +import math import tvm -from ..vela_api import SCALE_BIAS_LENGTH +from ethosu.vela import api as vapi +from ..vela_api import SCALE_BIAS_LENGTH, get_accelerator_config from .utils import get_outer_loops, get_op_attrs, get_base_address, get_loads, get_stores from .dma import get_ifm_params, get_ofm_params from .spec import SerialKernel, SerialAddressRange, SerialActivation, Serial2DConvolution @@ -50,6 +52,8 @@ def get_conv2d_params(stmt, producers, consumers): Whether this operator allocates its output. """ + accel_config = get_accelerator_config() + attrs, body = get_op_attrs(stmt) _, _, _, _, _, inner = get_outer_loops(body, "NHWC") rh = inner @@ -76,17 +80,64 @@ def get_conv2d_params(stmt, producers, consumers): # Get scale_bias info scale_bias_load = loads[3] scale_bias_base = get_base_address(scale_bias_load.index) - serial_scale_bias = SerialAddressRange( - address=tvm.tir.Load("uint8", scale_bias_load.buffer_var, scale_bias_base), - length=SCALE_BIAS_LENGTH * serial_ofm[3], - ) # Get weight info weight_load = loads[2] weight_base = get_base_address(weight_load.index) - serial_weight = SerialAddressRange( - address=tvm.tir.Load("uint8", weight_load.buffer_var, weight_base), - length=serial_ofm[3] * serial_kernel[0] * serial_kernel[1] * rc.extent, - ) + channels = serial_ofm[3] if isinstance(serial_ofm[3], int) else serial_ofm[3].value + + if accel_config == vapi.NpuAccelerator.Ethos_U65_512: + scale_bias_length = SCALE_BIAS_LENGTH * math.ceil(channels / 2) + scale_bias2_length = SCALE_BIAS_LENGTH * math.floor(channels / 2) + + serial_scale_bias = SerialAddressRange( + address=tvm.tir.Load("uint8", scale_bias_load.buffer_var, scale_bias_base), + length=scale_bias_length, + ) + serial_scale_bias2 = SerialAddressRange( + address=tvm.tir.Load( + "uint8", scale_bias_load.buffer_var, scale_bias_base + scale_bias_length + ), + length=scale_bias2_length, + ) + + weight_length = ( + channels * serial_kernel[0] * serial_kernel[1] * math.ceil(rc.extent.value / 2) + ) + weight2_length = ( + channels * serial_kernel[0] * serial_kernel[1] * math.floor(rc.extent.value / 2) + ) + + serial_weight = SerialAddressRange( + address=tvm.tir.Load("uint8", weight_load.buffer_var, weight_base), + length=weight_length, + ) + serial_weight2 = SerialAddressRange( + address=tvm.tir.Load("uint8", weight_load.buffer_var, weight_base + weight_length), + length=weight2_length, + ) + else: + scale_bias_length = SCALE_BIAS_LENGTH * channels + + serial_scale_bias = SerialAddressRange( + address=tvm.tir.Load("uint8", scale_bias_load.buffer_var, scale_bias_base), + length=scale_bias_length, + ) + # Insert -1s into the spec to denote the absence of the other pointer + serial_scale_bias2 = SerialAddressRange( + address=tvm.tir.IntImm("int8", -1), + length=tvm.tir.IntImm("int8", -1), + ) + + weight_length = channels * serial_kernel[0] * serial_kernel[1] * rc.extent.value + + serial_weight = SerialAddressRange( + address=tvm.tir.Load("uint8", weight_load.buffer_var, weight_base), + length=weight_length, + ) + serial_weight2 = SerialAddressRange( + address=tvm.tir.IntImm("int8", -1), + length=tvm.tir.IntImm("int8", -1), + ) # Get activation info serial_activation = SerialActivation( op=attrs["activation"], clip_min=attrs["clip_min"], clip_max=attrs["clip_max"] @@ -97,8 +148,10 @@ def get_conv2d_params(stmt, producers, consumers): ofm=serial_ofm, kernel=serial_kernel, weight=serial_weight, + weight2=serial_weight2, weight_zero_point=attrs["weight_zero_point"], scale_bias=serial_scale_bias, + scale_bias2=serial_scale_bias2, padding=serial_padding, activation=serial_activation, rounding_mode=attrs["rounding_mode"], diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index c2fff8abb9b04..a610eed5d61a8 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -14,13 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name, unused-argument, no-else-return, inconsistent-return-statements +# pylint: disable=invalid-name, unused-argument, no-else-return, inconsistent-return-statements, too-many-nested-blocks """The TIR passes to be run on Arm(R) Ethos(TM)-U NPU TIR Compiler.""" from collections import namedtuple import numpy as np # type: ignore import tvm from tvm.relay.backend.contrib.ethosu import vela_api +from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator as tirtocs +from ethosu.vela import api as vapi from .convolution import get_conv2d_params from .depthwise import get_depthwise_conv2d_params from .pooling import get_pooling_params @@ -28,7 +30,6 @@ from .identity import get_identity_params from .unary_elementwise import get_unary_elementwise_params from .transform import get_copy_params -from .utils import get_weights_pointer, get_scale_bias_pointer def RemoveZeroStores(): @@ -306,6 +307,7 @@ def EncodeConstants(const_dict): pointer_to_buffer = {} rewrite_buffer = {} rewrite_pointer = {} + pointer_to_offset = {} accel_config = vela_api.get_accelerator_config() def _align_scale_bias(tir_extern_call, bias): @@ -338,11 +340,44 @@ def _new_buffer(old_buffer, new_value): rewrite_buffer[old_buffer] = new_buffer rewrite_pointer[old_buffer.data] = new_buffer.data + def _encode_weights_or_bias(ptr1, ptr2, stmt, encode_func): + """Encode the weights or align the bias either for one or two cores, + depending on the variant.""" + assert ptr1 in pointer_to_buffer + buffer = pointer_to_buffer[ptr1] + constant = buffer_to_const[buffer] + + # If we have just one core, encode the whole constant + if ptr2 is None: + new_const = encode_func(stmt, constant) + return new_const, len(new_const) + + # Assume OHWI + channels = constant.shape[0] + split_const = np.split(constant, channels, axis=0) + + const_list = [split_const[i] for i in range(channels) if i % 2 == 0] + const_to_encode = np.concatenate(const_list, axis=0) + + new_const = encode_func(stmt, const_to_encode) + new_const_length = len(new_const) + + # Encode half of the constant separately for the other core if it exists + assert ptr1.same_as(ptr2) + const2_list = [split_const[i] for i in range(channels) if i % 2 == 1] + const2_to_encode = np.concatenate(const2_list, axis=0) + + new_const2 = encode_func(stmt, const2_to_encode) + new_const = np.append(new_const, new_const2).astype("uint8") + + return new_const, new_const_length + def _visit_encode_pre(stmt): if isinstance(stmt, tvm.tir.Call): + op = str(stmt.args[0].value) # Handle copies as a special-case by propagating the buffer information # from the read to the write pointer. - if stmt.args[0] == "ethosu_copy": + if op == "ethosu_copy": read_pointer = stmt.args[1].buffer_var if read_pointer in pointer_to_buffer: write_pointer = stmt.args[3].buffer_var @@ -350,23 +385,46 @@ def _visit_encode_pre(stmt): assert stmt.args[3].index == 0 assert stmt.args[1].index == 0 pointer_to_buffer[write_pointer] = pointer_to_buffer[read_pointer] - else: + + ops_with_weights = { + "ethosu_conv2d": tirtocs.translate_ethosu_conv2d, + "ethosu_depthwise_conv2d": tirtocs.translate_ethosu_depthwise_conv2d, + } + if op in ops_with_weights.keys(): + npu_op, _ = ops_with_weights[op](stmt) + # Encode the weights - weights_pointer = get_weights_pointer(stmt) - if weights_pointer is not None: - assert weights_pointer in pointer_to_buffer - weights_buffer = pointer_to_buffer[weights_pointer] - weights_value = buffer_to_const[weights_buffer] - new_weights_value = _encode_weights(stmt, weights_value) - _new_buffer(weights_buffer, new_weights_value) - # Align the scale_bias to 16 bytes - scale_bias_pointer = get_scale_bias_pointer(stmt) - if scale_bias_pointer is not None: - assert scale_bias_pointer in pointer_to_buffer - scale_bias_buffer = pointer_to_buffer[scale_bias_pointer] - scale_bias_value = buffer_to_const[scale_bias_buffer] - new_scale_bias_value = _align_scale_bias(stmt, scale_bias_value) - _new_buffer(scale_bias_buffer, new_scale_bias_value) + weights_pointer = npu_op.weights[0].address.buffer_var + weights2_pointer = ( + npu_op.weights[1].address.buffer_var + if accel_config == vapi.NpuAccelerator.Ethos_U65_512 + else None + ) + + new_weights, new_weights_length = _encode_weights_or_bias( + weights_pointer, weights2_pointer, stmt, _encode_weights + ) + + weights_buffer = pointer_to_buffer[weights_pointer] + _new_buffer(weights_buffer, new_weights) + pointer_to_offset[weights_pointer] = new_weights_length + + # Align the bias(es) to 16 bit + scale_bias_pointer = npu_op.biases[0].address.buffer_var + scale_bias2_pointer = ( + npu_op.biases[1].address.buffer_var + if accel_config == vapi.NpuAccelerator.Ethos_U65_512 + else None + ) + + new_scale_bias, new_scale_bias_length = _encode_weights_or_bias( + scale_bias_pointer, scale_bias2_pointer, stmt, _align_scale_bias + ) + + scale_bias_buffer = pointer_to_buffer[scale_bias_pointer] + + _new_buffer(scale_bias_buffer, new_scale_bias) + pointer_to_offset[scale_bias_pointer] = new_scale_bias_length def _visit_encode_post(stmt): # Because encoding may change the data type (e.g. bias to uint8) and type information @@ -406,6 +464,14 @@ def _visit_rewrite(stmt): # Only rewrite the arguments of buffers that have been encoded if buffer in new_buffers: new_arg = np.prod(list(pointer_to_buffer[pointer].shape)) + if isinstance(stmt.args[i + 1], tvm.tir.Load): + if pointer.same_as(stmt.args[i + 1].buffer_var): + # we've got a pair of loads form the same buffer + new_arg = stmt.args[i + 1].index.value + elif isinstance(stmt.args[i - 3], tvm.tir.Load): + if pointer.same_as(stmt.args[i - 3].buffer_var): + new_arg = new_arg - load.index.value + new_args.append(new_arg) continue new_args.append(stmt.args[i]) @@ -433,10 +499,11 @@ def _visit_rewrite(stmt): load_pointer = stmt.buffer_var if load_pointer in rewrite_pointer: new_pointer = rewrite_pointer[load_pointer] + offset = stmt.index + if offset != 0: + offset = pointer_to_offset[load_pointer] element_type = new_pointer.type_annotation.element_type.dtype - return tvm.tir.Load( - element_type, new_pointer, stmt.index, stmt.predicate, stmt.span - ) + return tvm.tir.Load(element_type, new_pointer, offset, stmt.predicate, stmt.span) if isinstance(stmt, tvm.tir.AttrStmt): node_pointer = stmt.node if node_pointer in rewrite_pointer: @@ -448,7 +515,7 @@ def _visit_rewrite(stmt): def _ftransform(f, mod, ctx): for i, param in enumerate(f.params): if i in const_dict: - buffer_to_const[f.buffer_map[param]] = const_dict[i].flatten() + buffer_to_const[f.buffer_map[param]] = const_dict[i] pointer_to_buffer[f.buffer_map[param].data] = f.buffer_map[param] # First analyse what needs to be rewritten @@ -469,7 +536,7 @@ def _ftransform(f, mod, ctx): new_value = buffer_to_const[new_buffer] new_const_dict[i] = new_value elif buffer in buffer_to_const: - new_const_dict[i] = buffer_to_const[buffer] + new_const_dict[i] = buffer_to_const[buffer].flatten() new_buffer_map[param] = buffer else: new_buffer_map[param] = buffer diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/spec.py b/python/tvm/relay/backend/contrib/ethosu/tir/spec.py index f9d38df9d901f..cb1637d9eeb20 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/spec.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/spec.py @@ -184,8 +184,10 @@ def __init__( ofm: SerialFeatureMap, kernel: SerialKernel, weight: SerialAddressRange, + weight2: SerialAddressRange, weight_zero_point: int, scale_bias: SerialAddressRange, + scale_bias2: SerialAddressRange, padding: SerialPadding, activation: SerialActivation, rounding_mode: str, @@ -195,8 +197,10 @@ def __init__( self.ofm = ofm self.kernel = kernel self.weight = weight + self.weight2 = weight2 self.weight_zero_point = weight_zero_point self.scale_bias = scale_bias + self.scale_bias2 = scale_bias2 self.padding = padding self.activation = activation self.rounding_mode = rounding_mode diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/utils.py b/python/tvm/relay/backend/contrib/ethosu/tir/utils.py index de1c0ab19f6e1..3dd0c13ff08aa 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/utils.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/utils.py @@ -20,24 +20,6 @@ from tvm import arith -# TODO(@mbaret): Formalise this with a specification -def get_weights_pointer(tir_extern_call): - """Get the weights pointer from a NPU extern call if it exists""" - supported_ops = ["ethosu_conv2d", "ethosu_depthwise_conv2d"] - if tir_extern_call.args[0] in supported_ops: - return tir_extern_call.args[41].buffer_var - return None - - -# TODO(@mbaret): Formalise this with a specification -def get_scale_bias_pointer(tir_extern_call): - """Get the scale_bias pointer from a NPU extern call if it exists""" - supported_ops = ["ethosu_conv2d", "ethosu_depthwise_conv2d"] - if tir_extern_call.args[0] in supported_ops: - return tir_extern_call.args[44].buffer_var - return None - - def get_op_attrs(stmt): """Iterate through nested attribute statements accumulating their values in an attribute dictionary. diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index c55a6310ffa58..1b34d7d14ce9d 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -304,6 +304,7 @@ def replace_npu_address_range_with_address(npu_addr_range): buffer = npu_addr_range.address.buffer_var assert buffer in buffer_addresses.keys(), f"searching for buffer : {buffer}, but not found" address, buffer_type = buffer_addresses[buffer] + address = address + npu_addr_range.address.index.value return vapi.NpuAddressRange(_REGION_MAP[buffer_type], address, npu_addr_range.length) def replace_tir_loads(npu_object): @@ -478,13 +479,30 @@ def _create_npu_op_conv2d( """This is a helper function to capture a list of arguments to create Vela NpuConv2DOperation object. """ + has_two_weights = serial_2d_convolution.weight2.address != -1 + has_two_biases = serial_2d_convolution.scale_bias2.address != -1 + npu_conv2d_op = vapi.NpuConv2DOperation() npu_conv2d_op.ifm = _create_npu_feature_map(serial_2d_convolution.ifm) npu_conv2d_op.ofm = _create_npu_feature_map(serial_2d_convolution.ofm) npu_conv2d_op.kernel = _create_npu_kernel(serial_2d_convolution.kernel) - npu_conv2d_op.weights = [_create_npu_address_range(serial_2d_convolution.weight)] + npu_conv2d_op.weights = ( + [ + _create_npu_address_range(serial_2d_convolution.weight), + _create_npu_address_range(serial_2d_convolution.weight2), + ] + if has_two_weights + else [_create_npu_address_range(serial_2d_convolution.weight)] + ) weights_zero_point = np.int64(serial_2d_convolution.weight_zero_point.value) - npu_conv2d_op.biases = [_create_npu_address_range(serial_2d_convolution.scale_bias)] + npu_conv2d_op.biases = ( + [ + _create_npu_address_range(serial_2d_convolution.scale_bias), + _create_npu_address_range(serial_2d_convolution.scale_bias2), + ] + if has_two_biases + else [_create_npu_address_range(serial_2d_convolution.scale_bias)] + ) npu_conv2d_op.padding = _create_npu_padding(serial_2d_convolution.padding) npu_conv2d_op.activation = _create_npu_activation(serial_2d_convolution.activation) diff --git a/python/tvm/relay/backend/contrib/ethosu/vela_api.py b/python/tvm/relay/backend/contrib/ethosu/vela_api.py index fd915a504d673..816f6fd59771c 100644 --- a/python/tvm/relay/backend/contrib/ethosu/vela_api.py +++ b/python/tvm/relay/backend/contrib/ethosu/vela_api.py @@ -141,17 +141,18 @@ def encode_weights( assert op in supported_ops.keys() npu_op, weights_zero_point = supported_ops[op](tir_extern_call) block_config = get_optimal_block_config(npu_op, accel_config) - # The weight layout is assumed to be flat OHWI, always. - assert len(values.shape) == 1 is_depthwise = op == "ethosu_depthwise_conv2d" - shape_ohwi = ( - npu_op.ofm.shape.depth, - npu_op.kernel.height, - npu_op.kernel.width, - 1 if is_depthwise else npu_op.ifm.shape.depth, - ) - assert values.size == np.prod(shape_ohwi) - values = np.reshape(values, shape_ohwi) + # Recover the original shape if we are dealing with a flattened tensor + if len(values.shape) == 1: + shape_ohwi = ( + npu_op.ofm.shape.depth, + npu_op.kernel.height, + npu_op.kernel.width, + 1 if is_depthwise else npu_op.ifm.shape.depth, + ) + assert values.size == np.prod(shape_ohwi) + values = np.reshape(values, shape_ohwi) + return compress_weights( weights=values, weights_zp=weights_zero_point, @@ -217,6 +218,7 @@ def compress_weights( weights.shape[layout_transform_indices[weights_layout][3]], ] block_traversal = calculate_block_traversal_mode(is_depthwise, shape_ohwi, ifm_bitdepth) + compressed_weights = vapi.npu_encode_weights( accelerator=accel_config, weights_volume=weights_ohwi, @@ -389,6 +391,7 @@ def get_accelerator_config() -> vapi.NpuAccelerator: "ethos-u55-64": vapi.NpuAccelerator.Ethos_U55_64, "ethos-u55-32": vapi.NpuAccelerator.Ethos_U55_32, "ethos-u65-256": vapi.NpuAccelerator.Ethos_U65_256, + "ethos-u65-512": vapi.NpuAccelerator.Ethos_U65_512, } compiler_attrs = tvm.get_global_func("relay.ext.ethos-u.get_compiler_attrs")() accel_config_str = compiler_attrs.accelerator_config diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index 052586387b1fc..8c0ef02cdf064 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -54,11 +54,11 @@ def get_shape_expr(in_expr, out_expr): return shape -@pytest.mark.parametrize("ifm_shape", [(1, 299, 299, 3), (1, 55, 55, 3)]) +@pytest.mark.parametrize("accel_type", ACCEL_TYPES + ["ethos-u65-512"]) +@pytest.mark.parametrize("ifm_shape", [(1, 31, 31, 3), (1, 55, 55, 4)]) @pytest.mark.parametrize("kernel_shape", [(3, 2), (1, 3)]) @pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 1)), ((3, 2), (1, 1))]) @pytest.mark.parametrize("padding", ["SAME", "VALID"]) -@pytest.mark.parametrize("accel_type", ACCEL_TYPES) @pytest.mark.parametrize("activation", ["NONE", "RELU"]) def test_ethosu_conv2d_single( ifm_shape, @@ -80,7 +80,7 @@ def tf_function(self, x): op = tf.nn.conv2d( x, filters=tf.constant( - np.random.uniform(size=[kernel_shape[0], kernel_shape[1], 3, 3]), + np.random.uniform(size=[kernel_shape[0], kernel_shape[1], ifm_shape[3], 5]), dtype=tf.float32, ), strides=tf_strides, diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index 315712996ac88..469cf3fe0620b 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -50,16 +50,16 @@ def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[ placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer, 0), 128, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 128, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 128, T.int8(-1), T.int8(-1), 12, T.load("uint8", placeholder_d_global, 0), 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2, 0), 112, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 2), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 112, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 2), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 112, T.int8(-1), T.int8(-1), 12, T.load("uint8", placeholder_d_global, 0), 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_4, 0), 112, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_5, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 4), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 112, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 4), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 112, T.int8(-1), T.int8(-1), 12, T.load("uint8", placeholder_d_global, 0), 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_6, 0), 112, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_7, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 6), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 112, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 6), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 112, T.int8(-1), T.int8(-1), 12, T.load("uint8", placeholder_d_global, 0), 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -117,10 +117,10 @@ def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[ placeholder_d_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True}) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer, 0), 304, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1, 0), 80, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 1, 8, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, 12, T.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 1, 8, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, T.int8(-1), T.int8(-1), 12, T.load("uint8", placeholder_d_global, 0), 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer, 0), 304, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1, 0), 80, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, T.load("int8", placeholder.data, 256), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, T.load("int8", ethosu_write.data, 64), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 1, 8, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, 12, T.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, T.load("int8", placeholder.data, 256), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, T.load("int8", ethosu_write.data, 64), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 1, 8, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, T.int8(-1), T.int8(-1), 12, T.load("uint8", placeholder_d_global, 0), 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -177,8 +177,8 @@ def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[ buffer_3 = T.buffer_var("uint8", "") # body ethosu_write_1 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer, 0), 592, 12, T.load("uint8", buffer_1, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 8, 16, 0, 16, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_2, 0), 160, 12, T.load("uint8", buffer_3, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer, 0), 592, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer_1, 0), 160, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 8, 16, 0, 16, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_2, 0), 160, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer_3, 0), 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -242,19 +242,19 @@ def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[ ethosu_write_1 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) placeholder_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True}) placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer, 0), 592, 12, T.load("uint8", buffer_1, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer, 0), 592, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer_1, 0), 160, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, T.int8(-1), T.int8(-1), 12, T.load("uint8", placeholder_d_global, 0), 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_4, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_5, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 2), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 2), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, T.int8(-1), T.int8(-1), 12, T.load("uint8", placeholder_d_global, 0), 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_6, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_7, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 4), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 4), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, T.int8(-1), T.int8(-1), 12, T.load("uint8", placeholder_d_global, 0), 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_8, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_9, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 6), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 6), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, T.int8(-1), T.int8(-1), 12, T.load("uint8", placeholder_d_global, 0), 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_remove_concatenates.py b/tests/python/contrib/test_ethosu/test_remove_concatenates.py index f6e0e2d855cd2..0eb7177a7dd8e 100644 --- a/tests/python/contrib/test_ethosu/test_remove_concatenates.py +++ b/tests/python/contrib/test_ethosu/test_remove_concatenates.py @@ -43,10 +43,10 @@ def main(placeholder: T.Buffer[(1, 8, 12, 16), "int8"], placeholder_1: T.Buffer[ buffer_7 = T.buffer_var("uint8", "") # body T_concat_1 = T.allocate([2816], "int8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, T.load("int8", placeholder_1.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 160, 16, 1, "int8", 8, 10, 16, 8, 0, 10, T.load("int8", T_concat_1, 192), 0, 0, 0, T.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer, 0), 2992, 12, T.load("uint8", buffer_1, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, T.load("int8", T_concat_1, 192), 0, 0, 0, T.float32(0.5), 10, "NHWC", 352, 16, 1, "int8", 8, 10, 16, 8, 0, 10, T.load("int8", T_concat.data, 352), 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_2, 0), 2992, 12, T.load("uint8", buffer_3, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 12, 16, 8, 0, 12, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 192, 16, 1, "int8", 8, 12, 16, 8, 0, 12, T.load("int8", T_concat_1, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_4, 0), 2992, 12, T.load("uint8", buffer_5, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 22, 16, 8, 0, 22, T.load("int8", T_concat_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 352, 16, 1, "int8", 8, 22, 16, 8, 0, 22, T.load("int8", T_concat.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_6, 0), 2992, 12, T.load("uint8", buffer_7, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, T.load("int8", placeholder_1.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 160, 16, 1, "int8", 8, 10, 16, 8, 0, 10, T.load("int8", T_concat_1, 192), 0, 0, 0, T.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer, 0), 2992, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer_1, 0), 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, T.load("int8", T_concat_1, 192), 0, 0, 0, T.float32(0.5), 10, "NHWC", 352, 16, 1, "int8", 8, 10, 16, 8, 0, 10, T.load("int8", T_concat.data, 352), 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_2, 0), 2992, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer_3, 0), 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 12, 16, 8, 0, 12, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 192, 16, 1, "int8", 8, 12, 16, 8, 0, 12, T.load("int8", T_concat_1, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_4, 0), 2992, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer_5, 0), 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 22, 16, 8, 0, 22, T.load("int8", T_concat_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 352, 16, 1, "int8", 8, 22, 16, 8, 0, 22, T.load("int8", T_concat.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_6, 0), 2992, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer_7, 0), 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py index 67fb2c7609621..4bd1788f4e3d2 100644 --- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -23,7 +23,7 @@ from tvm.relay.testing import run_opt_pass from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir from tvm.relay.backend.contrib.ethosu.tir.scheduler import total_cascader -from .infra import make_ethosu_conv2d, get_convolutional_args +from .infra import make_ethosu_conv2d def _create_serial_conv2d_params( @@ -129,6 +129,28 @@ def _create_serial_conv2d_params( ] +def get_conv2d_args(call, include_buffers=False, remove_constants=False): + """A method to extract the arguments from conv2d extern call.""" + args = call.args + conv_args = [] + remove_indices = [0] + + if remove_constants: + remove_indices += [41, 42, 43, 44, 46, 47, 48, 49] + + for i, arg in enumerate(args): + if i in remove_indices: + continue + elif isinstance(arg, tvm.tir.expr.IntImm) or isinstance(arg, tvm.tir.expr.FloatImm): + conv_args.append(arg.value) + elif isinstance(arg, tvm.tir.expr.Load) and not include_buffers: + conv_args.append(arg.index) + else: + conv_args.append(arg) + + return conv_args + + @pytest.mark.parametrize( "trial", [ @@ -321,7 +343,7 @@ def _get_func( def _visit(stmt): if isinstance(stmt, tvm.tir.Call): - data.append(get_convolutional_args(stmt, remove_constants=True)) + data.append(get_conv2d_args(stmt, remove_constants=True)) tvm.tir.stmt_functor.post_order_visit(mod["main"].body, _visit) @@ -342,10 +364,10 @@ def main(placeholder_5: T.Buffer[(1, 8, 8, 3), "int8"], ethosu_write_1: T.Buffer buffer_3 = T.buffer_var("uint8", "") # body ethosu_write_2 = T.allocate([1024], "int8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_3, 0), 160, 12, T.load("uint8", buffer_2, 0), 320, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer, 0), 304, 12, T.load("uint8", buffer_1, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, T.load("int8", placeholder_5.data, 12), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_3, 0), 160, 12, T.load("uint8", buffer_2, 0), 320, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 32), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer, 0), 304, 12, T.load("uint8", buffer_1, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_3, 0), 160, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer_2, 0), 320, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer, 0), 304, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer_1, 0), 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, T.load("int8", placeholder_5.data, 12), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_3, 0), 160, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer_2, 0), 320, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 32), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer, 0), 304, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer_1, 0), 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None @@ -361,10 +383,10 @@ def main(placeholder_5: T.Buffer[(1, 8, 8, 3), "int8"], ethosu_write_1: T.Buffer buffer_3 = T.buffer_var("uint8", "") # body ethosu_write_2 = T.allocate([1536], "int8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 256), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_2, 0), 1312, 12, T.load("uint8", buffer_1, 0), 320, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 256), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3, 0), 2608, 12, T.load("uint8", buffer, 0), 80, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 48), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_2, 0), 1312, 12, T.load("uint8", buffer_1, 0), 320, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, T.load("int8", ethosu_write_1.data, 256), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3, 0), 2608, 12, T.load("uint8", buffer, 0), 80, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 256), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_2, 0), 1312, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer_1, 0), 320, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 256), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3, 0), 2608, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer, 0), 80, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 48), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_2, 0), 1312, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer_1, 0), 320, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, T.load("int8", ethosu_write_1.data, 256), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3, 0), 2608, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer, 0), 80, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None @@ -380,12 +402,12 @@ def main(placeholder_5: T.Buffer[(1, 16, 16, 3), "int8"], ethosu_write_1: T.Buff buffer_3 = T.buffer_var("uint8", "") # body ethosu_write_2 = T.allocate([2560], "int8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 3, 8, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 8, 8, 32, 8, 0, 8, T.load("int8", ethosu_write_2, 512), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer_3, 0), 880, 12, T.load("uint8", buffer_2, 0), 320, 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 32, 8, 0, 8, T.load("int8", ethosu_write_2, 512), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer, 0), 1744, 12, T.load("uint8", buffer_1, 0), 80, 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 12, 16, 3, 12, 0, 16, T.load("int8", placeholder_5.data, 192), 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 10, 8, 32, 10, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer_3, 0), 880, 12, T.load("uint8", buffer_2, 0), 320, 0, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 10, 8, 32, 10, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 256), 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer, 0), 1744, 12, T.load("uint8", buffer_1, 0), 80, 0, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 16, 3, 4, 0, 16, T.load("int8", placeholder_5.data, 576), 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 4, 8, 32, 4, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer_3, 0), 880, 12, T.load("uint8", buffer_2, 0), 320, 0, 1, 2, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 32, 4, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 4, 8, 4, 0, 4, T.load("int8", ethosu_write_1.data, 512), 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer, 0), 1744, 12, T.load("uint8", buffer_1, 0), 80, 0, 1, 2, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 3, 8, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 8, 8, 32, 8, 0, 8, T.load("int8", ethosu_write_2, 512), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer_3, 0), 880, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer_2, 0), 320, T.int8(-1), T.int8(-1), 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 32, 8, 0, 8, T.load("int8", ethosu_write_2, 512), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer, 0), 1744, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer_1, 0), 80,T.int8(-1), T.int8(-1), 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 12, 16, 3, 12, 0, 16, T.load("int8", placeholder_5.data, 192), 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 10, 8, 32, 10, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer_3, 0), 880, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer_2, 0), 320, T.int8(-1), T.int8(-1), 0, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 10, 8, 32, 10, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, T.load("int8", ethosu_write_1.data, 256), 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer, 0), 1744, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer_1, 0), 80, T.int8(-1), T.int8(-1), 0, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 16, 3, 4, 0, 16, T.load("int8", placeholder_5.data, 576), 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 4, 8, 32, 4, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer_3, 0), 880, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer_2, 0), 320, T.int8(-1), T.int8(-1), 0, 1, 2, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 32, 4, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 4, 8, 4, 0, 4, T.load("int8", ethosu_write_1.data, 512), 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, T.load("uint8", buffer, 0), 1744, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer_1, 0), 80, T.int8(-1), T.int8(-1), 0, 1, 2, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None @@ -401,10 +423,10 @@ def main(placeholder_5: T.Buffer[(1, 8, 1, 8, 16), "int8"], ethosu_write_1: T.Bu buffer_3 = T.buffer_var("uint8", "") # body ethosu_write_2 = T.allocate([2304], "int8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 384), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer, 0), 1456, 12, T.load("uint8", buffer_1, 0), 352, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 384), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3, 0), 11040, 12, T.load("uint8", buffer_2, 0), 272, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 256), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer, 0), 1456, 12, T.load("uint8", buffer_1, 0), 352, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, T.load("int8", ethosu_write_1.data, 1024), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3, 0), 11040, 12, T.load("uint8", buffer_2, 0), 272, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 384), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer, 0), 1456, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer_1, 0), 352, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 384), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3, 0), 11040, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer_2, 0), 272, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, T.load("int8", placeholder_5.data, 256), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer, 0), 1456, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer_1, 0), 352, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, T.load("int8", ethosu_write_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, T.load("int8", ethosu_write_1.data, 1024), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_3, 0), 11040, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer_2, 0), 272, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None @@ -420,10 +442,10 @@ def main(placeholder: T.Buffer[(1, 8, 8, 3), "int8"], ethosu_write: T.Buffer[(1, buffer_3 = T.buffer_var("uint8", "") # body ethosu_write_1 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 3, 4, 0, 8, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 16, 32, 8, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer, 0), 160, 12, T.load("uint8", buffer_1, 0), 320, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 32, 8, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 32, 8, 16, 0, 32, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_2, 0), 304, 12, T.load("uint8", buffer_3, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 3, 4, 0, 8, T.load("int8", placeholder.data, 96), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 16, 32, 8, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer, 0), 160, 12, T.load("uint8", buffer_1, 0), 320, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 32, 8, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 32, 8, 16, 0, 32, T.load("int8", ethosu_write.data, 4096), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_2, 0), 304, 12, T.load("uint8", buffer_3, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 3, 4, 0, 8, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 16, 32, 8, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer, 0), 160, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer_1, 0), 320, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 32, 8, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 32, 8, 16, 0, 32, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_2, 0), 304, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer_3, 0), 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 3, 4, 0, 8, T.load("int8", placeholder.data, 96), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 16, 32, 8, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer, 0), 160, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer_1, 0), 320, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 32, 8, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 32, 8, 16, 0, 32, T.load("int8", ethosu_write.data, 4096), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer_2, 0), 304, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer_3, 0), 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", dtype="handle")) __tvm_meta__ = None @@ -439,8 +461,8 @@ def main(placeholder: T.Buffer[(1, 8, 1, 8, 16), "int8"], ethosu_write: T.Buffer buffer_3 = T.buffer_var("uint8", "") # body ethosu_write_1 = T.allocate([12288], "int8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 3, 8, 0, 8, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 16, 16, 35, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 768, 16, 256, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer, 0), 1456, 12, T.load("uint8", buffer_1, 0), 352, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NEAREST", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 35, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 768, 16, 256, "int8", 32, 32, 26, 32, 0, 32, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 1024, 16, 512, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_2, 0), 11040, 12, T.load("uint8", buffer_3, 0), 272, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NEAREST", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 3, 8, 0, 8, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 16, 16, 35, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 768, 16, 256, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer, 0), 1456, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer_1, 0), 352, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NEAREST", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 35, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 768, 16, 256, "int8", 32, 32, 26, 32, 0, 32, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 1024, 16, 512, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_2, 0), 11040, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer_3, 0), 272, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NEAREST", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -597,7 +619,7 @@ def main(placeholder_3: T.Buffer[(1, 10, 12, 8), "int8"], ethosu_write_1: T.Buff buffer = T.buffer_var("uint8", "") buffer_1 = T.buffer_var("uint8", "") # body - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 4, 8, 0, 8, T.load("int8", placeholder_3.data, 120), 0, 0, 0, T.float32(0.5), 10, "NHWC", 96, 8, 1, "int8", 8, 8, 16, 8, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer, 0), 848, 12, T.load("uint8", buffer_1, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 4, 8, 0, 8, T.load("int8", placeholder_3.data, 120), 0, 0, 0, T.float32(0.5), 10, "NHWC", 96, 8, 1, "int8", 8, 8, 16, 8, 0, 8, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer, 0), 848, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer_1, 0), 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None @@ -610,7 +632,7 @@ def main(placeholder_3: T.Buffer[(1, 7, 9, 5), "int8"], ethosu_write_1: T.Buffer buffer = T.buffer_var("uint8", "") buffer_1 = T.buffer_var("uint8", "") # body - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 3, 5, 3, 3, 0, 5, T.load("int8", placeholder_3.data, 146), 0, 0, 0, T.float32(0.5), 10, "NHWC", 45, 5, 1, "int8", 3, 5, 16, 3, 0, 5, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 80, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 656, 12, T.load("uint8", buffer, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 3, 5, 3, 3, 0, 5, T.load("int8", placeholder_3.data, 146), 0, 0, 0, T.float32(0.5), 10, "NHWC", 45, 5, 1, "int8", 3, 5, 16, 3, 0, 5, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 80, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 656, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer, 0), 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -652,8 +674,8 @@ def main(placeholder_3: T.Buffer[(4, 6, 8, 1), "int8"], ethosu_write_1: T.Buffer buffer = T.buffer_var("uint8", "") buffer_1 = T.buffer_var("uint8", "") # body - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, 12, T.load("uint8", buffer, 0), 160, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 72), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 384), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, 12, T.load("uint8", buffer, 0), 160, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer, 0), 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 72), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 384), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer, 0), 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None @@ -666,8 +688,8 @@ def main(placeholder_3: T.Buffer[(1, 24, 8), "int8"], ethosu_write_1: T.Buffer[( buffer = T.buffer_var("uint8", "") buffer_1 = T.buffer_var("uint8", "") # body - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, 12, T.load("uint8", buffer, 0), 160, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 72), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 384), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, 12, T.load("uint8", buffer, 0), 160, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer, 0), 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 72), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 384), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer, 0), 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None @@ -680,8 +702,8 @@ def main(placeholder_3: T.Buffer[(192, 1), "int8"], ethosu_write_1: T.Buffer[(1, buffer = T.buffer_var("uint8", "") buffer_1 = T.buffer_var("uint8", "") # body - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, 12, T.load("uint8", buffer, 0), 160, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 72), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 384), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, 12, T.load("uint8", buffer, 0), 160, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer, 0), 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 72), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 384), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer, 0), 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None @@ -694,8 +716,8 @@ def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(1, 8 buffer = T.buffer_var("uint8", "") buffer_1 = T.buffer_var("uint8", "") # body - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, 12, T.load("uint8", buffer, 0), 160, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 72), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 384), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, 12, T.load("uint8", buffer, 0), 160, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer, 0), 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, T.load("int8", placeholder_3.data, 72), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, T.load("int8", ethosu_write_1.data, 384), 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, T.load("uint8", buffer_1, 0), 848, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer, 0), 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py b/tests/python/contrib/test_ethosu/test_replace_copy.py index 7aee57d548fe4..f60aeee691811 100644 --- a/tests/python/contrib/test_ethosu/test_replace_copy.py +++ b/tests/python/contrib/test_ethosu/test_replace_copy.py @@ -41,7 +41,7 @@ def main(placeholder_3: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write_1: T.Buf placeholder_d_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1, 0), 304, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer, 0), 80, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 8, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, 12, T.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 8, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 304, T.int8(-1), T.int8(-1), 12, T.load("uint8", placeholder_d_global, 0), 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -87,10 +87,10 @@ def main(placeholder_5: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write_1: T.Buf placeholder_d_global = T.allocate([112], "uint8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer, 0), 416, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1, 0), 112, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 10, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 416, 12, T.load("uint8", placeholder_d_global, 0), 112, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 10, 16, 0, 16, T.load("int8", ethosu_write_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 416, T.int8(-1), T.int8(-1), 12, T.load("uint8", placeholder_d_global, 0), 112, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2, 0), 272, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3, 0), 64, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 6, 16, 0, 16, T.load("int8", ethosu_write_1.data, 10), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 272, 12, T.load("uint8", placeholder_d_global, 0), 64, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder_5.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 6, 16, 0, 16, T.load("int8", ethosu_write_1.data, 10), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 272, T.int8(-1), T.int8(-1), 12, T.load("uint8", placeholder_d_global, 0), 64, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_scheduler.py b/tests/python/contrib/test_ethosu/test_scheduler.py index 6a4aba4e38fc8..174e1bc761810 100644 --- a/tests/python/contrib/test_ethosu/test_scheduler.py +++ b/tests/python/contrib/test_ethosu/test_scheduler.py @@ -194,10 +194,10 @@ def main(input_buffer: T.Buffer[(1, 56, 56, 96), "int8"], output_buffer: T.Buffe T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", weight_buffer, 0), 2608, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", bias_buffer, 0), 240, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 96, 56, 0, 56, T.load("int8", input_buffer.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 5376, 96, 1, "int8", 56, 56, 24, 56, 0, 56, T.load("int8", featuremap_buffer, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 2608, 12, T.load("uint8", placeholder_d_global, 0), 240, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 96, 56, 0, 56, T.load("int8", input_buffer.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 5376, 96, 1, "int8", 56, 56, 24, 56, 0, 56, T.load("int8", featuremap_buffer, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 2608, T.int8(-1), T.int8(-1), 12, T.load("uint8", placeholder_d_global, 0), 240, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", weight_buffer2, 0), 736, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", bias_buffer2, 0), 240, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 24, 56, 0, 56, T.load("int8", featuremap_buffer, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, T.load("int8", featuremap_buffer2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 736, 12, T.load("uint8", placeholder_d_global, 0), 240, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 24, 56, 0, 56, T.load("int8", featuremap_buffer, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, T.load("int8", featuremap_buffer2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 736, T.int8(-1), T.int8(-1), 12, T.load("uint8", placeholder_d_global, 0), 240, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 56, 56, 24, 56, 0, 56, T.load("int8", featuremap_buffer, 0), 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, T.load("int8", featuremap_buffer2, 0), 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, T.load("int8", output_buffer.data, 0), 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "ADD", 0, "NONE", 0, 0, "TFL", dtype="handle")) __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py index add8021083c64..4a49682ced25b 100644 --- a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py +++ b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py @@ -39,7 +39,7 @@ def main(placeholder_3: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_conv2d_1: T.Bu placeholder_4 = T.buffer_var("uint8", "") placeholder_5 = T.buffer_var("uint8", "") # body - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 8, 8, 3, 8, 0, 8, T.load("uint8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 8, 8, 16, 8, 0, 8, T.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_4, 0), 0, 12, T.load("uint8", placeholder_5, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="uint8")) + T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 8, 8, 3, 8, 0, 8, T.load("uint8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 8, 8, 16, 8, 0, 8, T.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_4, 0), 0, T.int8(-1), T.int8(-1), 12, T.load("uint8", placeholder_5, 0), 0, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="uint8")) # fmt: on @@ -58,10 +58,10 @@ def main(placeholder_6: T.Buffer[(1, 8, 8, 3), "int8"], ethosu_conv2d_1: T.Buffe # body ethosu_conv2d_2 = T.allocate([1024], "uint8", "global") ethosu_conv2d_3 = T.allocate([2048], "uint8", "global") - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 3, 4, 0, 8, T.load("uint8", placeholder_6.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 4, 8, 32, 4, 0, 8, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_7, 0), 0, 12, T.load("uint8", placeholder_8, 0), 0, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="uint8")) - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 32, 4, 0, 8, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "uint8", 4, 8, 8, 4, 0, 8, T.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_9, 0), 0, 12, T.load("uint8", placeholder_5, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="uint8")) - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 3, 4, 0, 8, T.load("uint8", placeholder_6.data, 96), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 4, 8, 32, 4, 0, 8, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_7, 0), 0, 12, T.load("uint8", placeholder_8, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="uint8")) - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 32, 4, 0, 8, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "uint8", 4, 8, 8, 4, 0, 8, T.load("uint8", ethosu_conv2d_1.data, 256), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_9, 0), 0, 12, T.load("uint8", placeholder_5, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="uint8")) + T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 3, 4, 0, 8, T.load("uint8", placeholder_6.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 4, 8, 32, 4, 0, 8, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_7, 0), 0, T.int8(-1), T.int8(-1), 12, T.load("uint8", placeholder_8, 0), 0, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="uint8")) + T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 32, 4, 0, 8, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "uint8", 4, 8, 8, 4, 0, 8, T.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_9, 0), 0, T.int8(-1), T.int8(-1), 12, T.load("uint8", placeholder_5, 0), 0, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="uint8")) + T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 3, 4, 0, 8, T.load("uint8", placeholder_6.data, 96), 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 4, 8, 32, 4, 0, 8, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_7, 0), 0, T.int8(-1), T.int8(-1), 12, T.load("uint8", placeholder_8, 0), 0, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="uint8")) + T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 32, 4, 0, 8, T.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "uint8", 4, 8, 8, 4, 0, 8, T.load("uint8", ethosu_conv2d_1.data, 256), 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_9, 0), 0, T.int8(-1), T.int8(-1), 12, T.load("uint8", placeholder_5, 0), 0, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="uint8")) # fmt: on @@ -80,7 +80,7 @@ def main(placeholder_3: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_conv2d_1: T.Bu placeholder_d_global = T.allocate([8], "int32", "global") T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", placeholder_4, 0), 256, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("int32", placeholder_5, 0), 8, T.load("int32", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 16, 16, 32, 16, 0, 16, T.load("uint8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "uint8", 16, 16, 8, 16, 0, 16, T.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 0, 12, T.load("uint8", placeholder_d_global, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 16, 16, 32, 16, 0, 16, T.load("uint8", placeholder_3.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "uint8", 16, 16, 8, 16, 0, 16, T.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 0, T.int8(-1), T.int8(-1), 12, T.load("uint8", placeholder_d_global, 0), 0, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", dtype="handle")) # fmt: on @@ -114,16 +114,16 @@ def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[ placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer, 0), 128, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_1, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 128, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 128, T.int8(-1), T.int8(-1), 12, T.load("uint8", placeholder_d_global, 0), 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2, 0), 112, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 2), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 112, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 2), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 112, T.int8(-1), T.int8(-1), 12, T.load("uint8", placeholder_d_global, 0), 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_4, 0), 112, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_5, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 4), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 112, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 4), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 112, T.int8(-1), T.int8(-1), 12, T.load("uint8", placeholder_d_global, 0), 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_6, 0), 112, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_7, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 6), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 112, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 6), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 112, T.int8(-1), T.int8(-1), 12, T.load("uint8", placeholder_d_global, 0), 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -161,19 +161,19 @@ def main(placeholder: T.Buffer[(1, 16, 16, 32), "int8"], ethosu_write: T.Buffer[ ethosu_write_1 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) placeholder_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True}) placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer, 0), 592, 12, T.load("uint8", buffer_1, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, T.load("int8", placeholder.data, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", buffer, 0), 592, T.int8(-1), T.int8(-1), 12, T.load("uint8", buffer_1, 0), 160, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_2, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_3, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 0), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, T.int8(-1), T.int8(-1), 12, T.load("uint8", placeholder_d_global, 0), 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_4, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_5, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 2), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 2), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, T.int8(-1), T.int8(-1), 12, T.load("uint8", placeholder_d_global, 0), 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_6, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_7, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 4), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 4), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, T.int8(-1), T.int8(-1), 12, T.load("uint8", placeholder_d_global, 0), 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_8, 0), 80, T.load("uint8", placeholder_global, 0), dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", T.load("uint8", buffer_9, 0), 32, T.load("uint8", placeholder_d_global, 0), dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 6), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, 12, T.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, T.load("int8", ethosu_write_1, 0), 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, T.load("int8", ethosu_write.data, 6), 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, T.load("uint8", placeholder_global, 0), 80, T.int8(-1), T.int8(-1), 12, T.load("uint8", placeholder_d_global, 0), 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", dtype="handle")) __tvm_meta__ = None # fmt: on @@ -255,7 +255,9 @@ def test_buffer_info_extraction(): buffer_info = tir_to_cs_translator.extract_buffer_info(tir_mod, test_case["param_dict"]) for buffer_var, info in buffer_info.items(): if buffer_var in test_case["param_dict"].keys(): - assert (info.values == test_case["param_dict"][buffer_var]).all() + assert ( + info.values.flatten() == test_case["param_dict"][buffer_var].flatten() + ).all() assert info.dtype == test_case["param_dict"][buffer_var].dtype info.btype == tir_to_cs_translator.BufferType.constant else: