diff --git a/src/relax/transform/dead_code_elimination.cc b/src/relax/transform/dead_code_elimination.cc index 6d9f25296a5a..248e4c1c00b7 100644 --- a/src/relax/transform/dead_code_elimination.cc +++ b/src/relax/transform/dead_code_elimination.cc @@ -124,7 +124,7 @@ IRModule DeadCodeElimination(const IRModule& arg_mod, Array ent entry_functions.insert(mod->GetGlobalVar(name)); } for (const auto& [gv, func] : mod->functions) { - if (func->GetLinkageType() == LinkageType::kExternal) { + if (func.as() || func->GetLinkageType() == LinkageType::kExternal) { entry_functions.insert(gv); } } diff --git a/tests/python/relax/test_transform_dead_code_elimination.py b/tests/python/relax/test_transform_dead_code_elimination.py index 7b749d677880..c0a2d47b19f1 100644 --- a/tests/python/relax/test_transform_dead_code_elimination.py +++ b/tests/python/relax/test_transform_dead_code_elimination.py @@ -497,5 +497,15 @@ def main( verify(Input, Expected) +def test_extern_func(): + """DeadCodeElimination should retain the ExternFunc in the IRModule.""" + + builder = tvm.relax.BlockBuilder() + builder.add_func(tvm.relax.extern("extern_func"), "extern_func") + before = builder.get() + + verify(before, before) + + if __name__ == "__main__": tvm.testing.main()