From 4c4fcee4c2fcf9fafc689056eee57cf4f12ba270 Mon Sep 17 00:00:00 2001 From: "Michael J. Klaiber" Date: Tue, 9 Aug 2022 18:40:55 +0200 Subject: [PATCH] [UMA] UMA v1.0 (#12087) * Add minimal working structure for generic interface * Separate target definition from codegen * Update file structure to support multiple NPU targets * Add scheduling and pass support to codegen * Update schedule function and pass registration * Add generic partitioner for relay graph partitioning * Add pattern-based relay graph partitioning and AOT codegen * Update API * Add UltraTrail relay passes and schedule function * Update UltraTrail relay passes * Add tir_to_runtime hook for UltraTrail * Add operator strategy registration to lowering * Add option to pass constants as attributes * Refactor naming: Generic to UMA * Change API to single user-facing backend class UMABackend * Add initial codegen API * [UMA] add a generic packed function to register targets * Restructure files and add initial codegen * Minor code cleanup * Add UMA config and MergeCompilerRegion example * Move UMA configuration to init parameters * Add python hooks for C-codegen. Still has known restrictons * Fix relay_to_tir hook to keep virtual device in main function * Remove register schedules, scheduling is moved to passes for now * Remove extract constants since non-scalar constants are now supported by TVM * API documentation and some code fixes and cleanup * Fix typo * Fix UMA lowering * Prototype for UMA-based target attribute registration * Add default option and type deduction to register_target_attr * Change pass phases to enum * [Relay] Plumb external codegen target via Target.current() for all external codegen paths (See https://discuss.tvm.apache.org/t/byoc-supporting-cutlass-byoc-with-collage/12796/6 for context, which in turn is part of Collage (https://github.com/apache/tvm-rfcs/blob/main/rfcs/0062-collage.md). We want both old-style (via relay.ext.$toolchain) and new-style (via "RelayToTIR" Pass attribute on target kind) external codegen to be able to access the current 'external codegen' Target instance via Target.current(). - For old-style, plumb the true Target through TEComplier and push it on the context stack before calling relay.ext.$toolchain. - For new-style, pass the CompilationConfig to the RelayToTIRTargetHook pass, make the jump from "Compiler" attribute value to Target via the new CompilationConfig::FindPrimitiveTargetForKind method, and push on the stack before invoking the custom "RelayToTIR" pass. While working on this discovered RelayToTIRTargetHook was incompatible with the VM's compilation flow since RelayToTIRTargetHook assumes all "Compiler" attributed functions are inlined. Generalize it to support both inline and global function styles. Extend Target::IsExternalCodegen to recognize target kinds with "RelayToTIR" attributes as external. Update target hooks unit test to exercise new support for outline-style, picking up the current target, and compiling via the VM. * Use current target in lowering * Use attr:kRelayToTIR * Remove erronousely commited quick fix * Towards test cases for uma * Add test_uma * Initial UMA structure for version 1 * [UMA]: conv2d unit test * [UMA] update of tutorial * [UMA] update of pass format, still issue with conv2d c code * [UMA] refactoring of test_uma_lowering_with_umalower.py * [UMA] refactoring of test_uma_lowering_with_umalower.py * [UMA] Adding backend, codegen, patterns, strategies and run file for MyAiHw * [UMA] update towards my_ai_hw usecase * [UMA] working testcase for conv2d with uma * [UMA] testcase * [UMA] uma lower.py: replaced outdated function create_prim_func_from_outputs to be compatible withe latest content of "main" * UMA: Move torch import to top to avoid free(): invalid pointer error * Add stub files for targets * Add tests for ultratrail codegen * Adopt my_ai_hw accelerator for new target definition * Add unit test for target attributes * Test string arguments * Extend target test * [UMA] tutorial first versin * [UMA] moved unit tests to contrib * [UMA] renaming interfaces * Fix umalower_tests in ci * make uma a python module * [UMA] Update of UMAv1 API + added testcases + tutorialV1 * [UMA] UMAv1 * [UMA] cmake file updated * AOT test infrastructure adapted * UMA: add __init__.py for uma.api * Finish uma tests * Use upstream version of dmlc-core * [UMA] tir_to_runtime documentation update * [UMA] cleanup * [UMA] fix for test_partition * [UMA] lint fix * [UMA] lint fix * [UMA] lint fix * [UMA] lint fix * [UMA] fix of build scripts for arm and i386 * Fix remaining linter errors * [UMA] CMakeLists.txt added UMA tvm_option * [UMA] added UMA tvm_option * [UMA] guard against multiple registrations * [UMA] fixed comments as pointed out in PR 12087 * [UMA] fixed comments as pointed out in PR 12087 * [UMA] skip uma tests if uma is not available * [UMA] added UMA rst * [UMA] Moved tutorial to RST file in gallery * [UMA] moved uma cli to apps * [UMA] change requests according to PR-12087 * [UMA] update and sync of uma_cli and tutorial * [UMA] update of template passe: remove Pad block of Conv2D * [UMA] lint updates * [UMA] Test updates * [UMA] fixes according to comments from PR 12087 discussion * [UMA] lint updates * [UMA] moved UMA _template file to apps * [UMA] lint * [UMA] Remove exceptions when dispatching over targets * [UMA] vanilla pattern update * [UMA] added mobilenet integration test * [UMA] clang lint * Remove tir to runtime * [UMA] Use sequential for UMA relay passes * Use comparison against BYOC flow in test_partition * [UMA] tutorial update: moved code blocks to RST * [UMA] tutorial update and lint fixes * [UMA] removing UMA from i386 build, as there is a fail in the CI pipeline due to missing CLANG for i386 * [BYOC-DNNL] covered case for sum node without attr * [UMA] pylint * [UMA] pylint * [UMA] aot fix * [UMA] Changes PR review * [UMA] cc lint * [UMA] cc lint * Use better function name for te_lowering and annotate current target at TE functions Co-authored-by: Paul Palomero Bernardo Co-authored-by: Christoph Gerum Co-authored-by: mbs-octoml Co-authored-by: Christoph Gerum --- CMakeLists.txt | 2 + apps/uma/_template/__init__.py | 22 ++ apps/uma/_template/backend.py | 45 +++ apps/uma/_template/codegen.py | 28 ++ apps/uma/_template/conv2dnchw.cc | 96 ++++++ apps/uma/_template/passes.py | 136 ++++++++ apps/uma/_template/patterns.py | 25 ++ apps/uma/_template/run.py | 82 +++++ apps/uma/_template/strategies.py | 33 ++ apps/uma/uma_cli.py | 98 ++++++ cmake/config.cmake | 3 + cmake/modules/LibInfo.cmake | 1 + cmake/modules/contrib/UMA.cmake | 22 ++ docs/conf.py | 1 + gallery/tutorial/uma.py | 292 +++++++++++++++++ .../tvm/relay/backend/contrib/uma/__init__.py | 23 ++ .../relay/backend/contrib/uma/api/__init__.py | 25 ++ .../relay/backend/contrib/uma/api/_ffi_api.py | 20 ++ .../relay/backend/contrib/uma/api/codegen.py | 64 ++++ .../relay/backend/contrib/uma/api/lower.py | 165 ++++++++++ .../backend/contrib/uma/api/partitioner.py | 122 ++++++++ .../relay/backend/contrib/uma/api/utils.py | 73 +++++ .../tvm/relay/backend/contrib/uma/backend.py | 293 ++++++++++++++++++ python/tvm/relay/op/contrib/dnnl.py | 2 + python/tvm/testing/aot.py | 12 +- src/relay/backend/contrib/uma/relay_to_tir.cc | 175 +++++++++++ src/relay/backend/contrib/uma/targets.cc | 80 +++++ .../backend/contrib/uma/tir_to_runtime.cc | 82 +++++ src/support/libinfo.cc | 1 + .../python/contrib/test_uma/test_partition.py | 97 ++++++ tests/python/contrib/test_uma/test_target.py | 85 +++++ .../test_uma_lowering_with_umalower.py | 121 ++++++++ .../contrib/test_uma/test_uma_pipeline.py | 136 ++++++++ .../python/contrib/test_uma/test_uma_utils.py | 87 ++++++ .../test_uma/test_uma_vanilla_accelerator.py | 56 ++++ tests/scripts/task_config_build_arm.sh | 1 + tests/scripts/task_config_build_cortexm.sh | 2 + tests/scripts/task_config_build_cpu.sh | 1 + tests/scripts/task_config_build_i386.sh | 1 + 39 files changed, 2605 insertions(+), 5 deletions(-) create mode 100644 apps/uma/_template/__init__.py create mode 100644 apps/uma/_template/backend.py create mode 100644 apps/uma/_template/codegen.py create mode 100644 apps/uma/_template/conv2dnchw.cc create mode 100644 apps/uma/_template/passes.py create mode 100644 apps/uma/_template/patterns.py create mode 100644 apps/uma/_template/run.py create mode 100644 apps/uma/_template/strategies.py create mode 100644 apps/uma/uma_cli.py create mode 100644 cmake/modules/contrib/UMA.cmake create mode 100644 gallery/tutorial/uma.py create mode 100644 python/tvm/relay/backend/contrib/uma/__init__.py create mode 100644 python/tvm/relay/backend/contrib/uma/api/__init__.py create mode 100644 python/tvm/relay/backend/contrib/uma/api/_ffi_api.py create mode 100644 python/tvm/relay/backend/contrib/uma/api/codegen.py create mode 100644 python/tvm/relay/backend/contrib/uma/api/lower.py create mode 100644 python/tvm/relay/backend/contrib/uma/api/partitioner.py create mode 100644 python/tvm/relay/backend/contrib/uma/api/utils.py create mode 100644 python/tvm/relay/backend/contrib/uma/backend.py create mode 100644 src/relay/backend/contrib/uma/relay_to_tir.cc create mode 100644 src/relay/backend/contrib/uma/targets.cc create mode 100644 src/relay/backend/contrib/uma/tir_to_runtime.cc create mode 100644 tests/python/contrib/test_uma/test_partition.py create mode 100644 tests/python/contrib/test_uma/test_target.py create mode 100644 tests/python/contrib/test_uma/test_uma_lowering_with_umalower.py create mode 100644 tests/python/contrib/test_uma/test_uma_pipeline.py create mode 100644 tests/python/contrib/test_uma/test_uma_utils.py create mode 100644 tests/python/contrib/test_uma/test_uma_vanilla_accelerator.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 8dc03ee0f40e..7dd061954156 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -113,6 +113,7 @@ tvm_option(USE_VITIS_AI "Build with VITIS-AI Codegen support" OFF) tvm_option(SUMMARIZE "Print CMake option summary after configuring" OFF) tvm_option(USE_CLML "Build with CLML Codegen support" OFF) tvm_option(USE_CLML_GRAPH_EXECUTOR "Build with CLML graph runtime" OFF) +tvm_option(USE_UMA "Build with UMA support" OFF) # include directories include_directories(${CMAKE_INCLUDE_PATH}) @@ -497,6 +498,7 @@ include(cmake/modules/contrib/TensorRT.cmake) include(cmake/modules/contrib/VitisAI.cmake) include(cmake/modules/contrib/Verilator.cmake) include(cmake/modules/contrib/CLML.cmake) +include(cmake/modules/contrib/UMA.cmake) include(cmake/modules/Git.cmake) include(cmake/modules/LibInfo.cmake) include(cmake/modules/RustExt.cmake) diff --git a/apps/uma/_template/__init__.py b/apps/uma/_template/__init__.py new file mode 100644 index 000000000000..2cc0ee880d76 --- /dev/null +++ b/apps/uma/_template/__init__.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" + +Template files for UMA tutorial + + +""" diff --git a/apps/uma/_template/backend.py b/apps/uma/_template/backend.py new file mode 100644 index 000000000000..5ee7ecc19ef6 --- /dev/null +++ b/apps/uma/_template/backend.py @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""UMA backend for the my_ai_hw accelerator""" +from passes import MyAiHwConv2dPass +from tvm.relay.backend.contrib.uma.api.utils import PassPhase +from tvm.relay.backend.contrib.uma.backend import UMABackend +from codegen import gen_includes +from patterns import conv2d_pattern + + +class MyAiHwBackend(UMABackend): + """UMA backend for the MyAiHw accelerator.""" + + def __init__(self): + super().__init__() + + # Target configuration + self._register_target_attr("dimension") + + # Relay Pattern registration + self._register_pattern("conv2d", conv2d_pattern()) + + # Relay to TIR function registration + self._register_tir_pass(PassPhase.TIR_PHASE_0, MyAiHwConv2dPass()) + + # TIR to runtime function registration + self._register_codegen(fmt="c", includes=gen_includes) + + @property + def target_name(self): + return "my_ai_hw" diff --git a/apps/uma/_template/codegen.py b/apps/uma/_template/codegen.py new file mode 100644 index 000000000000..5e1d6b45e81f --- /dev/null +++ b/apps/uma/_template/codegen.py @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""UMA codegen for the my_ai_hw accelerator""" + +import tvm +import pathlib + + +def gen_includes() -> str: + topdir = pathlib.Path(__file__).parent.absolute() + + includes = "" + includes += f'#include "{topdir}/conv2dnchw.cc"' + return includes diff --git a/apps/uma/_template/conv2dnchw.cc b/apps/uma/_template/conv2dnchw.cc new file mode 100644 index 000000000000..bfb4300e2aa3 --- /dev/null +++ b/apps/uma/_template/conv2dnchw.cc @@ -0,0 +1,96 @@ +/* +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +*/ +#include + +// TODO(mjklaiber): leverage pragma import_c in the future +#ifdef __cplusplus +extern "C" +#endif + + /*! + * \brief Conv2D function for mock-accelerator examples. Limited to same-padded Conv2D with + * stride (1,1) and datatype float. \param ifmap Pointer to input feature map data of size + * iw*ih*ic*sizeof(float). \param weights Pointer to weight data of size + * kh*kw*ic**oc*sizeof(float). \param result Pointer to output feature map data of size + * iw*ih*oc*sizeof(float). \param oc Number of channels of output feature map. \param iw Width + * of input feature map, ifmap. \param ih Height of input feature map, ifmap. \param ic Number + * of channels of input feature map. \param kh Height of convolution kernels. \param kw Width of + * convolution kernels. + * + * \return error code + * + */ + int + my_ai_hw_conv2dnchw(float* ifmap, float* weights, float* result, int oc, int iw, int ih, int ic, + int kh, int kw) { + + int kw_low = kw / 2; + int kh_low = kh / 2; + int kw_high = iw + kw / 2; + int kh_high = ih + kh / 2; + + int padded_iw = iw + 2 * kw_low; + int padded_ih = ih + 2 * kh_low; + + // This is only example code. A real hardware accelerator would call a device specific malloc + // function. + float* pad_temp = (float*)malloc( + (((ic * padded_iw * padded_ih) + (padded_ih * padded_iw)) + padded_iw) * sizeof(float)); + + if (pad_temp == NULL) { + return -1; + } + + for (int i1 = 0; i1 < ic; ++i1) { + for (int i2 = 0; i2 < padded_ih; ++i2) { + for (int i3 = 0; i3 < padded_iw; ++i3) { + ((float*)pad_temp)[(((i1 * padded_iw * padded_ih) + (i2 * padded_iw)) + i3)] = + (((((kh_low <= i2) && (i2 < kh_high)) && (kw_low <= i3)) && (i3 < kw_high)) + ? ifmap[((((i1 * iw * ih) + ((i2 - kh_low) * iw)) + i3 - kw_low))] + : 0.000000e+00f); + } + } + } + for (int i11 = 0; i11 < oc; ++i11) { + for (int i21 = 0; i21 < ih; ++i21) { + for (int i31 = 0; i31 < iw; ++i31) { + for (int i4 = 0; i4 < ic; ++i4) { + for (int i5 = 0; i5 < kh; ++i5) { + for (int i6 = 0; i6 < kw; ++i6) { + int cse_var_1 = (((i11 * iw * ih) + (i21 * iw)) + i31); + if (((i4 == 0) && (i5 == 0)) && (i6 == 0)) { + result[cse_var_1] = 0.000000e+00f; + } + result[cse_var_1] = + (result[cse_var_1] + + (((float*) + pad_temp)[i4 * padded_iw * padded_ih + (i21 + i5) * padded_iw + i31 + i6] * + weights[((((i11 * ic * kh * kw) + (i4 * kh * kw)) + (i5 * kw)) + i6)])); + } + } + } + } + } + } + + // This is only example code. A real hardware accelerator would call a device specific free + // function. + free(pad_temp); + return 0; +} diff --git a/apps/uma/_template/passes.py b/apps/uma/_template/passes.py new file mode 100644 index 000000000000..b4f261a5ab49 --- /dev/null +++ b/apps/uma/_template/passes.py @@ -0,0 +1,136 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Transform passes for the my_ai_hw accelerator""" + +import tvm +from tvm import tir +from tvm.relay.backend.contrib.uma.api.utils import add_llvm_to_block + + +@tvm.tir.transform.prim_func_pass(opt_level=2) +class MyAiHwConv2dPass: + _EXTERNAL_FUNCTION_NAME = "my_ai_hw_conv2dnchw" + _TVM_BLOCK_MATCH_NAME = "conv2d_nchw" + + def transform_function( + self, func: tvm.tir.PrimFunc, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext + ) -> tvm.tir.PrimFunc: + return self._my_ai_hw_conv2d_pass(func, mod, ctx) + + @classmethod + def _my_ai_hw_conv2d_pass(cls, func, mod, ctx): + _loops = dict() + _handles = [] + _entry_node = None + + def _has_block(name: str, func: tvm.tir.PrimFunc) -> bool: + """ + Determine of a tir.block with `name` exists in `func` + """ + + def _hb(op): + if isinstance(op, tvm.tir.Block): + _found_blocks.append(op.name_hint) + + _found_blocks = [] + tvm.tir.stmt_functor.post_order_visit(func.body, _hb) + return name in _found_blocks + + def _detect_and_replace_conv2d( + func: tvm.tir.PrimFunc, mod: tvm.ir.IRModule, ctx: tvm.ir.transform.PassContext + ) -> tvm.tir.PrimFunc: + def _replace_conv2d(op): + if op == _entry_node: + irb = tvm.tir.ir_builder.create() + # Collection of buffer address + buffers = [b[1].data for b in _handles] + # extraction of loop offsets + for k, v in _loops.items(): + assert v.min.value == 0 + offset_order = ["co", "w", "h", "ci", "kh", "kw"] + offsets = [_loops[i].extent.value for i in offset_order] + args = buffers + offsets + irb.emit(tir_call(irb, True, cls._EXTERNAL_FUNCTION_NAME, *args)) + irb_result = irb.get() + return irb_result + elif isinstance(op, tvm.tir.SeqStmt): + # Remove that pad block of TOPI's conv2DNCHW by only returning the 2nd statement + return op.seq[1] + return op + + sch = tir.Schedule(func) + + if _has_block(cls._TVM_BLOCK_MATCH_NAME, func): + conv2d_block = sch.get_block(cls._TVM_BLOCK_MATCH_NAME) + rv_loops = sch.get_loops(conv2d_block) + assert len(rv_loops) == 7 + loops = dict( + n=rv_loops[0], + co=rv_loops[1], + h=rv_loops[2], + w=rv_loops[3], + ci=rv_loops[4], + kh=rv_loops[5], + kw=rv_loops[6], + ) + _entry_node = sch.get(rv_loops[1]) + _loops = {k: sch.get(v) for k, v in loops.items()} + _handles = func.buffer_map.items() + + x = tvm.tir.stmt_functor.ir_transform( + func.body, None, _replace_conv2d, ["tir.For", "tir.SeqStmt"] + ) + return func.with_body(x) + else: + return func + + r = _detect_and_replace_conv2d(func, mod, ctx) + return r + + +def tir_call(ib: tvm.tir.ir_builder, extern: bool, name: str, *args): + """ + ib: ir_builder + extern: bool + True --> tvm.tir.call_extern + False --> tvm.tir.call_packed + name: str + function name + *args: + arguments for function call + """ + + def buf_from_array(ib, arr, dtype): + # Allocate enough memory to store the whole array + var = ib.allocate("int32", (len(arr),), scope="global") + for i, v in enumerate(arr): + var[i] = v + # Declare a buffer, which is basically a view on the chunk of memory that we allocated + buf = tvm.tir.decl_buffer((len(arr),), dtype, data=var, scope="global") + return buf + + if extern: + args = [i.data if isinstance(i, tvm.tir.Buffer) else i for i in args] + return tvm.tir.call_extern("int32", name, *args) + else: + args = [ + buf_from_array(ib, i, "int32") + if isinstance(i, (tuple, list, tvm.ir.container.Array)) + else i + for i in args + ] + return tvm.tir.call_packed(name, *args) diff --git a/apps/uma/_template/patterns.py b/apps/uma/_template/patterns.py new file mode 100644 index 000000000000..ce25fe4dff8e --- /dev/null +++ b/apps/uma/_template/patterns.py @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relay graph patterns for the my_ai_hw accelerator""" + +from tvm.relay.dataflow_pattern import is_op, wildcard + + +def conv2d_pattern(): + pattern = is_op("nn.conv2d")(wildcard(), wildcard()) + pattern = pattern.has_attr({"strides": [1, 1], "groups": 1}) + return pattern diff --git a/apps/uma/_template/run.py b/apps/uma/_template/run.py new file mode 100644 index 000000000000..852ae1234d0f --- /dev/null +++ b/apps/uma/_template/run.py @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from tvm.micro.testing.aot_test_utils import AOT_DEFAULT_RUNNER +import tvm +from tvm import relay +from backend import MyAiHwBackend +from tvm.relay import transform +from collections import OrderedDict +import numpy as np + + +from tvm.testing.aot import ( + AOTTestModel as AOTModel, + AOTTestRunner as AOTRunner, + generate_ref_data, + compile_and_run, +) + + +def create_conv2d(groups=1, runner=AOT_DEFAULT_RUNNER, weight_shape=32): + dtype = "float32" + ishape = (1, 32, 14, 14) + wshape = (32, weight_shape, 3, 3) + pass_config = {"tir.usmp.enable": True} + runner = AOTRunner( + makefile=runner.makefile, + prologue=runner.prologue, + epilogue=runner.epilogue, + includes=runner.includes, + parameters=runner.parameters, + pass_config=pass_config, + ) + data0 = relay.var("data", shape=ishape, dtype=dtype) + weight0 = relay.var("weight", shape=wshape, dtype=dtype) + out = relay.nn.conv2d(data0, weight0, kernel_size=(3, 3), padding=(1, 1), groups=groups) + main_f = relay.Function([data0, weight0], out) + mod = tvm.IRModule() + mod["main"] = main_f + mod = transform.InferType()(mod) + i_data = np.random.uniform(0, 1, ishape).astype(dtype) + w1_data = np.random.uniform(0, 1, wshape).astype(dtype) + inputs = OrderedDict([("data", i_data), ("weight", w1_data)]) + output_list = generate_ref_data(mod, inputs) + return mod, inputs, output_list, runner + + +def main(): + mod, inputs, output_list, runner = create_conv2d() + + uma_backend = MyAiHwBackend() + uma_backend.register() + mod = uma_backend.partition(mod) + target = tvm.target.Target("my_ai_hw", host=tvm.target.Target("c")) + + export_directory = tvm.contrib.utils.tempdir(keep_for_debug=True).path + print(f"Generated files are in {export_directory}") + compile_and_run( + AOTModel(module=mod, inputs=inputs, outputs=output_list), + runner, + interface_api="c", + use_unpacked_api=True, + target=target, + test_dir=str(export_directory), + ) + + +if __name__ == "__main__": + main() diff --git a/apps/uma/_template/strategies.py b/apps/uma/_template/strategies.py new file mode 100644 index 000000000000..aa1ea07280e4 --- /dev/null +++ b/apps/uma/_template/strategies.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Strategies for the my_ai_hw accelerator""" + +# Example how to integrate a custom conv1d strategy: + +# @relay.op.strategy.override_native_generic_func("custom_conv1d_strategy") +# def custom_conv1d_strategy(attrs, inputs, out_type, target): +# strategy = _op.OpStrategy() +# strategy.add_implementation( +# wrap_compute_conv1d(custom_conv1d_compute), +# wrap_topi_schedule(custom_conv1d_schedule), +# name="custom_conv1d.generic", +# return strategy +# + +# For further details see: +# - github.com/apache/tvm-rfcs/blob/main/rfcs/0060_UMA_Unified_Modular_Accelerator_Interface.md +# - $TVM_HOME/python/tvm/relay/op/strategy/x86.py diff --git a/apps/uma/uma_cli.py b/apps/uma/uma_cli.py new file mode 100644 index 000000000000..159fa9e62cb6 --- /dev/null +++ b/apps/uma/uma_cli.py @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" + UMA Command Line Interface (CLI) + + Tool to create code skeletons for an easy integration of + new AI hardware accelerators/libraries into TVM using UMA +""" + +import argparse +import os +import shutil +import sys +import pathlib +from inflection import camelize, underscore + + +def _parse_args(): + parser = argparse.ArgumentParser(description="UMA Interface command line interface") + parser.add_argument( + "--add_hardware", + type=str, + required=True, + ) + parser.add_argument( + "--tutorial", + type=str, + ) + args = parser.parse_args() + return args + + +def replace_template_name( + files: list, template_name: str, add_hw_name: str, template_source: str = "_template" +) -> None: + """ + Replace names in template skeleton code by new name + """ + for f in files: + with open(f) as read_file: + data = read_file.read() + for case in [underscore, camelize]: + data = data.replace(case(template_name), case(add_hw_name)) + data = data.replace(template_source, underscore(add_hw_name)) + with open(f, "w") as write_file: + write_file.write(data) + + +def main(): + """ + UMA Command Line Interface (CLI) + """ + args = _parse_args() + add_hw_name = args.add_hardware + uma_template_path = pathlib.Path(os.getcwd(), "_template").absolute() + + add_hw_path = os.path.join(uma_template_path.parent, add_hw_name) + if os.path.exists(add_hw_path): + print( + f"Hardware with name {add_hw_name} already exists in UMA file structure: {add_hw_path}" + ) + sys.exit(-1) + else: + os.mkdir(add_hw_path) + + uma_files = ["backend.py", "codegen.py", "passes.py", "patterns.py", "run.py", "strategies.py"] + if args.tutorial == "vanilla": + uma_files.append("conv2dnchw.cc") + + source_files = [os.path.join(uma_template_path, f) for f in uma_files] + destination_files = [os.path.join(add_hw_path, f) for f in uma_files] + + for src, dst in zip(source_files, destination_files): + shutil.copyfile(src, dst) + + template_name = "my_ai_hw" + replace_template_name(destination_files, template_name, add_hw_name) + + print(f"Success: added {add_hw_name} to {add_hw_path}") + + +if __name__ == "__main__": + main() diff --git a/cmake/config.cmake b/cmake/config.cmake index 4cd10f104a83..18725de844b2 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -383,3 +383,6 @@ set(SUMMARIZE OFF) # To enable pass the path to the root libtorch (or PyTorch) directory # OFF or /path/to/torch/ set(USE_LIBTORCH OFF) + +# Whether to use the Universal Modular Accelerator Interface +set(USE_UMA OFF) diff --git a/cmake/modules/LibInfo.cmake b/cmake/modules/LibInfo.cmake index 3b3d8a4bcc9a..6bc8f6b46390 100644 --- a/cmake/modules/LibInfo.cmake +++ b/cmake/modules/LibInfo.cmake @@ -116,6 +116,7 @@ function(add_lib_info src_file) TVM_INFO_USE_VULKAN="${USE_VULKAN}" TVM_INFO_USE_CLML="${USE_CLML}" TVM_INFO_USE_CLML_GRAPH_EXECUTOR="${USE_CLML_GRAPH_EXECUTOR}" + TVM_INFO_USE_UMA="${USE_UMA}" ) endfunction() diff --git a/cmake/modules/contrib/UMA.cmake b/cmake/modules/contrib/UMA.cmake new file mode 100644 index 000000000000..1d3a9a30ec0f --- /dev/null +++ b/cmake/modules/contrib/UMA.cmake @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +if(USE_UMA) + file(GLOB COMPILER_UMA_SRCS + CONFIGURE_DEPENDS src/relay/backend/contrib/uma/*) + list(APPEND COMPILER_SRCS ${COMPILER_UMA_SRCS}) +endif(USE_UMA) diff --git a/docs/conf.py b/docs/conf.py index 82b0d2962338..d645958ca6db 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -264,6 +264,7 @@ def git_describe_version(original_version): "topi.pi", "cross_compilation_and_rpc.py", "relay_quick_start.py", + "uma.py", ], "compile_models": [ "from_pytorch.py", diff --git a/gallery/tutorial/uma.py b/gallery/tutorial/uma.py new file mode 100644 index 000000000000..ed4fc4cf805c --- /dev/null +++ b/gallery/tutorial/uma.py @@ -0,0 +1,292 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +.. _tutorial-uma: + +Making your Hardware Accelerator TVM-ready with UMA +=================================================== +**Authors**: `Michael J. Klaiber `_, `Christoph Gerum `_, +`Paul Palomero Bernardo `_ + +""" + + +###################################################################### +# This is an introductory tutorial to the **Universal Modular Accelerator Interface** (UMA). +# UMA provides an easy-to-use API to integrate new hardware accelerators into TVM. +# +# This tutorial gives you step-by-step guidance how to use UMA to +# make your hardware accelerator TVM-ready. +# While there is no one-fits-all solution for this problem, UMA targets to provide a stable and Python-only +# API to integrate a number of hardware accelerator classes into TVM. +# +# +# In this tutorial you will get to know the UMA API in three use cases of increasing complexity. +# In these use case the three mock-accelerators +# **Vanilla**, **Strawberry** and **Chocolate** are introduced and +# integrated into TVM using UMA. +# + +# sphinx_gallery_start_ignore +from tvm import testing + +testing.utils.install_request_hook(depth=3) +# sphinx_gallery_end_ignore + + +###################################################################### +# Vanilla +# ------------- +# **Vanilla** is a simple accelerator consisting of a MAC array and has no internal memory. +# It is can ONLY process Conv2D layers, all other layers are executed on a CPU, that also orchestrates **Vanilla**. +# Both the CPU and Vanilla use a shared memory. +# + +###################################################################### +# .. image:: https://raw.githubusercontent.com/apache/tvm-site/main/images/tutorial/uma_vanilla_block_diagram.png +# :width: 100% +# :alt: A block diagram of Vanilla +# + +###################################################################### +# **Vanilla** has a C interface ``vanilla_conv2dnchw(...)``` for carrying out a Conv2D operation (including same-padding), +# that accepts pointers to input feature map, weights and result, +# as well as the dimensions of `Conv2D`: `oc`, `iw`, `ih`, `ic`, `kh`, `kw`. +# +# .. code-block:: c++ +# +# int vanilla_conv2dnchw(float* ifmap, float* weights, float* result, int oc, int iw, int ih, int ic, int kh, int kw); + + +################################################################################ +# The script `uma_cli` creates code skeletons with API-calls into the UMA-API for new accelerators. +# +# For **Vanilla** we use it as follows: (``--tutorial vanilla`` adds all the additional files required for this part of the tutorial) +# +# .. code-block:: bash +# +# pip install inflection +# cd $TVM_HOME/apps/uma +# python uma_cli.py --add_hardware vanilla_accelerator --tutorial vanilla +# + +################################################################################ +# uma_cli.py generates these files in the directory ``vanilla_accelerator`` which we are going to revist. +# +# .. code-block:: bash +# +# backend.py +# codegen.py +# conv2dnchw.cc +# passes.py +# patterns.py +# run.py +# strategies.py + + +################################################################################ +# Vanilla backend +# +# The generated backend for vanilla is found in `vanilla_accelerator/backend.py`: + +###################################################################### +# +# .. code-block:: python +# +# class VanillaAcceleratorBackend(UMABackend): +# """UMA backend for VanillaAccelerator.""" +# +# def __init__(self): +# super().__init__() +# +# self._register_pattern("conv2d", conv2d_pattern()) +# self._register_tir_pass(PassPhase.TIR_PHASE_0, VanillaAcceleratorConv2DPass()) +# self._register_codegen(fmt="c", includes=gen_includes) +# +# @property +# def target_name(self): +# return "vanilla_accelerator" + + +################################################################################ +# Define offloaded patterns +# +# To specify that `Conv2D` is offloaded to **Vanilla**, it is described as Relay dataflow pattern +# (`DFPattern `_) in `vanilla_accelerator/patterns.py` + + +################################################################################ +# +# .. code-block:: python +# +# def conv2d_pattern(): +# pattern = is_op("nn.conv2d")(wildcard(), wildcard()) +# pattern = pattern.has_attr({"strides": [1, 1]}) +# return pattern + + +################################################################################ +# To map **Conv2D** operations from the input graph to **Vanilla**'s +# low level function call ``vanilla_conv2dnchw(...)``, the TIR pass +# *VanillaAcceleratorConv2DPass* (that will be discussed later in this tutorial) +# is registered in `VanillaAcceleratorBackend`. + + +################################################################################ +# Codegen + +################################################################################ +# The file ``vanilla_accelerator/codegen.py`` defines static C-code that is added to the +# resulting C-Code generated by TVMÅ› C-Codegen in ``gen_includes``. +# Here C-code is added to include **Vanilla**'s low level library``vanilla_conv2dnchw()``. +# +# .. code-block:: python +# +# def gen_includes() -> str: +# topdir = pathlib.Path(__file__).parent.absolute() +# +# includes = "" +# includes += f'#include "{topdir}/conv2dnchw.cc"' +# return includes + + +################################################################################ +# As shown above in `VanillaAcceleratorBackend` it is registered to UMA with +# the `self._register_codegen` +# +# .. code-block:: python +# +# self._register_codegen(fmt="c", includes=gen_includes) + + +########################################################### +# Building the Neural Network and run it on Vanilla +# +# To demonstrate UMA's functionality, we will generate C code for a single Conv2D layer and run it on +# the Vanilla accelerator. +# The file ``vanilla_accelerator/run.py`` provides a demo running a Conv2D layer +# making use of Vanilla's C-API. +# +# +# .. code-block:: python +# +# def main(): +# mod, inputs, output_list, runner = create_conv2d() +# +# uma_backend = VanillaAcceleratorBackend() +# uma_backend.register() +# mod = uma_backend.partition(mod) +# target = tvm.target.Target("vanilla_accelerator", host=tvm.target.Target("c")) +# +# export_directory = tvm.contrib.utils.tempdir(keep_for_debug=True).path +# print(f"Generated files are in {export_directory}") +# compile_and_run( +# AOTModel(module=mod, inputs=inputs, outputs=output_list), +# runner, +# interface_api="c", +# use_unpacked_api=True, +# target=target, +# test_dir=str(export_directory), +# ) +# +# +# main() + +############################################################ +# By running ``vanilla_accelerator/run.py`` the output files are generated in the model library format (MLF). +# + +########################################################### +# Output: +# +# .. code-block:: bash +# +# Generated files are in /tmp/tvm-debug-mode-tempdirs/2022-07-13T13-26-22___x5u76h0p/00000 + +########################################################### +# Let's examine the generated files: +# +# +# Output: +# +# .. code-block:: bash +# +# cd /tmp/tvm-debug-mode-tempdirs/2022-07-13T13-26-22___x5u76h0p/00000 +# cd build/ +# ls -1 +# +# codegen +# lib.tar +# metadata.json +# parameters +# runtime +# src + +########################################################### +# To evaluate the generated C code go to ``codegen/host/src/default_lib2.c`` +# +# .. code-block:: bash +# +# cd codegen/host/src/ +# ls -1 +# +# default_lib0.c +# default_lib1.c +# default_lib2.c +# + +########################################################### +# In `default_lib2.c` you can now see that the generated code calls +# into Vanilla's C-API and executes a Conv2D layer: +# +# .. code-block:: c++ +# +# TVM_DLL int32_t tvmgen_default_vanilla_accelerator_main_0(float* placeholder, float* placeholder1, float* conv2d_nchw, uint8_t* global_workspace_1_var) { +# vanilla_accelerator_conv2dnchw(placeholder, placeholder1, conv2d_nchw, 32, 14, 14, 32, 3, 3); +# return 0; +# } +# + + +########################################################### +# Strawberry +# --------------- +# Coming soon ... + +########################################################### +# Chocolate +# -------------- +# Coming soon ... +# + +###################################################################### +# Request for Community Input +# ----------------------------- +# If this tutorial **did not** fit to your accelerator, lease add your requirements to the UMA thread in +# the TVM discuss forum: `Link `_. +# We are eager to extend this tutorial to provide guidance on making further classes of AI hardware +# accelerators TVM-ready using the UMA interface. +# + +###################################################################### +# References +# ----------- +# [UMA-RFC] `UMA: Universal Modular Accelerator Interface `_, +# TVM RFC, June 2022. +# +# [DFPattern] `Pattern Matching in Relay `_ +# diff --git a/python/tvm/relay/backend/contrib/uma/__init__.py b/python/tvm/relay/backend/contrib/uma/__init__.py new file mode 100644 index 000000000000..061a42e23a87 --- /dev/null +++ b/python/tvm/relay/backend/contrib/uma/__init__.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""UMA modules for Relay.""" + +from .backend import UMABackend +from .api.utils import uma_available + +__all__ = ["UMABackend", "uma_available"] diff --git a/python/tvm/relay/backend/contrib/uma/api/__init__.py b/python/tvm/relay/backend/contrib/uma/api/__init__.py new file mode 100644 index 000000000000..f826a56016fa --- /dev/null +++ b/python/tvm/relay/backend/contrib/uma/api/__init__.py @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""UMA: Universal Modular Accelerator Interface API""" + +from .codegen import UMACodegen +from .lower import UMALower +from .partitioner import UMAPartitioner + + +__all__ = ["UMACodegen", "UMALower", "UMAPartitioner"] diff --git a/python/tvm/relay/backend/contrib/uma/api/_ffi_api.py b/python/tvm/relay/backend/contrib/uma/api/_ffi_api.py new file mode 100644 index 000000000000..5f67cb7ec246 --- /dev/null +++ b/python/tvm/relay/backend/contrib/uma/api/_ffi_api.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for relay transformation passes.""" +import tvm._ffi # type: ignore + +tvm._ffi._init_api("relay.ext.uma", __name__) diff --git a/python/tvm/relay/backend/contrib/uma/api/codegen.py b/python/tvm/relay/backend/contrib/uma/api/codegen.py new file mode 100644 index 000000000000..8bbb77c91b44 --- /dev/null +++ b/python/tvm/relay/backend/contrib/uma/api/codegen.py @@ -0,0 +1,64 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Codegen base class of the Universal Modular Accelerator Interface (UMA)""" + +from typing import Callable, Optional +import tvm + + +class UMACodegen(object): + """ + Codegen base class of the Universal Modular Accelerator Interface (UMA) + """ + + def __init__(self, target_name: str) -> None: + self.target_name = target_name + + def _register_codegen( + self, fmt: str = "c", includes: Optional[Callable[[], str]] = None, **kwargs + ) -> None: + """Registration codegen in UMA. + + Parameters + ---------- + fmt: str + format of codegen. Currently only "c" is supported. + includes : OptionalCallable[[], str]] + user-defined function that adds C-#include statement to UMA C-Code. + """ + if fmt == "c": + self._register_c_codegen(includes, **kwargs) + else: + raise RuntimeError(f'Unsupported codegen format "{fmt}"') + + def _register_c_codegen(self, includes: Optional[Callable[[], str]] = None) -> None: + """Registration of UMA helper functions, e.g. includes and replace_call_extern. + + Parameters + ---------- + includes : OptionalCallable[[], str]] + user-defined function that adds C-#include statement to UMA C-Code. + """ + if includes is not None: + tvm._ffi.register_func( + f"relay.ext.uma.codegen_c_includes_{self.target_name}", + includes, + override=True, + ) + + def register(self) -> None: + pass diff --git a/python/tvm/relay/backend/contrib/uma/api/lower.py b/python/tvm/relay/backend/contrib/uma/api/lower.py new file mode 100644 index 000000000000..34630949a151 --- /dev/null +++ b/python/tvm/relay/backend/contrib/uma/api/lower.py @@ -0,0 +1,165 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Lowering base class of the Universal Modular Accelerator Interface (UMA)""" + +from typing import List, Tuple, Callable, Optional + +import tvm +from tvm import relay, te +from tvm.relay.op.op import register_strategy +from . import _ffi_api +from .utils import PassPhase + +OperatorStrategies = List[ + Tuple[ + str, + Callable[ + [tvm.ir.Attrs, tvm.ir.Array, tvm.ir.TensorType, tvm.target.Target], + tvm.relay.op.op.OpStrategy, + ], + Optional[int], + ] +] + + +class UMALower: + """Lowering base class of the Universal Modular Accelerator Interface (UMA).""" + + def __init__(self, target_name: str) -> None: + self.target_name = target_name + self._operator_strategies: OperatorStrategies = [] + self._tir_passes: List[Tuple[PassPhase, tvm.tir.transform.PrimFuncPass]] = [] + + def _lower_relay_to_tir(self, relay_prim_func: relay.Function) -> tvm.tir.PrimFunc: + """Lower a Relay primitive function to a S-TIR primitive function. + + Parameters + ---------- + prim_func : tvm.relay.Function + The Relay function to lower. + + Returns + ------- + out : tvm.tir.PrimFunc + The lowered schedulable TensorIR primitive function. + + """ + + def _get_tensors(te_cached_func): + outputs = list(te_cached_func.outputs) + stack = [] + visited = set() + for output_ in outputs: + if output_ not in visited: + visited.add(output_) + stack.append(output_) + + args = [] + while len(stack) != 0: + tensor = stack.pop() + if isinstance(tensor.op, tvm.te.tensor.PlaceholderOp): + args.append(tensor) + elif isinstance(tensor.op, tvm.te.tensor.ComputeOp): + inputs = tensor.op.input_tensors + for input_ in inputs: + if input_ not in visited: + visited.add(input_) + stack.append(input_) + + return args + outputs + + lower_to_te = tvm._ffi.get_global_func("relay.backend.LowerToTE") + te_cached_func = lower_to_te(relay_prim_func) + x = _get_tensors(te_cached_func) + tir_prim_func = te.create_prim_func(x) + tir_prim_func = tir_prim_func.with_attr( + "global_symbol", relay_prim_func.attrs["global_symbol"] + ) + + compiler_attr = relay_prim_func.attrs["Compiler"] + target = tvm.target.Target.current() + if target.kind.name != compiler_attr: + target = tvm.target.Target(compiler_attr) + + tir_prim_func = tir_prim_func.with_attr("target", target) + tir_prim_func = tir_prim_func.with_attr("relay_attrs", relay_prim_func.attrs) + return tir_prim_func + + def _lower_stir_to_nstir(self, prim_func: tvm.tir.PrimFunc) -> tvm.tir.PrimFunc: + """Lower a S-TIR primitive function to a NS-TIR primitive function. + + Parameters + ---------- + prim_func : tvm.tir.PrimFunc + The primitive function to lower. + + Returns + ------- + out : tvm.tir.PrimFunc + The lowered non-schedulable TensorIR primitive function. + + """ + curr_ctxt = tvm.transform.PassContext().current() + assert "tir.add_lower_pass" not in curr_ctxt.config + + pass_map = { + PassPhase.TIR_PHASE_0: 0, + PassPhase.TIR_PHASE_1: 1, + PassPhase.TIR_PHASE_2: 2, + PassPhase.TIR_PHASE_3: 3, + } + lower_passes = [(pass_map[k], v) for k, v in self._tir_passes] + + with tvm.transform.PassContext( + opt_level=curr_ctxt.opt_level, + required_pass=curr_ctxt.required_pass, + disabled_pass=curr_ctxt.disabled_pass, + instruments=curr_ctxt.instruments, + config={**dict(curr_ctxt.config), "tir.add_lower_pass": lower_passes}, + ): + mod = tvm.lower(tvm.ir.IRModule.from_expr(prim_func)) + prim_func = mod[prim_func.attrs["global_symbol"]] + return prim_func + + def relay_to_tir(self, mod: tvm.ir.IRModule) -> tvm.ir.IRModule: + """ + This is the hook for python-based lowering of a Relay module which lowers NPU + external functions to TIR. + + Parameters + ---------- + mod : tvm.ir.IRModule + This is the Relay module. + + Returns + ------- + mod : tvm.ir.IRModule + The Relay module with scheduled NPU external functions. + """ + mod = _ffi_api.OutlineCompilerFunctions(self.target_name)(mod) + for gvar, func in mod.functions.items(): + if "Compiler" in func.attrs and func.attrs["Compiler"] == self.target_name: + func = self._lower_relay_to_tir(func) + func = self._lower_stir_to_nstir(func) + mod.update_func(gvar, func) + return mod + + def register(self) -> None: + """Register all relevant relay-to-tir functions.""" + tvm._ffi.register_func(f"relay.ext.uma.{self.target_name}.relay_to_tir", self.relay_to_tir) + for op, strategy, plevel in self._operator_strategies: + register_strategy(op, strategy, plevel) diff --git a/python/tvm/relay/backend/contrib/uma/api/partitioner.py b/python/tvm/relay/backend/contrib/uma/api/partitioner.py new file mode 100644 index 000000000000..48cac81d13d8 --- /dev/null +++ b/python/tvm/relay/backend/contrib/uma/api/partitioner.py @@ -0,0 +1,122 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Partitioner base class of the Universal Modular Accelerator Interface (UMA)""" + +from typing import Callable, Dict, List, Tuple, Optional + +import tvm +from tvm import relay +from tvm.relay.build_module import bind_params_by_name +from tvm.relay.op.contrib.register import register_pattern_table +from .utils import PassPhase + + +PatternTable = List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Callable]] + + +class UMAPartitioner: + """Partitioner base class of the Universal Modular Accelerator Interface (UMA).""" + + def __init__(self, target_name: str, merge_compiler_regions: bool = True) -> None: + self.target_name = target_name + self.merge_compiler_regions = merge_compiler_regions + + self._relay_passes: List[Tuple[PassPhase, tvm.transform.Pass]] = [] + self._patterns: PatternTable = [] + + def add_pattern( + self, + name: str, + pattern: tvm.relay.dataflow_pattern.DFPattern, + predicate: Optional[Callable] = None, + ) -> None: + """Add pattern to UMA partitioner + + Parameters + ---------- + name : str + relay name of pattern + + pattern: tvm.relay.dataflow_pattern.DFPattern + pattern description as DFPattern + + predicate: Optional[Callable] + Optional predicate + + """ + + name = self.target_name + "." + name + if predicate: + self._patterns.append((name, pattern, predicate)) + else: + self._patterns.append((name, pattern)) + + def _pattern_table(self) -> PatternTable: + return self._patterns + + def register(self) -> None: + """Register all relevant relay-to-relay functions.""" + register_pattern_table(self.target_name, self._pattern_table) + + def partition( + self, mod: tvm.IRModule, params: Optional[Dict[str, tvm.runtime.NDArray]] = None + ) -> tvm.IRModule: + """Partition the relay graph in parts supported and unsupported by the + target hardware accelerator. + + Parameters + ---------- + mod : tvm.IRModule + The relay module to be partitioned. + + params: Optional[Dict[str, tvm.runtime.NDArray]] + + Returns + ------- + out : tvm.IRModule + The partitioned relay module. + + """ + if params: + mod["main"] = bind_params_by_name(mod["main"], params) + + pass_sequence = [] + pass_sequence.extend( + [p[1] for p in self._relay_passes if p[0] == PassPhase.PRE_PARTITIONING] + ) + pass_sequence.append(relay.transform.MergeComposite(self._pattern_table())) + pass_sequence.append(relay.transform.AnnotateTarget(self.target_name)) + if self.merge_compiler_regions: + pass_sequence.append(relay.transform.MergeCompilerRegions()) + pass_sequence.append(relay.transform.PartitionGraph()) + pass_sequence.extend( + [p[1] for p in self._relay_passes if p[0] == PassPhase.POST_PARTITIONING_0] + ) + + sequential_passes = tvm.transform.Sequential(pass_sequence) + mod = sequential_passes(mod) + + # Defunctionalize the partitioned functions to allow lowering + for gvar, func in mod.functions.items(): + mod.update_func(gvar, relay.transform.Defunctionalization(func, mod)) + + post_partition_passes_1 = tvm.transform.Sequential( + [p[1] for p in self._relay_passes if p[0] == PassPhase.POST_PARTITIONING_1] + ) + mod = post_partition_passes_1(mod) + + return mod diff --git a/python/tvm/relay/backend/contrib/uma/api/utils.py b/python/tvm/relay/backend/contrib/uma/api/utils.py new file mode 100644 index 000000000000..e217fbf3d6ad --- /dev/null +++ b/python/tvm/relay/backend/contrib/uma/api/utils.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Utility methods for the Universal Modular Accelerator Interface (UMA)""" + +from enum import Enum, auto +import uuid + +import tvm +import tvm.tir +from tvm.contrib import utils, clang + + +def uma_available() -> bool: + registration_func = tvm.get_global_func( + "relay.backend.contrib.uma.RegisterTarget", allow_missing=True + ) + return registration_func is not None + + +class PassPhase(Enum): + """ + UMA pass phases: + + PRE_PARTITIONING: prior to UMA partitioning + POST_PARTITIONING_0: after UMA partitioning, before Defunctionalization + POST_PARTITIONING_1: after UMA partitioning and after Defunctionalization + TIR_PHASE_0: Generates the raw IR and loop levels. + TIR_PHASE_1: Flattens the array storage. + TIR_PHASE_2: Transforms loops, like unroll, vectorization and thread-binding. + TIR_PHASE_3: Does some cleanup work. + + Reference to TIR phases: src/driver/driver_api.c + """ + + PRE_PARTITIONING = auto() + POST_PARTITIONING_0 = auto() + POST_PARTITIONING_1 = auto() + TIR_PHASE_0 = auto() + TIR_PHASE_1 = auto() + TIR_PHASE_2 = auto() + TIR_PHASE_3 = auto() + + +def _c_to_llvm(c_code: str) -> str: + unique_filename = str(uuid.uuid4()) + temp = utils.tempdir() + ll_path = temp.relpath(f"{unique_filename}.ll") + ll_code = clang.create_llvm([c_code], output=ll_path) + return ll_code + + +def add_llvm_to_block( + sch: tvm.tir.Schedule, block_name: str, c_code_str: str = "" +) -> tvm.tir.Schedule: + block = sch.get_block(block_name) + loops = sch.get_loops(block) + assert len(loops) > 0 + sch.annotate(loops[0], "pragma_import_llvm", _c_to_llvm(c_code_str)) + return sch diff --git a/python/tvm/relay/backend/contrib/uma/backend.py b/python/tvm/relay/backend/contrib/uma/backend.py new file mode 100644 index 000000000000..40ec06e45367 --- /dev/null +++ b/python/tvm/relay/backend/contrib/uma/backend.py @@ -0,0 +1,293 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Backend base class of the Universal Modular Accelerator Interface (UMA)""" + +from abc import ABC, abstractmethod +from typing import Union, Dict, Callable, Optional, Any + +import tvm +from tvm.relay.backend.contrib.uma.api.codegen import UMACodegen +from tvm.relay.backend.contrib.uma.api.lower import UMALower +from tvm.relay.backend.contrib.uma.api.partitioner import UMAPartitioner +from tvm.relay.backend.contrib.uma.api.utils import PassPhase + + +class UMABackend(ABC): + """Backend base class of the Universal Modular Accelerator Interface (UMA)""" + + def __init__(self, merge_compiler_regions: bool = True) -> None: + self._target_attrs: Dict = {} + self._target_preprocessor: Callable[[str], Dict[str, Any]] = None + self._relay_to_relay = UMAPartitioner(self.target_name, merge_compiler_regions) + self._relay_to_tir = UMALower(self.target_name) + self._tir_to_runtime = UMACodegen(self.target_name) + + @property + @abstractmethod + def target_name(self) -> str: + """Name of the hardware target. + + Returns + ------- + out : str + The hardware target name. + """ + ... + + # Target configuration + def _register_target_attr( + self, + name: str, + default: Optional[Union[str, int, bool]] = "", + ) -> None: + """Register a target attribute name that can be used during target instantiation. + Parameters + ---------- + name: str + The name of the target attribute. + + default: Optional[Union[str, int, bool]] + A default value for the attribute. + If none is provided, the attribute will be treated as a string. + + Example + ------- + Here is an example of how two attribute options are registered. + + .. code-block:: python + + self._register_target_attr("attrA", default=0) + self._register_target_attr("attrB", default=False) + """ + self._target_attrs[name] = default + + # Relay to Relay function registration + def _register_relay_pass(self, phase: PassPhase, relay_pass: tvm.transform.Pass) -> None: + """Registers a relay pass at the given phase in the lowering process. + + Parameters + ---------- + phase: PassPhase + The phase at which the pass is registered. + + relay_pass: tvm.transform.Pass + The relay pass to be registered. + + Example + ------- + Here is an example of how two relay passes are registered. + Passes of the same phase are executed in the order they are registered. + + .. code-block:: python + + self._register_relay_pass(PassPhase.PRE_PARTITIONING, MyPassA) + self._register_relay_pass(PassPhase.POST_PARTITIONING, MyPassB) + + Where a relay pass can look like this: + + .. code-block:: python + + @tvm.ir.transform.module_pass(opt_level=0) + class MyPassA: + def transform_module(self, mod, ctx): + # My pass functionality... + return mod + """ + self._relay_to_relay._relay_passes.append((phase, relay_pass)) + + def _register_pattern( + self, + name: str, + pattern: tvm.relay.dataflow_pattern.DFPattern, + predicate: Optional[Callable] = None, + ) -> None: + """Registers a dataflow pattern that is used to partition the relay graph. + + Parameters + ---------- + name: str + The name of the pattern + + pattern: tvm.relay.dataflow_pattern.DFPattern + Relay DFPattern + + predicate: Optional[Callable] + Optional predicate for Relay DFPattern + Example + ------- + Here is an example of how two dataflow patterns are registered. + During partioning, patterns are searched in order of registration. + + .. code-block:: python + + self._register_pattern("conv1d", conv1d_pattern) + self._register_pattern("conv2d", conv2d_pattern) + + Where a dataflow pattern can look like this: + + .. code-block:: python + + conv1d_pattern = is_op("nn.conv1d")(wildcard(), wildcard()) + optional_bias = lambda x: is_op("nn.bias_add")(x, wildcard()) + optional_relu = lambda x: is_op("nn.relu")(x) + conv1d_pattern = conv1d_pattern.optional(optional_bias).optional(optional_relu) + """ + self._relay_to_relay.add_pattern(name, pattern, predicate) + + # Relay to TIR function registration + def _register_operator_strategy( + self, + op: str, + strategy: Callable[ + [tvm.ir.Attrs, tvm.ir.Array, tvm.ir.TensorType, tvm.target.Target], + tvm.relay.op.op.OpStrategy, + ], + plevel: Optional[int] = 11, + ) -> None: + """Registers an operator strategy that is used to partition the relay graph. + + Parameters + ---------- + op: str + The name of the operator for which this strategy will be registered. + + strategy: Callable[[tvm.ir.Attrs, tvm.ir.Array, tvm.ir.TensorType, tvm.target.Target], + tvm.relay.op.op.OpStrategy] + The strategy function. + + plevel: Optional[int] = 11 + The priority level of the strategy. Higher plevel equals higher priorization. + The TVM default for topi strategies is 10 so by default new UMA strategies are + always used. + + Example + ------- + Here is an example of how two operator strategies are registered. + + .. code-block:: python + + self._register_operator_strategy("nn.conv1d", custom_conv1d_strategy) + self._register_operator_strategy("nn.conv2d", custom_conv2d_strategy) + + Where a strategy function can look like this: + + .. code-block:: python + + @relay.op.strategy.override_native_generic_func("custom_conv1d_strategy") + def custom_conv1d_strategy(attrs, inputs, out_type, target): + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_conv1d(custom_conv1d_compute), + wrap_topi_schedule(custom_conv1d_schedule), + name="custom_conv1d.generic", + return strategy + """ + self._relay_to_tir._operator_strategies.append((op, strategy, plevel)) + + def _register_tir_pass( + self, phase: PassPhase, tir_pass: tvm.tir.transform.PrimFuncPass + ) -> None: + """Registers a TIR pass at the given phase in the lowering process. + + Parameters + ---------- + phase: PassPhase + The phase at which the pass is registered. + + tir_pass: tvm.tir.transform.PrimFuncPass + The TIR pass to be registered. + Example + ------- + Here is an example of how two TIR passes are registered. + Passes of the same phase are executed in the order they are registered. + + .. code-block:: python + + self._register_tir_pass(PassPhase.TIR_PHASE_0, MyPassA) + self._register_tir_pass(PassPhase.TIR_PHASE_1, MyPassB) + + Where a TIR pass can look like this: + + .. code-block:: python + + @tvm.tir.transform.prim_func_pass(opt_level=0) + class MyPassA: + def transform_function(self, func, mod, ctx): + # My pass functionality... + return func + """ + self._relay_to_tir._tir_passes.append((phase, tir_pass)) + + # TIR to runtime function registration + def _register_codegen(self, fmt: str = "c", **kwargs) -> None: + """Registers a codegen which is used in place of the default C-codegen. + + Parameters + ---------- + fmt: str + The codegen format. For now, only C-codegen is supported by UMA. + + **kwargs + Keyword arguments for the chosen codegen. + + Example + ------- + Here is an example of how the custom C-codegen is registered and configured. + Passes of the same phase are executed in the order they are registered. + + .. code-block:: python + + self._register_codegen( + fmt="c", includes=gen_includes + ) + + The C-codegen currently provides one hook which allows the user to insert code through + the python API. + - `includes` hooks into the include stream and allows insertion of custom includes. + + + The code generation functions can look like this: + + .. code-block:: python + + def gen_includes() -> str: + includes = "#include \n" + return includes + """ + self._tir_to_runtime._register_codegen(fmt, **kwargs) + + # Backend functions + def register(self) -> None: + """ + Registering UMABackend: + registering target attributes, relay_to_relay, relay_to_tir and tir_to_runtime + """ + registration_func = tvm.get_global_func("relay.backend.contrib.uma.RegisterTarget") + + for name, attr in self._target_attrs: + if attr is None: + raise ValueError("Target attribute None is not supported.") + + if registration_func(self.target_name, self._target_attrs): + self._relay_to_relay.register() + self._relay_to_tir.register() + self._tir_to_runtime.register() + + def partition( + self, mod: tvm.IRModule, params: Optional[Dict[str, tvm.runtime.NDArray]] = None + ) -> tvm.IRModule: + return self._relay_to_relay.partition(mod, params) diff --git a/python/tvm/relay/op/contrib/dnnl.py b/python/tvm/relay/op/contrib/dnnl.py index b8c49176ac8f..f76d4bd10daf 100644 --- a/python/tvm/relay/op/contrib/dnnl.py +++ b/python/tvm/relay/op/contrib/dnnl.py @@ -249,6 +249,8 @@ def predicate(expr): for e, op_name in zip([expr, expr.args[0]], ["sum", "bias_add"]): args = get_args(e) attrs = get_attrs(e.args[0]) + if attrs is None: + return False if not checker(attrs, args, op_name): return False return True diff --git a/python/tvm/testing/aot.py b/python/tvm/testing/aot.py index a87e61666d35..5d7fb62cd204 100644 --- a/python/tvm/testing/aot.py +++ b/python/tvm/testing/aot.py @@ -23,7 +23,6 @@ import shutil import subprocess import tarfile -import tempfile import logging from typing import Any, NamedTuple, Union, Optional, List, Dict import numpy as np @@ -837,8 +836,8 @@ def run_and_check_body(base_path): assert AOT_SUCCESS_TOKEN in run_log.read() if test_dir is None: - with tempfile.TemporaryDirectory() as tmpdir: - run_and_check_body(os.path.join(tmpdir, "test")) + tmpdir = utils.tempdir() + run_and_check_body(os.path.join(tmpdir.path, "test")) else: run_and_check_body(test_dir) @@ -854,7 +853,7 @@ def compile_and_run( enable_op_fusion: bool = True, data_linkage: AOTDataLinkage = None, use_runtime_executor: bool = True, - target: str = "c", + target: Union[str, tvm.target.Target, List[tvm.target.Target]] = "c", target_opts: Dict = None, test_dir: str = None, verbose: bool = False, @@ -874,6 +873,9 @@ def compile_and_run( for key, val in target_opts.items(): target += f" {key}={val}" + if isinstance(target, str): + target = tvm.target.Target(target) + compiled_test_mods = compile_models( models=models, interface_api=interface_api, @@ -883,7 +885,7 @@ def compile_and_run( enable_op_fusion=enable_op_fusion, pass_config=runner.pass_config, use_runtime_executor=use_runtime_executor, - target=tvm.target.Target(target), + target=target, schedule_name=schedule_name, ) diff --git a/src/relay/backend/contrib/uma/relay_to_tir.cc b/src/relay/backend/contrib/uma/relay_to_tir.cc new file mode 100644 index 000000000000..8aed69453158 --- /dev/null +++ b/src/relay/backend/contrib/uma/relay_to_tir.cc @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/backend/contrib/uma/codegen.cc + * + * \brief this file contains the target hooks for the Universal Modular Accelerator Interface (UMA). + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace tvm { +namespace relay { +namespace contrib { +namespace uma { + +// TODO(@mjklaiber, @manupa-arm, @areusch) move this to include +/*! + * \brief This mutator outlines functions that are marked with a named + * "Compiler" attribute. Functions that do not match this condition remain + * unaltered. + */ +class OutlineCompilerFunctionsMutator : public MixedModeMutator { + public: + explicit OutlineCompilerFunctionsMutator(const IRModule& mod, const std::string& compiler_name) + : mod_(mod), compiler_name_(compiler_name) {} + + Expr VisitExpr_(const LetNode* op) final { + auto pre_visit = [this](const LetNode* op) { + Expr var = this->VisitExpr(op->var); + Expr value = this->VisitExpr(op->value); + + // Outlineable function no longer needs let binding + if (this->CanOutlineExpr(value)) { + this->memo_[var] = value; + } + }; + auto post_visit = [this](const LetNode* op) { + // Rely on the Memoizer to cache pre-visit values + Expr value = this->VisitExpr(op->value); + Expr body = this->VisitExpr(op->body); + auto expr = GetRef(op); + + // Drop the let binding + if (this->CanOutlineExpr(value)) { + this->memo_[expr] = this->VisitExpr(op->body); + } else { + Var var = Downcast(this->VisitExpr(op->var)); + if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { + this->memo_[expr] = expr; + } else { + this->memo_[expr] = Let(var, value, body); + } + } + }; + ExpandANormalForm(op, pre_visit, post_visit); + return memo_[GetRef(op)]; + } + + Expr Rewrite_(const CallNode* pre, const Expr& post) override { + Call call = Downcast(post); + if (CanOutlineExpr(call->op)) { + Function func = Downcast(call->op); + auto gv_name = func->GetAttr("global_symbol").value_or(""); + ICHECK_NE(gv_name, "") + << "Function to be outlined must have global_symbol attribute, but didn't."; + GlobalVar gv(gv_name); + if (func->checked_type_.defined()) { + gv->checked_type_ = func->checked_type(); + } + mod_->Update(gv, func); + return Call(gv, call->args, call->attrs, call->type_args); + } + return post; + } + + private: + /*! + * \brief Check if the expr is a function and has the same + * compiler name as compiler_name_. + * + * \param expr The input expr. + * \return True if is outlineable else False. + */ + bool CanOutlineExpr(const Expr& expr) { + if (!expr->IsInstance()) { + return false; + } + Function func = Downcast(expr); + auto compiler = func->GetAttr(attr::kCompiler); + if (!compiler.defined()) { + return false; + } + if (compiler != compiler_name_) { + return false; + } + return true; + } + + /*! \brief The module that the pass will run on. */ + IRModule mod_; + /*! \brief The name of the compiler to enable outlining on external functions for. */ + std::string compiler_name_; +}; + +/*! + * \brief A pass to outline compiler specific functions. + */ +tvm::transform::Pass OutlineCompilerFunctions(const std::string& compiler_name) { + runtime::TypedPackedFunc pass_func = + [=](IRModule mod, transform::PassContext ctx) { + GlobalVar gv = mod->GetGlobalVar("main"); + Function main_func = Downcast(mod->Lookup("main")); + auto new_main_body = + OutlineCompilerFunctionsMutator(mod, compiler_name).VisitExpr(main_func->body); + if (!new_main_body.same_as(main_func->body)) { + Function new_main_func = WithFields(main_func, main_func->params, new_main_body); + mod->Update(gv, new_main_func); + } + return mod; + }; + return tvm::transform::CreateModulePass(pass_func, 0, + "relay.backend.contrib.uma.OutlineCompilerFunctions", {}); +} + +TVM_REGISTER_GLOBAL("relay.ext.uma.OutlineCompilerFunctions") + .set_body_typed(OutlineCompilerFunctions); + +/*! + * \brief This pass will lower UMA functions in a Relay module to scheduled TIR prim functions. + */ +tvm::transform::Pass RelayToTIR(String target_name) { + runtime::TypedPackedFunc pass_func = + [=](IRModule ir_module, transform::PassContext pass_context) { + auto relay_to_tir_pf = + tvm::runtime::Registry::Get("relay.ext.uma." + target_name + ".relay_to_tir"); + ICHECK(relay_to_tir_pf); + ir_module = (*relay_to_tir_pf)(ir_module); + return ir_module; + }; + return tvm::transform::CreateModulePass(pass_func, 0, "relay.contrib.uma.RelayToTIR", {}); +} + +} // namespace uma +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/contrib/uma/targets.cc b/src/relay/backend/contrib/uma/targets.cc new file mode 100644 index 000000000000..a17f6694f79f --- /dev/null +++ b/src/relay/backend/contrib/uma/targets.cc @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/backend/contrib/uma/targets.cc + * + * \brief this file contains the targets for the Universal Modular Accelerator Interface (UMA). + */ + +#include +#include + +namespace tvm { + +namespace relay { +namespace contrib { +namespace uma { +tvm::transform::Pass RelayToTIR(String target_name); +runtime::Module TIRToRuntime(IRModule mod, Target target); +} // namespace uma +} // namespace contrib +} // namespace relay + +TVM_REGISTER_GLOBAL("relay.backend.contrib.uma.RegisterTarget") + .set_body_typed([](String target_name, Map attr_options) -> bool { + // @todo(cgerum): We probably should get rid of target.register rather sooner than later + // And use a proper registry for uma backends + for (const String registered_target_name : ::tvm::TargetKindRegEntry::ListTargetKinds()) { + if (registered_target_name == target_name) { + return false; + } + } + + auto target_kind = + ::tvm::TargetKindRegEntry::RegisterOrGet(target_name) + .set_name() + .set_device_type(kDLCPU) + .add_attr_option>("keys") + .add_attr_option("tag") + .add_attr_option("device") + .add_attr_option("model") + .add_attr_option>("libs") + .add_attr_option("host") + .add_attr_option("from_device") + .set_attr(tvm::attr::kRelayToTIR, + relay::contrib::uma::RelayToTIR(target_name)) + .set_attr("TIRToRuntime", relay::contrib::uma::TIRToRuntime); + + for (auto& attr_option : attr_options) { + auto option_name = attr_option.first; + auto default_value = attr_option.second; + if (default_value->IsInstance()) { + target_kind.add_attr_option(option_name, Downcast(default_value)); + } else if (default_value->IsInstance()) { + target_kind.add_attr_option(option_name, Downcast(default_value)); + } else { + LOG(FATAL) << "Only String, Integer, or Bool are supported. Given attribute option type: " + << attr_option.second->GetTypeKey(); + } + } + return true; + }); + +} // namespace tvm diff --git a/src/relay/backend/contrib/uma/tir_to_runtime.cc b/src/relay/backend/contrib/uma/tir_to_runtime.cc new file mode 100644 index 000000000000..4b5cd4332476 --- /dev/null +++ b/src/relay/backend/contrib/uma/tir_to_runtime.cc @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include +#include + +#include "../../../../runtime/file_utils.h" +#include "../../../../target/source/codegen_c.h" +#include "../../../../target/source/codegen_c_host.h" + +namespace tvm { +using namespace tir; +namespace relay { +namespace contrib { +namespace uma { + +class UMACodegen : public codegen::CodeGenCHost { + public: + explicit UMACodegen(String target_str) : target_str_(target_str) {} + + void Init(bool output_ssa, bool emit_asserts) { + auto includes_pf = + tvm::runtime::Registry::Get("relay.ext.uma.codegen_c_includes_" + target_str_); + if (includes_pf) { + String includes = (*includes_pf)(); + decl_stream << includes; + } + std::unordered_set devices; + devices.insert(target_str_); + CodeGenCHost::Init(output_ssa, emit_asserts, target_str_, devices); + } + + /*! + * \brief Emit code that offloads a subgraph to the UMA target + * + * \return string of code that offloads a subgraph to the UMA target + */ + void AddFunction(const PrimFunc& prim_func) { CodeGenC::AddFunction(prim_func); } + + private: + String target_str_; +}; + +runtime::Module TIRToRuntime(IRModule mod, Target target) { + bool output_ssa = false; + bool emit_asserts = false; + UMACodegen codegen(target->kind->name); + Array function_names; + codegen.Init(output_ssa, emit_asserts); + for (auto kv : mod->functions) { + auto prim_func = Downcast(kv.second); + auto global_symbol = prim_func->GetAttr(tvm::attr::kGlobalSymbol); + function_names.push_back(global_symbol.value()); + codegen.AddFunction(prim_func); + } + std::string code = codegen.Finish(); + return codegen::CSourceModuleCreate(code, "c", function_names); +} + +} // namespace uma +}; // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc index 6f0a6114f3d9..4b2f6034730d 100644 --- a/src/support/libinfo.cc +++ b/src/support/libinfo.cc @@ -317,6 +317,7 @@ TVM_DLL Map GetLibInfo() { {"USE_VULKAN", TVM_INFO_USE_VULKAN}, {"USE_CLML", TVM_INFO_USE_CLML}, {"USE_CLML_GRAPH_EXECUTOR", TVM_INFO_USE_CLML_GRAPH_EXECUTOR}, + {"USE_UMA", TVM_INFO_USE_UMA}, }; return result; } diff --git a/tests/python/contrib/test_uma/test_partition.py b/tests/python/contrib/test_uma/test_partition.py new file mode 100644 index 000000000000..ec2107f881bc --- /dev/null +++ b/tests/python/contrib/test_uma/test_partition.py @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +import tvm +import tvm.relay as relay + +from tvm.relay.backend.contrib.uma.api import UMAPartitioner +from tvm.relay.op.contrib.register import get_pattern_table +from tvm.relay.testing import resnet, mlp +from tvm.relay.backend.contrib.uma import uma_available + + +pytestmark = pytest.mark.skipif(not uma_available(), reason="UMA not available") + + +def test_partition_table(): + partitioner = UMAPartitioner("test_partition") + assert get_pattern_table("test_partition") is None + + partitioner.register() + + assert get_pattern_table("test_partition") is not None + + +@pytest.mark.parametrize( + "workload,backend,merge", + [ + ("resnet", "dnnl", False), + ("resnet", "dnnl", True), + ("mlp", "dnnl", False), + ("mlp", "dnnl", True), + ("resnet", "cutlass", False), + ("resnet", "cutlass", True), + ("mlp", "cutlass", False), + ("mlp", "cutlass", True), + ], +) +def test_existing_pattern_tables(workload, backend, merge): + """Tests that uma partitioner creates the same partitions than default BYOC partitioning""" + partitioner = UMAPartitioner(backend, merge) + pattern_table = get_pattern_table(backend) + + for entry in pattern_table: + partitioner.add_pattern(*entry) + + if workload == "resnet": + net = resnet.get_net(1, 10) + elif workload == "mlp": + net = mlp.get_net(1, 10) + else: + assert False, f"don't know how to find workload for {workload}" + + mod = tvm.ir.IRModule() + mod["main"] = net + + partitioner.register() + partitioned_mod = partitioner.partition(mod) + + def partition_default(mod): + """partitions using default BYOC flow""" + + sequence = [ + relay.transform.MergeComposite(pattern_table), + relay.transform.AnnotateTarget(backend), + ] + + if merge: + sequence.append(relay.transform.MergeCompilerRegions()) + + sequence.append(relay.transform.PartitionGraph()) + sequential = tvm.transform.Sequential(sequence) + + return sequential(mod) + + default_partitioned_mod = partition_default(mod) + + assert len(partitioned_mod.functions) == len(default_partitioned_mod.functions) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/contrib/test_uma/test_target.py b/tests/python/contrib/test_uma/test_target.py new file mode 100644 index 000000000000..558c4e518230 --- /dev/null +++ b/tests/python/contrib/test_uma/test_target.py @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Union + +import pytest +import tvm +from tests.python.contrib.test_uma.test_uma_vanilla_accelerator import VanillaAcceleratorBackend +from tvm.relay.backend.contrib.uma import uma_available + +pytestmark = pytest.mark.skipif(not uma_available(), reason="UMA not available") + + +@pytest.mark.parametrize( + "target_name,target_attrs,target_args", + [ + ("my_hwa", {}, {}), + ( + "my_hwa2", + { + "local_memory_size": 128 * 1024, + "variant": "version1", + }, + {"local_memory_size": 256 * 1024, "variant": "version2"}, + ), + ], +) +def test_uma_target(target_name, target_attrs, target_args): + registration_func = tvm.get_global_func("relay.backend.contrib.uma.RegisterTarget") + registration_func(target_name, target_attrs) + + # Test Defaults + my_target = tvm.target.Target(target_name) + + assert str(my_target.kind) == target_name + + for attr in target_attrs.keys(): + assert my_target.attrs[attr] == target_attrs[attr] + + # Test with parameters overwritten + args = " ".join((f"--{k}={v}" for k, v in target_args.items())) + my_target = tvm.target.Target(f"{target_name} {args}") + + for attr in target_args.keys(): + assert my_target.attrs[attr] == target_args[attr] + + +@pytest.mark.parametrize( + "attr_name, target_attr", + [ + ("float_attr", 3.14), + ("none_attr", None), + ], +) +def test_invalid_attr_option(attr_name: str, target_attr: Union[str, int, bool, float, None]): + if target_attr is None: + # None cannot be caught as TVMError, as it causes a SIGKILL, therefore it must be prevented to be + # entered into relay.backend.contrib.uma.RegisterTarget at Python level. + with pytest.raises(ValueError): + uma_backend = VanillaAcceleratorBackend() + uma_backend._target_attrs = {attr_name: target_attr} + uma_backend.register() + else: + registration_func = tvm.get_global_func("relay.backend.contrib.uma.RegisterTarget") + target_name = f"{attr_name}_{target_attr}" + target_attr = {attr_name: target_attr} + with pytest.raises(tvm.TVMError, match=r"Only String, Integer, or Bool are supported. .*"): + registration_func(target_name, target_attr) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/contrib/test_uma/test_uma_lowering_with_umalower.py b/tests/python/contrib/test_uma/test_uma_lowering_with_umalower.py new file mode 100644 index 000000000000..d2e0af05e3ee --- /dev/null +++ b/tests/python/contrib/test_uma/test_uma_lowering_with_umalower.py @@ -0,0 +1,121 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import pathlib + +import tvm +from tests.python.contrib.test_uma.test_uma_utils import _create_schedule, _generate_io_arrays +from tvm import topi +from apps.uma._template.passes import MyAiHwConv2dPass +import tvm.testing +from tvm import te +from tvm.relay.backend.contrib.uma.api.lower import UMALower +from tvm.relay.backend.contrib.uma.api.utils import PassPhase +from tvm.relay.backend.contrib.uma import uma_available + + +pytestmark = pytest.mark.skipif(not uma_available(), reason="UMA not available") + + +def _conv2d_te_definition(shapes: dict) -> list: + n, w, h, ci, kw, kh, co = ( + shapes["n"], + shapes["w"], + shapes["h"], + shapes["ci"], + shapes["kw"], + shapes["kh"], + shapes["co"], + ) + ifmap = te.placeholder((n, ci, w, h), dtype="float32", name="ifmap") + weights = te.placeholder((co, ci, kw, kh), dtype="float32", name="weights") + result = topi.nn.conv2d_nchw(ifmap, weights, stride=1, padding=[kw // 2, kh // 2], dilation=1) + return [ifmap, weights, result] + + +def _pepare_conv2d_schedule(shapes, use_external_conv2d_impl=True): + placeholders = _conv2d_te_definition(shapes) + + apps_path = ( + pathlib.Path(str(__file__)).parent.parent.parent.parent.parent.joinpath("apps").absolute() + ) + conv2d_file = apps_path / "uma" / "_template" / "conv2dnchw.cc" + + with conv2d_file.open() as f: + sch_tir = _create_schedule( + placeholders, f, use_external_conv2d_impl=use_external_conv2d_impl + ) + return placeholders, sch_tir + + +def _run_external_conv2d(dut_io_arrays, conv2d_shapes, target): + # Run conv2d with external function + placeholders, schedule = _pepare_conv2d_schedule(conv2d_shapes) + + uma_lower = UMALower("lower_test") + uma_lower._tir_passes.append((PassPhase.TIR_PHASE_0, MyAiHwConv2dPass())) + with tvm.transform.PassContext(): + tir_mod = uma_lower._lower_stir_to_nstir(schedule.mod["main"]) + + ifmap_data, weight_data, result_data = dut_io_arrays + + llvm_conv2d_mod = tvm.build(tir_mod, placeholders, target=target, name="test_external_conv2d") + llvm_conv2d_mod(ifmap_data, weight_data, result_data) + + +def _run_reference_conv2d(reference_io_arrays, conv2d_shapes, target): + placeholders, schedule = _pepare_conv2d_schedule(conv2d_shapes) + ref_mod = tvm.build(schedule.mod, placeholders, target=target, name="test_reference_conv2d") + ifmap, weights, result = reference_io_arrays + ref_mod(ifmap, weights, result) + + +def _prepare_io_arrays(conv2d_shapes, dev): + dut_io_arrays = _generate_io_arrays(conv2d_shapes, dev) + _, _, ref_result = _generate_io_arrays(conv2d_shapes, dev) + reference_io_arrays = [dut_io_arrays[0], dut_io_arrays[1], ref_result] + return dut_io_arrays, reference_io_arrays + + +@pytest.mark.parametrize( + "n, w, h, ci, kw, kh, co", + [ + (1, 224, 224, 3, 3, 3, 4), + (1, 224, 224, 3, 5, 5, 4), + (1, 224, 224, 3, 7, 7, 4), + (1, 224, 320, 3, 7, 7, 4), + (1, 224, 224, 3, 7, 7, 4), + ], +) +def test_lower_with_uma(n, w, h, ci, kw, kh, co): + target = tvm.target.Target(target="llvm", host="llvm") + dev = tvm.device(target.kind.name, 0) + conv2d_shapes = dict(n=n, w=w, h=h, ci=ci, kw=kw, kh=kh, co=co) + + dut_io_arrays, reference_io_arrays = _prepare_io_arrays(conv2d_shapes, dev) + + _run_external_conv2d(dut_io_arrays, conv2d_shapes, target) + _run_reference_conv2d(reference_io_arrays, conv2d_shapes, target) + + # compare results + dut_results = dut_io_arrays[2].numpy() + ref_results = reference_io_arrays[2].numpy() + tvm.testing.assert_allclose(dut_results, ref_results, rtol=1e-5) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/contrib/test_uma/test_uma_pipeline.py b/tests/python/contrib/test_uma/test_uma_pipeline.py new file mode 100644 index 000000000000..49b4a196bbd4 --- /dev/null +++ b/tests/python/contrib/test_uma/test_uma_pipeline.py @@ -0,0 +1,136 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +from tvm.micro.testing.aot_test_utils import AOT_DEFAULT_RUNNER +from tvm.relay import transform, testing +from tvm.testing.aot import ( + AOTTestModel, + AOTTestRunner, + generate_ref_data, + compile_and_run, +) + +import tvm +from test_uma_vanilla_accelerator import VanillaAcceleratorBackend +from tvm import relay +import numpy as np +from collections import OrderedDict + +from tvm.relay.backend.contrib.uma.api.utils import uma_available + +pytestmark = pytest.mark.skipif(not uma_available(), reason="UMA not available") + + +@pytest.mark.parametrize( + "interface_api,use_unpacked_api,test_runner,groups,weight_shape", + [("c", True, AOT_DEFAULT_RUNNER, 1, 32)], +) +def test_conv2d(interface_api, use_unpacked_api, test_runner, groups, weight_shape): + """Test a subgraph with a single conv2d operator.""" + mod, inputs, output_list, test_runner = create_conv2d(groups, test_runner, weight_shape) + + uma_backend = VanillaAcceleratorBackend() + uma_backend.register() + mod = uma_backend.partition(mod) + target = tvm.target.Target("vanilla_accelerator", host=tvm.target.Target("c")) + + compile_and_run( + AOTTestModel(module=mod, inputs=inputs, outputs=output_list), + test_runner, + interface_api, + use_unpacked_api, + target=target, + ) + + +def create_conv2d(groups=1, test_runner=AOT_DEFAULT_RUNNER, weight_shape=32): + dtype = "float32" + ishape = (1, 32, 14, 14) + wshape = (32, weight_shape, 3, 3) + pass_config = {"tir.usmp.enable": True} + test_runner = AOTTestRunner( + makefile=test_runner.makefile, + prologue=test_runner.prologue, + epilogue=test_runner.epilogue, + includes=test_runner.includes, + parameters=test_runner.parameters, + pass_config=pass_config, + ) + data0 = relay.var("data", shape=ishape, dtype=dtype) + weight0 = relay.var("weight", shape=wshape, dtype=dtype) + out = relay.nn.conv2d(data0, weight0, kernel_size=(3, 3), padding=(1, 1), groups=groups) + main_f = relay.Function([data0, weight0], out) + mod = tvm.IRModule() + mod["main"] = main_f + mod = transform.InferType()(mod) + i_data = np.random.uniform(0, 1, ishape).astype(dtype) + w1_data = np.random.uniform(0, 1, wshape).astype(dtype) + inputs = OrderedDict([("data", i_data), ("weight", w1_data)]) + output_list = generate_ref_data(mod, inputs) + return mod, inputs, output_list, test_runner + + +def _generate_runtime_data(input_shapes: dict, output_shapes: dict) -> [OrderedDict, OrderedDict]: + assert len(input_shapes) == 1 + assert len(output_shapes) == 1 + + iname = list(input_shapes.keys())[0] + oname = list(output_shapes.keys())[0] + ishape = input_shapes[iname] + oshape = output_shapes[oname] + i_data = np.random.uniform(0, 1, ishape).astype("float32") + o_data = np.random.uniform(0, 1, oshape).astype("float32") + oname = "output" # name set by relay.build in executor_codegen_metadata.outputs + inputs = OrderedDict([(iname, i_data)]) + outputs = OrderedDict([(oname, o_data)]) + return inputs, outputs + + +def test_mobilenet(): + """Full network test with Mobilenet""" + use_unpacked_api = True + interface_api = "c" + test_runner = AOT_DEFAULT_RUNNER + + mod, params = testing.mobilenet.get_workload(batch_size=1) + + uma_backend = VanillaAcceleratorBackend() + uma_backend.register() + target = tvm.target.Target("vanilla_accelerator", host=tvm.target.Target("c")) + target_c = tvm.target.Target("c") + + data_shape = [int(x) for x in mod["main"].checked_type.arg_types[0].shape] + data = np.random.uniform(size=data_shape).astype("float32") + input_list = {"data": data} + output_list = generate_ref_data(mod, input_list, params) + mod = uma_backend.partition(mod) + aot_test_model = AOTTestModel(module=mod, inputs=input_list, outputs=output_list, params=params) + + compile_and_run( + aot_test_model, + test_runner, + interface_api, + use_unpacked_api, + workspace_byte_alignment=1, + debug_calculated_workspaces=False, + target=[target_c, target], + ) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/contrib/test_uma/test_uma_utils.py b/tests/python/contrib/test_uma/test_uma_utils.py new file mode 100644 index 000000000000..933602806f0e --- /dev/null +++ b/tests/python/contrib/test_uma/test_uma_utils.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import io + +import tvm +from tvm import topi, IRModule +import numpy as np +from tvm.contrib import utils, clang +import tvm.testing +from tvm import te +from typing import Union + + +def _create_schedule( + placeholder: list, + c_code: Union[str, io.TextIOWrapper] = "", + use_external_conv2d_impl: bool = True, +): + # How to do the same with TE + # Add pragma TE + # s = te.create_schedule(result.op) + # axis = result.op.axis + # s[result].pragma(axis[0], "import_llvm", c_to_llvm()) + # with tvm.transform.PassContext(config={"tir.add_lower_pass": [(1, my_ai_hw_conv2d_pass)]}): + # mod = tvm.lower(s, [ifmap, weights, result], simple_mode=True) + # + # llvm_mod = tvm.build(mod, [ifmap, weights, result], target=target, name="test_external_conv2d") + # llvm_mod(ifmap_data, weight_data, result_data) + if isinstance(c_code, io.TextIOWrapper): + c_code_str = c_code.read() + elif isinstance(c_code, str): + c_code_str = c_code + else: + raise TypeError() + + assert ( + use_external_conv2d_impl + and c_code_str != "" + or not use_external_conv2d_impl + and c_code_str == "" + ) + + def _c_to_llvm(c_code: str) -> str: + temp = utils.tempdir() + ll_path = temp.relpath("conv2d.ll") + ll_code = clang.create_llvm([c_code], output=ll_path) + return ll_code + + func_tir = te.create_prim_func(placeholder) + ir_module_from_te = IRModule({"main": func_tir}) + sch_tir = tvm.tir.Schedule(ir_module_from_te) + if use_external_conv2d_impl: + conv2d_b = sch_tir.get_block("conv2d_nchw") + conv2d_l = sch_tir.get_loops(conv2d_b) + sch_tir.annotate(conv2d_l[0], "pragma_import_llvm", _c_to_llvm(c_code_str)) + return sch_tir + + +def _generate_io_arrays(shapes: dict, dev): + n, w, h, ci, kw, kh, co = ( + shapes["n"], + shapes["w"], + shapes["h"], + shapes["ci"], + shapes["kw"], + shapes["kh"], + shapes["co"], + ) + + ifmap_data = tvm.nd.array(np.random.uniform(size=(n, ci, w, h)).astype("float32"), dev) + weight_data = tvm.nd.array(np.random.uniform(size=(co, ci, kh, kw)).astype("float32"), dev) + result_data = tvm.nd.array(np.zeros((n, co, w, h)).astype("float32"), dev) + return ifmap_data, weight_data, result_data diff --git a/tests/python/contrib/test_uma/test_uma_vanilla_accelerator.py b/tests/python/contrib/test_uma/test_uma_vanilla_accelerator.py new file mode 100644 index 000000000000..e7a6b21d4ab5 --- /dev/null +++ b/tests/python/contrib/test_uma/test_uma_vanilla_accelerator.py @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""UMA testcase for the vanilla_accelerator accelerator""" +import pytest + +from tvm.relay.backend.contrib.uma.api.utils import PassPhase +from tvm.relay.backend.contrib.uma.backend import UMABackend +from apps.uma._template.passes import ( + MyAiHwConv2dPass as VanillaAcceleratorConv2dPass, +) +from apps.uma._template.codegen import gen_includes + +from apps.uma._template.patterns import conv2d_pattern +from tvm.relay.backend.contrib.uma import uma_available + +pytestmark = pytest.mark.skipif(not uma_available(), reason="UMA not available") + + +class VanillaAcceleratorBackend(UMABackend): + """UMA backend for the VanillaAccelerator accelerator.""" + + def __init__(self): + super().__init__() + + ####################################################################### + # Relay to Relay function registration + ####################################################################### + self._register_pattern("conv2d", conv2d_pattern()) + + ####################################################################### + # Relay to TIR function registration + ####################################################################### + self._register_tir_pass(PassPhase.TIR_PHASE_0, VanillaAcceleratorConv2dPass()) + + ####################################################################### + # TIR to runtime function registration + ####################################################################### + self._register_codegen(fmt="c", includes=gen_includes) + + @property + def target_name(self): + return "vanilla_accelerator" diff --git a/tests/scripts/task_config_build_arm.sh b/tests/scripts/task_config_build_arm.sh index 189bdc250a8c..a01c1ed6d082 100755 --- a/tests/scripts/task_config_build_arm.sh +++ b/tests/scripts/task_config_build_arm.sh @@ -34,4 +34,5 @@ echo set\(USE_VTA_FSIM ON\) >> config.cmake echo set\(USE_ARM_COMPUTE_LIB ON\) >> config.cmake echo set\(USE_ARM_COMPUTE_LIB_GRAPH_EXECUTOR "/opt/acl"\) >> config.cmake echo set\(USE_CCACHE OFF\) >> config.cmake +echo set\(USE_UMA ON\) >> config.cmake echo set\(SUMMARIZE ON\) >> config.cmake diff --git a/tests/scripts/task_config_build_cortexm.sh b/tests/scripts/task_config_build_cortexm.sh index 29869983b86d..35dbd82110cd 100755 --- a/tests/scripts/task_config_build_cortexm.sh +++ b/tests/scripts/task_config_build_cortexm.sh @@ -27,9 +27,11 @@ echo set\(USE_SORT ON\) >> config.cmake echo set\(USE_MICRO ON\) >> config.cmake echo set\(USE_CMSISNN ON\) >> config.cmake echo set\(USE_ETHOSU ON\) >> config.cmake +echo set\(USE_UMA ON\) >> config.cmake echo set\(USE_PROFILER ON\) >> config.cmake echo set\(USE_LLVM llvm-config-10\) >> config.cmake echo set\(CMAKE_CXX_FLAGS -Werror\) >> config.cmake echo set\(HIDE_PRIVATE_SYMBOLS ON\) >> config.cmake echo set\(USE_CCACHE OFF\) >> config.cmake echo set\(SUMMARIZE ON\) >> config.cmake + diff --git a/tests/scripts/task_config_build_cpu.sh b/tests/scripts/task_config_build_cpu.sh index 84213be860dc..9dc5c62efaa7 100755 --- a/tests/scripts/task_config_build_cpu.sh +++ b/tests/scripts/task_config_build_cpu.sh @@ -48,4 +48,5 @@ echo set\(USE_VERILATOR ON\) >> config.cmake echo set\(USE_LIBBACKTRACE ON\) >> config.cmake echo set\(USE_CCACHE OFF\) >> config.cmake echo set\(USE_ETHOSU ON\) >> config.cmake +echo set\(USE_UMA ON\) >> config.cmake echo set\(SUMMARIZE ON\) >> config.cmake diff --git a/tests/scripts/task_config_build_i386.sh b/tests/scripts/task_config_build_i386.sh index c92aed3c1450..a570e9801ad3 100755 --- a/tests/scripts/task_config_build_i386.sh +++ b/tests/scripts/task_config_build_i386.sh @@ -34,5 +34,6 @@ echo set\(USE_VTA_FSIM ON\) >> config.cmake echo set\(USE_VTA_TSIM ON\) >> config.cmake echo set\(USE_VERILATOR ON\) >> config.cmake echo set\(USE_CCACHE OFF\) >> config.cmake +echo set\(USE_UMA OFF\) >> config.cmake echo set\(SUMMARIZE ON\) >> config.cmake