Skip to content

Commit

Permalink
[microTVM][TVMC] Add TVMC test for Arduino and Zephyr (apache#9584)
Browse files Browse the repository at this point in the history
* Add TVMC test for Arduino and Zephyr

* Address @gromero comments

* address comments
  • Loading branch information
mehrdadh authored and yangulei committed Jan 10, 2022
1 parent 2ff494c commit 845a627
Show file tree
Hide file tree
Showing 13 changed files with 471 additions and 138 deletions.
56 changes: 35 additions & 21 deletions apps/microtvm/arduino/template_project/microtvm_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@

BOARDS = API_SERVER_DIR / "boards.json"

ARDUINO_CLI_CMD = shutil.which("arduino-cli")

# Data structure to hold the information microtvm_api_server.py needs
# to communicate with each of these boards.
try:
Expand All @@ -78,7 +80,15 @@ class BoardAutodetectFailed(Exception):
),
server.ProjectOption(
"arduino_cli_cmd",
required=["build", "flash", "open_transport"],
required=(
["generate_project", "build", "flash", "open_transport"]
if not ARDUINO_CLI_CMD
else None
),
optional=(
["generate_project", "build", "flash", "open_transport"] if ARDUINO_CLI_CMD else None
),
default=ARDUINO_CLI_CMD,
type="str",
help="Path to the arduino-cli tool.",
),
Expand Down Expand Up @@ -247,22 +257,21 @@ def _convert_includes(self, project_dir, source_dir):
"""
for ext in ("c", "h", "cpp"):
for filename in source_dir.rglob(f"*.{ext}"):
with filename.open() as file:
lines = file.readlines()

for i in range(len(lines)):
# Check if line has an include
result = re.search(r"#include\s*[<\"]([^>]*)[>\"]", lines[i])
if not result:
continue
new_include = self._find_modified_include_path(
project_dir, filename, result.groups()[0]
)

lines[i] = f'#include "{new_include}"\n'

with filename.open("w") as file:
file.writelines(lines)
with filename.open("rb") as src_file:
lines = src_file.readlines()
with filename.open("wb") as dst_file:
for i, line in enumerate(lines):
line_str = str(line, "utf-8")
# Check if line has an include
result = re.search(r"#include\s*[<\"]([^>]*)[>\"]", line_str)
if not result:
dst_file.write(line)
else:
new_include = self._find_modified_include_path(
project_dir, filename, result.groups()[0]
)
updated_line = f'#include "{new_include}"\n'
dst_file.write(updated_line.encode("utf-8"))

# Most of the files we used to be able to point to directly are under "src/standalone_crt/include/".
# Howver, crt_config.h lives under "src/standalone_crt/crt_config/", and more exceptions might
Expand Down Expand Up @@ -317,7 +326,7 @@ def _get_platform_version(self, arduino_cli_path: str) -> float:

def generate_project(self, model_library_format_path, standalone_crt_dir, project_dir, options):
# Check Arduino version
version = self._get_platform_version(options["arduino_cli_cmd"])
version = self._get_platform_version(self._get_arduino_cli_cmd(options))
if version != ARDUINO_CLI_VERSION:
message = f"Arduino CLI version found is not supported: found {version}, expected {ARDUINO_CLI_VERSION}."
if options.get("warning_as_error") is not None and options["warning_as_error"]:
Expand Down Expand Up @@ -367,7 +376,7 @@ def build(self, options):
BUILD_DIR.mkdir()

compile_cmd = [
options["arduino_cli_cmd"],
self._get_arduino_cli_cmd(options),
"compile",
"./project/",
"--fqbn",
Expand All @@ -384,6 +393,11 @@ def build(self, options):

BOARD_LIST_HEADERS = ("Port", "Type", "Board Name", "FQBN", "Core")

def _get_arduino_cli_cmd(self, options: dict):
arduino_cli_cmd = options.get("arduino_cli_cmd", ARDUINO_CLI_CMD)
assert arduino_cli_cmd, "'arduino_cli_cmd' command not passed and not found by default!"
return arduino_cli_cmd

def _parse_boards_tabular_str(self, tabular_str):
"""Parses the tabular output from `arduino-cli board list` into a 2D array
Expand Down Expand Up @@ -416,7 +430,7 @@ def _parse_boards_tabular_str(self, tabular_str):
yield parsed_row

def _auto_detect_port(self, options):
list_cmd = [options["arduino_cli_cmd"], "board", "list"]
list_cmd = [self._get_arduino_cli_cmd(options), "board", "list"]
list_cmd_output = subprocess.run(
list_cmd, check=True, stdout=subprocess.PIPE
).stdout.decode("utf-8")
Expand All @@ -442,7 +456,7 @@ def flash(self, options):
port = self._get_arduino_port(options)

upload_cmd = [
options["arduino_cli_cmd"],
self._get_arduino_cli_cmd(options),
"upload",
"./project",
"--fqbn",
Expand Down
29 changes: 20 additions & 9 deletions apps/microtvm/zephyr/template_project/microtvm_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@
ZEPHYR_VERSION = 2.5


WEST_CMD = default = sys.executable + " -m west" if sys.executable else None

ZEPHYR_BASE = os.getenv("ZEPHYR_BASE")

# Data structure to hold the information microtvm_api_server.py needs
# to communicate with each of these boards.
try:
Expand Down Expand Up @@ -271,8 +275,8 @@ def _get_nrf_device_args(options):
),
server.ProjectOption(
"west_cmd",
optional=["generate_project"],
default=sys.executable + " -m west" if sys.executable else None,
optional=["build"],
default=WEST_CMD,
type="str",
help=(
"Path to the west tool. If given, supersedes both the zephyr_base "
Expand All @@ -281,8 +285,9 @@ def _get_nrf_device_args(options):
),
server.ProjectOption(
"zephyr_base",
optional=["build", "open_transport"],
default=os.getenv("ZEPHYR_BASE"),
required=(["generate_project", "open_transport"] if not ZEPHYR_BASE else None),
optional=(["generate_project", "open_transport", "build"] if ZEPHYR_BASE else ["build"]),
default=ZEPHYR_BASE,
type="str",
help="Path to the zephyr base directory.",
),
Expand Down Expand Up @@ -314,6 +319,13 @@ def _get_nrf_device_args(options):
]


def get_zephyr_base(options: dict):
"""Returns Zephyr base path"""
zephyr_base = options.get("zephyr_base", ZEPHYR_BASE)
assert zephyr_base, "'zephyr_base' option not passed and not found by default!"
return zephyr_base


class Handler(server.ProjectAPIHandler):
def __init__(self):
super(Handler, self).__init__()
Expand Down Expand Up @@ -388,8 +400,8 @@ def _create_prj_conf(self, project_dir, options):
"aot_demo": "memory microtvm_rpc_common common",
}

def _get_platform_version(self) -> float:
with open(pathlib.Path(os.getenv("ZEPHYR_BASE")) / "VERSION", "r") as f:
def _get_platform_version(self, zephyr_base: str) -> float:
with open(pathlib.Path(zephyr_base) / "VERSION", "r") as f:
lines = f.readlines()
for line in lines:
line = line.replace(" ", "").replace("\n", "").replace("\r", "")
Expand All @@ -402,7 +414,7 @@ def _get_platform_version(self) -> float:

def generate_project(self, model_library_format_path, standalone_crt_dir, project_dir, options):
# Check Zephyr version
version = self._get_platform_version()
version = self._get_platform_version(get_zephyr_base(options))
if version != ZEPHYR_VERSION:
message = f"Zephyr version found is not supported: found {version}, expected {ZEPHYR_VERSION}."
if options.get("warning_as_error") is not None and options["warning_as_error"]:
Expand Down Expand Up @@ -574,8 +586,7 @@ def _set_nonblock(fd):
class ZephyrSerialTransport:
@classmethod
def _lookup_baud_rate(cls, options):
zephyr_base = options.get("zephyr_base", os.environ["ZEPHYR_BASE"])
sys.path.insert(0, os.path.join(zephyr_base, "scripts", "dts"))
sys.path.insert(0, os.path.join(get_zephyr_base(options), "scripts", "dts"))
try:
import dtlib # pylint: disable=import-outside-toplevel
finally:
Expand Down
21 changes: 17 additions & 4 deletions python/tvm/micro/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,17 @@ def read(self, n, timeout_sec):
return self._api_client.read_transport(n, timeout_sec)["data"]


def prepare_options(received_options: dict, all_options: dict) -> dict:
"""Add default value of options that are not passed to the project options."""
options_default = {option["name"]: option["default"] for option in all_options}
prepared_options = dict(received_options) if received_options else dict()

for option, def_val in options_default.items():
if option not in prepared_options and def_val is not None:
prepared_options[option] = def_val
return prepared_options


class TemplateProjectError(Exception):
"""Raised when the Project API server given to GeneratedProject reports is_template=True."""

Expand All @@ -69,10 +80,10 @@ def from_directory(cls, project_dir: Union[pathlib.Path, str], options: dict):

def __init__(self, api_client, options):
self._api_client = api_client
self._options = options
self._info = self._api_client.server_info_query(__version__)
if self._info["is_template"]:
raise TemplateProjectError()
self._options = prepare_options(options, self.info()["project_options"])

def build(self):
self._api_client.build(self._options)
Expand Down Expand Up @@ -121,16 +132,18 @@ def _check_project_options(self, options: dict):
Here is a list of available options:{list(available_options)}."""
)

def generate_project_from_mlf(self, model_library_format_path, project_dir, options):
def generate_project_from_mlf(self, model_library_format_path, project_dir, options: dict):
"""Generate a project from MLF file."""
self._check_project_options(options)
prepared_options = prepare_options(options, self.info()["project_options"])
self._api_client.generate_project(
model_library_format_path=str(model_library_format_path),
standalone_crt_dir=get_standalone_crt_dir(),
project_dir=project_dir,
options=options,
options=prepared_options,
)

return GeneratedProject.from_directory(project_dir, options)
return GeneratedProject.from_directory(project_dir, prepared_options)

def info(self):
return self._info
Expand Down
16 changes: 16 additions & 0 deletions tests/micro/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
84 changes: 1 addition & 83 deletions tests/micro/arduino/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,9 @@
# specific language governing permissions and limitations
# under the License.

import datetime
import pathlib
import json
import pytest

import tvm.target.target
from tvm.micro import project
from tvm import relay
from tvm.relay.backend import Executor, Runtime

TEMPLATE_PROJECT_DIR = pathlib.Path(tvm.micro.get_microtvm_template_projects("arduino"))


BOARDS = TEMPLATE_PROJECT_DIR / "boards.json"

BOARDS = TEMPLATE_PROJECT_DIR / "boards.json"


def arduino_boards() -> dict:
"""Returns a dict mapping board to target model"""
with open(BOARDS) as f:
board_properties = json.load(f)

boards_model = {board: info["model"] for board, info in board_properties.items()}
return boards_model


ARDUINO_BOARDS = arduino_boards()
from test_utils import ARDUINO_BOARDS


def pytest_addoption(parser):
Expand Down Expand Up @@ -101,60 +76,3 @@ def arduino_cli_cmd(request):
@pytest.fixture(scope="session")
def tvm_debug(request):
return request.config.getoption("--tvm-debug")


def make_workspace_dir(test_name, board):
filepath = pathlib.Path(__file__)
board_workspace = (
filepath.parent
/ f"workspace_{test_name}_{board}"
/ datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
)

number = 0
while board_workspace.exists():
number += 1
board_workspace = pathlib.Path(str(board_workspace) + f"-{number}")
board_workspace.parent.mkdir(exist_ok=True, parents=True)
t = tvm.contrib.utils.tempdir(board_workspace)
# time.sleep(200)
return t


def make_kws_project(board, arduino_cli_cmd, tvm_debug, workspace_dir):
this_dir = pathlib.Path(__file__).parent
model = ARDUINO_BOARDS[board]
build_config = {"debug": tvm_debug}

with open(this_dir.parent / "testdata" / "kws" / "yes_no.tflite", "rb") as f:
tflite_model_buf = f.read()

# TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1
try:
import tflite.Model

tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
except AttributeError:
import tflite

tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)

mod, params = relay.frontend.from_tflite(tflite_model)
target = tvm.target.target.micro(model)
runtime = Runtime("crt")
executor = Executor("aot", {"unpacked-api": True})

with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
mod = relay.build(mod, target, runtime=runtime, executor=executor, params=params)

return tvm.micro.generate_project(
str(TEMPLATE_PROJECT_DIR),
mod,
workspace_dir / "project",
{
"arduino_board": board,
"arduino_cli_cmd": arduino_cli_cmd,
"project_type": "example_project",
"verbose": bool(build_config.get("debug")),
},
)
9 changes: 3 additions & 6 deletions tests/micro/arduino/test_arduino_error_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,22 @@
# specific language governing permissions and limitations
# under the License.

import pathlib
import re
import sys

import pytest

import conftest
from tvm.micro.project_api.server import ServerError

import test_utils

# A new project and workspace dir is created for EVERY test
@pytest.fixture
def workspace_dir(request, board):
return conftest.make_workspace_dir("arduino_error_detection", board)
return test_utils.make_workspace_dir("arduino_error_detection", board)


@pytest.fixture
def project(board, arduino_cli_cmd, tvm_debug, workspace_dir):
return conftest.make_kws_project(board, arduino_cli_cmd, tvm_debug, workspace_dir)
return test_utils.make_kws_project(board, arduino_cli_cmd, tvm_debug, workspace_dir)


def test_blank_project_compiles(workspace_dir, project):
Expand Down
Loading

0 comments on commit 845a627

Please sign in to comment.