Skip to content

Commit

Permalink
Address @gromero comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mehrdadh committed Nov 30, 2021
1 parent 7b931f2 commit 830f1c1
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 45 deletions.
28 changes: 21 additions & 7 deletions apps/microtvm/arduino/template_project/microtvm_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@

BOARDS = API_SERVER_DIR / "boards.json"

ARDUINO_CLI_CMD = shutil.which("arduino-cli")

# Data structure to hold the information microtvm_api_server.py needs
# to communicate with each of these boards.
try:
Expand All @@ -71,14 +73,19 @@ class BoardAutodetectFailed(Exception):
PROJECT_OPTIONS = [
server.ProjectOption(
"arduino_board",
required=["generate_project", "build", "flash", "open_transport"],
required=["build", "flash", "open_transport"],
choices=list(BOARD_PROPERTIES),
type="str",
help="Name of the Arduino board to build for.",
),
server.ProjectOption(
"arduino_cli_cmd",
optional=["build", "flash", "open_transport"],
required=["generate_project", "build", "flash", "open_transport"]
if not ARDUINO_CLI_CMD
else None,
optional=["generate_project", "build", "flash", "open_transport"]
if ARDUINO_CLI_CMD
else None,
default="arduino-cli",
type="str",
help="Path to the arduino-cli tool.",
Expand Down Expand Up @@ -251,7 +258,9 @@ def _convert_includes(self, project_dir, source_dir):
with filename.open() as file:
try:
lines = file.readlines()
except:
# TODO: This exception only happens using `tvmc micro` and is not catched on Arduino tests.
# Needs more investigation.
except UnicodeDecodeError:
pass

for i in range(len(lines)):
Expand Down Expand Up @@ -321,7 +330,7 @@ def _get_platform_version(self, arduino_cli_path: str) -> float:

def generate_project(self, model_library_format_path, standalone_crt_dir, project_dir, options):
# Check Arduino version
version = self._get_platform_version(options["arduino_cli_cmd"])
version = self._get_platform_version(self._get_arduino_cli_cmd(options))
if version != ARDUINO_CLI_VERSION:
message = f"Arduino CLI version found is not supported: found {version}, expected {ARDUINO_CLI_VERSION}."
if options.get("warning_as_error") is not None and options["warning_as_error"]:
Expand Down Expand Up @@ -371,7 +380,7 @@ def build(self, options):
BUILD_DIR.mkdir()

compile_cmd = [
options["arduino_cli_cmd"],
self._get_arduino_cli_cmd(options),
"compile",
"./project/",
"--fqbn",
Expand All @@ -388,6 +397,11 @@ def build(self, options):

BOARD_LIST_HEADERS = ("Port", "Type", "Board Name", "FQBN", "Core")

def _get_arduino_cli_cmd(self, options: dict):
arduino_cli_cmd = options.get("arduino_cli_cmd", ARDUINO_CLI_CMD)
assert arduino_cli_cmd, "'arduino_cli_cmd' command not passed and not found by default!"
return arduino_cli_cmd

def _parse_boards_tabular_str(self, tabular_str):
"""Parses the tabular output from `arduino-cli board list` into a 2D array
Expand Down Expand Up @@ -420,7 +434,7 @@ def _parse_boards_tabular_str(self, tabular_str):
yield parsed_row

def _auto_detect_port(self, options):
list_cmd = [options["arduino_cli_cmd"], "board", "list"]
list_cmd = [self._get_arduino_cli_cmd(options), "board", "list"]
list_cmd_output = subprocess.run(
list_cmd, check=True, stdout=subprocess.PIPE
).stdout.decode("utf-8")
Expand All @@ -446,7 +460,7 @@ def flash(self, options):
port = self._get_arduino_port(options)

upload_cmd = [
options["arduino_cli_cmd"],
self._get_arduino_cli_cmd(options),
"upload",
"./project",
"--fqbn",
Expand Down
25 changes: 18 additions & 7 deletions apps/microtvm/zephyr/template_project/microtvm_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@
ZEPHYR_VERSION = 2.5


WEST_CMD = default = sys.executable + " -m west" if sys.executable else None

ZEPHYR_BASE = os.getenv("ZEPHYR_BASE")

# Data structure to hold the information microtvm_api_server.py needs
# to communicate with each of these boards.
try:
Expand Down Expand Up @@ -271,8 +275,8 @@ def _get_nrf_device_args(options):
),
server.ProjectOption(
"west_cmd",
optional=["generate_project"],
default=sys.executable + " -m west" if sys.executable else None,
optional=["build"],
default=WEST_CMD,
type="str",
help=(
"Path to the west tool. If given, supersedes both the zephyr_base "
Expand All @@ -281,8 +285,9 @@ def _get_nrf_device_args(options):
),
server.ProjectOption(
"zephyr_base",
optional=["build", "open_transport"],
default=os.getenv("ZEPHYR_BASE"),
required=["generate_project", "open_transport"] if not ZEPHYR_BASE else None,
optional=["generate_project", "open_transport", "build"] if ZEPHYR_BASE else ["build"],
default=ZEPHYR_BASE,
type="str",
help="Path to the zephyr base directory.",
),
Expand Down Expand Up @@ -314,6 +319,13 @@ def _get_nrf_device_args(options):
]


def get_zephyr_base(options: dict):
"""Returns Zephyr base path"""
zephyr_base = options.get("zephyr_base", ZEPHYR_BASE)
assert zephyr_base, "'zephyr_base' not passed and not found by default!"
return zephyr_base


class Handler(server.ProjectAPIHandler):
def __init__(self):
super(Handler, self).__init__()
Expand Down Expand Up @@ -402,7 +414,7 @@ def _get_platform_version(self, zephyr_base: str) -> float:

def generate_project(self, model_library_format_path, standalone_crt_dir, project_dir, options):
# Check Zephyr version
version = self._get_platform_version(options["zephyr_base"])
version = self._get_platform_version(get_zephyr_base(options))
if version != ZEPHYR_VERSION:
message = f"Zephyr version found is not supported: found {version}, expected {ZEPHYR_VERSION}."
if options.get("warning_as_error") is not None and options["warning_as_error"]:
Expand Down Expand Up @@ -574,8 +586,7 @@ def _set_nonblock(fd):
class ZephyrSerialTransport:
@classmethod
def _lookup_baud_rate(cls, options):
zephyr_base = options.get("zephyr_base", os.environ["ZEPHYR_BASE"])
sys.path.insert(0, os.path.join(zephyr_base, "scripts", "dts"))
sys.path.insert(0, os.path.join(get_zephyr_base(options), "scripts", "dts"))
try:
import dtlib # pylint: disable=import-outside-toplevel
finally:
Expand Down
1 change: 0 additions & 1 deletion tests/micro/arduino/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def make_workspace_dir(test_name, board):
board_workspace = pathlib.Path(str(board_workspace) + f"-{number}")
board_workspace.parent.mkdir(exist_ok=True, parents=True)
t = tvm.contrib.utils.tempdir(board_workspace)
# time.sleep(200)
return t


Expand Down
61 changes: 31 additions & 30 deletions tests/micro/common/test_tvmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

import tvm
from tvm.contrib.download import download_testdata
from tvm.relay.backend import Executor, Runtime

from ..zephyr.test_utils import ZEPHYR_BOARDS
from ..arduino.test_utils import ARDUINO_BOARDS
Expand All @@ -36,7 +35,7 @@
MODEL_URL = "https://github.com/tensorflow/tflite-micro/raw/main/tensorflow/lite/micro/examples/micro_speech/micro_speech.tflite"
MODEL_FILE = "micro_speech.tflite"


# TODO(mehrdadh): replace this with _main from tvm.driver.tvmc.main
def _run_tvmc(cmd_args: list, *args, **kwargs):
"""Run a tvmc command and return the results"""
cmd_args_list = TVMC_COMMAND + cmd_args
Expand Down Expand Up @@ -74,8 +73,8 @@ def test_tvmc_model_build_only(board):
tar_path = str(temp_dir / "model.tar")
project_dir = str(temp_dir / "project")

runtime = str(Runtime("crt"))
executor = str(Executor("graph"))
runtime = "crt"
executor = "graph"

cmd_result = _run_tvmc(
[
Expand All @@ -99,18 +98,19 @@ def test_tvmc_model_build_only(board):
)
assert cmd_result == 0, "tvmc failed in step: compile"

cmd_result = _run_tvmc(
[
"micro",
"create-project",
project_dir,
tar_path,
platform,
"--project-option",
"project_type=host_driven",
f"{platform}_board={board}",
]
)
create_project_cmd = [
"micro",
"create-project",
project_dir,
tar_path,
platform,
"--project-option",
"project_type=host_driven",
]
if platform == "zephyr":
create_project_cmd.append(f"{platform}_board={board}")

cmd_result = _run_tvmc(create_project_cmd)
assert cmd_result == 0, "tvmc micro failed in step: create-project"

cmd_result = _run_tvmc(
Expand All @@ -129,8 +129,8 @@ def test_tvmc_model_run(board):
tar_path = str(temp_dir / "model.tar")
project_dir = str(temp_dir / "project")

runtime = str(Runtime("crt"))
executor = str(Executor("graph"))
runtime = "crt"
executor = "graph"

cmd_result = _run_tvmc(
[
Expand All @@ -154,18 +154,19 @@ def test_tvmc_model_run(board):
)
assert cmd_result == 0, "tvmc failed in step: compile"

cmd_result = _run_tvmc(
[
"micro",
"create-project",
project_dir,
tar_path,
platform,
"--project-option",
"project_type=host_driven",
f"{platform}_board={board}",
]
)
create_project_cmd = [
"micro",
"create-project",
project_dir,
tar_path,
platform,
"--project-option",
"project_type=host_driven",
]
if platform == "zephyr":
create_project_cmd.append(f"{platform}_board={board}")

cmd_result = _run_tvmc(create_project_cmd)
assert cmd_result == 0, "tvmc micro failed in step: create-project"

cmd_result = _run_tvmc(
Expand Down

0 comments on commit 830f1c1

Please sign in to comment.