diff --git a/python/tvm/auto_scheduler/testing/tune_onnx.py b/python/tvm/auto_scheduler/testing/tune_onnx.py index 7e955601dbd5b..260d5814cd6aa 100644 --- a/python/tvm/auto_scheduler/testing/tune_onnx.py +++ b/python/tvm/auto_scheduler/testing/tune_onnx.py @@ -20,14 +20,13 @@ import os from distutils.util import strtobool -import numpy as np # type: ignore import onnx # type: ignore import tvm from tvm import auto_scheduler from tvm import meta_schedule as ms from tvm import relay from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc -from tvm.meta_schedule.testing.utils import generate_input_data, f_timer, f_timer_vm, f_per_layer +from tvm.meta_schedule.testing.utils import generate_input_data, f_timer, f_per_layer from tvm.meta_schedule.utils import cpu_count from tvm.relay.frontend import from_onnx from tvm.support import describe @@ -103,10 +102,10 @@ def _parse_args(): help="example: `True / False", ) args.add_argument( - "--use-vm", - type=lambda x: bool(strtobool(x)), + "--backend", + type=str, required=True, - help="example: `True / False", + help="example: graph / vm", ) parsed = args.parse_args() parsed.target = tvm.target.Target(parsed.target) @@ -197,19 +196,21 @@ def main(): opt_level=3, config={"relay.backend.use_auto_scheduler": True}, ): - if ARGS.use_vm: + if ARGS.backend == "vm": lib = relay.vm.compile( mod, target=ARGS.target, params=params, ) - else: + elif ARGS.backend == "graph": lib = relay.build( mod, target=ARGS.target, params=params, ) - if not ARGS.use_vm: + else: + raise ValueError(f"Backend {ARGS.backend} not supported!") + if ARGS.backend == "graph": graph, rt_mod, params = lib.graph_json, lib.lib, lib.params run_module_via_rpc( @@ -220,22 +221,13 @@ def main(): continuation=f_per_layer(graph), ) - run_module_via_rpc( - rpc_config=ARGS.rpc_config, - lib=lib, - dev_type=ARGS.target.kind.name, - args=input_data, - continuation=f_timer, - ) - else: - run_module_via_rpc( - rpc_config=ARGS.rpc_config, - lib=lib, - dev_type=ARGS.target.kind.name, - args=input_data, - continuation=f_timer_vm, - use_vm=True, - ) + run_module_via_rpc( + rpc_config=ARGS.rpc_config, + lib=lib, + dev_type=ARGS.target.kind.name, + args=input_data, + continuation=f_timer(ARGS.backend), + ) if __name__ == "__main__": diff --git a/python/tvm/auto_scheduler/testing/tune_relay.py b/python/tvm/auto_scheduler/testing/tune_relay.py index f7ad2238c5295..4c9eaa985c4ae 100644 --- a/python/tvm/auto_scheduler/testing/tune_relay.py +++ b/python/tvm/auto_scheduler/testing/tune_relay.py @@ -20,14 +20,13 @@ import os from distutils.util import strtobool -import numpy as np # type: ignore import tvm from tvm import auto_scheduler from tvm import meta_schedule as ms from tvm import relay from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc from tvm.meta_schedule.testing.relay_workload import get_network -from tvm.meta_schedule.testing.utils import generate_input_data, f_timer, f_timer_vm, f_per_layer +from tvm.meta_schedule.testing.utils import generate_input_data, f_timer, f_per_layer from tvm.meta_schedule.utils import cpu_count from tvm.support import describe @@ -101,10 +100,10 @@ def _parse_args(): help="example: `True / False", ) args.add_argument( - "--use-vm", - type=lambda x: bool(strtobool(x)), + "--backend", + type=str, required=True, - help="example: `True / False", + help="example: graph / vm", ) parsed = args.parse_args() parsed.target = tvm.target.Target(parsed.target) @@ -197,19 +196,22 @@ def main(): opt_level=3, config={"relay.backend.use_auto_scheduler": True}, ): - if ARGS.use_vm: + if ARGS.backend == "vm": lib = relay.vm.compile( mod, target=ARGS.target, params=params, ) - else: + elif ARGS.backend == "graph": lib = relay.build( mod, target=ARGS.target, params=params, ) - if not ARGS.use_vm: + else: + raise ValueError(f"Backend {ARGS.backend} not supported!") + + if ARGS.backend == "graph": graph, rt_mod, params = lib.graph_json, lib.lib, lib.params run_module_via_rpc( @@ -220,22 +222,13 @@ def main(): continuation=f_per_layer(graph), ) - run_module_via_rpc( - rpc_config=ARGS.rpc_config, - lib=lib, - dev_type=ARGS.target.kind.name, - args=input_data, - continuation=f_timer, - ) - else: - run_module_via_rpc( - rpc_config=ARGS.rpc_config, - lib=lib, - dev_type=ARGS.target.kind.name, - args=input_data, - continuation=f_timer_vm, - use_vm=True, - ) + run_module_via_rpc( + rpc_config=ARGS.rpc_config, + lib=lib, + dev_type=ARGS.target.kind.name, + args=input_data, + continuation=f_timer(ARGS.backend), + ) if __name__ == "__main__": diff --git a/python/tvm/meta_schedule/cost_model/cost_model.py b/python/tvm/meta_schedule/cost_model/cost_model.py index e479cb725428a..2fdb9b93494f9 100644 --- a/python/tvm/meta_schedule/cost_model/cost_model.py +++ b/python/tvm/meta_schedule/cost_model/cost_model.py @@ -73,7 +73,7 @@ def update( _ffi_api.CostModelUpdate(self, context, candidates, results) # type: ignore # pylint: disable=no-member def predict(self, context: TuneContext, candidates: List[MeasureCandidate]) -> np.ndarray: - """Update the cost model given running results. + """Predict normalized score with the cost model. Parameters ---------- diff --git a/python/tvm/meta_schedule/testing/custom_builder_runner.py b/python/tvm/meta_schedule/testing/custom_builder_runner.py index 454689fbd2838..e203848c2cbb8 100644 --- a/python/tvm/meta_schedule/testing/custom_builder_runner.py +++ b/python/tvm/meta_schedule/testing/custom_builder_runner.py @@ -148,7 +148,7 @@ def run_module_via_rpc( dev_type: str, args: Dict[str, "np.ndarray"], continuation: Callable, - use_vm: Optional[bool] = False, + backend: Optional[str] = "graph", ): """Execute a tvm.runtime.Module on RPC remote""" # pylint: disable=import-outside-toplevel @@ -162,14 +162,14 @@ def run_module_via_rpc( with tempfile.TemporaryDirectory() as tmp_dir: filename = os.path.join(tmp_dir, "tvm_tmp_mod." + tar.output_format) - if use_vm: + if backend == "vm": code, lib = lib.save() lib.export_library(filename, tar) session = rpc_config.connect_server() session.upload(filename) _, filename = os.path.split(filename) rt_mod = session.load_module(filename) - if use_vm: + if backend == "vm": rt_mod = session.get_function("runtime.Load_Executable")(code, rt_mod) dev = session.device(dev_type=dev_type, dev_id=0) nd_args = {k: ndarray.array(v, dev) for k, v in args.items()} diff --git a/python/tvm/meta_schedule/testing/tune_onnx.py b/python/tvm/meta_schedule/testing/tune_onnx.py index 0c52010184fbd..b3737b6f99be9 100644 --- a/python/tvm/meta_schedule/testing/tune_onnx.py +++ b/python/tvm/meta_schedule/testing/tune_onnx.py @@ -20,14 +20,13 @@ import logging from distutils.util import strtobool -import numpy as np # type: ignore import onnx # type: ignore import tvm from tvm import meta_schedule as ms from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc from tvm.relay.frontend import from_onnx from tvm.support import describe -from .utils import generate_input_data, f_timer, f_timer_vm, f_per_layer +from .utils import generate_input_data, f_timer, f_per_layer def _parse_args(): @@ -100,10 +99,10 @@ def _parse_args(): help="example: `True / False", ) args.add_argument( - "--use-vm", - type=lambda x: bool(strtobool(x)), + "--backend", + type=str, required=True, - help="example: `True / False", + help="example: graph / vm", ) parsed = args.parse_args() parsed.target = tvm.target.Target(parsed.target) @@ -161,12 +160,12 @@ def main(): runner=runner, # type: ignore work_dir=ARGS.work_dir, params=params, - use_vm=ARGS.use_vm, + backend=ARGS.backend, ) print("Tuning Time:") print(profiler.table()) - if not ARGS.use_vm: + if ARGS.backend == "graph": graph, rt_mod, params = lib.graph_json, lib.lib, lib.params run_module_via_rpc( @@ -177,22 +176,13 @@ def main(): continuation=f_per_layer(graph), ) - run_module_via_rpc( - rpc_config=ARGS.rpc_config, - lib=lib, - dev_type=ARGS.target.kind.name, - args=input_data, - continuation=f_timer, - ) - else: - run_module_via_rpc( - rpc_config=ARGS.rpc_config, - lib=lib, - dev_type=ARGS.target.kind.name, - args=input_data, - continuation=f_timer_vm, - use_vm=True, - ) + run_module_via_rpc( + rpc_config=ARGS.rpc_config, + lib=lib, + dev_type=ARGS.target.kind.name, + args=input_data, + continuation=f_timer(ARGS.backend), + ) if __name__ == "__main__": diff --git a/python/tvm/meta_schedule/testing/tune_relay.py b/python/tvm/meta_schedule/testing/tune_relay.py index c295b07a60e08..3afeb80f71b06 100644 --- a/python/tvm/meta_schedule/testing/tune_relay.py +++ b/python/tvm/meta_schedule/testing/tune_relay.py @@ -20,12 +20,11 @@ import logging from distutils.util import strtobool -import numpy as np # type: ignore import tvm from tvm import meta_schedule as ms from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc from tvm.meta_schedule.testing.relay_workload import get_network -from tvm.meta_schedule.testing.utils import generate_input_data, f_timer, f_timer_vm, f_per_layer +from tvm.meta_schedule.testing.utils import generate_input_data, f_timer, f_per_layer from tvm.support import describe @@ -98,10 +97,11 @@ def _parse_args(): help="example: `True / False", ) args.add_argument( - "--use-vm", - type=lambda x: bool(strtobool(x)), + "--backend", + type=str, required=True, - help="example: `True / False", + choices=["graph", "vm"], + help="example: graph / vm", ) parsed = args.parse_args() parsed.target = tvm.target.Target(parsed.target) @@ -161,13 +161,13 @@ def main(): runner=runner, # type: ignore work_dir=ARGS.work_dir, params=params, - use_vm=ARGS.use_vm, + backend=ARGS.backend, ) print("Tuning Time:") print(profiler.table()) graph, rt_mod, params = lib.graph_json, lib.lib, lib.params - if not ARGS.use_vm: + if ARGS.backend == "graph": graph, rt_mod, params = lib.graph_json, lib.lib, lib.params run_module_via_rpc( @@ -178,22 +178,13 @@ def main(): continuation=f_per_layer(graph), ) - run_module_via_rpc( - rpc_config=ARGS.rpc_config, - lib=lib, - dev_type=ARGS.target.kind.name, - args=input_data, - continuation=f_timer, - ) - else: - run_module_via_rpc( - rpc_config=ARGS.rpc_config, - lib=lib, - dev_type=ARGS.target.kind.name, - args=input_data, - continuation=f_timer_vm, - use_vm=True, - ) + run_module_via_rpc( + rpc_config=ARGS.rpc_config, + lib=lib, + dev_type=ARGS.target.kind.name, + args=input_data, + continuation=f_timer(ARGS.backend), + ) if __name__ == "__main__": diff --git a/python/tvm/meta_schedule/testing/utils.py b/python/tvm/meta_schedule/testing/utils.py index aeb530dd38386..a3986b1168ba9 100644 --- a/python/tvm/meta_schedule/testing/utils.py +++ b/python/tvm/meta_schedule/testing/utils.py @@ -17,8 +17,8 @@ """Testing utility functions in meta schedule""" from statistics import median from typing import Callable, Dict, Optional, Union, List -import numpy as np # type: ignore import json +import numpy as np # type: ignore import tvm from tvm.ir import IRModule @@ -100,71 +100,74 @@ def generate_input_data(input_shape: List[int], input_dtype: str) -> np.ndarray: """ if input_dtype.startswith("float"): return np.random.uniform(size=input_shape).astype(input_dtype) - elif input_dtype in ["uint8", "int8"]: - return np.random.randint(low=0, high=127, size=input_shape, dtype="int32") - elif input_dtype in ["int32", "int64"]: + if input_dtype in ["uint8", "int8"]: + return np.random.randint( + low=0, + high=127, + size=input_shape, + dtype="int32", # TODO(zxybazh): fix the datatype when int8 / uint8 is supported better + ) + if input_dtype in ["int32", "int64"]: return np.random.randint(low=0, high=10000, size=input_shape, dtype=input_dtype) - else: - raise ValueError("Unsupported input datatype!") + raise ValueError("Unsupported input datatype!") -def f_timer(rt_mod: tvm.runtime.Module, dev: tvm.device, input_data: Dict[str, NDArray]) -> None: - """Run and benchmark the given runtime module, print out the result. +def f_timer(backend: str) -> Callable: + """Create a function to run and benchmark the performance of whole given runtime module, + or Executable in relay vm. Parameters ---------- - rt_mod : tvm.runtime.Module - The runtime module. - dev : tvm.device - The device type to run workload. - input_data : Dict[str, np.ndarray] - The input data as a dictionary. - """ - from tvm.contrib.graph_executor import GraphModule # pylint:disable=import-outside-toplevel - - mod = GraphModule(rt_mod["default"](dev)) - for input_name, input_value in input_data.items(): - mod.set_input(input_name, input_value) - ftimer = mod.module.time_evaluator( - "run", - dev, - min_repeat_ms=500, - repeat=3, - ) - results = list(np.array(ftimer().results) * 1000.0) # type: ignore - print("Running time in time_evaluator: ", results) - print("-------------------------------") - print(f"Min : {min(results)}") - print(f"Max : {max(results)}") - print(f"Median : {median(results)}") + backend : str + The backend to use, graph / vm. + Returns + ------- + func : Callable + The function to benchmark the workload. + """ -def f_timer_vm( - rt_mod: tvm.runtime.vm.Executable, dev: tvm.device, input_data: Dict[str, NDArray] -) -> None: - """Run and benchmark the given runtime module, print out the result. + def func( + rt_mod: Union[tvm.runtime.Module, tvm.runtime.vm.Executable], + dev: tvm.device, + input_data: Dict[str, NDArray], + ) -> None: + """Run and benchmark the given runtime module, print out the result. - Parameters - ---------- - rt_mod : tvm.runtime.vm.Executable - The runtime module. - dev : tvm.device - The device type to run workload. - input_data : Dict[str, np.ndarray] - The input data as a dictionary. - """ - from tvm.runtime.vm import VirtualMachine # pylint:disable=import-outside-toplevel + Parameters + ---------- + rt_mod : Union[tvm.runtime.Module, tvm.runtime.vm.Executable] + The runtime module or vm executable. + dev : tvm.device + The device type to run workload. + input_data : Dict[str, np.ndarray] + The input data as a dictionary. + """ + from tvm.contrib.graph_executor import GraphModule # pylint:disable=import-outside-toplevel + from tvm.runtime.vm import VirtualMachine # pylint:disable=import-outside-toplevel + + if backend == "vm": + vm = VirtualMachine(rt_mod, dev) # pylint: disable=invalid-name + results = vm.benchmark( + dev, min_repeat_ms=500, repeat=5, number=1, end_to_end=False, **input_data + ).results + elif backend == "graph": + mod = GraphModule(rt_mod["default"](dev)) + for input_name, input_value in input_data.items(): + mod.set_input(input_name, input_value) + ftimer = mod.module.time_evaluator("run", dev, min_repeat_ms=500, repeat=5, number=1) + else: + raise ValueError(f"Backend {backend} not supported in f_timer!") + + results = list(np.array(ftimer().results) * 1000.0) # type: ignore + + print("Running time in time_evaluator: ", results) + print("-------------------------------") + print(f"Min : {min(results)}") + print(f"Max : {max(results)}") + print(f"Median : {median(results)}") - vm = VirtualMachine(rt_mod, dev) # pylint: disable=invalid-name - results = vm.benchmark( - dev, min_repeat_ms=500, repeat=3, number=3, end_to_end=True, **input_data - ).results - results = list(np.array(results) * 1000.0) # type: ignore - print("Running time in time_evaluator: ", results) - print("-------------------------------") - print(f"Min : {min(results)}") - print(f"Max : {max(results)}") - print(f"Median : {median(results)}") + return func def f_per_layer(graph: str) -> Callable: @@ -182,7 +185,11 @@ def f_per_layer(graph: str) -> Callable: The function using the json format graph. """ - def func(rt_mod: tvm.runtime.Module, dev: tvm.device, input_data: Dict[str, NDArray]) -> None: + def func( + rt_mod: tvm.runtime.Module, + dev: tvm.device, + input_data: Dict[str, NDArray], + ) -> None: """Run and benchmark the per-layer performance of given runtime module, print out the result. @@ -197,7 +204,6 @@ def func(rt_mod: tvm.runtime.Module, dev: tvm.device, input_data: Dict[str, NDAr """ # pylint:disable=import-outside-toplevel from tvm.contrib.debugger.debug_executor import create - from tabulate import tabulate # pylint:enable=import-outside-toplevel @@ -212,7 +218,7 @@ def func(rt_mod: tvm.runtime.Module, dev: tvm.device, input_data: Dict[str, NDAr print("|graph_nodes| = ", len(graph_nodes)) print("|graph_time| = ", len(graph_time)) - graph_nodes_time = [(k, float(v) * 1e6) for k, v in zip(graph_nodes, graph_time)] - print(tabulate(graph_nodes_time, headers=["Layer", "Time(us)"])) + for k, v in zip(graph_nodes, graph_time): + print(k, float(v) * 1e6, "us") return func diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py index b9dc2b7e94bb2..437a4d627c40d 100644 --- a/python/tvm/meta_schedule/tune.py +++ b/python/tvm/meta_schedule/tune.py @@ -525,7 +525,7 @@ def tune_relay( postprocs: Optional[FnPostproc] = None, mutator_probs: Optional[FnMutatorProb] = None, num_threads: Optional[int] = None, - use_vm: Optional[bool] = False, + backend: Optional[str] = "graph", ) -> Module: """Tune a TIR IRModule with a given target. @@ -589,7 +589,9 @@ def tune_relay( opt_level=3, config={"relay.backend.use_meta_schedule": True}, ): - if not use_vm: + if backend == "graph": return relay_build(mod, target=target, params=params) - else: + elif backend == "vm": return relay.vm.compile(mod, target=target, params=params) + else: + raise ValueError(f"Backend {backend} not supported in ApplyHistoryBest!")