diff --git a/python/tvm/driver/tvmc/compiler.py b/python/tvm/driver/tvmc/compiler.py index 66c7b358471f..8f24dd4d7536 100644 --- a/python/tvm/driver/tvmc/compiler.py +++ b/python/tvm/driver/tvmc/compiler.py @@ -201,6 +201,7 @@ def compile_model( disabled_pass: Optional[str] = None, pass_context_configs: Optional[List[str]] = None, additional_target_options: Optional[Dict[str, Dict[str, Any]]] = None, + use_vm: bool = False, ): """Compile a model from a supported framework into a TVM module. @@ -248,7 +249,8 @@ def compile_model( PassContext. additional_target_options: Optional[Dict[str, Dict[str, Any]]] Additional target options in a dictionary to combine with initial Target arguments - + use_vm: bool + Whether to use the VM to compile the model as opposed to the graph executor Returns ------- @@ -291,8 +293,13 @@ def compile_model( opt_level=opt_level, config=config, disabled_pass=disabled_pass ): logger.debug("building relay graph with autoscheduler") - graph_module = relay.build( - mod, target=tvm_target, executor=executor, runtime=runtime, params=params + graph_module = build( + mod, + tvm_target=tvm_target, + executor=executor, + runtime=runtime, + params=params, + use_vm=use_vm, ) else: with autotvm.apply_history_best(tuning_records): @@ -300,16 +307,26 @@ def compile_model( opt_level=opt_level, config=config, disabled_pass=disabled_pass ): logger.debug("building relay graph with tuning records") - graph_module = relay.build( - mod, target=tvm_target, executor=executor, runtime=runtime, params=params + graph_module = build( + mod, + tvm_target=tvm_target, + executor=executor, + runtime=runtime, + params=params, + use_vm=use_vm, ) else: with tvm.transform.PassContext( opt_level=opt_level, config=config, disabled_pass=disabled_pass ): logger.debug("building relay graph (no tuning records provided)") - graph_module = relay.build( - mod, target=tvm_target, executor=executor, runtime=runtime, params=params + graph_module = build( + mod, + tvm_target=tvm_target, + executor=executor, + runtime=runtime, + params=params, + use_vm=use_vm, ) # Generate output dump files with sources @@ -319,7 +336,10 @@ def compile_model( dump_code = [dump_code] dumps = {} for source_type in dump_code: - lib = graph_module.get_lib() + if use_vm: + lib = graph_module.lib + else: + lib = graph_module.get_lib() # TODO lib.get_source call have inconsistent behavior for unsupported # formats (@leandron). source = str(mod) if source_type == "relay" else lib.get_source(source_type) @@ -327,11 +347,7 @@ def compile_model( # Create a new tvmc model package object from the graph definition. package_path = tvmc_model.export_package( - graph_module, - package_path, - cross, - cross_options, - output_format, + graph_module, package_path, cross, cross_options, output_format ) # Write dumps to file. @@ -341,6 +357,41 @@ def compile_model( return TVMCPackage(package_path) +def build( + mod: tvm.IRModule, + tvm_target: str, + executor: Executor, + runtime: Runtime, + params: Dict[str, tvm.nd.NDArray], + use_vm: bool, +): + """ + Builds the model with the provided executor. + + Parameters + ---------- + mod : tvm.IRModule + The relay module corresponding to this model. + tvm_target : str + The target for which to compile. Can be a plain string or + a path. + executor : Executor + The graph executor to build the model if use_vm is not True + runtime : Runtime + The runtime configuration. + params : dict + A parameter dictionary for the model. + use_vm: bool + Whether to use the VM to compile the model as opposed to the graph executor + + """ + if use_vm: + logger.debug("building with vm compile") + return relay.vm.compile(mod, target=tvm_target, params=params) + logger.debug("building with relay build") + return relay.build(mod, target=tvm_target, executor=executor, runtime=runtime, params=params) + + def save_dumps(module_name: str, dumps: Dict[str, str], dump_root: str = "."): """ Serialize dump files to the disk. diff --git a/python/tvm/driver/tvmc/model.py b/python/tvm/driver/tvmc/model.py index 9a2617f3ed53..93ca27c60947 100644 --- a/python/tvm/driver/tvmc/model.py +++ b/python/tvm/driver/tvmc/model.py @@ -57,6 +57,8 @@ from tvm.driver.tvmc import TVMCException from tvm.relay.backend.executor_factory import GraphExecutorFactoryModule from tvm.runtime.module import BenchmarkResult +from tvm.runtime.vm import Executable + try: from tvm.micro import export_model_library_format @@ -182,6 +184,42 @@ def default_package_path(self): """ return self._tmp_dir.relpath("model_package.tar") + def export_vm_format( + self, + vm_exec: Executable, + package_path: Optional[str] = None, + lib_format: str = "so", + ): + """Save this TVMCModel compiled via vm to file. + Parameters + ---------- + vm_exec : vm.Executable + The VM Executable containing compiled the compiled artifacts needed to run this model. + package_path : str, None + Where the model should be saved. Note that it will be packaged as a .tar file. + If not provided, the package will be saved to a generically named file in tmp. + lib_format : str + How to export the modules function library. Must be one of "so" or "tar". + + Returns + ------- + package_path : str + The path that the package was saved to. + """ + lib_name = "lib." + lib_format + temp = self._tmp_dir + if package_path is None: + package_path = self.default_package_path() + + path_lib = temp.relpath(lib_name) + vm_exec.mod.export_library(path_lib) + self.lib_path = path_lib + # Package up all the temp files into a tar file. + with tarfile.open(package_path, "w") as tar: + tar.add(path_lib, lib_name) + + return package_path + def export_classic_format( self, executor_factory: GraphExecutorFactoryModule, @@ -248,7 +286,7 @@ def export_classic_format( def export_package( self, - executor_factory: GraphExecutorFactoryModule, + executor_factory: Union[GraphExecutorFactoryModule, Executable], package_path: Optional[str] = None, cross: Optional[Union[str, Callable]] = None, cross_options: Optional[str] = None, @@ -281,7 +319,9 @@ def export_package( if output_format == "mlf" and cross: raise TVMCException("Specifying the MLF output and a cross compiler is not supported.") - if output_format in ["so", "tar"]: + if isinstance(executor_factory, Executable): + package_path = self.export_vm_format(executor_factory, package_path, output_format) + elif output_format in ["so", "tar"]: package_path = self.export_classic_format( executor_factory, package_path, cross, cross_options, output_format ) @@ -314,9 +354,16 @@ class TVMCPackage(object): project_dir : Path, str If given and loading a MLF file, the path to the project directory that contains the file. + + use_vm : bool + Whether the graph module was compiled with vm or not. """ - def __init__(self, package_path: str, project_dir: Optional[Union[Path, str]] = None): + def __init__( + self, + package_path: str, + project_dir: Optional[Union[Path, str]] = None, + ): self._tmp_dir = utils.tempdir() self.package_path = package_path self.import_package(self.package_path) @@ -351,23 +398,40 @@ def import_package(self, package_path: str): self.type = "mlf" else: # Classic format - lib_name_so = "mod.so" - lib_name_tar = "mod.tar" - if os.path.exists(temp.relpath(lib_name_so)): - self.lib_name = lib_name_so - elif os.path.exists(temp.relpath(lib_name_tar)): - self.lib_name = lib_name_tar + classic_lib_name_so = "mod.so" + classic_lib_name_tar = "mod.tar" + + # VM format + vm_lib_name_so = "lib.so" + vm_lib_name_tar = "lib.tar" + + if os.path.exists(temp.relpath(classic_lib_name_so)): + self.lib_name = classic_lib_name_so + self.type = "classic" + elif os.path.exists(temp.relpath(classic_lib_name_tar)): + self.lib_name = classic_lib_name_tar + self.type = "classic" + elif os.path.exists(temp.relpath(vm_lib_name_so)): + self.lib_name = vm_lib_name_so + self.type = "vm" + elif os.path.exists(temp.relpath(vm_lib_name_tar)): + self.lib_name = vm_lib_name_tar + self.type = "vm" else: raise TVMCException("Couldn't find exported library in the package.") - self.lib_path = temp.relpath(self.lib_name) - graph = temp.relpath("mod.json") - params = temp.relpath("mod.params") + self.lib_path = temp.relpath(self.lib_name) - self.type = "classic" + graph, params = None, None + if self.type == "classic": + graph = temp.relpath("mod.json") + params = temp.relpath("mod.params") - with open(params, "rb") as param_file: - self.params = bytearray(param_file.read()) + if params is not None: + with open(params, "rb") as param_file: + self.params = bytearray(param_file.read()) + else: + self.params = None if graph is not None: with open(graph) as graph_file: diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py index 8db127214c28..1b6d82371230 100644 --- a/python/tvm/driver/tvmc/runner.py +++ b/python/tvm/driver/tvmc/runner.py @@ -28,9 +28,11 @@ import tvm from tvm import rpc +from tvm.runtime import vm from tvm.autotvm.measure import request_remote from tvm.contrib import graph_executor as executor from tvm.contrib.debugger import debug_executor +from tvm.runtime import profiler_vm from . import TVMCException from .arguments import TVMCSuppressedArgumentParser from .project import ( @@ -530,58 +532,93 @@ def run_module( assert device == "cpu" dev = session.cpu() - # TODO(gromero): Adjust for micro targets. - if profile: - logger.debug("Creating executor with profiling enabled.") - module = debug_executor.create(tvmc_package.graph, lib, dev, dump_root="./prof") + if tvmc_package.type == "vm": + assert inputs is not None, "vm runner requires inputs to be provided as a dict" + + input_tensor = {} + for e, i in inputs.items(): + input_tensor[e] = tvm.nd.array(i, dev) + + if profile: + logger.debug("Creating vm with profile enabled.") + exe = profiler_vm.VirtualMachineProfiler(lib, dev) + res = exe.profile(**input_tensor, func_name="main") + # This print is intentional + print(res) + else: + exe = vm.VirtualMachine(lib, dev) + + exe_outputs = exe.invoke("main", **input_tensor) + times = exe.benchmark( + dev, + **input_tensor, + func_name="main", + repeat=repeat, + number=number, + end_to_end=end_to_end, + ) + + # Special handling if the output only has a single value + if not isinstance(exe_outputs, list): + exe_outputs = [exe_outputs] + + outputs = {} + for i, val in enumerate(exe_outputs): + output_name = "output_{}".format(i) + outputs[output_name] = val.numpy() else: - if device == "micro": - logger.debug("Creating executor (micro) with profiling disabled.") - module = tvm.micro.create_local_graph_executor(tvmc_package.graph, lib, dev) + # TODO(gromero): Adjust for micro targets. + if profile: + logger.debug("Creating runtime with profiling enabled.") + module = debug_executor.create(tvmc_package.graph, lib, dev, dump_root="./prof") else: - logger.debug("Creating executor with profiling disabled.") - module = executor.create(tvmc_package.graph, lib, dev) + if device == "micro": + logger.debug("Creating runtime (micro) with profiling disabled.") + module = tvm.micro.create_local_graph_executor(tvmc_package.graph, lib, dev) + else: + logger.debug("Creating runtime with profiling disabled.") + module = executor.create(tvmc_package.graph, lib, dev) - logger.debug("Loading params into the runtime module.") - module.load_params(tvmc_package.params) + logger.debug("Loading params into the runtime module.") + module.load_params(tvmc_package.params) - logger.debug("Collecting graph input shape and type:") - shape_dict, dtype_dict = module.get_input_info() - logger.debug("Graph input shape: %s", shape_dict) - logger.debug("Graph input type: %s", dtype_dict) + logger.debug("Collecting graph input shape and type:") + shape_dict, dtype_dict = module.get_input_info() + logger.debug("Graph input shape: %s", shape_dict) + logger.debug("Graph input type: %s", dtype_dict) - inputs_dict = make_inputs_dict(shape_dict, dtype_dict, inputs, fill_mode) + inputs_dict = make_inputs_dict(shape_dict, dtype_dict, inputs, fill_mode) - logger.debug("Setting inputs to the module.") - module.set_input(**inputs_dict) + logger.debug("Setting inputs to the module.") + module.set_input(**inputs_dict) - # Run must be called explicitly if profiling - if profile: - logger.info("Running the module with profiling enabled.") - report = module.profile() - # This print is intentional - print(report) + # Run must be called explicitly if profiling + if profile: + logger.info("Running the module with profiling enabled.") + report = module.profile() + # This print is intentional + print(report) - if device == "micro": - # TODO(gromero): Fix time_evaluator() for micro targets. Once it's - # fixed module.benchmark() can be used instead and this if/else can - # be removed. - module.run() - times = [] - else: - # Call the benchmarking function of the executor. - # Optionally measure e2e data transfers from the - # CPU to device memory overheads (e.g. PCIE - # overheads if the device is a discrete GPU). - if end_to_end: - dev = session.cpu() - times = module.benchmark(dev, number=number, repeat=repeat, end_to_end=end_to_end) - - logger.debug("Collecting the output tensors.") - num_outputs = module.get_num_outputs() - outputs = {} - for i in range(num_outputs): - output_name = "output_{}".format(i) - outputs[output_name] = module.get_output(i).numpy() + if device == "micro": + # TODO(gromero): Fix time_evaluator() for micro targets. Once it's + # fixed module.benchmark() can be used instead and this if/else can + # be removed. + module.run() + times = [] + else: + # Call the benchmarking function of the executor. + # Optionally measure e2e data transfers from the + # CPU to device memory overheads (e.g. PCIE + # overheads if the device is a discrete GPU). + if end_to_end: + dev = session.cpu() + times = module.benchmark(dev, number=number, repeat=repeat, end_to_end=end_to_end) + + logger.debug("Collecting the output tensors.") + num_outputs = module.get_num_outputs() + outputs = {} + for i in range(num_outputs): + output_name = "output_{}".format(i) + outputs[output_name] = module.get_output(i).numpy() return TVMCResult(outputs, times) diff --git a/tests/python/driver/tvmc/test_compiler.py b/tests/python/driver/tvmc/test_compiler.py index 4b21f4edc8d5..bc836de7d554 100644 --- a/tests/python/driver/tvmc/test_compiler.py +++ b/tests/python/driver/tvmc/test_compiler.py @@ -49,21 +49,32 @@ def test_save_dumps(tmpdir_factory): # End to end tests for compilation -def verify_compile_tflite_module(model, shape_dict=None): - pytest.importorskip("tflite") - tvmc_model = tvmc.load(model, shape_dict=shape_dict) - tvmc_package = tvmc.compile(tvmc_model, target="llvm", dump_code="ll", desired_layout="NCHW") - dumps_path = tvmc_package.package_path + ".ll" - +def verify_tvmc_package(tvmc_package, dumps_path, use_vm=False): # check for output types assert type(tvmc_package) is TVMCPackage - assert type(tvmc_package.graph) is str - assert type(tvmc_package.lib_path) is str - assert type(tvmc_package.params) is bytearray assert os.path.exists(dumps_path) + assert type(tvmc_package.lib_path) is str + + if use_vm: + assert tvmc_package.graph is None + assert tvmc_package.params is None + else: + assert type(tvmc_package.graph) is str + assert type(tvmc_package.params) is bytearray + + +def verify_compile_tflite_module(model, shape_dict=None, use_vm=False): + pytest.importorskip("tflite") + tvmc_model = tvmc.load(model, shape_dict=shape_dict) + tvmc_package = tvmc.compile( + tvmc_model, target="llvm", dump_code="ll", desired_layout="NCHW", use_vm=use_vm + ) + dumps_path = tvmc_package.package_path + ".ll" + verify_tvmc_package(tvmc_package, dumps_path, use_vm=use_vm) -def test_compile_tflite_module(tflite_mobilenet_v1_1_quant): +@pytest.mark.parametrize("use_vm", [True, False]) +def test_compile_tflite_module(use_vm, tflite_mobilenet_v1_1_quant): # some CI environments wont offer tflite, so skip in case it is not present pytest.importorskip("tflite") # Check default compilation. @@ -71,7 +82,7 @@ def test_compile_tflite_module(tflite_mobilenet_v1_1_quant): # Check with manual shape override shape_string = "input:[1,224,224,3]" shape_dict = tvmc.shape_parser.parse_shape_string(shape_string) - verify_compile_tflite_module(tflite_mobilenet_v1_1_quant, shape_dict) + verify_compile_tflite_module(tflite_mobilenet_v1_1_quant, shape_dict, use_vm=use_vm) # This test will be skipped if the AArch64 cross-compilation toolchain is not installed. @@ -198,28 +209,23 @@ def test_cross_compile_options_aarch64_keras_module(keras_resnet50): assert os.path.exists(dumps_path) -def verify_compile_onnx_module(model, shape_dict=None): +def verify_compile_onnx_module(model, shape_dict=None, use_vm=False): # some CI environments wont offer onnx, so skip in case it is not present pytest.importorskip("onnx") tvmc_model = tvmc.load(model, shape_dict=shape_dict) - tvmc_package = tvmc.compile(tvmc_model, target="llvm", dump_code="ll") + tvmc_package = tvmc.compile(tvmc_model, target="llvm", dump_code="ll", use_vm=use_vm) dumps_path = tvmc_package.package_path + ".ll" - - # check for output types - assert type(tvmc_package) is TVMCPackage - assert type(tvmc_package.graph) is str - assert type(tvmc_package.lib_path) is str - assert type(tvmc_package.params) is bytearray - assert os.path.exists(dumps_path) + verify_tvmc_package(tvmc_package, dumps_path, use_vm=use_vm) -def test_compile_onnx_module(onnx_resnet50): +@pytest.mark.parametrize("use_vm", [True, False]) +def test_compile_onnx_module(use_vm, onnx_resnet50): # Test default compilation verify_compile_onnx_module(onnx_resnet50) # Test with manual shape dict shape_string = "data:[1,3,200,200]" shape_dict = tvmc.shape_parser.parse_shape_string(shape_string) - verify_compile_onnx_module(onnx_resnet50, shape_dict) + verify_compile_onnx_module(onnx_resnet50, shape_dict, use_vm=use_vm) # This test will be skipped if the AArch64 cross-compilation toolchain is not installed. diff --git a/tests/python/driver/tvmc/test_model.py b/tests/python/driver/tvmc/test_model.py index 5fccfea149b5..74c1c4ded8a4 100644 --- a/tests/python/driver/tvmc/test_model.py +++ b/tests/python/driver/tvmc/test_model.py @@ -17,6 +17,7 @@ import platform import pytest import os +import numpy as np from os import path @@ -29,13 +30,22 @@ platform.machine() == "aarch64", reason="Currently failing on AArch64 - see https://github.com/apache/tvm/issues/10673", ) -def test_tvmc_workflow(keras_simple): +@pytest.mark.parametrize("use_vm", [True, False]) +def test_tvmc_workflow(use_vm, keras_simple): pytest.importorskip("tensorflow") + import tensorflow as tf + + # Reset so the input name remains consistent across unit test runs + tf.keras.backend.clear_session() tvmc_model = tvmc.load(keras_simple) tuning_records = tvmc.tune(tvmc_model, target="llvm", enable_autoscheduler=True, trials=2) - tvmc_package = tvmc.compile(tvmc_model, tuning_records=tuning_records, target="llvm") - result = tvmc.run(tvmc_package, device="cpu", end_to_end=True) + tvmc_package = tvmc.compile( + tvmc_model, tuning_records=tuning_records, target="llvm", use_vm=use_vm + ) + input_dict = {"input_1": np.random.uniform(size=(1, 32, 32, 3)).astype("float32")} + + result = tvmc.run(tvmc_package, device="cpu", end_to_end=True, inputs=input_dict) assert type(tvmc_model) is TVMCModel assert type(tvmc_package) is TVMCPackage assert type(result) is TVMCResult @@ -45,7 +55,8 @@ def test_tvmc_workflow(keras_simple): assert "output_0" in result.outputs.keys() -def test_save_load_model(keras_simple, tmpdir_factory): +@pytest.mark.parametrize("use_vm", [True, False]) +def test_save_load_model(use_vm, keras_simple, tmpdir_factory): pytest.importorskip("onnx") tmpdir = tmpdir_factory.mktemp("data") @@ -55,7 +66,7 @@ def test_save_load_model(keras_simple, tmpdir_factory): tvmc.tune(tvmc_model, target="llvm", trials=2) # Create package artifacts - tvmc.compile(tvmc_model, target="llvm") + tvmc.compile(tvmc_model, target="llvm", use_vm=use_vm) # Save the model to disk model_path = os.path.join(tmpdir, "saved_model.tar") diff --git a/tests/python/driver/tvmc/test_runner.py b/tests/python/driver/tvmc/test_runner.py index 30ce2c6f2191..3f4ab11f6ba2 100644 --- a/tests/python/driver/tvmc/test_runner.py +++ b/tests/python/driver/tvmc/test_runner.py @@ -72,18 +72,20 @@ def test_get_top_results_keep_results(): assert len(sut[1]) == expected_number_of_results_per_line +@pytest.mark.parametrize("use_vm", [True, False]) def test_run_tflite_module__with_profile__valid_input( - tflite_mobilenet_v1_1_quant, tflite_compile_model, imagenet_cat + use_vm, tflite_mobilenet_v1_1_quant, tflite_compile_model, imagenet_cat ): # some CI environments wont offer TFLite, so skip in case it is not present pytest.importorskip("tflite") inputs = np.load(imagenet_cat) + input_dict = {"input": inputs["input"].astype("uint8")} - tflite_compiled_model = tflite_compile_model(tflite_mobilenet_v1_1_quant) + tflite_compiled_model = tflite_compile_model(tflite_mobilenet_v1_1_quant, use_vm=use_vm) result = tvmc.run( tflite_compiled_model, - inputs=inputs, + inputs=input_dict, hostname=None, device="cpu", profile=True,