diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index 3966a6258a20..638f75968fd3 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -123,42 +123,42 @@ class ModuleNode : public RelayNode { * \param str The unique string specifying the global variable. * \returns The global variable. */ - TVM_DLL GlobalVar GetGlobalVar(const std::string& str); + TVM_DLL GlobalVar GetGlobalVar(const std::string& str) const; /*! * \brief Look up a global function by its name. * \param str The unique string specifying the global variable. * \returns The global variable. */ - TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str); + TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str) const; /*! * \brief Lookup a global function by its variable. * \param var The global var to lookup. * \returns The function named by the variable argument. */ - TVM_DLL Function Lookup(const GlobalVar& var); + TVM_DLL Function Lookup(const GlobalVar& var) const; /*! * \brief Lookup a global function by its string name * \param name The name of the function. * \returns The function named by the argument. */ - TVM_DLL Function Lookup(const std::string& name); + TVM_DLL Function Lookup(const std::string& name) const; /*! * \brief Lookup a global type definition by its variable. * \param var The var of the global type definition. * \return The type definition. */ - TVM_DLL TypeData LookupDef(const GlobalTypeVar& var); + TVM_DLL TypeData LookupDef(const GlobalTypeVar& var) const; /*! * \brief Lookup a global type definition by its name. * \param var The name of the global type definition. * \return The type definition. */ - TVM_DLL TypeData LookupDef(const std::string& var); + TVM_DLL TypeData LookupDef(const std::string& var) const; /*! * \brief Update the functions inside this environment by diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index a55a9273d078..668c024a8d55 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -112,7 +112,7 @@ struct LambdaLifter : ExprMutator { CHECK(lifted_func.defined()); auto name = GenerateName(lifted_func); - auto global = module_->GetGlobalVar(name); + auto global = GlobalVarNode::make(name); // Add the lifted function to the module. module_->Add(global, lifted_func); diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index 6b5fee82af89..58f614a3cc77 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -57,15 +57,11 @@ Module ModuleNode::make(tvm::Map global_funcs, return Module(n); } -GlobalVar ModuleNode::GetGlobalVar(const std::string& name) { +GlobalVar ModuleNode::GetGlobalVar(const std::string& name) const { auto it = global_var_map_.find(name); - if (it == global_var_map_.end()) { - auto gvar = GlobalVarNode::make(name); - global_var_map_.Set(name, gvar); - return gvar; - } else { - return (*it).second; - } + CHECK(it != global_var_map_.end()) + << "Cannot find global var " << name << " in the Module"; + return (*it).second; } void ModuleNode::AddUnchecked(const GlobalVar& var, @@ -84,7 +80,7 @@ void ModuleNode::AddUnchecked(const GlobalVar& var, global_var_map_.Set(var->name_hint, var); } -GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) { +GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const { auto it = global_type_var_map_.find(name); CHECK(it != global_type_var_map_.end()) << "Cannot find global type var " << name << " in the Module"; @@ -137,26 +133,26 @@ void ModuleNode::Remove(const GlobalVar& var) { gvar_node->data.erase(var->name_hint); } -Function ModuleNode::Lookup(const GlobalVar& var) { +Function ModuleNode::Lookup(const GlobalVar& var) const { auto it = functions.find(var); CHECK(it != functions.end()) << "There is no definition of " << var->name_hint; return (*it).second; } -Function ModuleNode::Lookup(const std::string& name) { +Function ModuleNode::Lookup(const std::string& name) const { GlobalVar id = this->GetGlobalVar(name); return this->Lookup(id); } -TypeData ModuleNode::LookupDef(const GlobalTypeVar& var) { +TypeData ModuleNode::LookupDef(const GlobalTypeVar& var) const { auto it = type_definitions.find(var); CHECK(it != type_definitions.end()) << "There is no definition of " << var->var->name_hint; return (*it).second; } -TypeData ModuleNode::LookupDef(const std::string& name) { +TypeData ModuleNode::LookupDef(const std::string& name) const { GlobalTypeVar id = this->GetGlobalTypeVar(name); return this->LookupDef(id); }