diff --git a/apps/microtvm/arduino/template_project/microtvm_api_server.py b/apps/microtvm/arduino/template_project/microtvm_api_server.py index bb4b54d8fb27..95f941fe3473 100644 --- a/apps/microtvm/arduino/template_project/microtvm_api_server.py +++ b/apps/microtvm/arduino/template_project/microtvm_api_server.py @@ -33,8 +33,9 @@ from string import Template import re -import serial +from packaging import version import serial.tools.list_ports + from tvm.micro.project_api import server _LOG = logging.getLogger(__name__) @@ -46,10 +47,7 @@ IS_TEMPLATE = not (API_SERVER_DIR / MODEL_LIBRARY_FORMAT_RELPATH).exists() -# Used to check Arduino CLI version installed on the host. -# We only check two levels of the version. -ARDUINO_CLI_VERSION = 0.18 - +MIN_ARDUINO_CLI_VERSION = version.parse("0.18.0") BOARDS = API_SERVER_DIR / "boards.json" @@ -113,7 +111,7 @@ class BoardAutodetectFailed(Exception): ), server.ProjectOption( "warning_as_error", - optional=["generate_project"], + optional=["build", "flash"], type="bool", help="Treat warnings as errors and raise an Exception.", ), @@ -126,6 +124,7 @@ def __init__(self): self._proc = None self._port = None self._serial = None + self._version = None def server_info_query(self, tvm_version): return server.ServerInfo( @@ -314,25 +313,7 @@ def _find_modified_include_path(self, project_dir, file_path, include_path): # It's probably a standard C/C++ header return include_path - def _get_platform_version(self, arduino_cli_path: str) -> float: - # sample output of this command: - # 'arduino-cli alpha Version: 0.18.3 Commit: d710b642 Date: 2021-05-14T12:36:58Z\n' - version_output = subprocess.check_output([arduino_cli_path, "version"], encoding="utf-8") - full_version = re.findall("version: ([\.0-9]*)", version_output.lower()) - full_version = full_version[0].split(".") - version = float(f"{full_version[0]}.{full_version[1]}") - - return version - def generate_project(self, model_library_format_path, standalone_crt_dir, project_dir, options): - # Check Arduino version - version = self._get_platform_version(self._get_arduino_cli_cmd(options)) - if version != ARDUINO_CLI_VERSION: - message = f"Arduino CLI version found is not supported: found {version}, expected {ARDUINO_CLI_VERSION}." - if options.get("warning_as_error") is not None and options["warning_as_error"]: - raise server.ServerError(message=message) - _LOG.warning(message) - # Reference key directories with pathlib project_dir = pathlib.Path(project_dir) project_dir.mkdir() @@ -368,11 +349,45 @@ def generate_project(self, model_library_format_path, standalone_crt_dir, projec # Recursively change includes self._convert_includes(project_dir, source_dir) + def _get_arduino_cli_cmd(self, options: dict): + arduino_cli_cmd = options.get("arduino_cli_cmd", ARDUINO_CLI_CMD) + assert arduino_cli_cmd, "'arduino_cli_cmd' command not passed and not found by default!" + return arduino_cli_cmd + + def _get_platform_version(self, arduino_cli_path: str) -> float: + # sample output of this command: + # 'arduino-cli alpha Version: 0.18.3 Commit: d710b642 Date: 2021-05-14T12:36:58Z\n' + version_output = subprocess.run( + [arduino_cli_path, "version"], check=True, stdout=subprocess.PIPE + ).stdout.decode("utf-8") + str_version = re.search(r"Version: ([\.0-9]*)", version_output).group(1) + + # Using too low a version should raise an error. Note that naively + # comparing floats will fail here: 0.7 > 0.21, but 0.21 is a higher + # version (hence we need version.parse) + return version.parse(str_version) + + # This will only be run for build and upload + def _check_platform_version(self, options): + if not self._version: + cli_command = self._get_arduino_cli_cmd(options) + self._version = self._get_platform_version(cli_command) + + if self._version < MIN_ARDUINO_CLI_VERSION: + message = ( + f"Arduino CLI version too old: found {self._version}, " + f"need at least {str(MIN_ARDUINO_CLI_VERSION)}." + ) + if options.get("warning_as_error") is not None and options["warning_as_error"]: + raise server.ServerError(message=message) + _LOG.warning(message) + def _get_fqbn(self, options): o = BOARD_PROPERTIES[options["arduino_board"]] return f"{o['package']}:{o['architecture']}:{o['board']}" def build(self, options): + self._check_platform_version(options) BUILD_DIR.mkdir() compile_cmd = [ @@ -391,19 +406,14 @@ def build(self, options): # Specify project to compile subprocess.run(compile_cmd, check=True) - BOARD_LIST_HEADERS = ("Port", "Type", "Board Name", "FQBN", "Core") + POSSIBLE_BOARD_LIST_HEADERS = ("Port", "Protocol", "Type", "Board Name", "FQBN", "Core") - def _get_arduino_cli_cmd(self, options: dict): - arduino_cli_cmd = options.get("arduino_cli_cmd", ARDUINO_CLI_CMD) - assert arduino_cli_cmd, "'arduino_cli_cmd' command not passed and not found by default!" - return arduino_cli_cmd - - def _parse_boards_tabular_str(self, tabular_str): + def _parse_connected_boards(self, tabular_str): """Parses the tabular output from `arduino-cli board list` into a 2D array Examples -------- - >>> list(_parse_boards_tabular_str(bytes( + >>> list(_parse_connected_boards(bytes( ... "Port Type Board Name FQBN Core \n" ... "/dev/ttyS4 Serial Port Unknown \n" ... "/dev/ttyUSB0 Serial Port (USB) Spresense SPRESENSE:spresense:spresense SPRESENSE:spresense\n" @@ -414,20 +424,21 @@ def _parse_boards_tabular_str(self, tabular_str): """ - str_rows = tabular_str.split("\n")[:-2] - header = str_rows[0] - indices = [header.index(h) for h in self.BOARD_LIST_HEADERS] + [len(header)] + # Which column headers are present depends on the version of arduino-cli + column_regex = r"\s*|".join(self.POSSIBLE_BOARD_LIST_HEADERS) + r"\s*" + str_rows = tabular_str.split("\n") + column_headers = list(re.finditer(column_regex, str_rows[0])) + assert len(column_headers) > 0 for str_row in str_rows[1:]: - parsed_row = [] - for cell_index in range(len(self.BOARD_LIST_HEADERS)): - start = indices[cell_index] - end = indices[cell_index + 1] - str_cell = str_row[start:end] + if not str_row.strip(): + continue + device = {} - # Remove trailing whitespace used for padding - parsed_row.append(str_cell.rstrip()) - yield parsed_row + for column in column_headers: + col_name = column.group(0).strip().lower() + device[col_name] = str_row[column.start() : column.end()].strip() + yield device def _auto_detect_port(self, options): list_cmd = [self._get_arduino_cli_cmd(options), "board", "list"] @@ -436,9 +447,9 @@ def _auto_detect_port(self, options): ).stdout.decode("utf-8") desired_fqbn = self._get_fqbn(options) - for line in self._parse_boards_tabular_str(list_cmd_output): - if line[3] == desired_fqbn: - return line[0] + for device in self._parse_connected_boards(list_cmd_output): + if device["fqbn"] == desired_fqbn: + return device["port"] # If no compatible boards, raise an error raise BoardAutodetectFailed() @@ -453,6 +464,7 @@ def _get_arduino_port(self, options): return self._port def flash(self, options): + self._check_platform_version(options) port = self._get_arduino_port(options) upload_cmd = [ diff --git a/apps/microtvm/arduino/template_project/tests/test_arduino_microtvm_api_server.py b/apps/microtvm/arduino/template_project/tests/test_arduino_microtvm_api_server.py index 00969a5a892b..34659bca5627 100644 --- a/apps/microtvm/arduino/template_project/tests/test_arduino_microtvm_api_server.py +++ b/apps/microtvm/arduino/template_project/tests/test_arduino_microtvm_api_server.py @@ -20,8 +20,11 @@ from pathlib import Path from unittest import mock +from packaging import version import pytest +from tvm.micro.project_api import server + sys.path.insert(0, str(Path(__file__).parent.parent)) import microtvm_api_server @@ -63,53 +66,102 @@ def test_find_modified_include_path(self, mock_pathlib_path): ) assert valid_output == valid_arduino_import - BOARD_CONNECTED_OUTPUT = bytes( + # Format for arduino-cli v0.18.2 + BOARD_CONNECTED_V18 = ( "Port Type Board Name FQBN Core \n" "/dev/ttyACM0 Serial Port (USB) Arduino Nano 33 BLE arduino:mbed_nano:nano33ble arduino:mbed_nano\n" "/dev/ttyACM1 Serial Port (USB) Arduino Nano 33 arduino:mbed_nano:nano33 arduino:mbed_nano\n" "/dev/ttyS4 Serial Port Unknown \n" - "\n", - "utf-8", + "\n" + ) + # Format for arduino-cli v0.21.1 and above + BOARD_CONNECTED_V21 = ( + "Port Protocol Type Board Name FQBN Core \n" + "/dev/ttyACM0 serial arduino:mbed_nano:nano33ble arduino:mbed_nano\n" + "\n" ) - BOARD_DISCONNECTED_OUTPUT = bytes( - "Port Type Board Name FQBN Core\n" - "/dev/ttyS4 Serial Port Unknown \n" - "\n", - "utf-8", + BOARD_DISCONNECTED_V21 = ( + "Port Protocol Type Board Name FQBN Core\n" + "/dev/ttyS4 serial Serial Port Unknown\n" + "\n" ) + def test_parse_connected_boards(self): + h = microtvm_api_server.Handler() + boards = h._parse_connected_boards(self.BOARD_CONNECTED_V21) + assert list(boards) == [ + { + "port": "/dev/ttyACM0", + "protocol": "serial", + "type": "", + "board name": "", + "fqbn": "arduino:mbed_nano:nano33ble", + "core": "arduino:mbed_nano", + } + ] + @mock.patch("subprocess.run") - def test_auto_detect_port(self, mock_subprocess_run): + def test_auto_detect_port(self, mock_run): process_mock = mock.Mock() handler = microtvm_api_server.Handler() # Test it returns the correct port when a board is connected - mock_subprocess_run.return_value.stdout = self.BOARD_CONNECTED_OUTPUT + mock_run.return_value.stdout = bytes(self.BOARD_CONNECTED_V18, "utf-8") + assert handler._auto_detect_port(self.DEFAULT_OPTIONS) == "/dev/ttyACM0" + + # Should work with old or new arduino-cli version + mock_run.return_value.stdout = bytes(self.BOARD_CONNECTED_V21, "utf-8") assert handler._auto_detect_port(self.DEFAULT_OPTIONS) == "/dev/ttyACM0" # Test it raises an exception when no board is connected - mock_subprocess_run.return_value.stdout = self.BOARD_DISCONNECTED_OUTPUT + mock_run.return_value.stdout = bytes(self.BOARD_DISCONNECTED_V21, "utf-8") with pytest.raises(microtvm_api_server.BoardAutodetectFailed): handler._auto_detect_port(self.DEFAULT_OPTIONS) # Test that the FQBN needs to match EXACTLY handler._get_fqbn = mock.MagicMock(return_value="arduino:mbed_nano:nano33") - mock_subprocess_run.return_value.stdout = self.BOARD_CONNECTED_OUTPUT + mock_run.return_value.stdout = bytes(self.BOARD_CONNECTED_V18, "utf-8") assert ( handler._auto_detect_port({**self.DEFAULT_OPTIONS, "arduino_board": "nano33"}) == "/dev/ttyACM1" ) + BAD_CLI_VERSION = "arduino-cli Version: 0.7.1 Commit: 7668c465 Date: 2019-12-31T18:24:32Z\n" + GOOD_CLI_VERSION = "arduino-cli Version: 0.21.1 Commit: 9fcbb392 Date: 2022-02-24T15:41:45Z\n" + + @mock.patch("subprocess.run") + def test_auto_detect_port(self, mock_run): + handler = microtvm_api_server.Handler() + mock_run.return_value.stdout = bytes(self.GOOD_CLI_VERSION, "utf-8") + handler._check_platform_version(self.DEFAULT_OPTIONS) + assert handler._version == version.parse("0.21.1") + + handler = microtvm_api_server.Handler() + mock_run.return_value.stdout = bytes(self.BAD_CLI_VERSION, "utf-8") + with pytest.raises(server.ServerError) as error: + handler._check_platform_version({"warning_as_error": True}) + @mock.patch("subprocess.run") - def test_flash(self, mock_subprocess_run): + def test_flash(self, mock_run): + mock_run.return_value.stdout = bytes(self.GOOD_CLI_VERSION, "utf-8") + handler = microtvm_api_server.Handler() handler._port = "/dev/ttyACM0" # Test no exception thrown when command works handler.flash(self.DEFAULT_OPTIONS) - mock_subprocess_run.assert_called_once() + + # Test we checked version then called upload + assert mock_run.call_count == 2 + assert mock_run.call_args_list[0][0] == (["arduino-cli", "version"],) + assert mock_run.call_args_list[1][0][0][0:2] == ["arduino-cli", "upload"] + mock_run.reset_mock() # Test exception raised when `arduino-cli upload` returns error code - mock_subprocess_run.side_effect = subprocess.CalledProcessError(2, []) + mock_run.side_effect = subprocess.CalledProcessError(2, []) with pytest.raises(subprocess.CalledProcessError): handler.flash(self.DEFAULT_OPTIONS) + + # Version information should be cached and not checked again + mock_run.assert_called_once() + assert mock_run.call_args[0][0][0:2] == ["arduino-cli", "upload"]