From fc27e339b3ba43971b9215a5b8c3876088183e92 Mon Sep 17 00:00:00 2001 From: Manupa Karunaratne Date: Mon, 27 Sep 2021 20:36:21 +0100 Subject: [PATCH] Arm(R) Ethos(TM)-U NPU TIR to CS for Conv2D (#8811) This commit introduces the TIR to Command Stream(CS) translation using Vela API calls for conv2D and copy operations. It will create Vela npu_op objects for each command. Change-Id: I906d2cb333652813142cc70fb39b8372ec498bd0 --- .../contrib/ethosu/tir_to_cs_translator.py | 244 ++++++ .../test_ethosu/test_tir_to_cs_translator.py | 770 ++++++++++++++++++ 2 files changed, 1014 insertions(+) create mode 100644 tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py index ce9abcbd683d..4b28dc5b191e 100644 --- a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -25,6 +25,8 @@ import ethosu.vela.api as vapi # type: ignore import tvm +from tvm.tir import stmt_functor +from tvm.relay.backend.contrib.ethosu import util from tvm.relay.backend.contrib.ethosu import vela_api from tvm.relay.backend.contrib.ethosu.tir import spec @@ -39,6 +41,14 @@ class BufferType(Enum): output = auto() +_REGION_MAP = { + BufferType.constant: 0, + BufferType.scratch: 1, + BufferType.input: 3, + BufferType.output: 4, +} + + class BufferInfo(NamedTuple): """A data structure to hold metadata of the buffer""" @@ -49,6 +59,72 @@ class BufferInfo(NamedTuple): btype: BufferType +def translate(tir_module, params): + """This will take an tir module for the NPU + and compile to command stream + + Parameters + ---------- + tir_module : tvm.IRModule + The TIR module containing ethosu extern calls + params : dict + A dictionary containing TIR primfunc argument ordering + idx to constant NDArray map + accel_type : ethosu.vela.api.NpuAccelerator + the accelerator variant the tir module needs to compiled to + + Returns + ------- + cs : str + An hex string of the bytes of command stream + encoded_constants : str + An hex string of the bytes that includes concat'd + encoded weights, encoded biases and scales. + scratch_size : int + The size of the scratch buffer needed. + """ + + buffer_info = extract_buffer_info(tir_module, params) + extern_calls = extract_extern_calls(tir_module) + _npu_ops = list() + for extern_call in extern_calls: + _npu_ops.append(translate_ethosu_tir_extern_call(extern_call)) + _npu_ops, constant_tensor, scratch_size = assign_addresses(buffer_info, _npu_ops) + target_accel_type = vela_api.get_target_accel_type() + cmds = vapi.npu_generate_register_command_stream(_npu_ops, target_accel_type) + payload = vapi.npu_create_driver_payload(cmds, target_accel_type) + hex_value = "" if constant_tensor is None else constant_tensor.tobytes().hex() + return payload.hex(), hex_value, scratch_size + + +def extract_extern_calls(mod): + """This function will obtain all extern + calls from a TIR module + Parameters + ---------- + mod : tvm.IRModule + The TIR Module for NPU + + Returns + ------- + list + of tvm.tir.Call objects + that are tir extern calls + """ + # There should only be a single function + assert len(mod.functions.items()) == 1 + primfunc = mod.functions.items()[0][1] + + extern_calls = list() + + def populate_extern_calls(stmt): + if isinstance(stmt, tvm.tir.Call) and stmt.op.name == "tir.call_extern": + extern_calls.append(stmt) + + stmt_functor.post_order_visit(primfunc.body, populate_extern_calls) + return extern_calls + + def extract_buffer_info(mod, param_dict): """ This function is to read the tvm.IRModule that @@ -101,6 +177,156 @@ def populate_allocate_buffer_info(stmt): return buffer_info +def assign_addresses(buffer_info, npu_ops): + """This function will assign addresses to tensors + within two buffers : scratch and constants. + The scratch is the buffer created to hold all intermediary data + The constants is the buffer created via unifying all the constant data + (post-encoding). + Parameters + ---------- + buffer_info : dict + This is the dictionary obtained via calling extract_buffer_info. + The key is the buffer name to BufferInfo + npu_ops : list + A list of Vela NpuOps with tir.Loads for addresses + Returns + ------- + npu_ops : list + A list of Vela NpuOps with addesses within scratch and constant buffers + constant_tensor : NDArray + A unified constant data array of uint8 as the constant buffer + scratch_size : int + The size of the scratch tensor. + """ + + def replace_npu_fm_with_address(npu_fm): + assert isinstance(npu_fm.tiles.addresses[0], tvm.tir.Load) + # We currently does not support tiles + # Change this when tiles are needed + # (i.e. when using rolling buffers) + assert npu_fm.tiles.addresses[1:] == [0, 0, 0] + npu_fm.tiles.addresses[1:] = [0, 0, 0] + buffer = npu_fm.tiles.addresses[0].buffer_var + assert buffer in buffer_addresses.keys() + address, buffer_type = buffer_addresses[buffer] + npu_fm.tiles.addresses[0] = address + npu_fm.region = _REGION_MAP[buffer_type] + return npu_fm + + def replace_npu_address_range_with_address(npu_addr_range): + assert isinstance(npu_addr_range.address, tvm.tir.Load) + buffer = npu_addr_range.address.buffer_var + assert buffer in buffer_addresses.keys() + address, buffer_type = buffer_addresses[buffer] + return vapi.NpuAddressRange(_REGION_MAP[buffer_type], address, npu_addr_range.length) + + def replace_tir_loads(npu_object): + if isinstance(npu_object, vapi.NpuFeatureMap): + return replace_npu_fm_with_address(npu_object) + if isinstance(npu_object, vapi.NpuAddressRange): + return replace_npu_address_range_with_address(npu_object) + return npu_object + + def classify_io(buffer): + for _npu_op in npu_ops: + if issubclass(type(_npu_op), vapi.NpuBlockOperation): + if _npu_op.ifm and _npu_op.ifm.tiles.addresses[0].buffer_var == buffer: + return BufferType.input + if _npu_op.ifm2 and _npu_op.ifm2.tiles.addresses[0].buffer_var == buffer: + return BufferType.input + if _npu_op.ofm and _npu_op.ofm.tiles.addresses[0].buffer_var == buffer: + return BufferType.output + + raise ValueError(f"Unused IO : {buffer} in tir module.") + + scratch_size = 0 + constant_tensor = None + buffer_addresses = dict() + for _buffer, info in buffer_info.items(): + if info.values is not None: + assert np.dtype(info.dtype) == np.uint8 + assert info.btype == BufferType.constant + assert len(info.shape) == 1 + if constant_tensor is None: + buffer_addresses[_buffer] = (0, info.btype) + assert info.values.dtype == np.uint8 + size_in_bytes = info.values.size + # Every memory address the NPU access have to be 16 byte aligned + size_in_bytes = util.round_up(size_in_bytes, 16) + constant_tensor = np.resize(info.values, size_in_bytes) + else: + buffer_addresses[_buffer] = (constant_tensor.size, info.btype) + assert info.values.dtype == np.uint8 + size_in_bytes = info.values.size + # Every memory address the NPU access have to be 16 byte aligned + size_in_bytes = util.round_up(size_in_bytes, 16) + constant_tensor = np.append(constant_tensor, np.resize(info.values, size_in_bytes)) + else: + size_in_bytes = int( + (np.iinfo(np.dtype(info.dtype)).bits // 8) * np.prod(list(info.shape)) + ) + # Every memory address the NPU access have to be 16 byte aligned + size_in_bytes = util.round_up(size_in_bytes, 16) + if info.btype == BufferType.input_or_output: + buffer_type = classify_io(_buffer) + assert buffer_type in (BufferType.input, BufferType.output) + address = 0 + buffer_addresses[_buffer] = (address, buffer_type) + else: + assert info.btype == BufferType.scratch + address = scratch_size + scratch_size += size_in_bytes + buffer_addresses[_buffer] = (address, info.btype) + + for npu_op in npu_ops: + for attr_name, attr in npu_op.__dict__.items(): + if isinstance(attr, list): + new_attr = list() + for attr_ in attr: + new_attr.append(replace_tir_loads(attr_)) + setattr(npu_op, attr_name, new_attr) + else: + setattr(npu_op, attr_name, replace_tir_loads(attr)) + + return npu_ops, constant_tensor, scratch_size + + +def translate_ethosu_tir_extern_call(tir_extern_call): + """This is a dispatcher function to dispatch + correct translation call depending on the extern call's + first argument""" + supported_extern_calls = { + "ethosu_conv2d": translate_ethosu_conv2d, + "ethosu_copy": translate_ethosu_copy, + } + ext_call_type = tir_extern_call.args[0].value + assert ext_call_type in supported_extern_calls.keys(), f"{ext_call_type} is not yet supported" + npu_op = supported_extern_calls[ext_call_type](tir_extern_call) + # Some conversions return additional outputs + # if they are needed, the caller should use the function directly + if isinstance(npu_op, tuple): + return npu_op[0] + return npu_op + + +def translate_ethosu_copy(tir_extern_call): + """This function will translate a tir ethosu_copy extern_call + as produced by Relay to TIR compilation. + Parameters + ---------- + tir_extern_call : tvm.tir.Call + + Returns + ------- + ethosu.vela.api.NpuDmaOperation + The vela object containing the params of ethosu_copy + """ + # We skip the first element as it is the extern_call function name + serial_object = spec.create_serial_object(spec.SerialCopy, tir_extern_call.args[1:]) + return _create_npu_dma_op(serial_object) + + def _convert_clip_bounds(npu_op): """ This function will convert the min and max value @@ -330,3 +556,21 @@ def _create_npu_resampling_mode( mode = str(mode.value) assert mode in mode_map.keys() return mode_map[mode] + + +def _create_npu_dma_op(serial_copy): + """This is a helper function to capture the list of arguments + to create a NpuDmaOperation object""" + src = vapi.NpuAddressRange( + # region will be updated later + region=0, + address=serial_copy.read_address, + length=int(serial_copy.length.value), + ) + dest = vapi.NpuAddressRange( + # region will be updated later + region=0, + address=serial_copy.write_address, + length=int(serial_copy.length.value), + ) + return vapi.NpuDmaOperation(src, dest) diff --git a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py new file mode 100644 index 000000000000..479a1032453a --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py @@ -0,0 +1,770 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +import pytest + +pytest.importorskip("ethosu.vela") +import numpy as np + +import tvm +from tvm import tir +from tvm.tir import stmt_functor +from tvm.script import ty +from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator +from tvm.relay.backend.contrib.ethosu import util +import ethosu.vela.api as vapi + + +# fmt: off +"""A sample tir test case for translator""" +@tvm.script.tir +class SingleEthosUConv2D: + def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, ethosu_conv2d: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_4 = tir.match_buffer(placeholder_1, [1, 1, 3, 16], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = tir.match_buffer(placeholder_2, [16], dtype="int32", elem_offset=0, align=128, offset_factor=1) + placeholder_3 = tir.match_buffer(placeholder, [1, 8, 8, 3], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + ethosu_conv2d_1 = tir.match_buffer(ethosu_conv2d, [1, 8, 8, 16], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + tir.evaluate(tir.call_extern("ethosu_conv2d", "uint8", 8, 8, 3, 8, 0, 8, tir.load("uint8", placeholder_3.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 8, 8, 16, 8, 0, 8, tir.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 16, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_4.data, 0), 0, 12, tir.load("uint8", placeholder_5.data, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="uint8")) +# fmt: on + + +# fmt: off +"""A sample tir test case with multiple convolutions for translator""" +@tvm.script.tir +class MultiEthosUConv2D: + def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, placeholder_3: ty.handle, placeholder_4: ty.handle, ethosu_conv2d: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_9 = tir.match_buffer(placeholder_3, [1, 1, 32, 8], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + ethosu_conv2d_1 = tir.match_buffer(ethosu_conv2d, [1, 8, 8, 8], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_7 = tir.match_buffer(placeholder_1, [1, 1, 3, 32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_6 = tir.match_buffer(placeholder, [1, 8, 8, 3], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_8 = tir.match_buffer(placeholder_2, [32], dtype="int32", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = tir.match_buffer(placeholder_4, [8], dtype="int32", elem_offset=0, align=128, offset_factor=1) + # body + ethosu_conv2d_2 = tir.allocate([1024], "uint8", "global") + ethosu_conv2d_3 = tir.allocate([2048], "uint8", "global") + tir.evaluate(tir.call_extern("ethosu_conv2d", "uint8", 4, 8, 3, 4, 0, 8, tir.load("uint8", placeholder_6.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 4, 8, 32, 4, 0, 8, tir.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 256, 32, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_7.data, 0), 0, 12, tir.load("uint8", placeholder_8.data, 0), 0, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="uint8")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "uint8", 4, 8, 32, 4, 0, 8, tir.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 32, 1, "uint8", 4, 8, 8, 4, 0, 8, tir.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_9.data, 0), 0, 12, tir.load("uint8", placeholder_5.data, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="uint8")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "uint8", 4, 8, 3, 4, 0, 8, tir.load("uint8", placeholder_6.data, 96), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 3, 1, "uint8", 4, 8, 32, 4, 0, 8, tir.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 256, 32, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_7.data, 0), 0, 12, tir.load("uint8", placeholder_8.data, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="uint8")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "uint8", 4, 8, 32, 4, 0, 8, tir.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 32, 1, "uint8", 4, 8, 8, 4, 0, 8, tir.load("uint8", ethosu_conv2d_1.data, 256), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_9.data, 0), 0, 12, tir.load("uint8", placeholder_5.data, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="uint8")) +# fmt: on + + +# fmt: off +"""A sample tir test case with copy operations for translator""" +@tvm.script.tir +class MultiEthosUCopy: + def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, ethosu_conv2d: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + placeholder_3 = tir.match_buffer(placeholder, [1, 16, 16, 32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + ethosu_conv2d_1 = tir.match_buffer(ethosu_conv2d, [1, 16, 16, 8], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = tir.match_buffer(placeholder_2, [8], dtype="int32", elem_offset=0, align=128, offset_factor=1) + placeholder_4 = tir.match_buffer(placeholder_1, [8, 1, 1, 32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + placeholder_global = tir.allocate([256], "uint8", "global") + placeholder_d_global = tir.allocate([8], "int32", "global") + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", placeholder_4.data, 0), 256, tir.load("uint8", placeholder_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("int32", placeholder_5.data, 0), 8, tir.load("int32", placeholder_d_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "uint8", 16, 16, 32, 16, 0, 16, tir.load("uint8", placeholder_3.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 512, 32, 1, "uint8", 16, 16, 8, 16, 0, 16, tir.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 0, 12, tir.load("uint8", placeholder_d_global, 0), 0, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="handle")) +# fmt: on + + +# fmt: off +"""A TIR test module of weight streaming""" +@tvm.script.tir +class WeightStreamOnly: + def main(placeholder: ty.handle, ethosu_conv2d: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, placeholder_3: ty.handle, placeholder_4: ty.handle, placeholder_5: ty.handle, placeholder_6: ty.handle, placeholder_7: ty.handle, placeholder_8: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + buffer_4 = tir.match_buffer(placeholder_5, [144], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_3 = tir.match_buffer(placeholder_4, [20], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_7 = tir.match_buffer(placeholder_7, [144], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_5 = tir.match_buffer(placeholder_1, [144], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_6 = tir.match_buffer(placeholder_6, [20], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + ethosu_conv2d_1 = tir.match_buffer(ethosu_conv2d, [1, 16, 16, 8], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = tir.match_buffer(placeholder_3, [144], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_2 = tir.match_buffer(placeholder_2, [20], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_9 = tir.match_buffer(placeholder, [1, 16, 16, 32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer = tir.match_buffer(placeholder_8, [20], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + placeholder_global = tir.allocate([144], "uint8", "global") + placeholder_d_global = tir.allocate([20], "uint8", "global") + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_5.data, 0), 144, tir.load("uint8", placeholder_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_2.data, 0), 20, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "uint8", 16, 16, 32, 16, 0, 16, tir.load("uint8", placeholder_9.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 512, 32, 1, "uint8", 16, 16, 2, 16, 0, 16, tir.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 144, 12, tir.load("uint8", placeholder_d_global, 0), 20, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_1.data, 0), 144, tir.load("uint8", placeholder_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_3.data, 0), 20, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "uint8", 16, 16, 32, 16, 0, 16, tir.load("uint8", placeholder_9.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 512, 32, 1, "uint8", 16, 16, 2, 16, 0, 16, tir.load("uint8", ethosu_conv2d_1.data, 2), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 144, 12, tir.load("uint8", placeholder_d_global, 0), 20, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_4.data, 0), 144, tir.load("uint8", placeholder_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_6.data, 0), 20, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "uint8", 16, 16, 32, 16, 0, 16, tir.load("uint8", placeholder_9.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 512, 32, 1, "uint8", 16, 16, 2, 16, 0, 16, tir.load("uint8", ethosu_conv2d_1.data, 4), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 144, 12, tir.load("uint8", placeholder_d_global, 0), 20, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_7.data, 0), 144, tir.load("uint8", placeholder_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer.data, 0), 20, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "uint8", 16, 16, 32, 16, 0, 16, tir.load("uint8", placeholder_9.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 512, 32, 1, "uint8", 16, 16, 2, 16, 0, 16, tir.load("uint8", ethosu_conv2d_1.data, 6), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 144, 12, tir.load("uint8", placeholder_d_global, 0), 20, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="handle")) + __tvm_meta__ = None +# fmt: on + + +# fmt: off +"""A TIR test module of weight streaming and direct reading""" +@tvm.script.tir +class MixedRead: + def main(placeholder: ty.handle, placeholder_1: ty.handle, ethosu_conv2d: ty.handle, placeholder_2: ty.handle, placeholder_3: ty.handle, placeholder_4: ty.handle, placeholder_5: ty.handle, placeholder_6: ty.handle, placeholder_7: ty.handle, placeholder_8: ty.handle, placeholder_9: ty.handle, placeholder_10: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "main", "tir.noalias": True}) + buffer_5 = tir.match_buffer(placeholder_1, [592], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_7 = tir.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_3 = tir.match_buffer(placeholder_7, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_6 = tir.match_buffer(placeholder_4, [20], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_9 = tir.match_buffer(placeholder_5, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + ethosu_conv2d_1 = tir.match_buffer(ethosu_conv2d, [1, 16, 16, 8], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer = tir.match_buffer(placeholder_8, [20], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = tir.match_buffer(placeholder_10, [20], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_11 = tir.match_buffer(placeholder, [1, 16, 16, 32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_2 = tir.match_buffer(placeholder_6, [20], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_4 = tir.match_buffer(placeholder_3, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_8 = tir.match_buffer(placeholder_9, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + ethosu_conv2d_2 = tir.allocate([4096], "uint8", "global") + placeholder_global = tir.allocate([80], "uint8", "global") + placeholder_d_global = tir.allocate([20], "uint8", "global") + tir.evaluate(tir.call_extern("ethosu_conv2d", "uint8", 16, 16, 32, 16, 0, 16, tir.load("uint8", placeholder_11.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 512, 32, 1, "uint8", 16, 16, 16, 16, 0, 16, tir.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", buffer_5.data, 0), 592, 12, tir.load("uint8", buffer_7.data, 0), 160, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_4.data, 0), 80, tir.load("uint8", placeholder_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_6.data, 0), 20, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "uint8", 16, 16, 16, 16, 0, 16, tir.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 16, 1, "uint8", 16, 16, 2, 16, 0, 16, tir.load("uint8", ethosu_conv2d_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 80, 12, tir.load("uint8", placeholder_d_global, 0), 20, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_9.data, 0), 80, tir.load("uint8", placeholder_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_2.data, 0), 20, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "uint8", 16, 16, 16, 16, 0, 16, tir.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 16, 1, "uint8", 16, 16, 2, 16, 0, 16, tir.load("uint8", ethosu_conv2d_1.data, 2), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 80, 12, tir.load("uint8", placeholder_d_global, 0), 20, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_3.data, 0), 80, tir.load("uint8", placeholder_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer.data, 0), 20, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "uint8", 16, 16, 16, 16, 0, 16, tir.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 16, 1, "uint8", 16, 16, 2, 16, 0, 16, tir.load("uint8", ethosu_conv2d_1.data, 4), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 80, 12, tir.load("uint8", placeholder_d_global, 0), 20, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_8.data, 0), 80, tir.load("uint8", placeholder_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_1.data, 0), 20, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "uint8", 16, 16, 16, 16, 0, 16, tir.load("uint8", ethosu_conv2d_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 16, 1, "uint8", 16, 16, 2, 16, 0, 16, tir.load("uint8", ethosu_conv2d_1.data, 6), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 80, 12, tir.load("uint8", placeholder_d_global, 0), 20, 0, 0, 0, 0, "CLIP", 0, 255, "NONE", dtype="handle")) + __tvm_meta__ = None +# fmt: on + + +def test_buffer_info_extraction(): + test_cases = [ + { + # Stimulus + "tir_module": SingleEthosUConv2D(), + "param_dict": { + 1: np.random.randint( + np.iinfo("uint8").min, np.iinfo("uint8").max, [1, 1, 3, 16], "uint8" + ), + 2: np.random.randint(np.iinfo("int32").min, np.iinfo("int32").max, [16], "int32"), + }, + # Reference Outputs + "constants": { + "placeholder_4": 1, + "placeholder_5": 2, + }, + "data_buffers": { + "placeholder_3": ( + [1, 8, 8, 3], + "uint8", + tir_to_cs_translator.BufferType.input_or_output, + ), + "ethosu_conv2d_1": ( + [1, 8, 8, 16], + "uint8", + tir_to_cs_translator.BufferType.input_or_output, + ), + }, + }, + { + "tir_module": MultiEthosUConv2D(), + "param_dict": { + 1: np.random.randint( + np.iinfo("uint8").min, np.iinfo("uint8").max, [1, 1, 3, 32], "uint8" + ), + 2: np.random.randint(np.iinfo("int32").min, np.iinfo("int32").max, [32], "int32"), + 3: np.random.randint( + np.iinfo("uint8").min, np.iinfo("uint8").max, [1, 1, 32, 8], "uint8" + ), + 4: np.random.randint(np.iinfo("int32").min, np.iinfo("int32").max, [8], "int32"), + }, + # Reference Outputs + "constants": { + "placeholder_5": 4, + "placeholder_7": 1, + "placeholder_8": 2, + "placeholder_9": 3, + }, + "data_buffers": { + "placeholder_6": ( + [1, 8, 8, 3], + "uint8", + tir_to_cs_translator.BufferType.input_or_output, + ), + "ethosu_conv2d_1": ( + [1, 8, 8, 8], + "uint8", + tir_to_cs_translator.BufferType.input_or_output, + ), + "ethosu_conv2d_2": ([1024], "uint8", tir_to_cs_translator.BufferType.scratch), + "ethosu_conv2d_3": ([2048], "uint8", tir_to_cs_translator.BufferType.scratch), + }, + }, + ] + for test_case in test_cases: + buffer_info = tir_to_cs_translator.extract_buffer_info( + test_case["tir_module"], test_case["param_dict"] + ) + for buffer_var, info in buffer_info.items(): + buffer_name = buffer_var.name + if buffer_name in test_case["constants"].keys(): + assert ( + info.values == test_case["param_dict"][test_case["constants"][buffer_name]] + ).all() + assert ( + info.dtype == test_case["param_dict"][test_case["constants"][buffer_name]].dtype + ) + info.btype == tir_to_cs_translator.BufferType.constant + else: + assert list(info.shape) == test_case["data_buffers"][buffer_name][0] + assert info.dtype == test_case["data_buffers"][buffer_name][1] + assert info.btype == test_case["data_buffers"][buffer_name][2] + + +def test_translate_ethosu_conv2d(): + test_cases = [ + { + # Stimulus + "tir_module": SingleEthosUConv2D(), + "param_dict": { + 1: np.random.randint( + np.iinfo("uint8").min, np.iinfo("uint8").max, [1, 1, 3, 16], "uint8" + ), + 2: np.random.randint(np.iinfo("int32").min, np.iinfo("int32").max, [16], "int32"), + }, + # Reference outputs + "ref": [ + { + "ifm": { + "data_type": vapi.NpuDataType.UINT8, + "shape": vapi.NpuShape3D(8, 8, 3), + "tiles": vapi.NpuTileBox(8, 0, 8, [0, 0, 0, 0]), + "quantization": vapi.NpuQuantization(0.5, 10), + "layout": vapi.NpuLayout.NHWC, + "strides": vapi.NpuShape3D(24, 3, 1), + }, + "ofm": { + "data_type": vapi.NpuDataType.UINT8, + "shape": vapi.NpuShape3D(8, 8, 16), + "tiles": vapi.NpuTileBox(8, 0, 8, [0, 0, 0, 0]), + "quantization": vapi.NpuQuantization(0.25, 14), + "layout": vapi.NpuLayout.NHWC, + "strides": vapi.NpuShape3D(128, 16, 1), + }, + "kernel": vapi.NpuKernel( + w=1, h=1, stride_x=1, stride_y=1, dilation_x=1, dilation_y=1 + ), + "padding": vapi.NpuPadding(top=0, left=0, bottom=0, right=0), + "activation": { + "op": vapi.NpuActivationOp.NONE_OR_RELU, + "min": -3.5, + "max": 60.25, + }, + "ifm_upscale": vapi.NpuResamplingMode.NONE, + "w_zero_point": 12, + } + ], + }, + { + "tir_module": MultiEthosUConv2D(), + "param_dict": { + 1: np.random.randint( + np.iinfo("uint8").min, np.iinfo("uint8").max, [1, 1, 3, 32], "uint8" + ), + 2: np.random.randint(np.iinfo("int32").min, np.iinfo("int32").max, [32], "int32"), + 3: np.random.randint( + np.iinfo("uint8").min, np.iinfo("uint8").max, [1, 1, 32, 8], "uint8" + ), + 4: np.random.randint(np.iinfo("int32").min, np.iinfo("int32").max, [8], "int32"), + }, + # Reference Outputs + "ref": [ + { + "ifm": { + "data_type": vapi.NpuDataType.UINT8, + "shape": vapi.NpuShape3D(4, 8, 3), + "tiles": vapi.NpuTileBox(4, 0, 8, [0, 0, 0, 0]), + "quantization": vapi.NpuQuantization(0.5, 10), + "layout": vapi.NpuLayout.NHWC, + "strides": vapi.NpuShape3D(24, 3, 1), + }, + "ofm": { + "data_type": vapi.NpuDataType.UINT8, + "shape": vapi.NpuShape3D(4, 8, 32), + "tiles": vapi.NpuTileBox(4, 0, 8, [0, 0, 0, 0]), + "quantization": vapi.NpuQuantization(0.25, 14), + "layout": vapi.NpuLayout.NHWC, + "strides": vapi.NpuShape3D(256, 32, 1), + }, + "kernel": vapi.NpuKernel( + w=1, h=1, stride_x=1, stride_y=1, dilation_x=1, dilation_y=1 + ), + "padding": vapi.NpuPadding(top=0, left=0, bottom=0, right=0), + "activation": {"op": None}, + "ifm_upscale": vapi.NpuResamplingMode.NONE, + "w_zero_point": 12, + }, + { + "ifm": { + "data_type": vapi.NpuDataType.UINT8, + "shape": vapi.NpuShape3D(4, 8, 32), + "tiles": vapi.NpuTileBox(4, 0, 8, [0, 0, 0, 0]), + "quantization": vapi.NpuQuantization(0.5, 10), + "layout": vapi.NpuLayout.NHWC, + "strides": vapi.NpuShape3D(256, 32, 1), + }, + "ofm": { + "data_type": vapi.NpuDataType.UINT8, + "shape": vapi.NpuShape3D(4, 8, 8), + "tiles": vapi.NpuTileBox(4, 0, 8, [0, 0, 0, 0]), + "quantization": vapi.NpuQuantization(0.25, 14), + "layout": vapi.NpuLayout.NHWC, + "strides": vapi.NpuShape3D(64, 8, 1), + }, + "kernel": vapi.NpuKernel( + w=1, h=1, stride_x=1, stride_y=1, dilation_x=1, dilation_y=1 + ), + "padding": vapi.NpuPadding(top=0, left=0, bottom=0, right=0), + "activation": { + "op": vapi.NpuActivationOp.NONE_OR_RELU, + "min": -3.5, + "max": 60.25, + }, + "ifm_upscale": vapi.NpuResamplingMode.NONE, + "w_zero_point": 12, + }, + { + "ifm": { + "data_type": vapi.NpuDataType.UINT8, + "shape": vapi.NpuShape3D(4, 8, 3), + "tiles": vapi.NpuTileBox(4, 0, 8, [0, 0, 0, 0]), + "quantization": vapi.NpuQuantization(0.5, 10), + "layout": vapi.NpuLayout.NHWC, + "strides": vapi.NpuShape3D(24, 3, 1), + }, + "ofm": { + "data_type": vapi.NpuDataType.UINT8, + "shape": vapi.NpuShape3D(4, 8, 32), + "tiles": vapi.NpuTileBox(4, 0, 8, [0, 0, 0, 0]), + "quantization": vapi.NpuQuantization(0.25, 14), + "layout": vapi.NpuLayout.NHWC, + "strides": vapi.NpuShape3D(256, 32, 1), + }, + "kernel": vapi.NpuKernel( + w=1, h=1, stride_x=1, stride_y=1, dilation_x=1, dilation_y=1 + ), + "padding": vapi.NpuPadding(top=0, left=0, bottom=0, right=0), + "activation": { + "op": vapi.NpuActivationOp.NONE_OR_RELU, + "min": -3.5, + "max": 60.25, + }, + "ifm_upscale": vapi.NpuResamplingMode.NONE, + "w_zero_point": 12, + }, + { + "ifm": { + "data_type": vapi.NpuDataType.UINT8, + "shape": vapi.NpuShape3D(4, 8, 32), + "tiles": vapi.NpuTileBox(4, 0, 8, [0, 0, 0, 0]), + "quantization": vapi.NpuQuantization(0.5, 10), + "layout": vapi.NpuLayout.NHWC, + "strides": vapi.NpuShape3D(256, 32, 1), + }, + "ofm": { + "data_type": vapi.NpuDataType.UINT8, + "shape": vapi.NpuShape3D(4, 8, 8), + "tiles": vapi.NpuTileBox(4, 0, 8, [0, 0, 0, 0]), + "quantization": vapi.NpuQuantization(0.25, 14), + "layout": vapi.NpuLayout.NHWC, + "strides": vapi.NpuShape3D(64, 8, 1), + }, + "kernel": vapi.NpuKernel( + w=1, h=1, stride_x=1, stride_y=1, dilation_x=1, dilation_y=1 + ), + "padding": vapi.NpuPadding(top=0, left=0, bottom=0, right=0), + "activation": { + "op": vapi.NpuActivationOp.NONE_OR_RELU, + "min": -3.5, + "max": 60.25, + }, + "ifm_upscale": vapi.NpuResamplingMode.NONE, + "w_zero_point": 12, + }, + ], + }, + ] + + def extract_ethosu_conv2d_extern_calls(mod): + """This function will obtain all ethosu_conv2d + calls from a NPU TIR module + Parameters + ---------- + mod : tvm.IRModule + This is a NPU TIR Module + + Returns + ------- + list + of tvm.tir.Call objects + that are tir extern calls + for ethosu_conv2d + """ + # There should only be a single function + assert len(mod.functions.items()) == 1 + primfunc = mod.functions.items()[0][1] + + ethosu_conv2d_calls = list() + + def populate_ethosu_conv2d_calls(stmt): + if ( + isinstance(stmt, tvm.tir.Call) + and stmt.op.name == "tir.call_extern" + and stmt.args[0] == "ethosu_conv2d" + ): + ethosu_conv2d_calls.append(stmt) + + stmt_functor.post_order_visit(primfunc.body, populate_ethosu_conv2d_calls) + return ethosu_conv2d_calls + + for test_case in test_cases: + ethosu_conv2d_calls = extract_ethosu_conv2d_extern_calls(test_case["tir_module"]) + for idx, ethosu_conv2d_call in enumerate(ethosu_conv2d_calls): + ref = test_case["ref"][idx] + npu_op, w_zero_point = tir_to_cs_translator.translate_ethosu_conv2d(ethosu_conv2d_call) + # Compare IFM + assert npu_op.ifm.data_type == ref["ifm"]["data_type"] + assert npu_op.ifm.shape == ref["ifm"]["shape"] + assert npu_op.ifm.tiles.height_0 == ref["ifm"]["tiles"].height_0 + assert npu_op.ifm.tiles.height_1 == ref["ifm"]["tiles"].height_1 + assert npu_op.ifm.tiles.width_0 == ref["ifm"]["tiles"].width_0 + assert npu_op.ifm.quantization == ref["ifm"]["quantization"] + assert npu_op.ifm.layout == ref["ifm"]["layout"] + assert npu_op.ifm.strides == ref["ifm"]["strides"] + # Compare OFM + assert npu_op.ofm.data_type == ref["ofm"]["data_type"] + assert npu_op.ofm.shape == ref["ofm"]["shape"] + assert npu_op.ofm.tiles.height_0 == ref["ofm"]["tiles"].height_0 + assert npu_op.ofm.tiles.height_1 == ref["ofm"]["tiles"].height_1 + assert npu_op.ofm.tiles.width_0 == ref["ofm"]["tiles"].width_0 + assert npu_op.ofm.quantization == ref["ofm"]["quantization"] + assert npu_op.ofm.layout == ref["ofm"]["layout"] + assert npu_op.ofm.strides == ref["ofm"]["strides"] + # Compare kernel and padding + assert npu_op.kernel.__dict__ == ref["kernel"].__dict__ + assert npu_op.padding == ref["padding"] + # Compare activation + if ref["activation"]["op"] is None: + assert npu_op.activation is None + else: + assert npu_op.activation.op_type == ref["activation"]["op"] + assert npu_op.activation.min == ref["activation"]["min"] + assert npu_op.activation.max == ref["activation"]["max"] + # Compare ifm upscaling + assert npu_op.ifm_upscale == ref["ifm_upscale"] + # Compare weight quantization parameters + assert w_zero_point == ref["w_zero_point"] + + +def test_translate_ethosu_copy(): + def extract_ethosu_copy_extern_calls(mod): + """This function will obtain all ethosu_conv2d + calls from a NPU TIR module + Parameters + ---------- + mod : tvm.IRModule + This is a NPU TIR Module + + Returns + ------- + list + of tvm.tir.Call objects + that are tir extern calls + for ethosu_conv2d + """ + # There should only be a single function + assert len(mod.functions.items()) == 1 + primfunc = mod.functions.items()[0][1] + + ethosu_copy_calls = list() + + def populate_ethosu_copy_calls(stmt): + if ( + isinstance(stmt, tvm.tir.Call) + and stmt.op.name == "tir.call_extern" + and stmt.args[0] == "ethosu_copy" + ): + ethosu_copy_calls.append(stmt) + + stmt_functor.post_order_visit(primfunc.body, populate_ethosu_copy_calls) + return ethosu_copy_calls + + test_cases = [ + { + # Stimulus + "tir_module": MultiEthosUCopy(), + "param_dict": { + 1: np.random.randint( + np.iinfo("uint8").min, np.iinfo("uint8").max, [8, 1, 1, 32], "uint8" + ), + 2: np.random.randint(np.iinfo("int32").min, np.iinfo("int32").max, [8], "int32"), + }, + # Reference outputs + "ref": [ + { + "src": "placeholder_4", + "dest": "placeholder_global", + "length": 256, + }, + { + "src": "placeholder_5", + "dest": "placeholder_d_global", + "length": 8, + }, + ], + }, + ] + + for test_case in test_cases: + ethosu_copy_calls = extract_ethosu_copy_extern_calls(test_case["tir_module"]) + for idx, ethosu_copy_call in enumerate(ethosu_copy_calls): + npu_dma_op = tir_to_cs_translator.translate_ethosu_tir_extern_call(ethosu_copy_call) + assert npu_dma_op.src.address.buffer_var.name == test_case["ref"][idx]["src"] + assert npu_dma_op.dest.address.buffer_var.name == test_case["ref"][idx]["dest"] + assert npu_dma_op.src.length == test_case["ref"][idx]["length"] + assert npu_dma_op.dest.length == test_case["ref"][idx]["length"] + + +def test_assign_addresses(): + test_cases = [ + { + # Stimulus + "tir_module": WeightStreamOnly(), + "param_dict": { + 2: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [144], "uint8"), + 3: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [20], "uint8"), + 4: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [144], "uint8"), + 5: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [20], "uint8"), + 6: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [144], "uint8"), + 7: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [20], "uint8"), + 8: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [144], "uint8"), + 9: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [20], "uint8"), + }, + }, + { + # Stimulus + "tir_module": MixedRead(), + "param_dict": { + 1: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [592], "uint8"), + 3: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [160], "uint8"), + 4: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [80], "uint8"), + 5: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [20], "uint8"), + 6: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [80], "uint8"), + 7: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [20], "uint8"), + 8: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [80], "uint8"), + 9: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [20], "uint8"), + 10: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [80], "uint8"), + 11: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [20], "uint8"), + }, + }, + ] + + def extract_extern_calls(mod): + """This function will obtain all ethosu_conv2d + calls from a NPU TIR module + Parameters + ---------- + mod : tvm.IRModule + This is a NPU TIR Module + + Returns + ------- + list + of tvm.tir.Call objects + that are tir extern calls + for ethosu_conv2d + """ + # There should only be a single function + assert len(mod.functions.items()) == 1 + primfunc = mod.functions.items()[0][1] + + extern_calls = list() + + def populate_extern_calls(stmt): + if isinstance(stmt, tvm.tir.Call) and stmt.op.name == "tir.call_extern": + extern_calls.append(stmt) + + stmt_functor.post_order_visit(primfunc.body, populate_extern_calls) + return extern_calls + + def collect_tir_buffer_info(npu_ops): + """This is run prior to address assigning to collect tir buffer information + for verification later on""" + _npu_op_tir_buffers = dict() + for npu_op in npu_ops: + if isinstance(npu_op, vapi.NpuDmaOperation): + _npu_op_tir_buffers[npu_op] = (npu_op.src.address, npu_op.dest.address) + elif issubclass(type(npu_op), vapi.NpuBlockOperation): + _npu_op_tir_buffers[npu_op] = ( + npu_op.ifm.tiles.addresses[0], + npu_op.ofm.tiles.addresses[0], + npu_op.weights, + npu_op.biases, + ) + return _npu_op_tir_buffers + + def _check_buffer(address, region, length, buffer_var): + """Checks whether the buffer information is valid with + original tir buffers. + - If its constant, this will check + the slice in the constant tensor has the values. + - If its scratch, this will check + the slice is within scratch and does not have conflicts + with other scratch tensors. + - If its input/output, this will check the + address is zero + """ + inverse_region_map = { + 0: tir_to_cs_translator.BufferType.constant, + 1: tir_to_cs_translator.BufferType.scratch, + 3: tir_to_cs_translator.BufferType.input, + 4: tir_to_cs_translator.BufferType.output, + } + buffer_type = inverse_region_map[region] + if buffer_type == tir_to_cs_translator.BufferType.constant: + ref = buffer_info[buffer_var].values + assert (constant_tensor[address : address + length] == ref).all() + # Every buffer is adjusted to align to 16 bytes + length = util.round_up(length, 16) + # Mark these constants are read at least once + constant_tensor_read_mask[address : address + length] = np.ones(length, dtype="uint8") + elif buffer_type == tir_to_cs_translator.BufferType.scratch: + shape = list(buffer_info[buffer_var].shape) + assert length == np.prod(shape) + assert address < scratch_size + # Every buffer is adjusted to align to 16 bytes + length = util.round_up(length, 16) + assert address + length <= scratch_size + # The scratch area should not be used by anyother buffer + assert not scratch_allocation_mask[address : address + length].any() + # The scratch area is marked as used + scratch_allocation_mask[address : address + length] = np.ones(length, dtype="uint8") + elif buffer_type == tir_to_cs_translator.BufferType.input: + assert address == 0 + else: + assert buffer_type == tir_to_cs_translator.BufferType.output + assert address == 0 + + def verify(npu_ops): + """This wrapper verifies the allocated addresses matches with original tir buffers""" + checked_buffers = set() + + def check_buffer(address, region, length, buffer_var): + if buffer_var not in checked_buffers: + _check_buffer(address, region, length, buffer_var) + checked_buffers.add(buffer_var) + + for npu_op in npu_ops: + if isinstance(npu_op, vapi.NpuDmaOperation): + src_tir_buffer_var = npu_op_tir_buffers[npu_op][0].buffer_var + check_buffer( + npu_op.src.address, npu_op.src.region, npu_op.src.length, src_tir_buffer_var + ) + dest_tir_load = npu_op_tir_buffers[npu_op][1].buffer_var + check_buffer( + npu_op.dest.address, + npu_op.dest.region, + npu_op.dest.length, + dest_tir_load, + ) + elif issubclass(type(npu_op), vapi.NpuBlockOperation): + ifm_tir_buffer_var = npu_op_tir_buffers[npu_op][0].buffer_var + ifm_length = ( + npu_op.ifm.shape.height * npu_op.ifm.shape.width * npu_op.ifm.shape.depth + ) + check_buffer( + npu_op.ifm.tiles.addresses[0], + npu_op.ifm.region, + ifm_length, + ifm_tir_buffer_var, + ) + ofm_tir_buffer_var = npu_op_tir_buffers[npu_op][1].buffer_var + ofm_length = ( + npu_op.ofm.shape.height * npu_op.ofm.shape.width * npu_op.ofm.shape.depth + ) + check_buffer( + npu_op.ofm.tiles.addresses[0], + npu_op.ofm.region, + ofm_length, + ofm_tir_buffer_var, + ) + for idx, weight in enumerate(npu_op_tir_buffers[npu_op][2]): + assert isinstance(weight, vapi.NpuAddressRange) + check_buffer( + npu_op.weights[idx].address, + npu_op.weights[idx].region, + npu_op.weights[idx].length, + weight.address.buffer_var, + ) + for idx, bias in enumerate(npu_op_tir_buffers[npu_op][3]): + assert isinstance(bias, vapi.NpuAddressRange) + check_buffer( + npu_op.biases[idx].address, + npu_op.biases[idx].region, + npu_op.biases[idx].length, + bias.address.buffer_var, + ) + + for test_case in test_cases: + buffer_info = tir_to_cs_translator.extract_buffer_info( + test_case["tir_module"], test_case["param_dict"] + ) + extern_calls = extract_extern_calls(test_case["tir_module"]) + _npu_ops = list() + for extern_call in extern_calls: + _npu_ops.append(tir_to_cs_translator.translate_ethosu_tir_extern_call(extern_call)) + npu_op_tir_buffers = collect_tir_buffer_info(_npu_ops) + _npu_ops, constant_tensor, scratch_size = tir_to_cs_translator.assign_addresses( + buffer_info, _npu_ops + ) + scratch_allocation_mask = np.zeros(scratch_size, dtype="uint8") + constant_tensor_read_mask = np.zeros(constant_tensor.size, dtype="uint8") + verify(_npu_ops) + # This will be only 1 if all allocated scratch is used. + assert np.prod(scratch_allocation_mask) == 1 + # This will be only 1 if all constant tensors is read at least once. + assert np.prod(constant_tensor_read_mask) == 1 + + +if __name__ == "__main__": + test_buffer_info_extraction() + test_translate_ethosu_conv2d() + test_translate_ethosu_copy() + test_assign_addresses()