Skip to content

Commit

Permalink
fix (#3417)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarisaKirisame authored and tqchen committed Jun 24, 2019
1 parent 311434e commit 25bad44
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 20 deletions.
12 changes: 6 additions & 6 deletions include/tvm/relay/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,42 +123,42 @@ class ModuleNode : public RelayNode {
* \param str The unique string specifying the global variable.
* \returns The global variable.
*/
TVM_DLL GlobalVar GetGlobalVar(const std::string& str);
TVM_DLL GlobalVar GetGlobalVar(const std::string& str) const;

/*!
* \brief Look up a global function by its name.
* \param str The unique string specifying the global variable.
* \returns The global variable.
*/
TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str);
TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str) const;

/*!
* \brief Lookup a global function by its variable.
* \param var The global var to lookup.
* \returns The function named by the variable argument.
*/
TVM_DLL Function Lookup(const GlobalVar& var);
TVM_DLL Function Lookup(const GlobalVar& var) const;

/*!
* \brief Lookup a global function by its string name
* \param name The name of the function.
* \returns The function named by the argument.
*/
TVM_DLL Function Lookup(const std::string& name);
TVM_DLL Function Lookup(const std::string& name) const;

/*!
* \brief Lookup a global type definition by its variable.
* \param var The var of the global type definition.
* \return The type definition.
*/
TVM_DLL TypeData LookupDef(const GlobalTypeVar& var);
TVM_DLL TypeData LookupDef(const GlobalTypeVar& var) const;

/*!
* \brief Lookup a global type definition by its name.
* \param var The name of the global type definition.
* \return The type definition.
*/
TVM_DLL TypeData LookupDef(const std::string& var);
TVM_DLL TypeData LookupDef(const std::string& var) const;

/*!
* \brief Update the functions inside this environment by
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/vm/lambda_lift.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ struct LambdaLifter : ExprMutator {
CHECK(lifted_func.defined());

auto name = GenerateName(lifted_func);
auto global = module_->GetGlobalVar(name);
auto global = GlobalVarNode::make(name);

// Add the lifted function to the module.
module_->Add(global, lifted_func);
Expand Down
22 changes: 9 additions & 13 deletions src/relay/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,11 @@ Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs,
return Module(n);
}

GlobalVar ModuleNode::GetGlobalVar(const std::string& name) {
GlobalVar ModuleNode::GetGlobalVar(const std::string& name) const {
auto it = global_var_map_.find(name);
if (it == global_var_map_.end()) {
auto gvar = GlobalVarNode::make(name);
global_var_map_.Set(name, gvar);
return gvar;
} else {
return (*it).second;
}
CHECK(it != global_var_map_.end())
<< "Cannot find global var " << name << " in the Module";
return (*it).second;
}

void ModuleNode::AddUnchecked(const GlobalVar& var,
Expand All @@ -84,7 +80,7 @@ void ModuleNode::AddUnchecked(const GlobalVar& var,
global_var_map_.Set(var->name_hint, var);
}

GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) {
GlobalTypeVar ModuleNode::GetGlobalTypeVar(const std::string& name) const {
auto it = global_type_var_map_.find(name);
CHECK(it != global_type_var_map_.end())
<< "Cannot find global type var " << name << " in the Module";
Expand Down Expand Up @@ -137,26 +133,26 @@ void ModuleNode::Remove(const GlobalVar& var) {
gvar_node->data.erase(var->name_hint);
}

Function ModuleNode::Lookup(const GlobalVar& var) {
Function ModuleNode::Lookup(const GlobalVar& var) const {
auto it = functions.find(var);
CHECK(it != functions.end())
<< "There is no definition of " << var->name_hint;
return (*it).second;
}

Function ModuleNode::Lookup(const std::string& name) {
Function ModuleNode::Lookup(const std::string& name) const {
GlobalVar id = this->GetGlobalVar(name);
return this->Lookup(id);
}

TypeData ModuleNode::LookupDef(const GlobalTypeVar& var) {
TypeData ModuleNode::LookupDef(const GlobalTypeVar& var) const {
auto it = type_definitions.find(var);
CHECK(it != type_definitions.end())
<< "There is no definition of " << var->var->name_hint;
return (*it).second;
}

TypeData ModuleNode::LookupDef(const std::string& name) {
TypeData ModuleNode::LookupDef(const std::string& name) const {
GlobalTypeVar id = this->GetGlobalTypeVar(name);
return this->LookupDef(id);
}
Expand Down

0 comments on commit 25bad44

Please sign in to comment.