From 845a627a5995fc6ae38c41aba8c7dead408919db Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Tue, 30 Nov 2021 18:52:08 -0800 Subject: [PATCH] [microTVM][TVMC] Add TVMC test for Arduino and Zephyr (#9584) * Add TVMC test for Arduino and Zephyr * Address @gromero comments * address comments --- .../template_project/microtvm_api_server.py | 56 +++-- .../template_project/microtvm_api_server.py | 29 ++- python/tvm/micro/project.py | 21 +- tests/micro/__init__.py | 16 ++ tests/micro/arduino/conftest.py | 84 +------- .../arduino/test_arduino_error_detection.py | 9 +- .../micro/arduino/test_arduino_rpc_server.py | 20 +- tests/micro/arduino/test_arduino_workflow.py | 8 +- tests/micro/arduino/test_utils.py | 100 +++++++++ tests/micro/common/__init__.py | 16 ++ tests/micro/common/conftest.py | 47 +++++ tests/micro/common/test_tvmc.py | 199 ++++++++++++++++++ tests/scripts/task_python_microtvm.sh | 4 + 13 files changed, 471 insertions(+), 138 deletions(-) create mode 100644 tests/micro/__init__.py create mode 100644 tests/micro/arduino/test_utils.py create mode 100644 tests/micro/common/__init__.py create mode 100644 tests/micro/common/conftest.py create mode 100644 tests/micro/common/test_tvmc.py diff --git a/apps/microtvm/arduino/template_project/microtvm_api_server.py b/apps/microtvm/arduino/template_project/microtvm_api_server.py index 3039eb313908..bb4b54d8fb27 100644 --- a/apps/microtvm/arduino/template_project/microtvm_api_server.py +++ b/apps/microtvm/arduino/template_project/microtvm_api_server.py @@ -53,6 +53,8 @@ BOARDS = API_SERVER_DIR / "boards.json" +ARDUINO_CLI_CMD = shutil.which("arduino-cli") + # Data structure to hold the information microtvm_api_server.py needs # to communicate with each of these boards. try: @@ -78,7 +80,15 @@ class BoardAutodetectFailed(Exception): ), server.ProjectOption( "arduino_cli_cmd", - required=["build", "flash", "open_transport"], + required=( + ["generate_project", "build", "flash", "open_transport"] + if not ARDUINO_CLI_CMD + else None + ), + optional=( + ["generate_project", "build", "flash", "open_transport"] if ARDUINO_CLI_CMD else None + ), + default=ARDUINO_CLI_CMD, type="str", help="Path to the arduino-cli tool.", ), @@ -247,22 +257,21 @@ def _convert_includes(self, project_dir, source_dir): """ for ext in ("c", "h", "cpp"): for filename in source_dir.rglob(f"*.{ext}"): - with filename.open() as file: - lines = file.readlines() - - for i in range(len(lines)): - # Check if line has an include - result = re.search(r"#include\s*[<\"]([^>]*)[>\"]", lines[i]) - if not result: - continue - new_include = self._find_modified_include_path( - project_dir, filename, result.groups()[0] - ) - - lines[i] = f'#include "{new_include}"\n' - - with filename.open("w") as file: - file.writelines(lines) + with filename.open("rb") as src_file: + lines = src_file.readlines() + with filename.open("wb") as dst_file: + for i, line in enumerate(lines): + line_str = str(line, "utf-8") + # Check if line has an include + result = re.search(r"#include\s*[<\"]([^>]*)[>\"]", line_str) + if not result: + dst_file.write(line) + else: + new_include = self._find_modified_include_path( + project_dir, filename, result.groups()[0] + ) + updated_line = f'#include "{new_include}"\n' + dst_file.write(updated_line.encode("utf-8")) # Most of the files we used to be able to point to directly are under "src/standalone_crt/include/". # Howver, crt_config.h lives under "src/standalone_crt/crt_config/", and more exceptions might @@ -317,7 +326,7 @@ def _get_platform_version(self, arduino_cli_path: str) -> float: def generate_project(self, model_library_format_path, standalone_crt_dir, project_dir, options): # Check Arduino version - version = self._get_platform_version(options["arduino_cli_cmd"]) + version = self._get_platform_version(self._get_arduino_cli_cmd(options)) if version != ARDUINO_CLI_VERSION: message = f"Arduino CLI version found is not supported: found {version}, expected {ARDUINO_CLI_VERSION}." if options.get("warning_as_error") is not None and options["warning_as_error"]: @@ -367,7 +376,7 @@ def build(self, options): BUILD_DIR.mkdir() compile_cmd = [ - options["arduino_cli_cmd"], + self._get_arduino_cli_cmd(options), "compile", "./project/", "--fqbn", @@ -384,6 +393,11 @@ def build(self, options): BOARD_LIST_HEADERS = ("Port", "Type", "Board Name", "FQBN", "Core") + def _get_arduino_cli_cmd(self, options: dict): + arduino_cli_cmd = options.get("arduino_cli_cmd", ARDUINO_CLI_CMD) + assert arduino_cli_cmd, "'arduino_cli_cmd' command not passed and not found by default!" + return arduino_cli_cmd + def _parse_boards_tabular_str(self, tabular_str): """Parses the tabular output from `arduino-cli board list` into a 2D array @@ -416,7 +430,7 @@ def _parse_boards_tabular_str(self, tabular_str): yield parsed_row def _auto_detect_port(self, options): - list_cmd = [options["arduino_cli_cmd"], "board", "list"] + list_cmd = [self._get_arduino_cli_cmd(options), "board", "list"] list_cmd_output = subprocess.run( list_cmd, check=True, stdout=subprocess.PIPE ).stdout.decode("utf-8") @@ -442,7 +456,7 @@ def flash(self, options): port = self._get_arduino_port(options) upload_cmd = [ - options["arduino_cli_cmd"], + self._get_arduino_cli_cmd(options), "upload", "./project", "--fqbn", diff --git a/apps/microtvm/zephyr/template_project/microtvm_api_server.py b/apps/microtvm/zephyr/template_project/microtvm_api_server.py index 3c96f31dfe22..8069f6dcd390 100644 --- a/apps/microtvm/zephyr/template_project/microtvm_api_server.py +++ b/apps/microtvm/zephyr/template_project/microtvm_api_server.py @@ -66,6 +66,10 @@ ZEPHYR_VERSION = 2.5 +WEST_CMD = default = sys.executable + " -m west" if sys.executable else None + +ZEPHYR_BASE = os.getenv("ZEPHYR_BASE") + # Data structure to hold the information microtvm_api_server.py needs # to communicate with each of these boards. try: @@ -271,8 +275,8 @@ def _get_nrf_device_args(options): ), server.ProjectOption( "west_cmd", - optional=["generate_project"], - default=sys.executable + " -m west" if sys.executable else None, + optional=["build"], + default=WEST_CMD, type="str", help=( "Path to the west tool. If given, supersedes both the zephyr_base " @@ -281,8 +285,9 @@ def _get_nrf_device_args(options): ), server.ProjectOption( "zephyr_base", - optional=["build", "open_transport"], - default=os.getenv("ZEPHYR_BASE"), + required=(["generate_project", "open_transport"] if not ZEPHYR_BASE else None), + optional=(["generate_project", "open_transport", "build"] if ZEPHYR_BASE else ["build"]), + default=ZEPHYR_BASE, type="str", help="Path to the zephyr base directory.", ), @@ -314,6 +319,13 @@ def _get_nrf_device_args(options): ] +def get_zephyr_base(options: dict): + """Returns Zephyr base path""" + zephyr_base = options.get("zephyr_base", ZEPHYR_BASE) + assert zephyr_base, "'zephyr_base' option not passed and not found by default!" + return zephyr_base + + class Handler(server.ProjectAPIHandler): def __init__(self): super(Handler, self).__init__() @@ -388,8 +400,8 @@ def _create_prj_conf(self, project_dir, options): "aot_demo": "memory microtvm_rpc_common common", } - def _get_platform_version(self) -> float: - with open(pathlib.Path(os.getenv("ZEPHYR_BASE")) / "VERSION", "r") as f: + def _get_platform_version(self, zephyr_base: str) -> float: + with open(pathlib.Path(zephyr_base) / "VERSION", "r") as f: lines = f.readlines() for line in lines: line = line.replace(" ", "").replace("\n", "").replace("\r", "") @@ -402,7 +414,7 @@ def _get_platform_version(self) -> float: def generate_project(self, model_library_format_path, standalone_crt_dir, project_dir, options): # Check Zephyr version - version = self._get_platform_version() + version = self._get_platform_version(get_zephyr_base(options)) if version != ZEPHYR_VERSION: message = f"Zephyr version found is not supported: found {version}, expected {ZEPHYR_VERSION}." if options.get("warning_as_error") is not None and options["warning_as_error"]: @@ -574,8 +586,7 @@ def _set_nonblock(fd): class ZephyrSerialTransport: @classmethod def _lookup_baud_rate(cls, options): - zephyr_base = options.get("zephyr_base", os.environ["ZEPHYR_BASE"]) - sys.path.insert(0, os.path.join(zephyr_base, "scripts", "dts")) + sys.path.insert(0, os.path.join(get_zephyr_base(options), "scripts", "dts")) try: import dtlib # pylint: disable=import-outside-toplevel finally: diff --git a/python/tvm/micro/project.py b/python/tvm/micro/project.py index 907590dcd2cf..87cd509fd15f 100644 --- a/python/tvm/micro/project.py +++ b/python/tvm/micro/project.py @@ -56,6 +56,17 @@ def read(self, n, timeout_sec): return self._api_client.read_transport(n, timeout_sec)["data"] +def prepare_options(received_options: dict, all_options: dict) -> dict: + """Add default value of options that are not passed to the project options.""" + options_default = {option["name"]: option["default"] for option in all_options} + prepared_options = dict(received_options) if received_options else dict() + + for option, def_val in options_default.items(): + if option not in prepared_options and def_val is not None: + prepared_options[option] = def_val + return prepared_options + + class TemplateProjectError(Exception): """Raised when the Project API server given to GeneratedProject reports is_template=True.""" @@ -69,10 +80,10 @@ def from_directory(cls, project_dir: Union[pathlib.Path, str], options: dict): def __init__(self, api_client, options): self._api_client = api_client - self._options = options self._info = self._api_client.server_info_query(__version__) if self._info["is_template"]: raise TemplateProjectError() + self._options = prepare_options(options, self.info()["project_options"]) def build(self): self._api_client.build(self._options) @@ -121,16 +132,18 @@ def _check_project_options(self, options: dict): Here is a list of available options:{list(available_options)}.""" ) - def generate_project_from_mlf(self, model_library_format_path, project_dir, options): + def generate_project_from_mlf(self, model_library_format_path, project_dir, options: dict): + """Generate a project from MLF file.""" self._check_project_options(options) + prepared_options = prepare_options(options, self.info()["project_options"]) self._api_client.generate_project( model_library_format_path=str(model_library_format_path), standalone_crt_dir=get_standalone_crt_dir(), project_dir=project_dir, - options=options, + options=prepared_options, ) - return GeneratedProject.from_directory(project_dir, options) + return GeneratedProject.from_directory(project_dir, prepared_options) def info(self): return self._info diff --git a/tests/micro/__init__.py b/tests/micro/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/tests/micro/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/micro/arduino/conftest.py b/tests/micro/arduino/conftest.py index 71cc810affe3..bb4db18f7886 100644 --- a/tests/micro/arduino/conftest.py +++ b/tests/micro/arduino/conftest.py @@ -15,34 +15,9 @@ # specific language governing permissions and limitations # under the License. -import datetime -import pathlib -import json import pytest -import tvm.target.target -from tvm.micro import project -from tvm import relay -from tvm.relay.backend import Executor, Runtime - -TEMPLATE_PROJECT_DIR = pathlib.Path(tvm.micro.get_microtvm_template_projects("arduino")) - - -BOARDS = TEMPLATE_PROJECT_DIR / "boards.json" - -BOARDS = TEMPLATE_PROJECT_DIR / "boards.json" - - -def arduino_boards() -> dict: - """Returns a dict mapping board to target model""" - with open(BOARDS) as f: - board_properties = json.load(f) - - boards_model = {board: info["model"] for board, info in board_properties.items()} - return boards_model - - -ARDUINO_BOARDS = arduino_boards() +from test_utils import ARDUINO_BOARDS def pytest_addoption(parser): @@ -101,60 +76,3 @@ def arduino_cli_cmd(request): @pytest.fixture(scope="session") def tvm_debug(request): return request.config.getoption("--tvm-debug") - - -def make_workspace_dir(test_name, board): - filepath = pathlib.Path(__file__) - board_workspace = ( - filepath.parent - / f"workspace_{test_name}_{board}" - / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - ) - - number = 0 - while board_workspace.exists(): - number += 1 - board_workspace = pathlib.Path(str(board_workspace) + f"-{number}") - board_workspace.parent.mkdir(exist_ok=True, parents=True) - t = tvm.contrib.utils.tempdir(board_workspace) - # time.sleep(200) - return t - - -def make_kws_project(board, arduino_cli_cmd, tvm_debug, workspace_dir): - this_dir = pathlib.Path(__file__).parent - model = ARDUINO_BOARDS[board] - build_config = {"debug": tvm_debug} - - with open(this_dir.parent / "testdata" / "kws" / "yes_no.tflite", "rb") as f: - tflite_model_buf = f.read() - - # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1 - try: - import tflite.Model - - tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) - except AttributeError: - import tflite - - tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0) - - mod, params = relay.frontend.from_tflite(tflite_model) - target = tvm.target.target.micro(model) - runtime = Runtime("crt") - executor = Executor("aot", {"unpacked-api": True}) - - with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): - mod = relay.build(mod, target, runtime=runtime, executor=executor, params=params) - - return tvm.micro.generate_project( - str(TEMPLATE_PROJECT_DIR), - mod, - workspace_dir / "project", - { - "arduino_board": board, - "arduino_cli_cmd": arduino_cli_cmd, - "project_type": "example_project", - "verbose": bool(build_config.get("debug")), - }, - ) diff --git a/tests/micro/arduino/test_arduino_error_detection.py b/tests/micro/arduino/test_arduino_error_detection.py index 64e2c14d1c18..9db59b9259c3 100644 --- a/tests/micro/arduino/test_arduino_error_detection.py +++ b/tests/micro/arduino/test_arduino_error_detection.py @@ -15,25 +15,22 @@ # specific language governing permissions and limitations # under the License. -import pathlib -import re import sys - import pytest -import conftest from tvm.micro.project_api.server import ServerError +import test_utils # A new project and workspace dir is created for EVERY test @pytest.fixture def workspace_dir(request, board): - return conftest.make_workspace_dir("arduino_error_detection", board) + return test_utils.make_workspace_dir("arduino_error_detection", board) @pytest.fixture def project(board, arduino_cli_cmd, tvm_debug, workspace_dir): - return conftest.make_kws_project(board, arduino_cli_cmd, tvm_debug, workspace_dir) + return test_utils.make_kws_project(board, arduino_cli_cmd, tvm_debug, workspace_dir) def test_blank_project_compiles(workspace_dir, project): diff --git a/tests/micro/arduino/test_arduino_rpc_server.py b/tests/micro/arduino/test_arduino_rpc_server.py index a0dcb923a197..662b825672af 100644 --- a/tests/micro/arduino/test_arduino_rpc_server.py +++ b/tests/micro/arduino/test_arduino_rpc_server.py @@ -24,27 +24,27 @@ import pathlib import sys - import numpy as np import onnx import pytest + import tvm from PIL import Image from tvm import relay from tvm.relay.testing import byoc from tvm.relay.backend import Executor, Runtime -import conftest +import test_utils # # A new project and workspace dir is created for EVERY test @pytest.fixture def workspace_dir(board): - return conftest.make_workspace_dir("arduino_rpc_server", board) + return test_utils.make_workspace_dir("arduino_rpc_server", board) def _make_session(model, arduino_board, arduino_cli_cmd, workspace_dir, mod, build_config): project = tvm.micro.generate_project( - str(conftest.TEMPLATE_PROJECT_DIR), + str(test_utils.TEMPLATE_PROJECT_DIR), mod, workspace_dir / "project", { @@ -85,7 +85,7 @@ def _make_add_sess(model, arduino_board, arduino_cli_cmd, workspace_dir, build_c def test_compile_runtime(board, arduino_cli_cmd, tvm_debug, workspace_dir): """Test compiling the on-device runtime.""" - model = conftest.ARDUINO_BOARDS[board] + model = test_utils.ARDUINO_BOARDS[board] build_config = {"debug": tvm_debug} # NOTE: run test in a nested function so cPython will delete arrays before closing the session. @@ -110,7 +110,7 @@ def test_basic_add(sess): def test_platform_timer(board, arduino_cli_cmd, tvm_debug, workspace_dir): """Test compiling the on-device runtime.""" - model = conftest.ARDUINO_BOARDS[board] + model = test_utils.ARDUINO_BOARDS[board] build_config = {"debug": tvm_debug} # NOTE: run test in a nested function so cPython will delete arrays before closing the session. @@ -139,7 +139,7 @@ def test_basic_add(sess): @pytest.mark.requires_hardware def test_relay(board, arduino_cli_cmd, tvm_debug, workspace_dir): """Testing a simple relay graph""" - model = conftest.ARDUINO_BOARDS[board] + model = test_utils.ARDUINO_BOARDS[board] build_config = {"debug": tvm_debug} shape = (10,) @@ -171,7 +171,7 @@ def test_relay(board, arduino_cli_cmd, tvm_debug, workspace_dir): @pytest.mark.requires_hardware def test_onnx(board, arduino_cli_cmd, tvm_debug, workspace_dir): """Testing a simple ONNX model.""" - model = conftest.ARDUINO_BOARDS[board] + model = test_utils.ARDUINO_BOARDS[board] build_config = {"debug": tvm_debug} # Load test images. @@ -261,7 +261,7 @@ def check_result( @pytest.mark.requires_hardware def test_byoc_microtvm(board, arduino_cli_cmd, tvm_debug, workspace_dir): """This is a simple test case to check BYOC capabilities of microTVM""" - model = conftest.ARDUINO_BOARDS[board] + model = test_utils.ARDUINO_BOARDS[board] build_config = {"debug": tvm_debug} x = relay.var("x", shape=(10, 10)) @@ -345,7 +345,7 @@ def _make_add_sess_with_shape( @pytest.mark.requires_hardware def test_rpc_large_array(board, arduino_cli_cmd, tvm_debug, workspace_dir, shape): """Test large RPC array transfer.""" - model = conftest.ARDUINO_BOARDS[board] + model = test_utils.ARDUINO_BOARDS[board] build_config = {"debug": tvm_debug} # NOTE: run test in a nested function so cPython will delete arrays before closing the session. diff --git a/tests/micro/arduino/test_arduino_workflow.py b/tests/micro/arduino/test_arduino_workflow.py index fe6ea8fe3b2e..feccafa727d3 100644 --- a/tests/micro/arduino/test_arduino_workflow.py +++ b/tests/micro/arduino/test_arduino_workflow.py @@ -15,15 +15,13 @@ # specific language governing permissions and limitations # under the License. -import datetime import pathlib import re import shutil import sys - import pytest -import conftest +import test_utils """ This unit test simulates a simple user workflow, where we: @@ -41,7 +39,7 @@ # directory for all tests in this file @pytest.fixture(scope="module") def workspace_dir(request, board): - return conftest.make_workspace_dir("arduino_workflow", board) + return test_utils.make_workspace_dir("arduino_workflow", board) @pytest.fixture(scope="module") @@ -52,7 +50,7 @@ def project_dir(workspace_dir): # We MUST pass workspace_dir, not project_dir, or the workspace will be dereferenced too soon @pytest.fixture(scope="module") def project(board, arduino_cli_cmd, tvm_debug, workspace_dir): - return conftest.make_kws_project(board, arduino_cli_cmd, tvm_debug, workspace_dir) + return test_utils.make_kws_project(board, arduino_cli_cmd, tvm_debug, workspace_dir) def _get_directory_elements(directory): diff --git a/tests/micro/arduino/test_utils.py b/tests/micro/arduino/test_utils.py new file mode 100644 index 000000000000..c107d5b1febf --- /dev/null +++ b/tests/micro/arduino/test_utils.py @@ -0,0 +1,100 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +import pathlib +import requests +import datetime + +import tvm.micro +import tvm.target.target +from tvm.micro import project +from tvm import relay +from tvm.relay.backend import Executor, Runtime + + +TEMPLATE_PROJECT_DIR = pathlib.Path(tvm.micro.get_microtvm_template_projects("arduino")) + +BOARDS = TEMPLATE_PROJECT_DIR / "boards.json" + + +def arduino_boards() -> dict: + """Returns a dict mapping board to target model""" + with open(BOARDS) as f: + board_properties = json.load(f) + + boards_model = {board: info["model"] for board, info in board_properties.items()} + return boards_model + + +ARDUINO_BOARDS = arduino_boards() + + +def make_workspace_dir(test_name, board): + filepath = pathlib.Path(__file__) + board_workspace = ( + filepath.parent + / f"workspace_{test_name}_{board}" + / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + ) + + number = 0 + while board_workspace.exists(): + number += 1 + board_workspace = pathlib.Path(str(board_workspace) + f"-{number}") + board_workspace.parent.mkdir(exist_ok=True, parents=True) + t = tvm.contrib.utils.tempdir(board_workspace) + return t + + +def make_kws_project(board, arduino_cli_cmd, tvm_debug, workspace_dir): + this_dir = pathlib.Path(__file__).parent + model = ARDUINO_BOARDS[board] + build_config = {"debug": tvm_debug} + + with open(this_dir.parent / "testdata" / "kws" / "yes_no.tflite", "rb") as f: + tflite_model_buf = f.read() + + # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1 + try: + import tflite.Model + + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) + except AttributeError: + import tflite + + tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0) + + mod, params = relay.frontend.from_tflite(tflite_model) + target = tvm.target.target.micro(model) + runtime = Runtime("crt") + executor = Executor("aot", {"unpacked-api": True}) + + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + mod = relay.build(mod, target, runtime=runtime, executor=executor, params=params) + + return tvm.micro.generate_project( + str(TEMPLATE_PROJECT_DIR), + mod, + workspace_dir / "project", + { + "arduino_board": board, + "arduino_cli_cmd": arduino_cli_cmd, + "project_type": "example_project", + "verbose": bool(build_config.get("debug")), + }, + ) diff --git a/tests/micro/common/__init__.py b/tests/micro/common/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/tests/micro/common/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/micro/common/conftest.py b/tests/micro/common/conftest.py new file mode 100644 index 000000000000..3fbfdbcbc81d --- /dev/null +++ b/tests/micro/common/conftest.py @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations + +import pytest + +from ..zephyr.test_utils import ZEPHYR_BOARDS +from ..arduino.test_utils import ARDUINO_BOARDS + + +def pytest_addoption(parser): + parser.addoption( + "--board", + required=True, + choices=list(ARDUINO_BOARDS.keys()) + list(ZEPHYR_BOARDS.keys()), + help="microTVM boards for tests.", + ) + parser.addoption( + "--test-build-only", + action="store_true", + help="Only run tests that don't require physical hardware.", + ) + + +@pytest.fixture +def board(request): + return request.config.getoption("--board") + + +def pytest_collection_modifyitems(config, items): + if config.getoption("--test-build-only"): + skip_hardware_tests = pytest.mark.skip(reason="--test-build-only was passed") + for item in items: + if "requires_hardware" in item.keywords: + item.add_marker(skip_hardware_tests) diff --git a/tests/micro/common/test_tvmc.py b/tests/micro/common/test_tvmc.py new file mode 100644 index 000000000000..d462b3fadd9b --- /dev/null +++ b/tests/micro/common/test_tvmc.py @@ -0,0 +1,199 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import subprocess +import shlex +import sys +import logging +import tempfile +import pathlib +import sys +import os + +import tvm +from tvm.contrib.download import download_testdata + +from ..zephyr.test_utils import ZEPHYR_BOARDS +from ..arduino.test_utils import ARDUINO_BOARDS + +TVMC_COMMAND = [sys.executable, "-m", "tvm.driver.tvmc"] + +MODEL_URL = "https://github.com/tensorflow/tflite-micro/raw/main/tensorflow/lite/micro/examples/micro_speech/micro_speech.tflite" +MODEL_FILE = "micro_speech.tflite" + +# TODO(mehrdadh): replace this with _main from tvm.driver.tvmc.main +# Issue: https://github.com/apache/tvm/issues/9612 +def _run_tvmc(cmd_args: list, *args, **kwargs): + """Run a tvmc command and return the results""" + cmd_args_list = TVMC_COMMAND + cmd_args + cwd_str = "" if "cwd" not in kwargs else f" (in cwd: {kwargs['cwd']})" + logging.debug("run%s: %s", cwd_str, " ".join(shlex.quote(a) for a in cmd_args_list)) + return subprocess.check_call(cmd_args_list, *args, **kwargs) + + +def _get_target_and_platform(board: str): + if board in ZEPHYR_BOARDS.keys(): + target_model = ZEPHYR_BOARDS[board] + platform = "zephyr" + elif board in ARDUINO_BOARDS.keys(): + target_model = ARDUINO_BOARDS[board] + platform = "arduino" + else: + raise ValueError(f"Board {board} is not supported.") + + target = tvm.target.target.micro(target_model) + return str(target), platform + + +@tvm.testing.requires_micro +def test_tvmc_exist(board): + cmd_result = _run_tvmc(["micro", "-h"]) + assert cmd_result == 0 + + +@tvm.testing.requires_micro +def test_tvmc_model_build_only(board): + target, platform = _get_target_and_platform(board) + + model_path = model_path = download_testdata(MODEL_URL, MODEL_FILE, module="data") + temp_dir = pathlib.Path(tempfile.mkdtemp()) + tar_path = str(temp_dir / "model.tar") + project_dir = str(temp_dir / "project") + + runtime = "crt" + executor = "graph" + + cmd_result = _run_tvmc( + [ + "compile", + model_path, + f"--target={target}", + f"--runtime={runtime}", + f"--runtime-crt-system-lib", + str(1), + f"--executor={executor}", + "--executor-graph-link-params", + str(0), + "--output", + tar_path, + "--output-format", + "mlf", + "--pass-config", + "tir.disable_vectorize=1", + "--disabled-pass=AlterOpLayout", + ] + ) + assert cmd_result == 0, "tvmc failed in step: compile" + + create_project_cmd = [ + "micro", + "create-project", + project_dir, + tar_path, + platform, + "--project-option", + "project_type=host_driven", + ] + if platform == "zephyr": + create_project_cmd.append(f"{platform}_board={board}") + + cmd_result = _run_tvmc(create_project_cmd) + assert cmd_result == 0, "tvmc micro failed in step: create-project" + + cmd_result = _run_tvmc( + ["micro", "build", project_dir, platform, "--project-option", f"{platform}_board={board}"] + ) + assert cmd_result == 0, "tvmc micro failed in step: build" + + +@pytest.mark.requires_hardware +@tvm.testing.requires_micro +def test_tvmc_model_run(board): + target, platform = _get_target_and_platform(board) + + model_path = model_path = download_testdata(MODEL_URL, MODEL_FILE, module="data") + temp_dir = pathlib.Path(tempfile.mkdtemp()) + tar_path = str(temp_dir / "model.tar") + project_dir = str(temp_dir / "project") + + runtime = "crt" + executor = "graph" + + cmd_result = _run_tvmc( + [ + "compile", + model_path, + f"--target={target}", + f"--runtime={runtime}", + f"--runtime-crt-system-lib", + str(1), + f"--executor={executor}", + "--executor-graph-link-params", + str(0), + "--output", + tar_path, + "--output-format", + "mlf", + "--pass-config", + "tir.disable_vectorize=1", + "--disabled-pass=AlterOpLayout", + ] + ) + assert cmd_result == 0, "tvmc failed in step: compile" + + create_project_cmd = [ + "micro", + "create-project", + project_dir, + tar_path, + platform, + "--project-option", + "project_type=host_driven", + ] + if platform == "zephyr": + create_project_cmd.append(f"{platform}_board={board}") + + cmd_result = _run_tvmc(create_project_cmd) + assert cmd_result == 0, "tvmc micro failed in step: create-project" + + cmd_result = _run_tvmc( + ["micro", "build", project_dir, platform, "--project-option", f"{platform}_board={board}"] + ) + assert cmd_result == 0, "tvmc micro failed in step: build" + + cmd_result = _run_tvmc( + ["micro", "flash", project_dir, platform, "--project-option", f"{platform}_board={board}"] + ) + assert cmd_result == 0, "tvmc micro failed in step: flash" + + cmd_result = _run_tvmc( + [ + "run", + "--device", + "micro", + project_dir, + "--project-option", + f"{platform}_board={board}", + "--fill-mode", + "random", + ] + ) + assert cmd_result == 0, "tvmc micro failed in step: run" + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:])) diff --git a/tests/scripts/task_python_microtvm.sh b/tests/scripts/task_python_microtvm.sh index 8de8b908ee09..3f3fd61942f5 100755 --- a/tests/scripts/task_python_microtvm.sh +++ b/tests/scripts/task_python_microtvm.sh @@ -40,3 +40,7 @@ run_pytest ctypes python-microtvm-arduino-due tests/micro/arduino --test-build- # STM32 run_pytest ctypes python-microtvm-stm32 tests/micro/stm32 + +# Common Tests +run_pytest ctypes python-microtvm-common-qemu_x86 tests/micro/common --board=qemu_x86 +run_pytest ctypes python-microtvm-common-due tests/micro/common --test-build-only --board=due