diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 52303123c12e..e23765e92d8c 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -233,6 +233,9 @@ def build( elif isinstance(inputs, PrimFunc): input_mod = lower(inputs, name=name) elif isinstance(inputs, tvm.IRModule): + assert ( + len(inputs.get_global_vars()) > 0 + ), "Expected a non-empty IRModule, but the IRModule contained no functions." input_mod = lower(inputs) elif not isinstance(inputs, (dict, container.Map)): raise ValueError( diff --git a/python/tvm/relax/vm_build.py b/python/tvm/relax/vm_build.py index ca756d4dc681..243488e5d83f 100644 --- a/python/tvm/relax/vm_build.py +++ b/python/tvm/relax/vm_build.py @@ -243,7 +243,7 @@ def _vmlink( if ext_libs is None: ext_libs = [] lib = None - if tir_mod is not None: + if tir_mod is not None and len(tir_mod.get_global_vars()) > 0: lib = tvm.build( tir_mod, target=target, @@ -348,10 +348,10 @@ def _extract_attrs(mod: tvm.IRModule): ) -def _filter_tir(mod: tvm.IRModule) -> tvm.IRModule: - tir_mod = IRModule({}) - tir_mod = tir_mod.with_attrs(mod.attrs) - for gv in mod.get_global_vars(): - if isinstance(mod[gv], PrimFunc): - tir_mod[gv] = mod[gv] - return tir_mod +def _filter_tir(mod: tvm.IRModule) -> Optional[tvm.IRModule]: + tir_mod = {gvar: func for gvar, func in mod.functions.items() if isinstance(func, PrimFunc)} + + if tir_mod: + return IRModule(tir_mod, attrs=mod.attrs) + else: + return None diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 17cd5c49a1bf..4eca8aebd769 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -451,6 +451,7 @@ void CheckAndUpdateHostConsistency(Map* targets, Target* host) runtime::Module TIRToRuntime(const Map& inputs_arg, const Target& target_host_arg) { + CHECK(inputs_arg.size()) << "TIRToRuntime expects at least one IRModule as input."; std::vector device_modules; Map inputs = inputs_arg; Target target_host = target_host_arg; diff --git a/tests/python/relax/test_vm_build.py b/tests/python/relax/test_vm_build.py index 4f28c4a47a69..180535231d98 100644 --- a/tests/python/relax/test_vm_build.py +++ b/tests/python/relax/test_vm_build.py @@ -59,6 +59,26 @@ def foo(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")): tvm.testing.assert_allclose(inp2.numpy(), inp1.numpy(), rtol=1e-7, atol=1e-7) +def test_vm_compile_without_target_arg(exec_mode): + """Like test_vm_compile_simple, but with a default target""" + + @tvm.script.ir_module + class mod: + @R.function + def foo(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")): + z = R.call_pure_packed( + "test.vm.identity", x, y, sinfo_args=(R.Tensor(ndim=2, dtype="float32")) + ) + return y + + ex = relax.build(mod, exec_mode=exec_mode) + inp1 = tvm.nd.array(np.random.rand(3, 4).astype(np.float32)) + inp2 = tvm.nd.array(np.random.rand(3, 4).astype(np.float32)) + vm = relax.VirtualMachine(ex, tvm.cpu()) + vm["foo"](inp1, inp2) + tvm.testing.assert_allclose(inp2.numpy(), inp1.numpy(), rtol=1e-7, atol=1e-7) + + def test_match_check(exec_mode): @tvm.script.ir_module class TestMatchCheck: