diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 7df54da7ddc6a..61fd6820edd1b 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -473,17 +473,29 @@ runtime::Module TIRToRuntime(const Map& inputs_arg, } auto host_target = [&]() -> Target { + // All targets that contain a kIsEntryFunc=True function Array targets_with_entry_func; + + // All targets that can run on the CPU and contain at least one + // function without kIsEntryFunc=False. Array cpu_targets; for (const auto& [target, mod] : split) { - bool contains_entry_func = std::any_of( - mod->functions.begin(), mod->functions.end(), - [](const auto& kv) { return kv.second->HasNonzeroAttr(tvm::tir::attr::kIsEntryFunc); }); + bool contains_entry_func = false; + bool may_contain_entry_func = false; + for (const auto& [gvar, func] : mod->functions) { + Optional is_entry_func = func->attrs.GetAttr(tvm::tir::attr::kIsEntryFunc); + if (is_entry_func.defined() && is_entry_func.value()->value) { + contains_entry_func = true; + } else if (!is_entry_func.defined()) { + may_contain_entry_func = true; + } + } + if (contains_entry_func) { targets_with_entry_func.push_back(target); } - if (target->HasKey("cpu")) { + if (may_contain_entry_func && target->HasKey("cpu")) { cpu_targets.push_back(target); } } diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 29ecaa4e8e435..1a121d1e388e6 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -84,7 +84,8 @@ class HostDeviceSplitter : public StmtMutator { PrimFunc device_func(params, body); device_func = WithAttrs(std::move(device_func), {{tvm::attr::kTarget, device_target}, {tir::attr::kNoAlias, Bool(true)}, - {tir::attr::kIsGlobalFunc, Bool(true)}}); + {tir::attr::kIsGlobalFunc, Bool(true)}, + {tir::attr::kIsEntryFunc, Bool(false)}}); (*device_mod_)->Add(kernel_symbol_global, device_func); Array args = params.Map([](const Var& var) -> PrimExpr { return var; }); diff --git a/tests/python/unittest/test_tir_transform_split_host_device.py b/tests/python/unittest/test_tir_transform_split_host_device.py index 60bfb8a718d25..9b1095620ff60 100644 --- a/tests/python/unittest/test_tir_transform_split_host_device.py +++ b/tests/python/unittest/test_tir_transform_split_host_device.py @@ -122,6 +122,7 @@ def main_kernel(n: T.int32): "target": T.target("cuda"), "tir.noalias": T.bool(True), "tir.is_global_func": True, + "tir.is_entry_func": False, } ) T.evaluate(n) @@ -162,6 +163,7 @@ def main_kernel(n: T.int32): "target": T.target("cuda"), "tir.noalias": T.bool(True), "tir.is_global_func": True, + "tir.is_entry_func": False, } ) T.evaluate(n)