From c84630be815e50739ec69c8ccb7a85cae1856046 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Thu, 6 May 2021 10:13:39 -0700 Subject: [PATCH] [BYOC] Remove ext params stored in metadata from params to avoid duplication (#7977) * Remove ext params stored in metadata from params to avoid duplication * Add test for duplicate params --- src/relay/backend/build_module.cc | 13 +++++++++++++ tests/python/relay/test_external_codegen.py | 2 ++ 2 files changed, 15 insertions(+) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 71f19a1c21bc..88faff22cd31 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -565,6 +565,19 @@ class RelayBuildModule : public runtime::ModuleNode { auto ext_mods = executor_codegen_->GetExternalModules(); ret_.mod = tvm::codegen::CreateMetadataModule(ret_.params, ret_.mod, ext_mods, GetTargetHost(), executor_codegen_->GetMetadata()); + // Remove external params which were stored in metadata module. + for (tvm::runtime::Module mod : ext_mods) { + auto pf_var = mod.GetFunction("get_const_vars"); + if (pf_var != nullptr) { + Array variables = pf_var(); + for (size_t i = 0; i < variables.size(); i++) { + auto it = ret_.params.find(variables[i].operator std::string()); + if (it != ret_.params.end()) { + ret_.params.erase(it); + } + } + } + } } private: diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index be92ef200c31..156abfc4c22a 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -352,6 +352,8 @@ def test_load_params_with_constants_in_ext_codegen(): mod = transform.PartitionGraph()(mod) graph_module = relay.build(mod, target="llvm", params=params) + # Params will be stored in metadata module. + assert len(graph_module.get_params()) == 0 lib = update_lib(graph_module.get_lib()) rt_mod = tvm.contrib.graph_executor.create(graph_module.get_graph_json(), lib, tvm.cpu(0)) rt_mod.load_params(runtime.save_param_dict(graph_module.get_params()))