Skip to content

Commit

Permalink
[Hexagon] Improved ergonomics of HexagonLauncher in unit tests. (#10581)
Browse files Browse the repository at this point in the history
* [Hexagon] Improved ergonomics of HexagonLauncher in unit tests.

The goal of this commit is to reduce/eliminate common code required
through unit tests that interact with Hexagon hardware.

- New testing fixtures in `tests/python/contrib/test_hexagon`.  A test
  running on hexagon hardware should only need to use the
  `hexagon_session` fixture.

  - `rpc_server_port`: Iterates through port numbers, selecting an
    unused port for each unit test.  Avoids needing to explicitly
    specify unique ports for each unit test.

  - `tvm_tracker`: Starts a tracker on use, exits after test.  Avoids
    needing to manually start a tracker prior to running the unit
    test.

  - `hexagon_launcher`: Starts a `HexagonLauncher` server on use,
    stops server after test.  Avoids needing to call `start_server()`
    and `stop_server()` in each test.

  - `hexagon_session`: Starts a hexagon session using
    `hexagon_laucnehr.start_session()`, exits after test.

- Added `Session.upload` function, which delegates to
  `HexagonLauncher.upload`.  Avoids needing to interact with both the
  launcher and the session.

- Allowed `tvm.IRModule` as argument passed to `Session.load_module`,
  which will automatically save/upload the module, then load it.
  Avoids needing to handle save/upload of temporary files in each unit
  test.

* Added default port for tracker if not already set.

* Pass through None from hexagon_launcher to hexagon_session.

* Updated launcher to use external tracker if specified.

* Avoid setting up the local tracker unless required.

* Declare previous_port as global, instead of list.

* Corrected type hints.

* Docstring updates
  • Loading branch information
Lunderberg authored Mar 25, 2022
1 parent 14084f4 commit f16286e
Show file tree
Hide file tree
Showing 5 changed files with 290 additions and 223 deletions.
25 changes: 17 additions & 8 deletions python/tvm/contrib/hexagon/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,28 +195,37 @@ def start_session(self) -> Session:
"timeout": 0,
"key": self.HEXAGON_REMOTE_DEVICE_KEY,
}
return Session(hexagon_remote_kw)
return Session(self, hexagon_remote_kw)

def load_module(self, module_name: Union[str, pathlib.Path], session: Session):
def load_module(self, module: Union[str, pathlib.Path, tvm.runtime.Module], session: Session):
"""Load TVM module.
Parameters
----------
module_name : str or pathlib.Path
Name of the module to load. It must be either a bare file name
(without any path components), or a full path in the remote
system. If it is a file name, the file must be placed in the
remote workspace.
module : Union[str, pathlib.Path, tvm.runtime.Module]
The module to load. If `module` is a
`tvm.runtime.Module`, it will be uploaded to the remote
session and loaded.
If the object passed is a string or pathlib.Path, it must
be either a bare file name (without any path components),
or a full path in the remote system. If it is a file name,
the file must already have been uploaded to the remote,
and be placed in the remote workspace.
session : Session
Remote session. The session must be established (via __enter__)
prior to calling this function.
Returns
-------
TVMModule :
TVM module object.
"""
return session.load_module(module_name)
return session.load_module(module)

def get_graph_executor(
self, graph_json: str, module_name: Union[str, pathlib.Path], session: Session
Expand Down
65 changes: 62 additions & 3 deletions python/tvm/contrib/hexagon/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@

import os
import pathlib
import tempfile
from typing import Union

import tvm
from tvm import rpc as _rpc


Expand All @@ -28,19 +31,28 @@ class Session:
Parameters
----------
launcher : HexagonLauncherRPC
The launcher from which this session was started.
remote_kw : dict
Remote configs for RPC tracker.
session_name : str
Hexagon RPC session name.
remote_stack_size_bytes : int
The stack size of the remote device, to be passed to
tvm.contrib.hexagon.create_hexagon_session.
"""

def __init__(
self,
launcher: "HexagonLauncherRPC",
remote_kw: dict,
session_name: str = "hexagon-rpc",
remote_stack_size_bytes: int = 128 * 1024,
):
self._launcher = launcher
self._session_name = session_name
self._remote_stack_size_bytes = remote_stack_size_bytes
self._remote_kw = remote_kw
Expand Down Expand Up @@ -74,6 +86,53 @@ def __enter__(self):
def __exit__(self, exc_type, exc_value, exc_traceback):
pass

def load_module(self, path: Union[str, pathlib.Path]):
assert isinstance(path, (str, pathlib.Path)), "Invalid path type:" + str(type(path))
return self._rpc.get_function("tvm.hexagon.load_module")(str(path))
def upload(self, local_path: Union[str, pathlib.Path], remote_filename: str):
"""Upload a local file to the remote workspace.
Parameters
----------
local_path : str or pathlib.Path
Path to the local file to be copied.
remote_filename : str
Name of the file in the remote workspace.
"""
self._launcher.upload(local_path, remote_filename)

def load_module(self, module: Union[str, pathlib.Path, tvm.runtime.Module]):
"""Load TVM module.
Parameters
----------
module : Union[str, pathlib.Path, tvm.runtime.Module]
The module to load. If `module` is a
`tvm.runtime.Module`, it will be uploaded to the remote
session and loaded.
If the object passed is a string or pathlib.Path, it must
be either a bare file name (without any path components),
or a full path in the remote system. If it is a file name,
the file must already have been uploaded to the remote,
and be placed in the remote workspace.
session : Session
Remote session. The session must be established (via __enter__)
prior to calling this function.
Returns
-------
TVMModule :
TVM module object.
"""
if isinstance(module, tvm.runtime.Module):
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir = pathlib.Path(temp_dir)
binary_name = "test_binary.so"
binary_path = temp_dir / binary_name
module.save(str(binary_path))
self.upload(binary_path, binary_name)
module = binary_name

assert isinstance(module, (str, pathlib.Path)), "Invalid path type:" + str(type(module))
return self._rpc.get_function("tvm.hexagon.load_module")(str(module))
133 changes: 123 additions & 10 deletions tests/python/contrib/test_hexagon/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,15 @@
values from testing parameters """

import os
import random
import socket
from typing import Optional

import pytest

import tvm
from tvm import rpc
import tvm.rpc.tracker
from tvm.contrib.hexagon.build import HexagonLauncher

HEXAGON_TOOLCHAIN = "HEXAGON_TOOLCHAIN"
TVM_TRACKER_HOST = "TVM_TRACKER_HOST"
Expand Down Expand Up @@ -59,27 +64,135 @@ def requires_hexagon_toolchain(*args):


@tvm.testing.fixture
def android_serial_number() -> str:
return os.getenv(ANDROID_SERIAL_NUMBER, default=None)
def android_serial_number() -> Optional[str]:
serial = os.getenv(ANDROID_SERIAL_NUMBER, default="")
# Setting ANDROID_SERIAL_NUMBER to an empty string should be
# equivalent to having it unset.
if not serial.strip():
serial = None
return serial


# NOTE on server ports:
# These tests use different port numbers for the RPC server (7070 + ...).
# The reason is that an RPC session cannot be gracefully closed without
# triggering TIME_WAIT state on the server socket. This prevents another
# server to bind to the same port until the wait time elapses.

listen_port_min = 2000 # Well above the privileged ports (1024 or lower)
listen_port_max = 9000 # Below the search range end (port_end=9199) of RPC server
previous_port = None


def get_free_port():
# https://stackoverflow.com/a/52872579/2689797
def is_port_in_use(port: int) -> bool:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(("localhost", port)) == 0

global previous_port
if previous_port is None:
port = random.randint(listen_port_min, listen_port_max)
else:
port = previous_port + 1

while is_port_in_use(port):
port = port + 1 if port < listen_port_max else listen_port_min

previous_port = port
return port


@tvm.testing.fixture
def tvm_tracker_host() -> str:
return os.getenv(TVM_TRACKER_HOST, default=None)
@pytest.fixture(scope="session")
def _tracker_info() -> (str, int):
env_tracker_host = os.getenv(TVM_TRACKER_HOST, default="")
env_tracker_port = os.getenv(TVM_TRACKER_PORT, default="")

if env_tracker_host or env_tracker_port:
# A tracker is already running, and we should connect to it
# when running tests.
assert env_tracker_host, "TVM_TRACKER_PORT is defined, but TVM_TRACKER_HOST is not"
assert env_tracker_port, "TVM_TRACKER_HOST is defined, but TVM_TRACKER_PORT is not"
env_tracker_port = int(env_tracker_port)

try:
tvm.rpc.connect_tracker(env_tracker_host, env_tracker_port)
except RuntimeError as exc:
message = (
"Could not connect to external tracker "
"specified by $TVM_TRACKER_HOST and $TVM_TRACKER_PORT "
f"({env_tracker_host}:{env_tracker_port})"
)
raise RuntimeError(message) from exc

yield (env_tracker_host, env_tracker_port)

else:
# No tracker is provided to the tests, so we should start one
# for the tests to use.
tracker = tvm.rpc.tracker.Tracker("127.0.0.1", get_free_port())
try:
yield (tracker.host, tracker.port)
finally:
tracker.terminate()


@pytest.fixture(scope="session")
def tvm_tracker_host(_tracker_info) -> str:
host, port = _tracker_info
return host


@pytest.fixture(scope="session")
def tvm_tracker_port(_tracker_info) -> int:
host, port = _tracker_info
return port


@tvm.testing.fixture
def tvm_tracker_port() -> int:
port = os.getenv(TVM_TRACKER_PORT, default=None)
port = int(port) if port else None
return port
def rpc_server_port() -> int:
return get_free_port()


@tvm.testing.fixture
def adb_server_socket() -> str:
return os.getenv(ADB_SERVER_SOCKET, default="tcp:5037")


@tvm.testing.fixture
def hexagon_launcher(request, android_serial_number, rpc_server_port, adb_server_socket):
if android_serial_number is None:
yield None
else:
# Requesting these fixtures sets up a local tracker, if one
# hasn't been provided to us. Delaying the evaluation of
# these fixtures avoids starting a tracker unless necessary.
tvm_tracker_host = request.getfixturevalue("tvm_tracker_host")
tvm_tracker_port = request.getfixturevalue("tvm_tracker_port")

rpc_info = {
"rpc_tracker_host": tvm_tracker_host,
"rpc_tracker_port": tvm_tracker_port,
"rpc_server_port": rpc_server_port,
"adb_server_socket": adb_server_socket,
}
launcher = HexagonLauncher(serial_number=android_serial_number, rpc_info=rpc_info)
launcher.start_server()
try:
yield launcher
finally:
launcher.stop_server()


@tvm.testing.fixture
def hexagon_session(hexagon_launcher):
if hexagon_launcher is None:
yield None
else:
with hexagon_launcher.start_session() as session:
yield session


# If the execution aborts while an RPC server is running, the python
# code that is supposed to shut it dowm will never execute. This will
# keep pytest from terminating (indefinitely), so add a cleanup
Expand Down
47 changes: 16 additions & 31 deletions tests/python/contrib/test_hexagon/test_cache_read_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,7 @@ def intrin_func(ins, outs):


@requires_hexagon_toolchain
def test_cache_read_write(
android_serial_number, tvm_tracker_host, tvm_tracker_port, adb_server_socket
):
def test_cache_read_write(hexagon_session):
size = 128
outer_shape = (size,)
factor = 16
Expand Down Expand Up @@ -105,37 +103,24 @@ def test_cache_read_write(
func = tvm.build(
s, [x, y, z], tvm.target.Target(target_hexagon, host=target_hexagon), name="dmacpy"
)
temp = utils.tempdir()
dso_binary = "test_binary.so"
dso_binary_path = temp.relpath(dso_binary)
func.save(dso_binary_path)

if not android_serial_number:
if hexagon_session is None:
pytest.skip("Skip hardware test since ANDROID_SERIAL_NUMBER is not set.")

rpc_info = {
"rpc_tracker_host": tvm_tracker_host,
"rpc_tracker_port": tvm_tracker_port,
"rpc_server_port": 7070,
"adb_server_socket": adb_server_socket,
}
launcher = HexagonLauncher(serial_number=android_serial_number, rpc_info=rpc_info)
launcher.upload(dso_binary_path, dso_binary)
launcher.start_server()

with launcher.start_session() as sess:
mod = launcher.load_module(dso_binary, sess)
xt = tvm.nd.array(
np.random.randint(-128, high=127, size=size, dtype=x.dtype), device=sess.device
)
yt = tvm.nd.array(
np.random.randint(-128, high=127, size=size, dtype=x.dtype), device=sess.device
)
zt = tvm.nd.array(
np.random.randint(-128, high=127, size=size, dtype=x.dtype), device=sess.device
)
mod["dmacpy"](xt, yt, zt)
launcher.stop_server()
mod = hexagon_session.load_module(func)
xt = tvm.nd.array(
np.random.randint(low=-128, high=127, size=size, dtype=x.dtype),
device=hexagon_session.device,
)
yt = tvm.nd.array(
np.random.randint(low=-128, high=127, size=size, dtype=y.dtype),
device=hexagon_session.device,
)
zt = tvm.nd.array(
np.random.randint(low=-128, high=127, size=size, dtype=z.dtype),
device=hexagon_session.device,
)
mod["dmacpy"](xt, yt, zt)

ref = xt.numpy() + yt.numpy()
np.testing.assert_equal(zt.numpy(), ref)
Loading

0 comments on commit f16286e

Please sign in to comment.