From 3f1d068bcf398ac88828f97b31e26abeed5e02bb Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 11 Mar 2022 08:45:43 -0600 Subject: [PATCH 1/8] [Hexagon] Improved ergonomics of HexagonLauncher in unit tests. The goal of this commit is to reduce/eliminate common code required through unit tests that interact with Hexagon hardware. - New testing fixtures in `tests/python/contrib/test_hexagon`. A test running on hexagon hardware should only need to use the `hexagon_session` fixture. - `rpc_server_port`: Iterates through port numbers, selecting an unused port for each unit test. Avoids needing to explicitly specify unique ports for each unit test. - `tvm_tracker`: Starts a tracker on use, exits after test. Avoids needing to manually start a tracker prior to running the unit test. - `hexagon_launcher`: Starts a `HexagonLauncher` server on use, stops server after test. Avoids needing to call `start_server()` and `stop_server()` in each test. - `hexagon_session`: Starts a hexagon session using `hexagon_laucnehr.start_session()`, exits after test. - Added `Session.upload` function, which delegates to `HexagonLauncher.upload`. Avoids needing to interact with both the launcher and the session. - Allowed `tvm.IRModule` as argument passed to `Session.load_module`, which will automatically save/upload the module, then load it. Avoids needing to handle save/upload of temporary files in each unit test. --- python/tvm/contrib/hexagon/build.py | 25 +- python/tvm/contrib/hexagon/session.py | 58 ++++- tests/python/contrib/test_hexagon/conftest.py | 93 ++++++- .../test_hexagon/test_cache_read_write.py | 47 ++-- .../contrib/test_hexagon/test_launcher.py | 243 ++++++------------ 5 files changed, 250 insertions(+), 216 deletions(-) diff --git a/python/tvm/contrib/hexagon/build.py b/python/tvm/contrib/hexagon/build.py index 2858c4865be6..2b0e0b1da2bb 100644 --- a/python/tvm/contrib/hexagon/build.py +++ b/python/tvm/contrib/hexagon/build.py @@ -195,19 +195,27 @@ def start_session(self) -> Session: "timeout": 0, "key": self.HEXAGON_REMOTE_DEVICE_KEY, } - return Session(hexagon_remote_kw) + return Session(self, hexagon_remote_kw) - def load_module(self, module_name: Union[str, pathlib.Path], session: Session): + def load_module(self, module: Union[str, pathlib.Path], session: Session): """Load TVM module. Parameters ---------- - module_name : str or pathlib.Path - Name of the module to load. It must be either a bare file name - (without any path components), or a full path in the remote - system. If it is a file name, the file must be placed in the - remote workspace. + module : Union[str, pathlib.Path, tvm.runtime.Module] + + The module to load. If `module` is a + `tvm.runtime.Module`, it will be uploaded to the remote + session and loaded. + + If the object passed is a string or pathlib.Path, it must + be either a bare file name (without any path components), + or a full path in the remote system. If it is a file name, + the file must already have been uploaded to the remote, + and be placed in the remote workspace. + session : Session + Remote session. The session must be established (via __enter__) prior to calling this function. @@ -215,8 +223,9 @@ def load_module(self, module_name: Union[str, pathlib.Path], session: Session): ------- TVMModule : TVM module object. + """ - return session.load_module(module_name) + return session.load_module(module) def get_graph_executor( self, graph_json: str, module_name: Union[str, pathlib.Path], session: Session diff --git a/python/tvm/contrib/hexagon/session.py b/python/tvm/contrib/hexagon/session.py index 2d3f075daa05..e6f2dab966ed 100644 --- a/python/tvm/contrib/hexagon/session.py +++ b/python/tvm/contrib/hexagon/session.py @@ -19,7 +19,10 @@ import os import pathlib +import tempfile from typing import Union + +import tvm from tvm import rpc as _rpc @@ -37,10 +40,12 @@ class Session: def __init__( self, + launcher: "HexagonLauncherRPC", remote_kw: dict, session_name: str = "hexagon-rpc", remote_stack_size_bytes: int = 128 * 1024, ): + self._launcher = launcher self._session_name = session_name self._remote_stack_size_bytes = remote_stack_size_bytes self._remote_kw = remote_kw @@ -74,6 +79,53 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, exc_traceback): pass - def load_module(self, path: Union[str, pathlib.Path]): - assert isinstance(path, (str, pathlib.Path)), "Invalid path type:" + str(type(path)) - return self._rpc.get_function("tvm.hexagon.load_module")(str(path)) + def upload(self, local_path: Union[str, pathlib.Path], remote_filename: str): + """Upload a local file to the remote workspace. + + Parameters + ---------- + local_path : str or pathlib.Path + Path to the local file to be copied. + remote_filename : str + Name of the file in the remote workspace. + """ + self._launcher.upload(local_path, remote_filename) + + def load_module(self, module: Union[str, pathlib.Path, tvm.IRModule]): + """Load TVM module. + + Parameters + ---------- + module : Union[str, pathlib.Path, tvm.runtime.Module] + + The module to load. If `module` is a + `tvm.runtime.Module`, it will be uploaded to the remote + session and loaded. + + If the object passed is a string or pathlib.Path, it must + be either a bare file name (without any path components), + or a full path in the remote system. If it is a file name, + the file must already have been uploaded to the remote, + and be placed in the remote workspace. + + session : Session + + Remote session. The session must be established (via __enter__) + prior to calling this function. + + Returns + ------- + TVMModule : + TVM module object. + """ + if isinstance(module, tvm.runtime.Module): + with tempfile.TemporaryDirectory() as temp_dir: + temp_dir = pathlib.Path(temp_dir) + binary_name = "test_binary.so" + binary_path = temp_dir / binary_name + module.save(str(binary_path)) + self.upload(binary_path, binary_name) + module = binary_name + + assert isinstance(module, (str, pathlib.Path)), "Invalid path type:" + str(type(module)) + return self._rpc.get_function("tvm.hexagon.load_module")(str(module)) diff --git a/tests/python/contrib/test_hexagon/conftest.py b/tests/python/contrib/test_hexagon/conftest.py index 2f2c5703fb2c..8aba576747ba 100644 --- a/tests/python/contrib/test_hexagon/conftest.py +++ b/tests/python/contrib/test_hexagon/conftest.py @@ -19,10 +19,15 @@ values from testing parameters """ import os +import random +import socket +from typing import Optional + import pytest import tvm -from tvm import rpc +import tvm.rpc.tracker +from tvm.contrib.hexagon.build import HexagonLauncher HEXAGON_TOOLCHAIN = "HEXAGON_TOOLCHAIN" TVM_TRACKER_HOST = "TVM_TRACKER_HOST" @@ -59,8 +64,13 @@ def requires_hexagon_toolchain(*args): @tvm.testing.fixture -def android_serial_number() -> str: - return os.getenv(ANDROID_SERIAL_NUMBER, default=None) +def android_serial_number() -> Optional[str]: + serial = os.getenv(ANDROID_SERIAL_NUMBER, default="") + # Setting ANDROID_SERIAL_NUMBER to an empty string should be + # equivalent to having it unset. + if not serial.strip(): + serial = None + return serial @tvm.testing.fixture @@ -75,11 +85,88 @@ def tvm_tracker_port() -> int: return port +# NOTE on server ports: +# These tests use different port numbers for the RPC server (7070 + ...). +# The reason is that an RPC session cannot be gracefully closed without +# triggering TIME_WAIT state on the server socket. This prevents another +# server to bind to the same port until the wait time elapses. + +# rpc_port_min = 1024 # Lowest unprivileged port +rpc_port_min = 2000 # Well above the privileged ports (1024 or lower) +rpc_port_max = 9000 # Below the search range end (port_end=9199) of RPC server +previous_port = [None] + + +@tvm.testing.fixture +def rpc_server_port() -> int: + print(rpc_port_min) + + # https://stackoverflow.com/a/52872579/2689797 + def is_port_in_use(port: int) -> bool: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return s.connect_ex(("localhost", port)) == 0 + + if previous_port[0] is None: + port = random.randint(rpc_port_min, rpc_port_max) + else: + port = previous_port[0] + 1 + + while is_port_in_use(port): + port = port + 1 if port < rpc_port_max else rpc_port_min + + previous_port[0] = port + return port + + @tvm.testing.fixture def adb_server_socket() -> str: return os.getenv(ADB_SERVER_SOCKET, default="tcp:5037") +@tvm.testing.fixture +def tvm_tracker(tvm_tracker_port): + tracker = tvm.rpc.tracker.Tracker("127.0.0.1", tvm_tracker_port) + try: + yield tracker + finally: + tracker.terminate() + + +@tvm.testing.fixture +def rpc_info(tvm_tracker, rpc_server_port, adb_server_socket): + return { + "rpc_tracker_host": tvm_tracker.host, + "rpc_tracker_port": tvm_tracker.port, + "rpc_server_port": rpc_server_port, + "adb_server_socket": adb_server_socket, + } + + +@tvm.testing.fixture +def hexagon_launcher(android_serial_number, tvm_tracker, rpc_server_port, adb_server_socket): + if android_serial_number is None: + yield None + else: + rpc_info = { + "rpc_tracker_host": tvm_tracker.host, + "rpc_tracker_port": tvm_tracker.port, + "rpc_server_port": rpc_server_port, + "adb_server_socket": adb_server_socket, + } + launcher = HexagonLauncher(serial_number=android_serial_number, rpc_info=rpc_info) + launcher.start_server() + try: + yield launcher + finally: + launcher.stop_server() + + +@tvm.testing.fixture +def hexagon_session(hexagon_launcher): + with hexagon_launcher.start_session() as session: + yield session + + # If the execution aborts while an RPC server is running, the python # code that is supposed to shut it dowm will never execute. This will # keep pytest from terminating (indefinitely), so add a cleanup diff --git a/tests/python/contrib/test_hexagon/test_cache_read_write.py b/tests/python/contrib/test_hexagon/test_cache_read_write.py index 273f8c25b0f5..6bcd852424bf 100644 --- a/tests/python/contrib/test_hexagon/test_cache_read_write.py +++ b/tests/python/contrib/test_hexagon/test_cache_read_write.py @@ -63,9 +63,7 @@ def intrin_func(ins, outs): @requires_hexagon_toolchain -def test_cache_read_write( - android_serial_number, tvm_tracker_host, tvm_tracker_port, adb_server_socket -): +def test_cache_read_write(hexagon_session): size = 128 outer_shape = (size,) factor = 16 @@ -105,37 +103,24 @@ def test_cache_read_write( func = tvm.build( s, [x, y, z], tvm.target.Target(target_hexagon, host=target_hexagon), name="dmacpy" ) - temp = utils.tempdir() - dso_binary = "test_binary.so" - dso_binary_path = temp.relpath(dso_binary) - func.save(dso_binary_path) - if not android_serial_number: + if hexagon_session is None: pytest.skip("Skip hardware test since ANDROID_SERIAL_NUMBER is not set.") - rpc_info = { - "rpc_tracker_host": tvm_tracker_host, - "rpc_tracker_port": tvm_tracker_port, - "rpc_server_port": 7070, - "adb_server_socket": adb_server_socket, - } - launcher = HexagonLauncher(serial_number=android_serial_number, rpc_info=rpc_info) - launcher.upload(dso_binary_path, dso_binary) - launcher.start_server() - - with launcher.start_session() as sess: - mod = launcher.load_module(dso_binary, sess) - xt = tvm.nd.array( - np.random.randint(-128, high=127, size=size, dtype=x.dtype), device=sess.device - ) - yt = tvm.nd.array( - np.random.randint(-128, high=127, size=size, dtype=x.dtype), device=sess.device - ) - zt = tvm.nd.array( - np.random.randint(-128, high=127, size=size, dtype=x.dtype), device=sess.device - ) - mod["dmacpy"](xt, yt, zt) - launcher.stop_server() + mod = hexagon_session.load_module(func) + xt = tvm.nd.array( + np.random.randint(low=-128, high=127, size=size, dtype=x.dtype), + device=hexagon_session.device, + ) + yt = tvm.nd.array( + np.random.randint(low=-128, high=127, size=size, dtype=y.dtype), + device=hexagon_session.device, + ) + zt = tvm.nd.array( + np.random.randint(low=-128, high=127, size=size, dtype=z.dtype), + device=hexagon_session.device, + ) + mod["dmacpy"](xt, yt, zt) ref = xt.numpy() + yt.numpy() np.testing.assert_equal(zt.numpy(), ref) diff --git a/tests/python/contrib/test_hexagon/test_launcher.py b/tests/python/contrib/test_hexagon/test_launcher.py index 00d68ee3b559..3e72c38f1909 100644 --- a/tests/python/contrib/test_hexagon/test_launcher.py +++ b/tests/python/contrib/test_hexagon/test_launcher.py @@ -42,7 +42,7 @@ @requires_hexagon_toolchain -def test_add(android_serial_number, tvm_tracker_host, tvm_tracker_port, adb_server_socket): +def test_add(hexagon_session): dtype = "int8" A = tvm.te.placeholder((2,), dtype=dtype) B = tvm.te.placeholder((1,), dtype=dtype) @@ -54,40 +54,23 @@ def test_add(android_serial_number, tvm_tracker_host, tvm_tracker_port, adb_serv sched, [A, B, C], tvm.target.Target(target_hexagon, host=target_hexagon), name="add" ) - temp = utils.tempdir() - dso_binary = "test_binary.so" - dso_binary_path = temp.relpath(dso_binary) - func.save(dso_binary_path) + if hexagon_session is None: + pytest.skip(msg="Skip hardware test, ANDROID_SERIAL_NUMBER is not set.") - if not android_serial_number: - pytest.skip(msg="Skip hardware test since ANDROID_SERIAL_NUMBER is not set.") + mod = hexagon_session.load_module(func) - rpc_info = { - "rpc_tracker_host": tvm_tracker_host, - "rpc_tracker_port": tvm_tracker_port, - "rpc_server_port": RPC_SERVER_PORT + 0, # See note at the beginning of the file - "adb_server_socket": adb_server_socket, - } - launcher = HexagonLauncher(serial_number=android_serial_number, rpc_info=rpc_info) - launcher.upload(dso_binary_path, dso_binary) - launcher.start_server() - - with launcher.start_session() as sess: - mod = launcher.load_module(dso_binary, sess) - A_data = tvm.nd.array(np.array([2, 3], dtype=dtype), device=sess.device) - assert (A_data.numpy() == np.array([2, 3])).all() - B_data = tvm.nd.array(np.array([4], dtype=dtype), device=sess.device) - assert (B_data.numpy() == np.array([4])).all() - C_data = tvm.nd.array(np.array([0, 0], dtype=dtype), device=sess.device) - assert (C_data.numpy() == np.array([0, 0])).all() - mod["add"](A_data, B_data, C_data) - assert (C_data.numpy() == np.array([6, 7])).all() - - launcher.stop_server() + A_data = tvm.nd.array(np.array([2, 3], dtype=dtype), device=hexagon_session.device) + assert (A_data.numpy() == np.array([2, 3])).all() + B_data = tvm.nd.array(np.array([4], dtype=dtype), device=hexagon_session.device) + assert (B_data.numpy() == np.array([4])).all() + C_data = tvm.nd.array(np.array([0, 0], dtype=dtype), device=hexagon_session.device) + assert (C_data.numpy() == np.array([0, 0])).all() + mod["add"](A_data, B_data, C_data) + assert (C_data.numpy() == np.array([6, 7])).all() @requires_hexagon_toolchain -def test_add_vtcm(android_serial_number, tvm_tracker_host, tvm_tracker_port, adb_server_socket): +def test_add_vtcm(hexagon_session): dtype = "int8" A = tvm.te.placeholder((2,), dtype=dtype) B = tvm.te.placeholder((1,), dtype=dtype) @@ -99,40 +82,22 @@ def test_add_vtcm(android_serial_number, tvm_tracker_host, tvm_tracker_port, adb sched, [A, B, C], tvm.target.Target(target_hexagon, host=target_hexagon), name="add" ) - temp = utils.tempdir() - dso_binary = "test_binary.so" - dso_binary_path = temp.relpath(dso_binary) - func.save(dso_binary_path) - - if not android_serial_number: - pytest.skip(msg="Skip hardware test since ANDROID_SERIAL_NUMBER is not set.") - - rpc_info = { - "rpc_tracker_host": tvm_tracker_host, - "rpc_tracker_port": tvm_tracker_port, - "rpc_server_port": RPC_SERVER_PORT + 1, # See note at the beginning of the file - "adb_server_socket": adb_server_socket, - } - launcher = HexagonLauncher(serial_number=android_serial_number, rpc_info=rpc_info) - launcher.upload(dso_binary_path, dso_binary) - launcher.start_server() - - with launcher.start_session() as sess: - mod = launcher.load_module(dso_binary, sess) - A_data = tvm.nd.empty(A.shape, A.dtype, sess.device, "global.vtcm") - A_data.copyfrom(np.array([2, 3])) + if hexagon_session is None: + pytest.skip(msg="Skip hardware test, ANDROID_SERIAL_NUMBER is not set.") - B_data = tvm.nd.empty(B.shape, B.dtype, sess.device, "global.vtcm") - B_data.copyfrom(np.array([4])) + mod = hexagon_session.load_module(func) + A_data = tvm.nd.empty(A.shape, A.dtype, hexagon_session.device, "global.vtcm") + A_data.copyfrom(np.array([2, 3])) - C_data = tvm.nd.empty(C.shape, C.dtype, sess.device, "global.vtcm") - C_data.copyfrom(np.array([0, 0])) + B_data = tvm.nd.empty(B.shape, B.dtype, hexagon_session.device, "global.vtcm") + B_data.copyfrom(np.array([4])) - mod["add"](A_data, B_data, C_data) - result = C_data.numpy() - assert (result == np.array([6, 7])).all() + C_data = tvm.nd.empty(C.shape, C.dtype, hexagon_session.device, "global.vtcm") + C_data.copyfrom(np.array([0, 0])) - launcher.stop_server() + mod["add"](A_data, B_data, C_data) + result = C_data.numpy() + assert (result == np.array([6, 7])).all() class TestMatMul: @@ -141,9 +106,7 @@ class TestMatMul: K = tvm.testing.parameter(32) @requires_hexagon_toolchain - def test_matmul( - self, android_serial_number, tvm_tracker_host, tvm_tracker_port, adb_server_socket, M, N, K - ): + def test_matmul(self, hexagon_session, M, N, K): X = te.placeholder((M, K), dtype="float32") Y = te.placeholder((K, N), dtype="float32") k1 = te.reduce_axis((0, K), name="k1") @@ -155,35 +118,19 @@ def test_matmul( schedule, [X, Y, Z], tvm.target.Target(target_hexagon, host=target_hexagon) ) - temp = utils.tempdir() - dso_binary = "test_binary.so" - dso_binary_path = temp.relpath(dso_binary) - func.save(dso_binary_path) + if hexagon_session is None: + pytest.skip(msg="Skip hardware test, ANDROID_SERIAL_NUMBER is not set.") - if not android_serial_number: - pytest.skip(msg="Skip hardware test since ANDROID_SERIAL_NUMBER is not set.") - - rpc_info = { - "rpc_tracker_host": tvm_tracker_host, - "rpc_tracker_port": tvm_tracker_port, - "rpc_server_port": RPC_SERVER_PORT + 2, # See note at the beginning of the file - "adb_server_socket": adb_server_socket, - } - launcher = HexagonLauncher(serial_number=android_serial_number, rpc_info=rpc_info) - launcher.upload(dso_binary_path, dso_binary) - launcher.start_server() + mod = hexagon_session.load_module(func) x = np.random.uniform(size=[i.value for i in X.shape]).astype(X.dtype) y = np.random.uniform(size=[i.value for i in Y.shape]).astype(Y.dtype) z = np.zeros([i.value for i in Z.shape], dtype=Z.dtype) - with launcher.start_session() as sess: - mod = launcher.load_module(dso_binary, sess) - xt = tvm.nd.array(x, device=sess.device) - yt = tvm.nd.array(y, device=sess.device) - zt = tvm.nd.array(z, device=sess.device) - mod(xt, yt, zt) - launcher.stop_server() + xt = tvm.nd.array(x, device=hexagon_session.device) + yt = tvm.nd.array(y, device=hexagon_session.device) + zt = tvm.nd.array(z, device=hexagon_session.device) + mod(xt, yt, zt) target_llvm = tvm.target.Target("llvm") mod = tvm.build(schedule, [X, Y, Z], tvm.target.Target(target_llvm, host=target_llvm)) @@ -197,9 +144,7 @@ def test_matmul( @requires_hexagon_toolchain -def test_graph_executor( - android_serial_number, tvm_tracker_host, tvm_tracker_port, adb_server_socket -): +def test_graph_executor(hexagon_launcher, hexagon_session): dtype = "float32" data = relay.var("data", relay.TensorType((1, 64, 64, 3), dtype)) weight = relay.var("weight", relay.TensorType((5, 5, 3, 8), dtype)) @@ -220,15 +165,15 @@ def test_graph_executor( runtime = Runtime("cpp") executor = Executor("graph") - temp = utils.tempdir() - dso_binary = "test_binary.so" - dso_binary_path = temp.relpath(dso_binary) - weight_in = np.random.rand(5, 5, 3, 8).astype(dtype=dtype) data_in = np.random.rand(1, 64, 64, 3).astype(dtype=dtype) params = {"weight": weight_in} inputs = {"data": data_in} + temp = utils.tempdir() + dso_binary = "test_binary.so" + dso_binary_path = temp.relpath(dso_binary) + with tvm.transform.PassContext(opt_level=3): lowered = tvm.relay.build( relay_mod, @@ -238,26 +183,17 @@ def test_graph_executor( ) lowered.get_lib().save(dso_binary_path) - if not android_serial_number: + if hexagon_session is None: pytest.skip(msg="Skip hardware test since ANDROID_SERIAL_NUMBER is not set.") - rpc_info = { - "rpc_tracker_host": tvm_tracker_host, - "rpc_tracker_port": tvm_tracker_port, - "rpc_server_port": RPC_SERVER_PORT + 3, # See note at the beginning of the file - "adb_server_socket": adb_server_socket, - } - launcher = HexagonLauncher(serial_number=android_serial_number, rpc_info=rpc_info) - launcher.upload(dso_binary_path, dso_binary) - launcher.start_server() - - with launcher.start_session() as sess: - graph_mod = launcher.get_graph_executor(lowered.get_graph_json(), dso_binary, sess) - graph_mod.set_input(**params) - graph_mod.run(**inputs) - hexagon_output = graph_mod.get_output(0).numpy() + hexagon_launcher.upload(dso_binary_path, dso_binary) - launcher.stop_server() + graph_mod = hexagon_launcher.get_graph_executor( + lowered.get_graph_json(), dso_binary, hexagon_session + ) + graph_mod.set_input(**params) + graph_mod.run(**inputs) + hexagon_output = graph_mod.get_output(0).numpy() target_llvm = tvm.target.Target("llvm") with tvm.transform.PassContext(opt_level=3): @@ -276,9 +212,7 @@ def test_graph_executor( @requires_hexagon_toolchain -def test_graph_executor_multiple_conv2d( - tvm_tracker_host, tvm_tracker_port, android_serial_number, adb_server_socket -): +def test_graph_executor_multiple_conv2d(hexagon_launcher, hexagon_session): dtype = "float32" input_shape = (1, 8, 8, 3) w1_shape = (5, 5, 3, 1) @@ -325,18 +259,10 @@ def test_graph_executor_multiple_conv2d( ) lowered.get_lib().save(dso_binary_path) - if not android_serial_number: + if hexagon_session is None: pytest.skip(msg="Skip hardware test since ANDROID_SERIAL_NUMBER is not set.") - rpc_info = { - "rpc_tracker_host": tvm_tracker_host, - "rpc_tracker_port": tvm_tracker_port, - "rpc_server_port": RPC_SERVER_PORT + 4, # See note at the beginning of the file - "adb_server_socket": adb_server_socket, - } - launcher = HexagonLauncher(serial_number=android_serial_number, rpc_info=rpc_info) - launcher.upload(dso_binary_path, dso_binary) - launcher.start_server() + hexagon_launcher.upload(dso_binary_path, dso_binary) weight1_data = np.random.rand(w1_shape[0], w1_shape[1], w1_shape[2], w1_shape[3]).astype( dtype=dtype @@ -351,13 +277,12 @@ def test_graph_executor_multiple_conv2d( params = {"weight1": weight1_data, "weight2": weight2_data} inputs = {"data": input_data} - with launcher.start_session() as sess: - graph_mod = launcher.get_graph_executor(lowered.get_graph_json(), dso_binary, sess) - graph_mod.set_input(**params) - graph_mod.run(**inputs) - hexagon_output = graph_mod.get_output(0).numpy() - - launcher.stop_server() + graph_mod = hexagon_launcher.get_graph_executor( + lowered.get_graph_json(), dso_binary, hexagon_session + ) + graph_mod.set_input(**params) + graph_mod.run(**inputs) + hexagon_output = graph_mod.get_output(0).numpy() target_llvm = tvm.target.Target("llvm") with tvm.transform.PassContext(opt_level=3): @@ -387,7 +312,7 @@ def _workaround_create_aot_shared(): @requires_hexagon_toolchain -def test_aot_executor(tvm_tracker_host, tvm_tracker_port, android_serial_number, adb_server_socket): +def test_aot_executor(hexagon_launcher, hexagon_session): dtype = "float32" input_shape = (1, 128, 128, 3) w_shape = (5, 5, 3, 8) @@ -435,26 +360,15 @@ def test_aot_executor(tvm_tracker_host, tvm_tracker_port, android_serial_number, dso_binary_path, fcompile=_workaround_create_aot_shared(), hexagon_arch="v68" ) - if not android_serial_number: - pytest.skip(msg="Skip hardware test since ANDROID_SERIAL_NUMBER is not set.") + if hexagon_session is None: + pytest.skip(msg="Skip hardware test, ANDROID_SERIAL_NUMBER is not set.") - rpc_info = { - "rpc_tracker_host": tvm_tracker_host, - "rpc_tracker_port": tvm_tracker_port, - "rpc_server_port": RPC_SERVER_PORT + 5, # See note at the beginning of the file - "adb_server_socket": adb_server_socket, - } - launcher = HexagonLauncher(serial_number=android_serial_number, rpc_info=rpc_info) - launcher.upload(dso_binary_path, dso_binary) - launcher.start_server() + hexagon_launcher.upload(dso_binary_path, dso_binary) - with launcher.start_session() as sess: - aot_mod = launcher.get_aot_executor(dso_binary, sess) - aot_mod.set_input(**inputs) - aot_mod.run() - hexagon_output = aot_mod.get_output(0).numpy() - - launcher.stop_server() + aot_mod = hexagon_launcher.get_aot_executor(dso_binary, hexagon_session) + aot_mod.set_input(**inputs) + aot_mod.run() + hexagon_output = aot_mod.get_output(0).numpy() target_llvm = tvm.target.Target("llvm") with tvm.transform.PassContext(opt_level=3): @@ -474,9 +388,7 @@ def test_aot_executor(tvm_tracker_host, tvm_tracker_port, android_serial_number, @requires_hexagon_toolchain -def test_aot_executor_multiple_conv2d( - tvm_tracker_host, tvm_tracker_port, android_serial_number, adb_server_socket -): +def test_aot_executor_multiple_conv2d(hexagon_launcher, hexagon_session): dtype = "float32" input_shape = (1, 8, 8, 3) w1_shape = (5, 5, 3, 1) @@ -540,26 +452,15 @@ def test_aot_executor_multiple_conv2d( dso_binary_path, fcompile=_workaround_create_aot_shared(), hexagon_arch="v68" ) - if not android_serial_number: - pytest.skip(msg="Skip hardware test since ANDROID_SERIAL_NUMBER is not set.") + if hexagon_session is None: + pytest.skip(msg="Skip hardware test, ANDROID_SERIAL_NUMBER is not set.") + + hexagon_launcher.upload(dso_binary_path, dso_binary) - rpc_info = { - "rpc_tracker_host": tvm_tracker_host, - "rpc_tracker_port": tvm_tracker_port, - "rpc_server_port": RPC_SERVER_PORT + 6, # See note at the beginning of the file - "adb_server_socket": adb_server_socket, - } - launcher = HexagonLauncher(serial_number=android_serial_number, rpc_info=rpc_info) - launcher.upload(dso_binary_path, dso_binary) - launcher.start_server() - - with launcher.start_session() as sess: - aot_mod = launcher.get_aot_executor(dso_binary, sess) - aot_mod.set_input(**inputs) - aot_mod.run() - hexagon_output = aot_mod.get_output(0).numpy() - - launcher.stop_server() + aot_mod = hexagon_launcher.get_aot_executor(dso_binary, hexagon_session) + aot_mod.set_input(**inputs) + aot_mod.run() + hexagon_output = aot_mod.get_output(0).numpy() target_llvm = tvm.target.Target("llvm") with tvm.transform.PassContext(opt_level=3): From ebdf9aaba661069a195a50f50341427528292ea0 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 14 Mar 2022 09:18:45 -0500 Subject: [PATCH 2/8] Added default port for tracker if not already set. --- tests/python/contrib/test_hexagon/conftest.py | 41 ++++++++++--------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/tests/python/contrib/test_hexagon/conftest.py b/tests/python/contrib/test_hexagon/conftest.py index 8aba576747ba..57ea13db066b 100644 --- a/tests/python/contrib/test_hexagon/conftest.py +++ b/tests/python/contrib/test_hexagon/conftest.py @@ -78,52 +78,43 @@ def tvm_tracker_host() -> str: return os.getenv(TVM_TRACKER_HOST, default=None) -@tvm.testing.fixture -def tvm_tracker_port() -> int: - port = os.getenv(TVM_TRACKER_PORT, default=None) - port = int(port) if port else None - return port - - # NOTE on server ports: # These tests use different port numbers for the RPC server (7070 + ...). # The reason is that an RPC session cannot be gracefully closed without # triggering TIME_WAIT state on the server socket. This prevents another # server to bind to the same port until the wait time elapses. -# rpc_port_min = 1024 # Lowest unprivileged port -rpc_port_min = 2000 # Well above the privileged ports (1024 or lower) -rpc_port_max = 9000 # Below the search range end (port_end=9199) of RPC server +listen_port_min = 2000 # Well above the privileged ports (1024 or lower) +listen_port_max = 9000 # Below the search range end (port_end=9199) of RPC server previous_port = [None] -@tvm.testing.fixture -def rpc_server_port() -> int: - print(rpc_port_min) - +def get_free_port(): # https://stackoverflow.com/a/52872579/2689797 def is_port_in_use(port: int) -> bool: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: return s.connect_ex(("localhost", port)) == 0 if previous_port[0] is None: - port = random.randint(rpc_port_min, rpc_port_max) + port = random.randint(listen_port_min, listen_port_max) else: port = previous_port[0] + 1 while is_port_in_use(port): - port = port + 1 if port < rpc_port_max else rpc_port_min + port = port + 1 if port < listen_port_max else listen_port_min previous_port[0] = port return port -@tvm.testing.fixture -def adb_server_socket() -> str: - return os.getenv(ADB_SERVER_SOCKET, default="tcp:5037") +@pytest.fixture(scope="session") +def tvm_tracker_port() -> int: + port = os.getenv(TVM_TRACKER_PORT, default=None) + port = int(port) if port else get_free_port() + return port -@tvm.testing.fixture +@pytest.fixture(scope="session") def tvm_tracker(tvm_tracker_port): tracker = tvm.rpc.tracker.Tracker("127.0.0.1", tvm_tracker_port) try: @@ -132,6 +123,16 @@ def tvm_tracker(tvm_tracker_port): tracker.terminate() +@tvm.testing.fixture +def rpc_server_port() -> int: + return get_free_port() + + +@tvm.testing.fixture +def adb_server_socket() -> str: + return os.getenv(ADB_SERVER_SOCKET, default="tcp:5037") + + @tvm.testing.fixture def rpc_info(tvm_tracker, rpc_server_port, adb_server_socket): return { From 2bbee2232fd5d2e618ad931c27476d0295036572 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 14 Mar 2022 15:44:21 -0500 Subject: [PATCH 3/8] Pass through None from hexagon_launcher to hexagon_session. --- tests/python/contrib/test_hexagon/conftest.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/python/contrib/test_hexagon/conftest.py b/tests/python/contrib/test_hexagon/conftest.py index 57ea13db066b..b9fee1f3795b 100644 --- a/tests/python/contrib/test_hexagon/conftest.py +++ b/tests/python/contrib/test_hexagon/conftest.py @@ -164,8 +164,11 @@ def hexagon_launcher(android_serial_number, tvm_tracker, rpc_server_port, adb_se @tvm.testing.fixture def hexagon_session(hexagon_launcher): - with hexagon_launcher.start_session() as session: - yield session + if hexagon_launcher is None: + yield None + else: + with hexagon_launcher.start_session() as session: + yield session # If the execution aborts while an RPC server is running, the python From 52fda7ec61c632e8a6ea52cc77bff56e6a36861d Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 21 Mar 2022 09:31:51 -0500 Subject: [PATCH 4/8] Updated launcher to use external tracker if specified. --- tests/python/contrib/test_hexagon/conftest.py | 82 ++++++++++++------- 1 file changed, 54 insertions(+), 28 deletions(-) diff --git a/tests/python/contrib/test_hexagon/conftest.py b/tests/python/contrib/test_hexagon/conftest.py index b9fee1f3795b..32b22c1eb1f4 100644 --- a/tests/python/contrib/test_hexagon/conftest.py +++ b/tests/python/contrib/test_hexagon/conftest.py @@ -73,11 +73,6 @@ def android_serial_number() -> Optional[str]: return serial -@tvm.testing.fixture -def tvm_tracker_host() -> str: - return os.getenv(TVM_TRACKER_HOST, default=None) - - # NOTE on server ports: # These tests use different port numbers for the RPC server (7070 + ...). # The reason is that an RPC session cannot be gracefully closed without @@ -108,19 +103,58 @@ def is_port_in_use(port: int) -> bool: @pytest.fixture(scope="session") -def tvm_tracker_port() -> int: - port = os.getenv(TVM_TRACKER_PORT, default=None) - port = int(port) if port else get_free_port() - return port +def _tracker_info() -> (str, int): + env_tracker_host = os.getenv(TVM_TRACKER_HOST, default="") + env_tracker_port = os.getenv(TVM_TRACKER_PORT, default="") + + if env_tracker_host or env_tracker_port: + # A tracker is already running, and we should connect to it + # when running tests. + assert env_tracker_host, "TVM_TRACKER_PORT is defined, but TVM_TRACKER_HOST is not" + assert env_tracker_port, "TVM_TRACKER_HOST is defined, but TVM_TRACKER_PORT is not" + env_tracker_port = int(env_tracker_port) + + try: + tvm.rpc.connect_tracker(env_tracker_host, env_tracker_port) + except RuntimeError as exc: + message = ( + "Could not connect to external tracker " + "specified by $TVM_TRACKER_HOST and $TVM_TRACKER_PORT " + f"({env_tracker_host}:{env_tracker_port})" + ) + raise RuntimeError(message) from exc + + yield (env_tracker_host, env_tracker_port) + + else: + # No tracker is provided to the tests, so we should start one + # for the tests to use. + tracker = tvm.rpc.tracker.Tracker("127.0.0.1", get_free_port()) + try: + yield (tracker.host, tracker.port) + finally: + tracker.terminate() @pytest.fixture(scope="session") -def tvm_tracker(tvm_tracker_port): - tracker = tvm.rpc.tracker.Tracker("127.0.0.1", tvm_tracker_port) - try: - yield tracker - finally: - tracker.terminate() +def tvm_tracker_host(_tracker_info) -> str: + host, port = _tracker_info + return host + + +@pytest.fixture(scope="session") +def tvm_tracker_port(_tracker_info) -> int: + host, port = _tracker_info + return port + + +# @pytest.fixture(scope="session") +# def tvm_tracker(tvm_tracker_port): +# tracker = tvm.rpc.tracker.Tracker("127.0.0.1", tvm_tracker_port) +# try: +# yield tracker +# finally: +# tracker.terminate() @tvm.testing.fixture @@ -134,23 +168,15 @@ def adb_server_socket() -> str: @tvm.testing.fixture -def rpc_info(tvm_tracker, rpc_server_port, adb_server_socket): - return { - "rpc_tracker_host": tvm_tracker.host, - "rpc_tracker_port": tvm_tracker.port, - "rpc_server_port": rpc_server_port, - "adb_server_socket": adb_server_socket, - } - - -@tvm.testing.fixture -def hexagon_launcher(android_serial_number, tvm_tracker, rpc_server_port, adb_server_socket): +def hexagon_launcher( + android_serial_number, tvm_tracker_host, tvm_tracker_port, rpc_server_port, adb_server_socket +): if android_serial_number is None: yield None else: rpc_info = { - "rpc_tracker_host": tvm_tracker.host, - "rpc_tracker_port": tvm_tracker.port, + "rpc_tracker_host": tvm_tracker_host, + "rpc_tracker_port": tvm_tracker_port, "rpc_server_port": rpc_server_port, "adb_server_socket": adb_server_socket, } From cca8d47da010b67d23cf77a5f4f0dff96ad024bb Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 22 Mar 2022 15:54:27 -0500 Subject: [PATCH 5/8] Avoid setting up the local tracker unless required. --- tests/python/contrib/test_hexagon/conftest.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/tests/python/contrib/test_hexagon/conftest.py b/tests/python/contrib/test_hexagon/conftest.py index 32b22c1eb1f4..295aea2b0839 100644 --- a/tests/python/contrib/test_hexagon/conftest.py +++ b/tests/python/contrib/test_hexagon/conftest.py @@ -148,15 +148,6 @@ def tvm_tracker_port(_tracker_info) -> int: return port -# @pytest.fixture(scope="session") -# def tvm_tracker(tvm_tracker_port): -# tracker = tvm.rpc.tracker.Tracker("127.0.0.1", tvm_tracker_port) -# try: -# yield tracker -# finally: -# tracker.terminate() - - @tvm.testing.fixture def rpc_server_port() -> int: return get_free_port() @@ -168,12 +159,16 @@ def adb_server_socket() -> str: @tvm.testing.fixture -def hexagon_launcher( - android_serial_number, tvm_tracker_host, tvm_tracker_port, rpc_server_port, adb_server_socket -): +def hexagon_launcher(request, android_serial_number, rpc_server_port, adb_server_socket): if android_serial_number is None: yield None else: + # Requesting these fixtures sets up a local tracker, if one + # hasn't been provided to us. Delaying the evaluation of + # these fixtures avoids starting a tracker unless necessary. + tvm_tracker_host = request.getfixturevalue("tvm_tracker_host") + tvm_tracker_port = request.getfixturevalue("tvm_tracker_port") + rpc_info = { "rpc_tracker_host": tvm_tracker_host, "rpc_tracker_port": tvm_tracker_port, From 27df09cd0b879ebeb3919c4af7cd5b7d22222753 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 24 Mar 2022 09:07:08 -0500 Subject: [PATCH 6/8] Declare previous_port as global, instead of list. --- tests/python/contrib/test_hexagon/conftest.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/python/contrib/test_hexagon/conftest.py b/tests/python/contrib/test_hexagon/conftest.py index 32b22c1eb1f4..16241a9779ca 100644 --- a/tests/python/contrib/test_hexagon/conftest.py +++ b/tests/python/contrib/test_hexagon/conftest.py @@ -81,7 +81,7 @@ def android_serial_number() -> Optional[str]: listen_port_min = 2000 # Well above the privileged ports (1024 or lower) listen_port_max = 9000 # Below the search range end (port_end=9199) of RPC server -previous_port = [None] +previous_port = None def get_free_port(): @@ -90,15 +90,16 @@ def is_port_in_use(port: int) -> bool: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: return s.connect_ex(("localhost", port)) == 0 - if previous_port[0] is None: + global previous_port + if previous_port is None: port = random.randint(listen_port_min, listen_port_max) else: - port = previous_port[0] + 1 + port = previous_port + 1 while is_port_in_use(port): port = port + 1 if port < listen_port_max else listen_port_min - previous_port[0] = port + previous_port = port return port From d629df06a2c4237b123570dc29dea86641afe863 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 24 Mar 2022 09:09:01 -0500 Subject: [PATCH 7/8] Corrected type hints. --- python/tvm/contrib/hexagon/build.py | 2 +- python/tvm/contrib/hexagon/session.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/contrib/hexagon/build.py b/python/tvm/contrib/hexagon/build.py index 2b0e0b1da2bb..a40903b822ba 100644 --- a/python/tvm/contrib/hexagon/build.py +++ b/python/tvm/contrib/hexagon/build.py @@ -197,7 +197,7 @@ def start_session(self) -> Session: } return Session(self, hexagon_remote_kw) - def load_module(self, module: Union[str, pathlib.Path], session: Session): + def load_module(self, module: Union[str, pathlib.Path, tvm.runtime.Module], session: Session): """Load TVM module. Parameters diff --git a/python/tvm/contrib/hexagon/session.py b/python/tvm/contrib/hexagon/session.py index e6f2dab966ed..bd6f84066daf 100644 --- a/python/tvm/contrib/hexagon/session.py +++ b/python/tvm/contrib/hexagon/session.py @@ -91,7 +91,7 @@ def upload(self, local_path: Union[str, pathlib.Path], remote_filename: str): """ self._launcher.upload(local_path, remote_filename) - def load_module(self, module: Union[str, pathlib.Path, tvm.IRModule]): + def load_module(self, module: Union[str, pathlib.Path, tvm.runtime.Module]): """Load TVM module. Parameters From f168f9a89471b62dc4473be1aaddc45ad4af2e5c Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 24 Mar 2022 09:14:27 -0500 Subject: [PATCH 8/8] Docstring updates --- python/tvm/contrib/hexagon/session.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/tvm/contrib/hexagon/session.py b/python/tvm/contrib/hexagon/session.py index bd6f84066daf..44c4d145555c 100644 --- a/python/tvm/contrib/hexagon/session.py +++ b/python/tvm/contrib/hexagon/session.py @@ -31,11 +31,18 @@ class Session: Parameters ---------- + launcher : HexagonLauncherRPC + The launcher from which this session was started. + remote_kw : dict Remote configs for RPC tracker. session_name : str Hexagon RPC session name. + + remote_stack_size_bytes : int + The stack size of the remote device, to be passed to + tvm.contrib.hexagon.create_hexagon_session. """ def __init__(