Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Module] Make tags for ADT constructors and ConstructorValues more robust #3369

Merged
merged 8 commits into from
Jul 5, 2019
2 changes: 1 addition & 1 deletion include/tvm/relay/adt.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class ConstructorNode : public ExprNode {
/*! \brief The datatype the constructor will construct. */
GlobalTypeVar belong_to;
/*! \brief Index in the table of constructors (set when the type is registered). */
mutable int tag = -1;
mutable int64_t tag = -1;

ConstructorNode() {}

Expand Down
4 changes: 2 additions & 2 deletions include/tvm/relay/interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value);
class ConstructorValue;

struct ConstructorValueNode : ValueNode {
int tag;
int64_t tag;

tvm::Array<Value> fields;

Expand All @@ -195,7 +195,7 @@ struct ConstructorValueNode : ValueNode {
v->Visit("constructor", &constructor);
}

TVM_DLL static ConstructorValue make(int tag,
TVM_DLL static ConstructorValue make(int64_t tag,
tvm::Array<Value> fields,
Constructor construtor = {});

Expand Down
24 changes: 20 additions & 4 deletions include/tvm/relay/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <tvm/relay/type.h>
#include <string>
#include <vector>
#include <unordered_map>

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -133,33 +134,40 @@ class ModuleNode : public RelayNode {
TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str);

/*!
* \brief Lookup a global function by its variable.
* \brief Look up a global function by its variable.
* \param var The global var to lookup.
* \returns The function named by the variable argument.
*/
TVM_DLL Function Lookup(const GlobalVar& var);

/*!
* \brief Lookup a global function by its string name
* \brief Look up a global function by its string name
* \param name The name of the function.
* \returns The function named by the argument.
*/
TVM_DLL Function Lookup(const std::string& name);

/*!
* \brief Lookup a global type definition by its variable.
* \brief Look up a global type definition by its variable.
* \param var The var of the global type definition.
* \return The type definition.
*/
TVM_DLL TypeData LookupDef(const GlobalTypeVar& var);

/*!
* \brief Lookup a global type definition by its name.
* \brief Look up a global type definition by its name.
* \param var The name of the global type definition.
* \return The type definition.
*/
TVM_DLL TypeData LookupDef(const std::string& var);

/*!
* \brief Look up a constructor by its tag.
* \param tag The tag for the constructor.
* \return The constructor object.
*/
TVM_DLL Constructor LookupTag(const int64_t tag);

/*!
* \brief Update the functions inside this environment by
* functions in another environment.
Expand All @@ -185,6 +193,9 @@ class ModuleNode : public RelayNode {
TVM_DECLARE_NODE_TYPE_INFO(ModuleNode, Node);

private:
/*! \brief Helper function for registering a typedef's constructors */
void RegisterConstructors(const GlobalTypeVar& var, const TypeData& type);

/*! \brief A map from string names to global variables that
* ensures global uniqueness.
*/
Expand All @@ -194,6 +205,11 @@ class ModuleNode : public RelayNode {
* that ensures global uniqueness.
*/
tvm::Map<std::string, GlobalTypeVar> global_type_var_map_;

/*! \brief A map from constructor tags to constructor objects
* for convenient access
*/
std::unordered_map<int64_t, Constructor> constructor_tag_map_;
};

struct Module : public NodeRef {
Expand Down
15 changes: 8 additions & 7 deletions python/tvm/relay/backend/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,17 +113,18 @@ def __init__(self, value):
_make.RefValue, value)


def _arg_to_ast(arg):
def _arg_to_ast(mod, arg):
if isinstance(arg, TensorValue):
return Constant(arg.data.copyto(nd.cpu(0)))
elif isinstance(arg, TupleValue):
return Tuple([_arg_to_ast(field) for field in arg.fields])
return Tuple([_arg_to_ast(mod, field) for field in arg.fields])
elif isinstance(arg, tuple):
return Tuple([_arg_to_ast(field) for field in arg])
return Tuple([_arg_to_ast(mod, field) for field in arg])
elif isinstance(arg, RefValue):
return RefCreate(_arg_to_ast(arg.value))
return RefCreate(_arg_to_ast(mod, arg.value))
elif isinstance(arg, ConstructorValue):
return Call(arg.constructor, [_arg_to_ast(field) for field in arg.fields])
return Call(mod.get_constructor(arg.tag),
[_arg_to_ast(mod, field) for field in arg.fields])
elif isinstance(arg, np.ndarray):
return Constant(nd.array(arg))
elif isinstance(arg, Constant):
Expand Down Expand Up @@ -228,7 +229,7 @@ def evaluate(self, expr, binds=None):
if binds:
scope_builder = ScopeBuilder()
for key, value in binds.items():
scope_builder.let(key, _arg_to_ast(value))
scope_builder.let(key, _arg_to_ast(self.mod, value))
scope_builder.ret(expr)
expr = scope_builder.get()

Expand Down Expand Up @@ -294,7 +295,7 @@ def _interp_wrapper(*args, **kwargs):

relay_args = []
for arg in args:
relay_args.append(_arg_to_ast(arg))
relay_args.append(_arg_to_ast(self.mod, arg))

if isinstance(expr, GlobalVar):
func = self.mod[expr]
Expand Down
19 changes: 19 additions & 0 deletions python/tvm/relay/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,25 @@ def get_global_type_var(self, name):
"""
return _module.Module_GetGlobalTypeVar(self, name)

def get_constructor(self, tag):
"""Look up an ADT constructor by tag.

Parameters
----------
tag: int
The tag for a constructor.

Returns
-------
constructor: Constructor
The constructor associated with the given tag,

Raises
------
tvm.TVMError if the corresponding constructor cannot be found.
"""
return _module.Module_LookupTag(self, tag)

@staticmethod
def from_expr(expr):
return _module.Module_FromExpr(expr)
2 changes: 1 addition & 1 deletion src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "RefValueNode(" << node->value << ")";
});

ConstructorValue ConstructorValueNode::make(int tag,
ConstructorValue ConstructorValueNode::make(int64_t tag,
tvm::Array<Value> fields,
Constructor constructor) {
NodePtr<ConstructorValueNode> n = make_node<ConstructorValueNode>();
Expand Down
26 changes: 23 additions & 3 deletions src/relay/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs,
CHECK(!n->global_type_var_map_.count(kv.first->var->name_hint))
<< "Duplicate global type definition name " << kv.first->var->name_hint;
n->global_type_var_map_.Set(kv.first->var->name_hint, kv.first);
n->RegisterConstructors(kv.first, kv.second);
}

return Module(n);
Expand Down Expand Up @@ -110,15 +111,22 @@ void ModuleNode::Add(const GlobalVar& var,
AddUnchecked(var, checked_func);
}

void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& type) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using hash value as tag makes opcode not very intuitive to understand, How about maintaining a continuous range [0, num_constructors) for each ADT type?

Then for the following code:

type t = | Alice | Bob | Charlie | David

let test v =
  match v with
  | Alice   -> 1
  | Bob     -> 2
  | Charlie -> 3
  | David   -> 4

We can generate opcode like:

(switch* v
          case int 0: 1
          case int 1: 2
          case int 2: 3
          case int 3: 4)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point and I agree hashes are not ideal. The reason I went with hashes is to ensure that the tag will be determined completely by the name of the global type var and the index of the constructor in the type data and thus will not be affected, for example, by the order in which types are defined. It also means that tags will not collide across different ADTs. If hashes are undesirable, we can take the approach (as I alluded to before) of requiring a type check and only looking at the index of the constructor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I vote for eliminating the hash.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the constructor value carry information about its type, then? The case that is most difficult is the one in arg_to_ast: given only the constructor value, how do you work backwards?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for VM, we might need to discard type information in Object. As we already drop the type information in byte codes we generate.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll see what I can do about arg_to_ast then

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure, though, that it will be possible to drop type information for ADT values without ensuring that all tags be globally unique (and I'm not sure of any better way to do that than hashing).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree there is no other way . And seems like interpreter has the need to map runtime value back to AST node to support user provided bindings.

How about make a 8 bit hash instead and put it in the most significant byte of the tag. So when we debug, we can easily tell the variant from the least significant byte. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting suggestion, definitely more readable too. I will look into implementing that, thanks.

for (size_t i = 0; i < type->constructors.size(); ++i) {
size_t hash = dmlc::HashCombine(std::hash<std::string>()(var->var->name_hint),
std::hash<size_t>()(i));
type->constructors[i]->tag = static_cast<int64_t>(hash);
constructor_tag_map_[type->constructors[i]->tag] = type->constructors[i];
}
}

void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type) {
this->type_definitions.Set(var, type);
// set global type var map
CHECK(!global_type_var_map_.count(var->var->name_hint))
<< "Duplicate global type definition name " << var->var->name_hint;
global_type_var_map_.Set(var->var->name_hint, var);
for (size_t i = 0; i < type->constructors.size(); ++i) {
type->constructors[i]->tag = i;
}
RegisterConstructors(var, type);

// need to kind check at the end because the check can look up
// a definition potentially
Expand Down Expand Up @@ -161,6 +169,13 @@ TypeData ModuleNode::LookupDef(const std::string& name) {
return this->LookupDef(id);
}

Constructor ModuleNode::LookupTag(const int64_t tag) {
auto it = constructor_tag_map_.find(tag);
CHECK(it != constructor_tag_map_.end())
<< "There is no constructor with the tag " << tag;
return (*it).second;
}

void ModuleNode::Update(const Module& mod) {
for (auto pair : mod->functions) {
this->Update(pair.first, pair.second);
Expand Down Expand Up @@ -219,6 +234,11 @@ TVM_REGISTER_API("relay._module.Module_LookupDef_str")
return mod->LookupDef(var);
});

TVM_REGISTER_API("relay._module.Module_LookupTag")
.set_body_typed<Constructor(Module, int64_t)>([](Module mod, int64_t tag) {
return mod->LookupTag(tag);
});

TVM_REGISTER_API("relay._module.Module_FromExpr")
.set_body_typed<Module(Expr)>([](Expr e) {
return ModuleNode::FromExpr(e);
Expand Down
44 changes: 44 additions & 0 deletions tests/python/relay/test_ir_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests for module functionality."""
import tvm
from tvm import relay
from tvm.relay import Module
from tvm.relay.prelude import Prelude
from tvm.relay.testing import add_nat_definitions

def constructor_list(p):
return [p.nil, p.cons, p.rose, p.some, p.none, p.z, p.s]


def test_constructor_tag_round_trip():
mod1 = Module()
p1 = Prelude(mod1)
add_nat_definitions(p1)
mod2 = Module()
p2 = Prelude(mod2)
add_nat_definitions(p2)

# ensure hashes match across modules
ctors1 = constructor_list(p1)
ctors2 = constructor_list(p2)

for i in range(len(ctors1)):
tag = ctors1[i].tag
ctor = mod2.get_constructor(tag)
assert ctor == ctors2[i]
assert ctor.name_hint == ctors1[i].name_hint