Skip to content

Commit

Permalink
Fix vm usage.
Browse files Browse the repository at this point in the history
  • Loading branch information
zxybazh committed Jun 23, 2022
1 parent e7fd19a commit 0869ed3
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions python/tvm/meta_schedule/testing/custom_builder_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@
"""Customized builder and runner methods"""
# pylint: disable=import-outside-toplevel

from typing import TYPE_CHECKING, Callable, Dict, List
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union

if TYPE_CHECKING:
import numpy as np # type: ignore
from tvm.ir import IRModule
from tvm.meta_schedule.runner import EvaluatorConfig, RPCConfig
from tvm.runtime import Device, Module, NDArray
from tvm.target import Target
from tvm.runtime.vm import Executable


def build_relay(
Expand Down Expand Up @@ -143,10 +144,11 @@ def run_with_graph_executor(

def run_module_via_rpc(
rpc_config: "RPCConfig",
lib: "Module",
lib: Union["Module", "Executable"],
dev_type: str,
args: Dict[str, "np.ndarray"],
continuation: Callable,
use_vm: Optional[bool] = False,
):
"""Execute a tvm.runtime.Module on RPC remote"""
# pylint: disable=import-outside-toplevel
Expand All @@ -160,13 +162,15 @@ def run_module_via_rpc(

with tempfile.TemporaryDirectory() as tmp_dir:
filename = os.path.join(tmp_dir, "tvm_tmp_mod." + tar.output_format)
if use_vm:
code, lib = lib.save()
lib.export_library(filename, tar)
session = rpc_config.connect_server()
session.upload(filename)
_, filename = os.path.split(filename)
rt_mod = session.load_module(filename)
if use_vm:
rt_mod = session.get_function("runtime.Load_Executable")(code, rt_mod)
dev = session.device(dev_type=dev_type, dev_id=0)
nd_args = {}
for arg_key, arg_value in args.items():
nd_args[arg_key] = ndarray.array(arg_value, dev)
nd_args = {k: ndarray.array(v, dev) for k, v in args.items()}
return continuation(rt_mod, dev, nd_args)

0 comments on commit 0869ed3

Please sign in to comment.