diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py b/python/tvm/relay/backend/contrib/ethosu/legalize.py index 6f37b90f0f97..d83cd403ca14 100644 --- a/python/tvm/relay/backend/contrib/ethosu/legalize.py +++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py @@ -920,7 +920,7 @@ def callback( if axis == [1, 2] and params.keepdims: weight_scale = 1 - weight_values = np.ones([out_channels, filter_height, filter_width, in_channels]) + weight_values = np.ones([out_channels, filter_height, filter_width, 1]) scale_bias = vela_api.pack_biases( biases=np.zeros(ifm_shape[-1]), ifm_scale=params.ifm.q_params.scale_f32, @@ -985,7 +985,7 @@ def callback( ) else: weight_scale = 1 / (filter_height * filter_width) - weight_values = np.ones([out_channels, filter_height, filter_width, in_channels]) + weight_values = np.ones([out_channels, filter_height, filter_width, 1]) bias = -1 * int(params.ifm.q_params.zero_point) * filter_height * filter_width scale_bias = vela_api.pack_biases( diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py b/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py index 5a200fa1989b..27aa462d60f2 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py @@ -16,8 +16,10 @@ # under the License. # pylint: disable=invalid-name, unused-argument """Extract parameters from the convolution operators in TIR.""" +import math import tvm -from ..vela_api import SCALE_BIAS_LENGTH +from ethosu.vela import api as vapi +from ..vela_api import SCALE_BIAS_LENGTH, get_accelerator_config from .utils import get_outer_loops, get_op_attrs, get_base_address, get_loads, get_stores from .dma import get_ifm_params, get_ofm_params from .spec import SerialKernel, SerialAddressRange, SerialActivation, Serial2DConvolution @@ -47,6 +49,8 @@ def get_conv2d_params(stmt, producers_consumers): Whether this operator allocates its output. """ + accel_config = get_accelerator_config() + attrs, body = get_op_attrs(stmt) _, _, _, _, _, inner = get_outer_loops(body, "NHWC") rh = inner @@ -75,17 +79,64 @@ def get_conv2d_params(stmt, producers_consumers): # Get scale_bias info scale_bias_load = loads[3] scale_bias_base = [get_base_address(index) for index in scale_bias_load.indices] - serial_scale_bias = SerialAddressRange( - address=tvm.tir.BufferLoad(scale_bias_load.buffer, scale_bias_base), - length=SCALE_BIAS_LENGTH * serial_ofm[3], - ) # Get weight info weight_load = loads[2] weight_base = [get_base_address(index) for index in weight_load.indices] - serial_weight = SerialAddressRange( - address=tvm.tir.BufferLoad(weight_load.buffer, weight_base), - length=serial_ofm[3] * serial_kernel[0] * serial_kernel[1] * rc.extent, - ) + channels = serial_ofm[3] if isinstance(serial_ofm[3], int) else serial_ofm[3].value + + if accel_config == vapi.NpuAccelerator.Ethos_U65_512: + scale_bias_length = SCALE_BIAS_LENGTH * math.ceil(channels / 2) + scale_bias2_length = SCALE_BIAS_LENGTH * math.floor(channels / 2) + + serial_scale_bias = SerialAddressRange( + address=tvm.tir.BufferLoad(scale_bias_load.buffer, scale_bias_base), + length=scale_bias_length, + ) + serial_scale_bias2 = SerialAddressRange( + address=tvm.tir.BufferLoad( + scale_bias_load.buffer, [scale_bias_base[0] + scale_bias_length] + ), + length=scale_bias2_length, + ) + + weight_length = ( + channels * serial_kernel[0] * serial_kernel[1] * math.ceil(rc.extent.value / 2) + ) + weight2_length = ( + channels * serial_kernel[0] * serial_kernel[1] * math.floor(rc.extent.value / 2) + ) + + serial_weight = SerialAddressRange( + address=tvm.tir.BufferLoad(weight_load.buffer, weight_base), + length=weight_length, + ) + serial_weight2 = SerialAddressRange( + address=tvm.tir.BufferLoad(weight_load.buffer, [weight_base[0] + weight_length]), + length=weight2_length, + ) + else: + scale_bias_length = SCALE_BIAS_LENGTH * channels + + serial_scale_bias = SerialAddressRange( + address=tvm.tir.BufferLoad(scale_bias_load.buffer, scale_bias_base), + length=scale_bias_length, + ) + # Insert -1s into the spec to denote the absence of the other pointer + serial_scale_bias2 = SerialAddressRange( + address=tvm.tir.IntImm("int8", -1), + length=tvm.tir.IntImm("int8", -1), + ) + + weight_length = channels * serial_kernel[0] * serial_kernel[1] * rc.extent.value + + serial_weight = SerialAddressRange( + address=tvm.tir.BufferLoad(weight_load.buffer, weight_base), + length=weight_length, + ) + serial_weight2 = SerialAddressRange( + address=tvm.tir.IntImm("int8", -1), + length=tvm.tir.IntImm("int8", -1), + ) # Get activation info serial_activation = SerialActivation( op=attrs["activation"], clip_min=attrs["clip_min"], clip_max=attrs["clip_max"] @@ -96,8 +147,10 @@ def get_conv2d_params(stmt, producers_consumers): ofm=serial_ofm, kernel=serial_kernel, weight=serial_weight, + weight2=serial_weight2, weight_zero_point=attrs["weight_zero_point"], scale_bias=serial_scale_bias, + scale_bias2=serial_scale_bias2, padding=serial_padding, activation=serial_activation, rounding_mode=attrs["rounding_mode"], diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py index a35d96a1e4e9..baadede08d66 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -14,13 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name, unused-argument, no-else-return, inconsistent-return-statements +# pylint: disable=invalid-name, unused-argument, no-else-return, inconsistent-return-statements, too-many-nested-blocks """The TIR passes to be run on Arm(R) Ethos(TM)-U NPU TIR Compiler.""" from collections import namedtuple import numpy as np # type: ignore import tvm from tvm.relay.backend.contrib.ethosu import vela_api +from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator as tirtocs +from ethosu.vela import api as vapi from .convolution import get_conv2d_params from .depthwise import get_depthwise_conv2d_params from .pooling import get_pooling_params @@ -28,7 +30,6 @@ from .identity import get_identity_params from .unary_elementwise import get_unary_elementwise_params from .transform import get_copy_params -from .utils import get_weights_buffer, get_scale_bias_buffer from .producers_consumers import ProducersConsumers from .. import _ffi_api @@ -227,20 +228,33 @@ def DivideConstants(const_dict): def _visit(stmt): new_args = [] + # We don't want to divide the constant that will be executed on two cores in parallel + is_u65_conv2d = ( + vela_api.get_accelerator_config() == vapi.NpuAccelerator.Ethos_U65_512 + and stmt.args[0] == "ethosu_conv2d" + ) for i, arg in enumerate(stmt.args): if isinstance(arg, tvm.tir.expr.BufferLoad): # If we're trying to load a buffer that maps to a constant if arg.buffer.data in buffer_to_const: const = buffer_to_const[arg.buffer.data] - - assert len(arg.indices) == 1, "Ethos-U passes expects flattened buffers" + flattened_const_shape = np.prod(const.shape) offset = int(arg.indices[0]) # Note by convention the arg after a constant read is the length of the read length = int(stmt.args[i + 1]) # If it's anything other than a full read, create a new buffer - if offset != 0 or len(const) != length: - new_consts.append(const[offset : offset + length]) + if (offset != 0 or flattened_const_shape != length) and not is_u65_conv2d: + out_channels = const.shape[0] + offset_channels = int((offset * out_channels) / flattened_const_shape) + length_channels = int((length * out_channels) / flattened_const_shape) + # split the constant up across channels + split_const = np.split(const, out_channels, axis=0) + # create a new const out of the channels we want to keep + new_const = np.concatenate( + split_const[offset_channels : offset_channels + length_channels], axis=0 + ) + new_consts.append(new_const) new_buffer = tvm.tir.decl_buffer( (length,), arg.dtype, scope=arg.buffer.scope() ) @@ -256,8 +270,8 @@ def _visit(stmt): def _ftransform(f, mod, ctx): for i, param in enumerate(f.params): if i in const_dict: - buffer_to_const[param] = const_dict[i].flatten() - buffer_to_const[f.buffer_map[param].data] = const_dict[i].flatten() + buffer_to_const[param] = const_dict[i] + buffer_to_const[f.buffer_map[param].data] = const_dict[i] new_body = tvm.tir.stmt_functor.ir_transform(f.body, _visit, None, ["tir.Call"]) # Both the params and buffer map need updating for the newly introduced buffers @@ -339,7 +353,7 @@ def _encode_weights(tir_extern_call, weights): value = np.frombuffer(value_bytes, dtype="uint8") return value - def _declare_constant_buffer(old_buffer, encoded_constants): + def _declare_constant_buffer(old_buffer, encoded_constants, split_idx): """Create a new buffer and add the old buffer and its pointer to the rewriting maps.""" new_buffer = tvm.tir.decl_buffer( @@ -354,14 +368,47 @@ def _declare_constant_buffer(old_buffer, encoded_constants): "old_buffer": old_buffer, "new_buffer": new_buffer, "encoded_constants": encoded_constants, + "split_idx": split_idx, } ) + def _encode_weights_or_bias(buffer1, buffer2, stmt, encode_func): + """Encode the weights or align the bias either for one or two cores, + depending on the variant.""" + constant = old_buffer_to_const[buffer1] + + # If we have just one core, encode the whole constant + if buffer2 is None: + new_const = encode_func(stmt, constant) + return new_const, None + + # Assume that the constant tensor has not been flattened yet + assert len(constant.shape) != 1 + channels = constant.shape[0] + split_const = np.split(constant, channels, axis=0) + + const_list = [split_const[i] for i in range(channels) if i % 2 == 0] + const_to_encode = np.concatenate(const_list, axis=0) + + new_const = encode_func(stmt, const_to_encode) + split_idx = len(new_const) + + # Encode half of the constant separately for the other core if it exists + assert buffer1.same_as(buffer2) + const2_list = [split_const[i] for i in range(channels) if i % 2 == 1] + const2_to_encode = np.concatenate(const2_list, axis=0) + + new_const2 = encode_func(stmt, const2_to_encode) + new_const = np.append(new_const, new_const2).astype("uint8") + + return new_const, split_idx + def _visit(stmt): if isinstance(stmt, tvm.tir.Call): + op = str(stmt.args[0].value) # Handle copies as a special-case by propagating the buffer information # from the read to the write pointer. - if stmt.args[0] == "ethosu_copy": + if op == "ethosu_copy": read_buffer = stmt.args[1].buffer write_buffer = stmt.args[3].buffer # Assert writing to the base of the write_var (pre-StorageRewrite) @@ -370,24 +417,50 @@ def _visit(stmt): copied_buffers.append({"source": read_buffer, "dest": write_buffer}) copy_map[write_buffer] = read_buffer - else: + ops_with_weights = { + "ethosu_conv2d": tirtocs.translate_ethosu_conv2d, + "ethosu_depthwise_conv2d": tirtocs.translate_ethosu_depthwise_conv2d, + } + if op in ops_with_weights: + npu_op, _ = ops_with_weights[op](stmt) + # Encode the weights - weights_buffer = get_weights_buffer(stmt) - if weights_buffer is not None: - if weights_buffer in copy_map: - weights_buffer = copy_map[weights_buffer] - unencoded_weights_value = old_buffer_to_const[weights_buffer] - encoded_weights_value = _encode_weights(stmt, unencoded_weights_value) - _declare_constant_buffer(weights_buffer, encoded_weights_value) + weights_buffer = npu_op.weights[0].address.buffer + if weights_buffer in copy_map: + weights_buffer = copy_map[weights_buffer] + + # In case of U65 512 mac variant the weights are split across two cores + # and need to be encoded separately + weights2_buffer = ( + npu_op.weights[1].address.buffer + if accel_config == vapi.NpuAccelerator.Ethos_U65_512 + else None + ) + if weights2_buffer in copy_map: + weights2_buffer = copy_map[weights2_buffer] + + new_weights, split_idx = _encode_weights_or_bias( + weights_buffer, weights2_buffer, stmt, _encode_weights + ) + _declare_constant_buffer(weights_buffer, new_weights, split_idx) # Align the scale_bias to 16 bytes - scale_bias_buffer = get_scale_bias_buffer(stmt) - if scale_bias_buffer is not None: - if scale_bias_buffer in copy_map: - scale_bias_buffer = copy_map[scale_bias_buffer] - scale_bias_value = old_buffer_to_const[scale_bias_buffer] - aligned_scale_bias_value = _align_scale_bias(stmt, scale_bias_value) - _declare_constant_buffer(scale_bias_buffer, aligned_scale_bias_value) + scale_bias_buffer = npu_op.biases[0].address.buffer + if scale_bias_buffer in copy_map: + scale_bias_buffer = copy_map[scale_bias_buffer] + scale_bias2_buffer = ( + npu_op.biases[1].address.buffer + if accel_config == vapi.NpuAccelerator.Ethos_U65_512 + else None + ) + if scale_bias2_buffer in copy_map: + scale_bias2_buffer = copy_map[scale_bias2_buffer] + + new_scale_bias, split_idx = _encode_weights_or_bias( + scale_bias_buffer, scale_bias2_buffer, stmt, _align_scale_bias + ) + + _declare_constant_buffer(scale_bias_buffer, new_scale_bias, split_idx) tvm.tir.stmt_functor.post_order_visit(stmt, _visit) @@ -396,7 +469,9 @@ def _visit(stmt): "constant_buffer_replacements": constant_buffer_replacements, } - def transform_stmt(stmt, buf_remap, var_remap, pointer_to_buffer, new_buffer_to_const): + def transform_stmt( + stmt, buf_remap, var_remap, pointer_to_buffer, new_buffer_to_const, new_buffer_to_split_idx + ): def _visit_rewrite(stmt): if isinstance(stmt, tvm.tir.Call): # For extern calls, we need to rewrite pairs of arguments corresponding to @@ -411,7 +486,19 @@ def _visit_rewrite(stmt): isinstance(prev_arg, tvm.tir.BufferLoad) and prev_arg.buffer in new_buffer_to_const ): - arg = np.prod(list(prev_arg.buffer.shape)) + buffer_size = np.prod(list(prev_arg.buffer.shape)) + arg = buffer_size + # We have to check for split weights/bias for conv2d and depthwise_conv2d + if old_args[0] in ("ethosu_conv2d", "depthwise_conv2d"): + # We have split weights/bias + if prev_arg.buffer in new_buffer_to_split_idx: + split_idx = new_buffer_to_split_idx[prev_arg.buffer] + # The first half of the split buffer + if prev_arg.indices[0] == 0: + arg = split_idx + # the second half of the split buffer + else: + arg = buffer_size - split_idx new_args.append(arg) @@ -439,7 +526,12 @@ def _visit_rewrite(stmt): # rewrite the nodes which contain the Buffers. if isinstance(stmt, tvm.tir.BufferLoad): if stmt.buffer in buf_remap: - return tvm.tir.BufferLoad(buf_remap[stmt.buffer], stmt.indices, stmt.span) + new_buffer = buf_remap[stmt.buffer] + new_indices = stmt.indices + offset = new_indices[0] + if offset != 0 and new_buffer in new_buffer_to_split_idx: + offset = new_buffer_to_split_idx[new_buffer] + return tvm.tir.BufferLoad(buf_remap[stmt.buffer], [offset], stmt.span) if isinstance(stmt, tvm.tir.AttrStmt): node_pointer = stmt.node @@ -467,7 +559,7 @@ def _ftransform(f, mod, ctx): old_buffer_to_const = {} for i, param in enumerate(f.params): if i in const_dict: - old_buffer_to_const[f.buffer_map[param]] = const_dict[i].flatten() + old_buffer_to_const[f.buffer_map[param]] = const_dict[i] # Step 1: Collect information on the buffers that will be # replaced by encodings. @@ -477,12 +569,16 @@ def _ftransform(f, mod, ctx): # collected information. buf_remap = {} new_buffer_to_const = {} + new_buffer_to_split_idx = {} # Any encoded buffers must be replaced for info in buffer_information["constant_buffer_replacements"]: buf_remap[info["old_buffer"]] = info["new_buffer"] new_buffer_to_const[info["new_buffer"]] = info["encoded_constants"] + if info["split_idx"]: + new_buffer_to_split_idx[info["new_buffer"]] = info["split_idx"] + # Any buffers that are copied into from an encoded buffer must # be replaced. for info in buffer_information["copied_buffers"]: @@ -503,6 +599,9 @@ def _ftransform(f, mod, ctx): if copy_source in new_buffer_to_const: new_buffer_to_const[new_dest] = new_buffer_to_const[copy_source] + if copy_source in new_buffer_to_split_idx: + new_buffer_to_split_idx[new_dest] = new_buffer_to_split_idx[copy_source] + # Define additional dependent lookup tables. var_remap = {old.data: new.data for (old, new) in buf_remap.items()} pointer_to_buffer = { @@ -511,7 +610,12 @@ def _ftransform(f, mod, ctx): # Step 3: Then perform the rewrites new_body = transform_stmt( - f.body, buf_remap, var_remap, pointer_to_buffer, new_buffer_to_const + f.body, + buf_remap, + var_remap, + pointer_to_buffer, + new_buffer_to_const, + new_buffer_to_split_idx, ) # Step 4: Rewrite the buffer map and const dict to instead use the encoded versions @@ -522,9 +626,9 @@ def _ftransform(f, mod, ctx): buffer = buf_remap[buffer] if buffer in new_buffer_to_const: - new_const_dict[i] = new_buffer_to_const[buffer] + new_const_dict[i] = new_buffer_to_const[buffer].flatten() elif buffer in old_buffer_to_const: - new_const_dict[i] = old_buffer_to_const[buffer] + new_const_dict[i] = old_buffer_to_const[buffer].flatten() new_buffer_map[param] = buffer diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/spec.py b/python/tvm/relay/backend/contrib/ethosu/tir/spec.py index 8234c90c4ae1..1c647f912013 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/spec.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/spec.py @@ -194,8 +194,10 @@ def __init__( ofm: SerialFeatureMap, kernel: SerialKernel, weight: SerialAddressRange, + weight2: SerialAddressRange, weight_zero_point: int, scale_bias: SerialAddressRange, + scale_bias2: SerialAddressRange, padding: SerialPadding, activation: SerialActivation, rounding_mode: str, @@ -206,8 +208,10 @@ def __init__( self.ofm = ofm self.kernel = kernel self.weight = weight + self.weight2 = weight2 self.weight_zero_point = weight_zero_point self.scale_bias = scale_bias + self.scale_bias2 = scale_bias2 self.padding = padding self.activation = activation self.rounding_mode = rounding_mode diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/utils.py b/python/tvm/relay/backend/contrib/ethosu/tir/utils.py index 506f18ba3a99..a823667234df 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir/utils.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir/utils.py @@ -20,24 +20,6 @@ from tvm import arith -# TODO(@mbaret): Formalise this with a specification -def get_weights_buffer(tir_extern_call): - """Get the weights pointer from a NPU extern call if it exists""" - supported_ops = ["ethosu_conv2d", "ethosu_depthwise_conv2d"] - if tir_extern_call.args[0] in supported_ops: - return tir_extern_call.args[41].buffer - return None - - -# TODO(@mbaret): Formalise this with a specification -def get_scale_bias_buffer(tir_extern_call): - """Get the scale_bias pointer from a NPU extern call if it exists""" - supported_ops = ["ethosu_conv2d", "ethosu_depthwise_conv2d"] - if tir_extern_call.args[0] in supported_ops: - return tir_extern_call.args[44].buffer - return None - - def get_op_attrs(stmt): """Iterate through nested attribute statements accumulating their values in an attribute dictionary. diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index 58ac2d4fba9d..a3d46170dfca 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -441,6 +441,7 @@ def replace_npu_address_range_with_address(npu_addr_range): ) assert buffer in buffer_addresses.keys(), f"searching for buffer : {buffer}, but not found" address, buffer_type = buffer_addresses[buffer] + address = address + int(npu_addr_range.address.indices[0].value) return vapi.NpuAddressRange(_get_region(buffer_type), address, npu_addr_range.length) def replace_tir_loads(npu_object): @@ -606,13 +607,30 @@ def _create_npu_op_conv2d( """This is a helper function to capture a list of arguments to create Vela NpuConv2DOperation object. """ + has_two_weights = serial_2d_convolution.weight2.address != -1 + has_two_biases = serial_2d_convolution.scale_bias2.address != -1 + npu_conv2d_op = vapi.NpuConv2DOperation() npu_conv2d_op.ifm = _create_npu_feature_map(serial_2d_convolution.ifm) npu_conv2d_op.ofm = _create_npu_feature_map(serial_2d_convolution.ofm) npu_conv2d_op.kernel = _create_npu_kernel(serial_2d_convolution.kernel) - npu_conv2d_op.weights = [_create_npu_address_range(serial_2d_convolution.weight)] + npu_conv2d_op.weights = ( + [ + _create_npu_address_range(serial_2d_convolution.weight), + _create_npu_address_range(serial_2d_convolution.weight2), + ] + if has_two_weights + else [_create_npu_address_range(serial_2d_convolution.weight)] + ) weights_zero_point = np.int64(serial_2d_convolution.weight_zero_point.value) - npu_conv2d_op.biases = [_create_npu_address_range(serial_2d_convolution.scale_bias)] + npu_conv2d_op.biases = ( + [ + _create_npu_address_range(serial_2d_convolution.scale_bias), + _create_npu_address_range(serial_2d_convolution.scale_bias2), + ] + if has_two_biases + else [_create_npu_address_range(serial_2d_convolution.scale_bias)] + ) npu_conv2d_op.padding = _create_npu_padding(serial_2d_convolution.padding) npu_conv2d_op.activation = _create_npu_activation(serial_2d_convolution.activation) diff --git a/python/tvm/relay/backend/contrib/ethosu/vela_api.py b/python/tvm/relay/backend/contrib/ethosu/vela_api.py index 0eb8ab0cf8f2..6d01e8de57b5 100644 --- a/python/tvm/relay/backend/contrib/ethosu/vela_api.py +++ b/python/tvm/relay/backend/contrib/ethosu/vela_api.py @@ -140,17 +140,18 @@ def encode_weights( op = str(tir_extern_call.args[0].value) assert op in supported_ops.keys() npu_op, weights_zero_point = supported_ops[op](tir_extern_call) - # The weight layout is assumed to be flat OHWI, always. - assert len(values.shape) == 1 is_depthwise = op == "ethosu_depthwise_conv2d" - shape_ohwi = ( - npu_op.ofm.shape.depth, - npu_op.kernel.height, - npu_op.kernel.width, - 1 if is_depthwise else npu_op.ifm.shape.depth, - ) - assert values.size == np.prod(shape_ohwi) - values = np.reshape(values, shape_ohwi) + # Recover the original shape if we are dealing with a flattened tensor + if len(values.shape) == 1: + shape_ohwi = ( + npu_op.ofm.shape.depth, + npu_op.kernel.height, + npu_op.kernel.width, + 1 if is_depthwise else npu_op.ifm.shape.depth, + ) + assert values.size == np.prod(shape_ohwi) + values = np.reshape(values, shape_ohwi) + return compress_weights( weights=values, weights_zp=weights_zero_point, @@ -216,6 +217,7 @@ def compress_weights( weights.shape[layout_transform_indices[weights_layout][3]], ] block_traversal = calculate_block_traversal_mode(is_depthwise, shape_ohwi, ifm_bitdepth) + compressed_weights = vapi.npu_encode_weights( accelerator=accel_config, weights_volume=weights_ohwi, @@ -388,6 +390,7 @@ def get_accelerator_config() -> vapi.NpuAccelerator: "ethos-u55-64": vapi.NpuAccelerator.Ethos_U55_64, "ethos-u55-32": vapi.NpuAccelerator.Ethos_U55_32, "ethos-u65-256": vapi.NpuAccelerator.Ethos_U65_256, + "ethos-u65-512": vapi.NpuAccelerator.Ethos_U65_512, } compiler_attrs = tvm.get_global_func("relay.ext.ethos-u.get_compiler_attrs")() accel_config_str = compiler_attrs.accelerator_config diff --git a/tests/python/contrib/test_ethosu/test_codegen.py b/tests/python/contrib/test_ethosu/test_codegen.py index d38f4abbc547..4268392f1b78 100644 --- a/tests/python/contrib/test_ethosu/test_codegen.py +++ b/tests/python/contrib/test_ethosu/test_codegen.py @@ -37,11 +37,11 @@ ACCEL_TYPES = ["ethos-u55-256", "ethos-u55-128", "ethos-u55-64", "ethos-u55-32", "ethos-u65-256"] +@pytest.mark.parametrize("accel_type", ACCEL_TYPES + ["ethos-u65-512"]) @pytest.mark.parametrize("ifm_shape", [(1, 299, 299, 2), (1, 55, 55, 3)]) @pytest.mark.parametrize("kernel_shape", [(3, 2), (1, 3)]) @pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 1)), ((3, 2), (1, 1))]) @pytest.mark.parametrize("padding", ["SAME", "VALID"]) -@pytest.mark.parametrize("accel_type", ACCEL_TYPES) @pytest.mark.parametrize("activation", ["NONE", "RELU"]) def test_ethosu_conv2d_single( ifm_shape, @@ -79,7 +79,7 @@ def conv2d(x): @pytest.mark.parametrize("kernel_shape", [(3, 2), (1, 3)]) @pytest.mark.parametrize("strides, dilation", [((1, 1), (2, 1)), ((3, 2), (1, 1))]) @pytest.mark.parametrize("padding", ["SAME", "VALID"]) -@pytest.mark.parametrize("accel_type", ACCEL_TYPES) +@pytest.mark.parametrize("accel_type", ACCEL_TYPES + ["ethos-u65-512"]) @pytest.mark.parametrize("activation", ["NONE", "RELU"]) def test_ethosu_conv2d_double( ifm_shape, diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py index 57bcf0881886..457edc861d06 100644 --- a/tests/python/contrib/test_ethosu/test_encode_constants.py +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -32,7 +32,7 @@ # fmt: off @tvm.script.ir_module -class WeightStreamOnly: +class WeightStreamOnlyU55: @T.prim_func def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: # function attr dict @@ -54,21 +54,77 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), p2_global_1 = T.buffer_decl([32], dtype="uint8", data=p2_global.data) T.evaluate(T.call_extern("ethosu_copy", buffer[0], 128, p1_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 32, p2_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1_global[0], 128, 12, p2_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1_global[0], 128, T.int8(-1), T.int8(-1), 12, p2_global[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_2[0], 112, p1_global_1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_3[0], 32, p2_global_1[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1_global_1[0], 112, 12, p2_global_1[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1_global_1[0], 112, T.int8(-1), T.int8(-1), 12, p2_global_1[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_4[0], 112, p1_global_1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_5[0], 32, p2_global_1[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1_global_1[0], 112, 12, p2_global_1[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1_global_1[0], 112, T.int8(-1), T.int8(-1), 12, p2_global_1[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_6[0], 112, p1_global_1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_7[0], 32, p2_global_1[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1_global_1[0], 112, 12, p2_global_1[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, p1_global_1[0], 112, T.int8(-1), T.int8(-1), 12, p2_global_1[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + __tvm_meta__ = None + + +@tvm.script.ir_module +class WeightStreamOnlyU65: + @T.prim_func + def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + # buffer definition + buffer_encoded = T.buffer_decl([160], dtype="uint8") + buffer_encoded_1 = T.buffer_decl([32], dtype="uint8") + buffer_encoded_2 = T.buffer_decl([160], dtype="uint8") + buffer_encoded_3 = T.buffer_decl([32], dtype="uint8") + buffer_encoded_4 = T.buffer_decl([176], dtype="uint8") + buffer_encoded_5 = T.buffer_decl([32], dtype="uint8") + buffer_encoded_6 = T.buffer_decl([160], dtype="uint8") + buffer_encoded_7 = T.buffer_decl([32], dtype="uint8") + T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) + T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) + # body + placeholder_global = T.allocate([176], "uint8", "global", annotations={"disable_lower_builtin":True}) + placeholder_global_1 = T.buffer_decl([160], dtype="uint8", data=placeholder_global.data) + placeholder_global_2 = T.buffer_decl([160], dtype="uint8", data=placeholder_global.data) + placeholder_global_3 = T.buffer_decl([160], dtype="uint8", data=placeholder_global.data) + placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) + placeholder_d_global_1 = T.buffer_decl([32], dtype="uint8", data=placeholder_d_global.data) + placeholder_d_global_2 = T.buffer_decl([32], dtype="uint8", data=placeholder_d_global.data) + placeholder_d_global_3 = T.buffer_decl([32], dtype="uint8", data=placeholder_d_global.data) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded[0], 160, placeholder_global_1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_1[0], 32, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_1[0], 80, placeholder_global_1[80], 80, 12, placeholder_d_global[0], 16, placeholder_d_global[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_2[0], 160, placeholder_global_2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_3[0], 32, placeholder_d_global_1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_2[0], 80, placeholder_global_2[80], 80, 12, placeholder_d_global_1[0], 16, placeholder_d_global_1[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_4[0], 176, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_5[0], 32, placeholder_d_global_2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 96, placeholder_global[96], 80, 12, placeholder_d_global_2[0], 16, placeholder_d_global_2[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_6[0], 160, placeholder_global_3[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_7[0], 32, placeholder_d_global_3[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_3[0], 80, placeholder_global_3[80], 80, 12, placeholder_d_global_3[0], 16, placeholder_d_global_3[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None # fmt: on -def test_weight_stream_only(): +@pytest.mark.parametrize( + "accelerator, reference_mod, reference_const_sizes", + [ + ( + "ethos-u55-128", + WeightStreamOnlyU55, + [128, 32, 112, 32, 112, 32, 112, 32], + ), + ( + "ethos-u65-512", + WeightStreamOnlyU65, + [160, 32, 160, 32, 176, 32, 160, 32], + ), + ], +) +def test_weight_stream_only(accelerator, reference_mod, reference_const_sizes): def _planner(cached_func, const_dict, sch): weights = cached_func.inputs[1] bias = cached_func.inputs[2] @@ -95,21 +151,23 @@ def _get_func(): func = run_opt_pass(func, relay.transform.InferType()) return func - func = _get_func() - mod, consts = _lower_to_tir(func, cascader=_planner) - script = mod.script(show_meta=True) - test_mod = tvm.script.from_source(script) - reference_mod = WeightStreamOnly - tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) + config = { + "accelerator_config": accelerator, + } + with tvm.transform.PassContext(config={"relay.ext.ethos-u.options": config}): + func = _get_func() + mod, consts = _lower_to_tir(func, cascader=_planner) + script = mod.script(show_meta=True) + test_mod = tvm.script.from_source(script) + tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) - reference_const_sizes = [128, 32, 112, 32, 112, 32, 112, 32] - test_const_size = [value.size for value in list(consts.values())] - assert reference_const_sizes == test_const_size + test_const_size = [value.size for value in list(consts.values())] + assert reference_const_sizes == test_const_size # fmt: off @tvm.script.ir_module -class RereadWeights: +class RereadWeightsU55: @T.prim_func def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: # function attr dict @@ -123,15 +181,56 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), placeholder_d_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True}) T.evaluate(T.call_extern("ethosu_copy", buffer[0], 304, placeholder_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 80, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 304, 12, placeholder_d_global[0], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 304, T.int8(-1), T.int8(-1), 12, placeholder_d_global[0], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer[0], 304, placeholder_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 80, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[256], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[64], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 304, 12, placeholder_d_global[0], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[256], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[64], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 304, T.int8(-1), T.int8(-1), 12, placeholder_d_global[0], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + __tvm_meta__ = None + + +@tvm.script.ir_module +class RereadWeightsU65: + @T.prim_func + def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + # buffer definition + placeholder_encoded = T.buffer_decl([368], dtype="uint8") + placeholder_encoded_1 = T.buffer_decl([96], dtype="uint8") + T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) + T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) + # body + placeholder_global = T.allocate([368], "uint8", "global", annotations={"disable_lower_builtin":True}) + placeholder_global_1 = T.buffer_decl([368], dtype="uint8", data=placeholder_global.data) + placeholder_d_global = T.allocate([96], "uint8", "global", annotations={"disable_lower_builtin":True}) + placeholder_d_global_1 = T.buffer_decl([96], dtype="uint8", data=placeholder_d_global.data) + T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded[0], 368, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded_1[0], 96, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 192, placeholder_global[192], 176, 12, placeholder_d_global[0], 48, placeholder_d_global[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded[0], 368, placeholder_global_1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", placeholder_encoded_1[0], 96, placeholder_d_global_1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 8, 32, 16, 0, 8, placeholder[256], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 8, 8, 16, 0, 8, ethosu_write[64], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_1[0], 192, placeholder_global_1[192], 176, 12, placeholder_d_global_1[0], 48, placeholder_d_global_1[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + __tvm_meta__ = None # fmt: on -def test_re_read_weights(): +@pytest.mark.parametrize( + "accelerator, reference_mod, reference_const_sizes", + [ + ( + "ethos-u55-128", + RereadWeightsU55, + [304, 80], + ), + ( + "ethos-u65-512", + RereadWeightsU65, + [368, 96], + ), + ], +) +def test_re_read_weights(accelerator, reference_mod, reference_const_sizes): def _cascader(cached_func, const_dict, sch): weights = cached_func.inputs[1] bias = cached_func.inputs[2] @@ -158,21 +257,23 @@ def _get_func(): func = run_opt_pass(func, relay.transform.InferType()) return func - func = _get_func() - mod, consts = _lower_to_tir(func, cascader=_cascader) - script = mod.script(show_meta=True) - test_mod = tvm.script.from_source(script) - reference_mod = RereadWeights - tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) + config = { + "accelerator_config": accelerator, + } + with tvm.transform.PassContext(config={"relay.ext.ethos-u.options": config}): + func = _get_func() + mod, consts = _lower_to_tir(func, cascader=_cascader) + script = mod.script(show_meta=True) + test_mod = tvm.script.from_source(script) + tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) - reference_const_sizes = [304, 80] - test_const_size = [value.size for value in list(consts.values())] - assert reference_const_sizes == test_const_size + test_const_size = [value.size for value in list(consts.values())] + assert reference_const_sizes == test_const_size # fmt: off @tvm.script.ir_module -class DirectReadOnly: +class DirectReadOnlyU55: @T.prim_func def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: # function attr dict @@ -185,13 +286,48 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], "int8", data=ethosu_write.data) # body ethosu_write_1 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer[0], 592, 12, buffer_1[0], 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 8, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, buffer_2[0], 160, 12, buffer_3[0], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer[0], 592, T.int8(-1), T.int8(-1), 12, buffer_1[0], 160, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 8, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, buffer_2[0], 160, T.int8(-1), T.int8(-1), 12, buffer_3[0], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + __tvm_meta__ = None + + +@tvm.script.ir_module +class DirectReadOnlyU65: + @T.prim_func + def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + # buffer definition + placeholder_encoded = T.buffer_decl([608], dtype="uint8") + placeholder_encoded_1 = T.buffer_decl([160], dtype="uint8") + placeholder_encoded_2 = T.buffer_decl([208], dtype="uint8") + placeholder_encoded_3 = T.buffer_decl([96], dtype="uint8") + T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) + T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) + # body + ethosu_write_2 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_encoded[0], 304, placeholder_encoded[304], 304, 12, placeholder_encoded_1[0], 80, placeholder_encoded_1[80], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 8, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_encoded_2[0], 112, placeholder_encoded_2[112], 96, 12, placeholder_encoded_3[0], 48, placeholder_encoded_3[48], 48, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None # fmt: on -def test_direct_read_only(): +@pytest.mark.parametrize( + "accelerator, reference_mod, reference_const_sizes", + [ + ( + "ethos-u55-128", + DirectReadOnlyU55, + [592, 160, 160, 80], + ), + ( + "ethos-u65-512", + DirectReadOnlyU65, + [608, 160, 208, 96], + ), + ], +) +def test_direct_read_only(accelerator, reference_mod, reference_const_sizes): def _get_func(): ifm = relay.var("ifm", shape=(1, 16, 16, 32), dtype="int8") conv1 = make_ethosu_conv2d( @@ -216,22 +352,25 @@ def _get_func(): func = run_opt_pass(func, relay.transform.InferType()) return func - func = _get_func() - mod, consts = _lower_to_tir(func) + config = { + "accelerator_config": accelerator, + } + with tvm.transform.PassContext(config={"relay.ext.ethos-u.options": config}): + func = _get_func() + mod, consts = _lower_to_tir(func) - script = mod.script(show_meta=True) - test_mod = tvm.script.from_source(script) - reference_mod = DirectReadOnly - tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) + print(mod.script()) + script = mod.script(show_meta=True) + test_mod = tvm.script.from_source(script) + tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) - reference_const_sizes = [592, 160, 160, 80] - test_const_size = [value.size for value in list(consts.values())] - assert reference_const_sizes == test_const_size + test_const_size = [value.size for value in list(consts.values())] + assert reference_const_sizes == test_const_size # fmt: off @tvm.script.ir_module -class MixedRead: +class MixedReadU55: @T.prim_func def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: # function attr dict @@ -252,24 +391,84 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), ethosu_write_1 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) placeholder_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True}) placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer[0], 592, 12, buffer_1[0], 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer[0], 592, T.int8(-1), T.int8(-1), 12, buffer_1[0], 160, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_2[0], 80, placeholder_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_3[0], 32, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 80, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 80, T.int8(-1), T.int8(-1), 12, placeholder_d_global[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_4[0], 80, placeholder_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_5[0], 32, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 80, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 80, T.int8(-1), T.int8(-1), 12, placeholder_d_global[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_6[0], 80, placeholder_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_7[0], 32, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 80, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 80, T.int8(-1), T.int8(-1), 12, placeholder_d_global[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_8[0], 80, placeholder_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_9[0], 32, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 80, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 80, T.int8(-1), T.int8(-1), 12, placeholder_d_global[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + __tvm_meta__ = None + + +@tvm.script.ir_module +class MixedReadU65: + @T.prim_func + def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), "int8"]) -> None: + # function attr dict + T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + # buffer definition + buffer_encoded = T.buffer_decl([96], dtype="uint8") + buffer_encoded_1 = T.buffer_decl([32], dtype="uint8") + buffer_encoded_2 = T.buffer_decl([96], dtype="uint8") + buffer_encoded_3 = T.buffer_decl([32], dtype="uint8") + buffer_encoded_4 = T.buffer_decl([96], dtype="uint8") + buffer_encoded_5 = T.buffer_decl([32], dtype="uint8") + buffer_encoded_6 = T.buffer_decl([96], dtype="uint8") + buffer_encoded_7 = T.buffer_decl([32], dtype="uint8") + placeholder_encoded = T.buffer_decl([608], dtype="uint8") + placeholder_encoded_1 = T.buffer_decl([160], dtype="uint8") + T.preflattened_buffer(placeholder, [1, 16, 16, 32], dtype="int8", data=placeholder.data) + T.preflattened_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", data=ethosu_write.data) + # body + ethosu_write_2 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) + placeholder_global = T.allocate([96], "uint8", "global", annotations={"disable_lower_builtin":True}) + placeholder_global_1 = T.buffer_decl([96], dtype="uint8", data=placeholder_global.data) + placeholder_global_2 = T.buffer_decl([96], dtype="uint8", data=placeholder_global.data) + placeholder_global_3 = T.buffer_decl([96], dtype="uint8", data=placeholder_global.data) + placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) + placeholder_d_global_1 = T.buffer_decl([32], dtype="uint8", data=placeholder_d_global.data) + placeholder_d_global_2 = T.buffer_decl([32], dtype="uint8", data=placeholder_d_global.data) + placeholder_d_global_3 = T.buffer_decl([32], dtype="uint8", data=placeholder_d_global.data) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_encoded[0], 304, placeholder_encoded[304], 304, 12, placeholder_encoded_1[0], 80, placeholder_encoded_1[80], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded[0], 96, placeholder_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_1[0], 32, placeholder_d_global[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 48, placeholder_global[48], 48, 12, placeholder_d_global[0], 16, placeholder_d_global[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_2[0], 96, placeholder_global_1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_3[0], 32, placeholder_d_global_1[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_1[0], 48, placeholder_global_1[48], 48, 12, placeholder_d_global_1[0], 16, placeholder_d_global_1[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_4[0], 96, placeholder_global_2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_5[0], 32, placeholder_d_global_2[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_2[0], 48, placeholder_global_2[48], 48, 12, placeholder_d_global_2[0], 16, placeholder_d_global_2[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_6[0], 96, placeholder_global_3[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_copy", buffer_encoded_7[0], 32, placeholder_d_global_3[0], dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global_3[0], 48, placeholder_global_3[48], 48, 12, placeholder_d_global_3[0], 16, placeholder_d_global_3[16], 16, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None # fmt: on -def test_mixed_read(): +@pytest.mark.parametrize( + "accelerator, reference_mod, reference_const_sizes", + [ + ( + "ethos-u55-128", + MixedReadU55, + [592, 160, 80, 32, 80, 32, 80, 32, 80, 32], + ), + ( + "ethos-u65-512", + MixedReadU65, + [608, 160, 96, 32, 96, 32, 96, 32, 96, 32], + ), + ], +) +def test_mixed_read(accelerator, reference_mod, reference_const_sizes): def _planner(cached_func, const_dict, sch): weight = cached_func.inputs[4] scale_bias = cached_func.inputs[5] @@ -305,28 +504,20 @@ def _get_func(): func = run_opt_pass(func, relay.transform.InferType()) return func - func = _get_func() - mod, consts = _lower_to_tir(func, cascader=_planner) - - script = mod.script(show_meta=True) - test_mod = tvm.script.from_source(script) - reference_mod = MixedRead - tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) - - reference_const_sizes = [ - 592, - 160, - 80, - 32, - 80, - 32, - 80, - 32, - 80, - 32, - ] - test_const_size = [value.size for value in list(consts.values())] - assert reference_const_sizes == test_const_size + config = { + "accelerator_config": accelerator, + } + with tvm.transform.PassContext(config={"relay.ext.ethos-u.options": config}): + func = _get_func() + mod, consts = _lower_to_tir(func, cascader=_planner) + + script = mod.script(show_meta=True) + test_mod = tvm.script.from_source(script) + print(mod.script()) + tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) + + test_const_size = [value.size for value in list(consts.values())] + assert reference_const_sizes == test_const_size def test_constant_as_input(): diff --git a/tests/python/contrib/test_ethosu/test_remove_concatenates.py b/tests/python/contrib/test_ethosu/test_remove_concatenates.py index b92b70657e8b..cc996e59412c 100644 --- a/tests/python/contrib/test_ethosu/test_remove_concatenates.py +++ b/tests/python/contrib/test_ethosu/test_remove_concatenates.py @@ -46,10 +46,10 @@ def main(placeholder: T.Buffer[(1536,), "int8"], placeholder_1: T.Buffer[(1280,) T.preflattened_buffer(T_concat, [1, 8, 32, 16], "int8", data=T_concat.data) # body T_concat_1 = T.allocate([2816], "int8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, placeholder_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 160, 16, 1, "int8", 8, 10, 16, 8, 0, 10, T_concat_1[192], 0, 0, 0, T.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, buffer[0], 2992, 12, buffer_1[0], 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, T_concat_1[192], 0, 0, 0, T.float32(0.5), 10, "NHWC", 352, 16, 1, "int8", 8, 10, 16, 8, 0, 10, T_concat[352], 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 16, 1, 3, 3, 1, 1, 1, 1, buffer_2[0], 2992, 12, buffer_3[0], 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 12, 16, 8, 0, 12, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 192, 16, 1, "int8", 8, 12, 16, 8, 0, 12, T_concat_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, buffer_4[0], 2992, 12, buffer_5[0], 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 22, 16, 8, 0, 22, T_concat_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 352, 16, 1, "int8", 8, 22, 16, 8, 0, 22, T_concat[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 16, 1, 3, 3, 1, 1, 1, 1, buffer_6[0], 2992, 12, buffer_7[0], 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, placeholder_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 160, 16, 1, "int8", 8, 10, 16, 8, 0, 10, T_concat_1[192], 0, 0, 0, T.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, buffer[0], 2992, T.int8(-1), T.int8(-1), 12, buffer_1[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 10, 16, 8, 0, 10, T_concat_1[192], 0, 0, 0, T.float32(0.5), 10, "NHWC", 352, 16, 1, "int8", 8, 10, 16, 8, 0, 10, T_concat[352], 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 16, 1, 3, 3, 1, 1, 1, 1, buffer_2[0], 2992, T.int8(-1), T.int8(-1), 12, buffer_3[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 12, 16, 8, 0, 12, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 192, 16, 1, "int8", 8, 12, 16, 8, 0, 12, T_concat_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 352, 16, 1, 3, 3, 1, 1, 1, 1, buffer_4[0], 2992, T.int8(-1), T.int8(-1), 12, buffer_5[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 22, 16, 8, 0, 22, T_concat_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 352, 16, 1, "int8", 8, 22, 16, 8, 0, 22, T_concat[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 16, 1, 3, 3, 1, 1, 1, 1, buffer_6[0], 2992, T.int8(-1), T.int8(-1), 12, buffer_7[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py index ca2c0608e9d2..b49890a9cf36 100644 --- a/tests/python/contrib/test_ethosu/test_replace_conv2d.py +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -23,7 +23,7 @@ from tvm.relay.testing import run_opt_pass from tvm.relay.backend.contrib.ethosu.tir.compiler import _lower_to_tir from tvm.relay.backend.contrib.ethosu.tir.scheduler import total_cascader -from .infra import make_ethosu_conv2d, get_convolutional_args +from .infra import make_ethosu_conv2d def _create_serial_conv2d_params( @@ -132,6 +132,28 @@ def _create_serial_conv2d_params( ] +def get_conv2d_args(call, include_buffers=False, remove_constants=False): + """A method to extract the arguments from conv2d extern call.""" + args = call.args + conv_args = [] + remove_indices = [0] + + if remove_constants: + remove_indices += [41, 42, 43, 44, 46, 47, 48, 49] + + for i, arg in enumerate(args): + if i in remove_indices: + continue + elif isinstance(arg, tvm.tir.expr.IntImm) or isinstance(arg, tvm.tir.expr.FloatImm): + conv_args.append(arg.value) + elif isinstance(arg, tvm.tir.expr.BufferLoad) and not include_buffers: + conv_args.append(arg.indices[0]) + else: + conv_args.append(arg) + + return conv_args + + @pytest.mark.parametrize( "trial", [ @@ -324,7 +346,7 @@ def _get_func( def _visit(stmt): if isinstance(stmt, tvm.tir.Call): - data.append(get_convolutional_args(stmt, remove_constants=True)) + data.append(get_conv2d_args(stmt, remove_constants=True)) tvm.tir.stmt_functor.post_order_visit(mod["main"].body, _visit) @@ -347,10 +369,10 @@ def main(placeholder_5: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(512, T.preflattened_buffer(ethosu_write_1, [1, 8, 8, 8], 'int8', data=ethosu_write_1.data) # body ethosu_write_2 = T.allocate([1024], "int8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, buffer_3[0], 160, 12, buffer_2[0], 320, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, buffer[0], 304, 12, buffer_1[0], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, placeholder_5[12], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, buffer_3[0], 160, 12, buffer_2[0], 320, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, ethosu_write_1[32], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, buffer[0], 304, 12, buffer_1[0], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, buffer_3[0], 160, T.int8(-1), T.int8(-1), 12, buffer_2[0], 320, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, buffer[0], 304, T.int8(-1), T.int8(-1), 12, buffer_1[0], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, placeholder_5[12], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, buffer_3[0], 160, T.int8(-1), T.int8(-1), 12, buffer_2[0], 320, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, ethosu_write_1[32], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, buffer[0], 304, T.int8(-1), T.int8(-1), 12, buffer_1[0], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None @@ -368,10 +390,10 @@ def main(placeholder_5: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(512, T.preflattened_buffer(ethosu_write_1, [1, 8, 8, 8], 'int8', data=ethosu_write_1.data) # body ethosu_write_2 = T.allocate([1536], "int8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, ethosu_write_2[256], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, buffer_2[0], 1312, 12, buffer_1[0], 320, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, ethosu_write_2[256], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, buffer_3[0], 2608, 12, buffer[0], 80, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[48], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, buffer_2[0], 1312, 12, buffer_1[0], 320, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, ethosu_write_1[256], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, buffer_3[0], 2608, 12, buffer[0], 80, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, ethosu_write_2[256], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, buffer_2[0], 1312, T.int8(-1), T.int8(-1), 12, buffer_1[0], 320, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, ethosu_write_2[256], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, buffer_3[0], 2608, T.int8(-1), T.int8(-1), 12, buffer[0], 80, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[48], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, buffer_2[0], 1312, T.int8(-1), T.int8(-1), 12, buffer_1[0], 320, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, ethosu_write_1[256], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, buffer_3[0], 2608, T.int8(-1), T.int8(-1), 12, buffer[0], 80, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None @@ -389,12 +411,12 @@ def main(placeholder_5: T.Buffer[(768,), "int8"], ethosu_write_1: T.Buffer[(640, T.preflattened_buffer(ethosu_write_1, [1, 20, 4, 8], 'int8', data=ethosu_write_1.data) # body ethosu_write_2 = T.allocate([2560], "int8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 3, 8, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 8, 8, 32, 8, 0, 8, ethosu_write_2[512], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, buffer_3[0], 880, 12, buffer_2[0], 320, 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 32, 8, 0, 8, ethosu_write_2[512], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, buffer[0], 1744, 12, buffer_1[0], 80, 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 12, 16, 3, 12, 0, 16, placeholder_5[192], 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 10, 8, 32, 10, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, buffer_3[0], 880, 12, buffer_2[0], 320, 0, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 10, 8, 32, 10, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, ethosu_write_1[256], 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, buffer[0], 1744, 12, buffer_1[0], 80, 0, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 16, 3, 4, 0, 16, placeholder_5[576], 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 4, 8, 32, 4, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, buffer_3[0], 880, 12, buffer_2[0], 320, 0, 1, 2, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 32, 4, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 4, 8, 4, 0, 4, ethosu_write_1[512], 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, buffer[0], 1744, 12, buffer_1[0], 80, 0, 1, 2, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 3, 8, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 8, 8, 32, 8, 0, 8, ethosu_write_2[512], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, buffer_3[0], 880, T.int8(-1), T.int8(-1), 12, buffer_2[0], 320, T.int8(-1), T.int8(-1), 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 32, 8, 0, 8, ethosu_write_2[512], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, buffer[0], 1744, T.int8(-1), T.int8(-1), 12, buffer_1[0], 80, T.int8(-1), T.int8(-1), 2, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 12, 16, 3, 12, 0, 16, placeholder_5[192], 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 10, 8, 32, 10, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, buffer_3[0], 880, T.int8(-1), T.int8(-1), 12, buffer_2[0], 320, T.int8(-1), T.int8(-1), 0, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 10, 8, 32, 10, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, ethosu_write_1[256], 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, buffer[0], 1744, T.int8(-1), T.int8(-1), 12, buffer_1[0], 80, T.int8(-1), T.int8(-1), 0, 1, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 16, 3, 4, 0, 16, placeholder_5[576], 0, 0, 0, T.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 4, 8, 32, 4, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, buffer_3[0], 880, T.int8(-1), T.int8(-1), 12, buffer_2[0], 320, T.int8(-1), T.int8(-1), 0, 1, 2, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 32, 4, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 4, 8, 4, 0, 4, ethosu_write_1[512], 0, 0, 0, T.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, buffer[0], 1744, T.int8(-1), T.int8(-1), 12, buffer_1[0], 80, T.int8(-1), T.int8(-1), 0, 1, 2, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None @@ -412,10 +434,10 @@ def main(placeholder_5: T.Buffer[(1024,), "int8"], ethosu_write_1: T.Buffer[(204 T.preflattened_buffer(ethosu_write_1, [1, 8, 2, 8, 16], 'int8', data=ethosu_write_1.data) # body ethosu_write_2 = T.allocate([2304], "int8", "global", annotations={"disable_lower_builtin": True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, ethosu_write_2[384], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, buffer[0], 1456, 12, buffer_1[0], 352, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, ethosu_write_2[384], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, buffer_3[0], 11040, 12, buffer_2[0], 272, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[256], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, buffer[0], 1456, 12, buffer_1[0], 352, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, ethosu_write_1[1024], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, buffer_3[0], 11040, 12, buffer_2[0], 272, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, ethosu_write_2[384], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, buffer[0], 1456, T.int8(-1), T.int8(-1), 12, buffer_1[0], 352, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, ethosu_write_2[384], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, buffer_3[0], 11040, T.int8(-1), T.int8(-1), 12, buffer_2[0], 272, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, placeholder_5[256], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, buffer[0], 1456, T.int8(-1), T.int8(-1), 12, buffer_1[0], 352, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, ethosu_write_2[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, ethosu_write_1[1024], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, buffer_3[0], 11040, T.int8(-1), T.int8(-1), 12, buffer_2[0], 272, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None @@ -433,10 +455,10 @@ def main(placeholder: T.Buffer[(192,), "int8"], ethosu_write: T.Buffer[(8192,), T.preflattened_buffer(ethosu_write, [1, 32, 32, 8], 'int8', data=ethosu_write.data) # body ethosu_write_1 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 3, 4, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 16, 32, 8, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 32, 1, 1, 1, 1, 1, 1, 1, buffer[0], 160, 12, buffer_1[0], 320, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 32, 8, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 32, 8, 16, 0, 32, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 8, 1, 1, 1, 1, 1, 1, 1, buffer_2[0], 304, 12, buffer_3[0], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 3, 4, 0, 8, placeholder[96], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 16, 32, 8, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 32, 1, 1, 1, 1, 1, 1, 1, buffer[0], 160, 12, buffer_1[0], 320, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 32, 8, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 32, 8, 16, 0, 32, ethosu_write[4096], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 8, 1, 1, 1, 1, 1, 1, 1, buffer_2[0], 304, 12, buffer_3[0], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 3, 4, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 16, 32, 8, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 32, 1, 1, 1, 1, 1, 1, 1, buffer[0], 160, T.int8(-1), T.int8(-1), 12, buffer_1[0], 320, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 32, 8, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 32, 8, 16, 0, 32, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 8, 1, 1, 1, 1, 1, 1, 1, buffer_2[0], 304, T.int8(-1), T.int8(-1), 12, buffer_3[0], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 4, 8, 3, 4, 0, 8, placeholder[96], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 16, 32, 8, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 512, 32, 1, 1, 1, 1, 1, 1, 1, buffer[0], 160, T.int8(-1), T.int8(-1), 12, buffer_1[0], 320, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 16, 32, 8, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 32, 8, 16, 0, 32, ethosu_write[4096], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 8, 1, 1, 1, 1, 1, 1, 1, buffer_2[0], 304, T.int8(-1), T.int8(-1), 12, buffer_3[0], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "ZEROS", 0, 0, 0, dtype="handle")) __tvm_meta__ = None @@ -454,8 +476,8 @@ def main(placeholder: T.Buffer[(1024,), "int8"], ethosu_write: T.Buffer[(32768,) T.preflattened_buffer(ethosu_write, [1, 32, 2, 32, 16], 'int8', data=ethosu_write.data) # body ethosu_write_1 = T.allocate([12288], "int8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 3, 8, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 16, 16, 35, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 768, 16, 256, 3, 3, 1, 1, 1, 1, buffer[0], 1456, 12, buffer_1[0], 352, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NEAREST", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 35, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 768, 16, 256, "int8", 32, 32, 26, 32, 0, 32, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 1024, 16, 512, 3, 3, 1, 1, 1, 1, buffer_2[0], 11040, 12, buffer_3[0], 272, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NEAREST", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 3, 8, 0, 8, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 16, 16, 35, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 768, 16, 256, 3, 3, 1, 1, 1, 1, buffer[0], 1456, T.int8(-1), T.int8(-1), 12, buffer_1[0], 352, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NEAREST", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 35, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHCWB16", 768, 16, 256, "int8", 32, 32, 26, 32, 0, 32, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHCWB16", 1024, 16, 512, 3, 3, 1, 1, 1, 1, buffer_2[0], 11040, T.int8(-1), T.int8(-1), 12, buffer_3[0], 272, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NEAREST", 0, 0, 0, dtype="handle")) __tvm_meta__ = None # fmt: on @@ -614,7 +636,7 @@ def main(placeholder_3: T.Buffer[(960,), "int8"], ethosu_write_1: T.Buffer[(1024 T.preflattened_buffer(placeholder_3, [1, 10, 12, 8], 'int8', data=placeholder_3.data) T.preflattened_buffer(ethosu_write_1, [1, 8, 8, 16], 'int8', data=ethosu_write_1.data) # body - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 4, 8, 0, 8, placeholder_3[120], 0, 0, 0, T.float32(0.5), 10, "NHWC", 96, 8, 1, "int8", 8, 8, 16, 8, 0, 8, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 16, 1, 3, 3, 1, 1, 1, 1, buffer[0], 848, 12, buffer_1[0], 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 8, 8, 4, 8, 0, 8, placeholder_3[120], 0, 0, 0, T.float32(0.5), 10, "NHWC", 96, 8, 1, "int8", 8, 8, 16, 8, 0, 8, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 16, 1, 3, 3, 1, 1, 1, 1, buffer[0], 848, T.int8(-1), T.int8(-1), 12, buffer_1[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None @@ -629,7 +651,7 @@ def main(placeholder_3: T.Buffer[(315,), "int8"], ethosu_write_1: T.Buffer[(240, T.preflattened_buffer(placeholder_3, [1, 7, 9, 5], 'int8', data=placeholder_3.data) T.preflattened_buffer(ethosu_write_1, [1, 3, 5, 16], 'int8', data=ethosu_write_1.data) # body - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 3, 5, 3, 3, 0, 5, placeholder_3[146], 0, 0, 0, T.float32(0.5), 10, "NHWC", 45, 5, 1, "int8", 3, 5, 16, 3, 0, 5, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 80, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 656, 12, buffer[0], 160, 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 3, 5, 3, 3, 0, 5, placeholder_3[146], 0, 0, 0, T.float32(0.5), 10, "NHWC", 45, 5, 1, "int8", 3, 5, 16, 3, 0, 5, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 80, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 656, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None # fmt: on @@ -673,8 +695,8 @@ def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768, T.preflattened_buffer(placeholder_3, [4, 6, 8, 1], 'int8', data=placeholder_3.data) T.preflattened_buffer(ethosu_write_1, [1, 8, 6, 16], 'int8', data=ethosu_write_1.data) # body - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, 12, buffer[0], 160, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, 12, buffer[0], 160, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None @@ -689,8 +711,8 @@ def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768, T.preflattened_buffer(placeholder_3, [1, 24, 8], 'int8', data=placeholder_3.data) T.preflattened_buffer(ethosu_write_1, [1, 8, 6, 16], 'int8', data=ethosu_write_1.data) # body - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, 12, buffer[0], 160, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, 12, buffer[0], 160, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None @@ -705,8 +727,8 @@ def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768, T.preflattened_buffer(placeholder_3, [192, 1], 'int8', data=placeholder_3.data) T.preflattened_buffer(ethosu_write_1, [1, 8, 6, 16], 'int8', data=ethosu_write_1.data) # body - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, 12, buffer[0], 160, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, 12, buffer[0], 160, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None @@ -721,8 +743,8 @@ def main(placeholder_3: T.Buffer[(192,), "int8"], ethosu_write_1: T.Buffer[(768, T.preflattened_buffer(placeholder_3, [192], 'int8', data=placeholder_3.data) T.preflattened_buffer(ethosu_write_1, [1, 8, 6, 16], 'int8', data=ethosu_write_1.data) # body - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, 12, buffer[0], 160, 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, 12, buffer[0], 160, 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 1, 1, 0, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, placeholder_3[72], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, ethosu_write_1[384], 0, 0, 0, T.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, buffer_1[0], 848, T.int8(-1), T.int8(-1), 12, buffer[0], 160, T.int8(-1), T.int8(-1), 0, 1, 1, 1, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py b/tests/python/contrib/test_ethosu/test_replace_copy.py index 23d3d7fe967b..4f06695b25b1 100644 --- a/tests/python/contrib/test_ethosu/test_replace_copy.py +++ b/tests/python/contrib/test_ethosu/test_replace_copy.py @@ -43,7 +43,7 @@ def main(placeholder_3: T.Buffer[(8192,), "int8"], ethosu_write_1: T.Buffer[(204 placeholder_d_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin": True}) T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 304, placeholder_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer[0], 80, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 8, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 304, 12, placeholder_d_global[0], 80, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 8, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 304, T.int8(-1), T.int8(-1), 12, placeholder_d_global[0], 80, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None # fmt: on @@ -93,10 +93,10 @@ def main(placeholder_5: T.Buffer[(8192,), "int8"], ethosu_write_1: T.Buffer[(409 placeholder_d_global_unrolled_iter_1 = T.buffer_decl([64], dtype="uint8", data=placeholder_d_global_unrolled_iter_0.data) T.evaluate(T.call_extern("ethosu_copy", buffer[0], 416, placeholder_global_unrolled_iter_0[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 112, placeholder_d_global_unrolled_iter_0[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 10, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_global_unrolled_iter_0[0], 416, 12, placeholder_d_global_unrolled_iter_0[0], 112, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 10, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_global_unrolled_iter_0[0], 416, T.int8(-1), T.int8(-1), 12, placeholder_d_global_unrolled_iter_0[0], 112, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_2[0], 272, placeholder_global_unrolled_iter_1[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_3[0], 64, placeholder_d_global_unrolled_iter_1[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 6, 16, 0, 16, ethosu_write_1[10], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_global_unrolled_iter_1[0], 272, 12, placeholder_d_global_unrolled_iter_1[0], 64, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder_5[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 6, 16, 0, 16, ethosu_write_1[10], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_global_unrolled_iter_1[0], 272, T.int8(-1), T.int8(-1), 12, placeholder_d_global_unrolled_iter_1[0], 64, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_scheduler.py b/tests/python/contrib/test_ethosu/test_scheduler.py index bc0232fc99c6..8a83e769141d 100644 --- a/tests/python/contrib/test_ethosu/test_scheduler.py +++ b/tests/python/contrib/test_ethosu/test_scheduler.py @@ -198,10 +198,10 @@ def main(input_buffer: T.Buffer[(301056,), "int8"], output_buffer: T.Buffer[(752 T.evaluate(T.call_extern("ethosu_copy", weight_buffer[0], 2608, weight_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", bias_buffer[0], 240, bias_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 96, 56, 0, 56, input_buffer[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 5376, 96, 1, "int8", 56, 56, 24, 56, 0, 56, featuremap_buffer[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, weight_global[0], 2608, 12, bias_global[0], 240, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 96, 56, 0, 56, input_buffer[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 5376, 96, 1, "int8", 56, 56, 24, 56, 0, 56, featuremap_buffer[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, weight_global[0], 2608, T.int8(-1), T.int8(-1), 12, bias_global[0], 240, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", weight_buffer2[0], 736, weight_global2[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", bias_buffer2[0], 240, bias_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 24, 56, 0, 56, featuremap_buffer[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, featuremap_buffer2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, weight_global2[0], 736, 12, bias_global[0], 240, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 56, 56, 24, 56, 0, 56, featuremap_buffer[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, featuremap_buffer2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 1344, 24, 1, 1, 1, 1, 1, 1, 1, weight_global2[0], 736, T.int8(-1), T.int8(-1), 12, bias_global[0], 240, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_binary_elementwise", "int8", 56, 56, 24, 56, 0, 56, featuremap_buffer[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, featuremap_buffer2[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "int8", 56, 56, 24, 56, 0, 56, output_buffer[0], 0, 0, 0, T.float32(1), 0, "NHWC", 1344, 24, 1, "ADD", 0, "NONE", 0, 0, "TFL", 0, 0, 0, dtype="handle")) __tvm_meta__ = None # fmt: on diff --git a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py index 3e12a662167d..28522138cafc 100644 --- a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py +++ b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py @@ -39,7 +39,7 @@ def main(placeholder_3: T.Buffer[(8192,), "int8"], ethosu_conv2d_1: T.Buffer[(10 placeholder_4 = T.buffer_decl([1], "uint8") placeholder_5 = T.buffer_decl([1], "uint8") # body - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 8, 8, 3, 8, 0, 8, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 8, 8, 16, 8, 0, 8, ethosu_conv2d_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_4[0], 0, 12, placeholder_5[0], 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", 0, 0, 0, dtype="uint8")) + T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 8, 8, 3, 8, 0, 8, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 8, 8, 16, 8, 0, 8, ethosu_conv2d_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 16, 1, 1, 1, 1, 1, 1, 1, placeholder_4[0], 0, T.int8(-1), T.int8(-1), 12, placeholder_5[0], 0, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", 0, 0, 0, dtype="uint8")) # fmt: on @@ -58,10 +58,10 @@ def main(placeholder_6: T.Buffer[(192,), "int8"], ethosu_conv2d_1: T.Buffer[(512 # body ethosu_conv2d_2 = T.allocate([1024], "uint8", "global") ethosu_conv2d_3 = T.allocate([2048], "uint8", "global") - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 3, 4, 0, 8, placeholder_6[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 4, 8, 32, 4, 0, 8, ethosu_conv2d_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 1, 1, 1, 1, 1, 1, placeholder_7[0], 0, 12, placeholder_8[0], 0, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="uint8")) - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 32, 4, 0, 8, ethosu_conv2d_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "uint8", 4, 8, 8, 4, 0, 8, ethosu_conv2d_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_9[0], 0, 12, placeholder_5[0], 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", 0, 0, 0, dtype="uint8")) - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 3, 4, 0, 8, placeholder_6[96], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 4, 8, 32, 4, 0, 8, ethosu_conv2d_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 1, 1, 1, 1, 1, 1, placeholder_7[0], 0, 12, placeholder_8[0], 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", 0, 0, 0, dtype="uint8")) - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 32, 4, 0, 8, ethosu_conv2d_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "uint8", 4, 8, 8, 4, 0, 8, ethosu_conv2d_1[256], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_9[0], 0, 12, placeholder_5[0], 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", 0, 0, 0, dtype="uint8")) + T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 3, 4, 0, 8, placeholder_6[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 4, 8, 32, 4, 0, 8, ethosu_conv2d_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 1, 1, 1, 1, 1, 1, placeholder_7[0], 0, T.int8(-1), T.int8(-1), 12, placeholder_8[0], 0, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="uint8")) + T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 32, 4, 0, 8, ethosu_conv2d_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "uint8", 4, 8, 8, 4, 0, 8, ethosu_conv2d_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_9[0], 0, T.int8(-1), T.int8(-1), 12, placeholder_5[0], 0, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", 0, 0, 0, dtype="uint8")) + T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 3, 4, 0, 8, placeholder_6[96], 0, 0, 0, T.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 4, 8, 32, 4, 0, 8, ethosu_conv2d_2[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 32, 1, 1, 1, 1, 1, 1, 1, placeholder_7[0], 0, T.int8(-1), T.int8(-1), 12, placeholder_8[0], 0, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", 0, 0, 0, dtype="uint8")) + T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 4, 8, 32, 4, 0, 8, ethosu_conv2d_2[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 32, 1, "uint8", 4, 8, 8, 4, 0, 8, ethosu_conv2d_1[256], 0, 0, 0, T.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_9[0], 0, T.int8(-1), T.int8(-1), 12, placeholder_5[0], 0, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", 0, 0, 0, dtype="uint8")) # fmt: on @@ -80,7 +80,7 @@ def main(placeholder_3: T.Buffer[(8192,), "int8"], ethosu_conv2d_1: T.Buffer[(20 placeholder_d_global = T.allocate([8], "int32", "global") T.evaluate(T.call_extern("ethosu_copy", placeholder_4[0], 256, placeholder_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", placeholder_5[0], 8, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 16, 16, 32, 16, 0, 16, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "uint8", 16, 16, 8, 16, 0, 16, ethosu_conv2d_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 0, 12, placeholder_d_global[0], 0, 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "uint8", 16, 16, 32, 16, 0, 16, placeholder_3[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "uint8", 16, 16, 8, 16, 0, 16, ethosu_conv2d_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 0, T.int8(-1), T.int8(-1), 12, placeholder_d_global[0], 0, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "CLIP", 0, 255, "TFL", "NONE", 0, 0, 0, dtype="handle")) # fmt: on @@ -114,16 +114,16 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) T.evaluate(T.call_extern("ethosu_copy", buffer[0], 128, placeholder_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_1[0], 32, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 128, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 128, T.int8(-1), T.int8(-1), 12, placeholder_d_global[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_2[0], 112, placeholder_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_3[0], 32, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 112, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 112, T.int8(-1), T.int8(-1), 12, placeholder_d_global[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_4[0], 112, placeholder_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_5[0], 32, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 112, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 112, T.int8(-1), T.int8(-1), 12, placeholder_d_global[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_6[0], 112, placeholder_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_7[0], 32, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 112, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 112, T.int8(-1), T.int8(-1), 12, placeholder_d_global[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None # fmt: on @@ -161,19 +161,19 @@ def main(placeholder: T.Buffer[(8192,), "int8"], ethosu_write: T.Buffer[(2048,), ethosu_write_1 = T.allocate([4096], "int8", "global", annotations={"disable_lower_builtin":True}) placeholder_global = T.allocate([80], "uint8", "global", annotations={"disable_lower_builtin":True}) placeholder_d_global = T.allocate([32], "uint8", "global", annotations={"disable_lower_builtin":True}) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer[0], 592, 12, buffer_1[0], 160, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, placeholder[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, buffer[0], 592, T.int8(-1), T.int8(-1), 12, buffer_1[0], 160, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_2[0], 80, placeholder_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_3[0], 32, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 80, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[0], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 80, T.int8(-1), T.int8(-1), 12, placeholder_d_global[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_4[0], 80, placeholder_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_5[0], 32, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 80, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[2], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 80, T.int8(-1), T.int8(-1), 12, placeholder_d_global[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_6[0], 80, placeholder_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_7[0], 32, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 80, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[4], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 80, T.int8(-1), T.int8(-1), 12, placeholder_d_global[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_8[0], 80, placeholder_global[0], dtype="handle")) T.evaluate(T.call_extern("ethosu_copy", buffer_9[0], 32, placeholder_d_global[0], dtype="handle")) - T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 80, 12, placeholder_d_global[0], 32, 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) + T.evaluate(T.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, ethosu_write_1[0], 0, 0, 0, T.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, ethosu_write[6], 0, 0, 0, T.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, placeholder_global[0], 80, T.int8(-1), T.int8(-1), 12, placeholder_d_global[0], 32, T.int8(-1), T.int8(-1), 0, 0, 0, 0, "NONE", 0, 0, "TFL", "NONE", 0, 0, 0, dtype="handle")) __tvm_meta__ = None # fmt: on @@ -255,7 +255,9 @@ def test_buffer_info_extraction(): buffer_info = tir_to_cs_translator.extract_buffer_info(tir_mod, test_case["param_dict"]) for buffer_var, info in buffer_info.items(): if buffer_var in test_case["param_dict"].keys(): - assert (info.values == test_case["param_dict"][buffer_var]).all() + assert ( + info.values.flatten() == test_case["param_dict"][buffer_var].flatten() + ).all() assert info.dtype == test_case["param_dict"][buffer_var].dtype info.btype == tir_to_cs_translator.BufferType.constant else: