From 830f1c1ae0e285a1b943457887e859a600364e3a Mon Sep 17 00:00:00 2001 From: Mehrdad Hessar Date: Mon, 29 Nov 2021 12:15:38 -0800 Subject: [PATCH] Address @gromero comments --- .../template_project/microtvm_api_server.py | 28 ++++++--- .../template_project/microtvm_api_server.py | 25 +++++--- tests/micro/arduino/test_utils.py | 1 - tests/micro/common/test_tvmc.py | 61 ++++++++++--------- 4 files changed, 70 insertions(+), 45 deletions(-) diff --git a/apps/microtvm/arduino/template_project/microtvm_api_server.py b/apps/microtvm/arduino/template_project/microtvm_api_server.py index ffdbd10afe35b..e5a17c7096a6a 100644 --- a/apps/microtvm/arduino/template_project/microtvm_api_server.py +++ b/apps/microtvm/arduino/template_project/microtvm_api_server.py @@ -53,6 +53,8 @@ BOARDS = API_SERVER_DIR / "boards.json" +ARDUINO_CLI_CMD = shutil.which("arduino-cli") + # Data structure to hold the information microtvm_api_server.py needs # to communicate with each of these boards. try: @@ -71,14 +73,19 @@ class BoardAutodetectFailed(Exception): PROJECT_OPTIONS = [ server.ProjectOption( "arduino_board", - required=["generate_project", "build", "flash", "open_transport"], + required=["build", "flash", "open_transport"], choices=list(BOARD_PROPERTIES), type="str", help="Name of the Arduino board to build for.", ), server.ProjectOption( "arduino_cli_cmd", - optional=["build", "flash", "open_transport"], + required=["generate_project", "build", "flash", "open_transport"] + if not ARDUINO_CLI_CMD + else None, + optional=["generate_project", "build", "flash", "open_transport"] + if ARDUINO_CLI_CMD + else None, default="arduino-cli", type="str", help="Path to the arduino-cli tool.", @@ -251,7 +258,9 @@ def _convert_includes(self, project_dir, source_dir): with filename.open() as file: try: lines = file.readlines() - except: + # TODO: This exception only happens using `tvmc micro` and is not catched on Arduino tests. + # Needs more investigation. + except UnicodeDecodeError: pass for i in range(len(lines)): @@ -321,7 +330,7 @@ def _get_platform_version(self, arduino_cli_path: str) -> float: def generate_project(self, model_library_format_path, standalone_crt_dir, project_dir, options): # Check Arduino version - version = self._get_platform_version(options["arduino_cli_cmd"]) + version = self._get_platform_version(self._get_arduino_cli_cmd(options)) if version != ARDUINO_CLI_VERSION: message = f"Arduino CLI version found is not supported: found {version}, expected {ARDUINO_CLI_VERSION}." if options.get("warning_as_error") is not None and options["warning_as_error"]: @@ -371,7 +380,7 @@ def build(self, options): BUILD_DIR.mkdir() compile_cmd = [ - options["arduino_cli_cmd"], + self._get_arduino_cli_cmd(options), "compile", "./project/", "--fqbn", @@ -388,6 +397,11 @@ def build(self, options): BOARD_LIST_HEADERS = ("Port", "Type", "Board Name", "FQBN", "Core") + def _get_arduino_cli_cmd(self, options: dict): + arduino_cli_cmd = options.get("arduino_cli_cmd", ARDUINO_CLI_CMD) + assert arduino_cli_cmd, "'arduino_cli_cmd' command not passed and not found by default!" + return arduino_cli_cmd + def _parse_boards_tabular_str(self, tabular_str): """Parses the tabular output from `arduino-cli board list` into a 2D array @@ -420,7 +434,7 @@ def _parse_boards_tabular_str(self, tabular_str): yield parsed_row def _auto_detect_port(self, options): - list_cmd = [options["arduino_cli_cmd"], "board", "list"] + list_cmd = [self._get_arduino_cli_cmd(options), "board", "list"] list_cmd_output = subprocess.run( list_cmd, check=True, stdout=subprocess.PIPE ).stdout.decode("utf-8") @@ -446,7 +460,7 @@ def flash(self, options): port = self._get_arduino_port(options) upload_cmd = [ - options["arduino_cli_cmd"], + self._get_arduino_cli_cmd(options), "upload", "./project", "--fqbn", diff --git a/apps/microtvm/zephyr/template_project/microtvm_api_server.py b/apps/microtvm/zephyr/template_project/microtvm_api_server.py index 7d6383dfc4bc6..13769549de28c 100644 --- a/apps/microtvm/zephyr/template_project/microtvm_api_server.py +++ b/apps/microtvm/zephyr/template_project/microtvm_api_server.py @@ -66,6 +66,10 @@ ZEPHYR_VERSION = 2.5 +WEST_CMD = default = sys.executable + " -m west" if sys.executable else None + +ZEPHYR_BASE = os.getenv("ZEPHYR_BASE") + # Data structure to hold the information microtvm_api_server.py needs # to communicate with each of these boards. try: @@ -271,8 +275,8 @@ def _get_nrf_device_args(options): ), server.ProjectOption( "west_cmd", - optional=["generate_project"], - default=sys.executable + " -m west" if sys.executable else None, + optional=["build"], + default=WEST_CMD, type="str", help=( "Path to the west tool. If given, supersedes both the zephyr_base " @@ -281,8 +285,9 @@ def _get_nrf_device_args(options): ), server.ProjectOption( "zephyr_base", - optional=["build", "open_transport"], - default=os.getenv("ZEPHYR_BASE"), + required=["generate_project", "open_transport"] if not ZEPHYR_BASE else None, + optional=["generate_project", "open_transport", "build"] if ZEPHYR_BASE else ["build"], + default=ZEPHYR_BASE, type="str", help="Path to the zephyr base directory.", ), @@ -314,6 +319,13 @@ def _get_nrf_device_args(options): ] +def get_zephyr_base(options: dict): + """Returns Zephyr base path""" + zephyr_base = options.get("zephyr_base", ZEPHYR_BASE) + assert zephyr_base, "'zephyr_base' not passed and not found by default!" + return zephyr_base + + class Handler(server.ProjectAPIHandler): def __init__(self): super(Handler, self).__init__() @@ -402,7 +414,7 @@ def _get_platform_version(self, zephyr_base: str) -> float: def generate_project(self, model_library_format_path, standalone_crt_dir, project_dir, options): # Check Zephyr version - version = self._get_platform_version(options["zephyr_base"]) + version = self._get_platform_version(get_zephyr_base(options)) if version != ZEPHYR_VERSION: message = f"Zephyr version found is not supported: found {version}, expected {ZEPHYR_VERSION}." if options.get("warning_as_error") is not None and options["warning_as_error"]: @@ -574,8 +586,7 @@ def _set_nonblock(fd): class ZephyrSerialTransport: @classmethod def _lookup_baud_rate(cls, options): - zephyr_base = options.get("zephyr_base", os.environ["ZEPHYR_BASE"]) - sys.path.insert(0, os.path.join(zephyr_base, "scripts", "dts")) + sys.path.insert(0, os.path.join(get_zephyr_base(options), "scripts", "dts")) try: import dtlib # pylint: disable=import-outside-toplevel finally: diff --git a/tests/micro/arduino/test_utils.py b/tests/micro/arduino/test_utils.py index dca1a79743cc7..c107d5b1febfc 100644 --- a/tests/micro/arduino/test_utils.py +++ b/tests/micro/arduino/test_utils.py @@ -58,7 +58,6 @@ def make_workspace_dir(test_name, board): board_workspace = pathlib.Path(str(board_workspace) + f"-{number}") board_workspace.parent.mkdir(exist_ok=True, parents=True) t = tvm.contrib.utils.tempdir(board_workspace) - # time.sleep(200) return t diff --git a/tests/micro/common/test_tvmc.py b/tests/micro/common/test_tvmc.py index fde516b210654..57ea99ee462c1 100644 --- a/tests/micro/common/test_tvmc.py +++ b/tests/micro/common/test_tvmc.py @@ -26,7 +26,6 @@ import tvm from tvm.contrib.download import download_testdata -from tvm.relay.backend import Executor, Runtime from ..zephyr.test_utils import ZEPHYR_BOARDS from ..arduino.test_utils import ARDUINO_BOARDS @@ -36,7 +35,7 @@ MODEL_URL = "https://github.com/tensorflow/tflite-micro/raw/main/tensorflow/lite/micro/examples/micro_speech/micro_speech.tflite" MODEL_FILE = "micro_speech.tflite" - +# TODO(mehrdadh): replace this with _main from tvm.driver.tvmc.main def _run_tvmc(cmd_args: list, *args, **kwargs): """Run a tvmc command and return the results""" cmd_args_list = TVMC_COMMAND + cmd_args @@ -74,8 +73,8 @@ def test_tvmc_model_build_only(board): tar_path = str(temp_dir / "model.tar") project_dir = str(temp_dir / "project") - runtime = str(Runtime("crt")) - executor = str(Executor("graph")) + runtime = "crt" + executor = "graph" cmd_result = _run_tvmc( [ @@ -99,18 +98,19 @@ def test_tvmc_model_build_only(board): ) assert cmd_result == 0, "tvmc failed in step: compile" - cmd_result = _run_tvmc( - [ - "micro", - "create-project", - project_dir, - tar_path, - platform, - "--project-option", - "project_type=host_driven", - f"{platform}_board={board}", - ] - ) + create_project_cmd = [ + "micro", + "create-project", + project_dir, + tar_path, + platform, + "--project-option", + "project_type=host_driven", + ] + if platform == "zephyr": + create_project_cmd.append(f"{platform}_board={board}") + + cmd_result = _run_tvmc(create_project_cmd) assert cmd_result == 0, "tvmc micro failed in step: create-project" cmd_result = _run_tvmc( @@ -129,8 +129,8 @@ def test_tvmc_model_run(board): tar_path = str(temp_dir / "model.tar") project_dir = str(temp_dir / "project") - runtime = str(Runtime("crt")) - executor = str(Executor("graph")) + runtime = "crt" + executor = "graph" cmd_result = _run_tvmc( [ @@ -154,18 +154,19 @@ def test_tvmc_model_run(board): ) assert cmd_result == 0, "tvmc failed in step: compile" - cmd_result = _run_tvmc( - [ - "micro", - "create-project", - project_dir, - tar_path, - platform, - "--project-option", - "project_type=host_driven", - f"{platform}_board={board}", - ] - ) + create_project_cmd = [ + "micro", + "create-project", + project_dir, + tar_path, + platform, + "--project-option", + "project_type=host_driven", + ] + if platform == "zephyr": + create_project_cmd.append(f"{platform}_board={board}") + + cmd_result = _run_tvmc(create_project_cmd) assert cmd_result == 0, "tvmc micro failed in step: create-project" cmd_result = _run_tvmc(