Skip to content

Commit

Permalink
[Relay][Module] Make tags for ADT constructors and ConstructorValues …
Browse files Browse the repository at this point in the history
…more robust (apache#3369)

* Use hash of ADT name and constructor idx to generate tag, add reverse mapping to module and use where appropriate

* Lint and build fixes

* Add round-tripping test for getting constructors by tag

* Use int64_t everywhere for tags

* Add additional identity check

* Bring out _arg_to_ast again

* Use 8-bit hash of GTV name as MSB of tag, index as LSB for more readable tags

* Use int32 instead of int64 for tag
  • Loading branch information
slyubomirsky authored and Wei Chen committed Jul 11, 2019
1 parent f718cda commit 6d69e63
Show file tree
Hide file tree
Showing 8 changed files with 145 additions and 18 deletions.
2 changes: 1 addition & 1 deletion include/tvm/relay/adt.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class ConstructorNode : public ExprNode {
/*! \brief The datatype the constructor will construct. */
GlobalTypeVar belong_to;
/*! \brief Index in the table of constructors (set when the type is registered). */
mutable int tag = -1;
mutable int32_t tag = -1;

ConstructorNode() {}

Expand Down
4 changes: 2 additions & 2 deletions include/tvm/relay/interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value);
class ConstructorValue;

struct ConstructorValueNode : ValueNode {
int tag;
int32_t tag;

tvm::Array<Value> fields;

Expand All @@ -195,7 +195,7 @@ struct ConstructorValueNode : ValueNode {
v->Visit("constructor", &constructor);
}

TVM_DLL static ConstructorValue make(int tag,
TVM_DLL static ConstructorValue make(int32_t tag,
tvm::Array<Value> fields,
Constructor construtor = {});

Expand Down
24 changes: 20 additions & 4 deletions include/tvm/relay/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <tvm/relay/type.h>
#include <string>
#include <vector>
#include <unordered_map>

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -133,33 +134,40 @@ class ModuleNode : public RelayNode {
TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str) const;

/*!
* \brief Lookup a global function by its variable.
* \brief Look up a global function by its variable.
* \param var The global var to lookup.
* \returns The function named by the variable argument.
*/
TVM_DLL Function Lookup(const GlobalVar& var) const;

/*!
* \brief Lookup a global function by its string name
* \brief Look up a global function by its string name
* \param name The name of the function.
* \returns The function named by the argument.
*/
TVM_DLL Function Lookup(const std::string& name) const;

/*!
* \brief Lookup a global type definition by its variable.
* \brief Look up a global type definition by its variable.
* \param var The var of the global type definition.
* \return The type definition.
*/
TVM_DLL TypeData LookupDef(const GlobalTypeVar& var) const;

/*!
* \brief Lookup a global type definition by its name.
* \brief Look up a global type definition by its name.
* \param var The name of the global type definition.
* \return The type definition.
*/
TVM_DLL TypeData LookupDef(const std::string& var) const;

/*!
* \brief Look up a constructor by its tag.
* \param tag The tag for the constructor.
* \return The constructor object.
*/
TVM_DLL Constructor LookupTag(const int32_t tag);

/*!
* \brief Update the functions inside this environment by
* functions in another environment.
Expand All @@ -185,6 +193,9 @@ class ModuleNode : public RelayNode {
TVM_DECLARE_NODE_TYPE_INFO(ModuleNode, Node);

private:
/*! \brief Helper function for registering a typedef's constructors */
void RegisterConstructors(const GlobalTypeVar& var, const TypeData& type);

/*! \brief A map from string names to global variables that
* ensures global uniqueness.
*/
Expand All @@ -194,6 +205,11 @@ class ModuleNode : public RelayNode {
* that ensures global uniqueness.
*/
tvm::Map<std::string, GlobalTypeVar> global_type_var_map_;

/*! \brief A map from constructor tags to constructor objects
* for convenient access
*/
std::unordered_map<int32_t, Constructor> constructor_tag_map_;
};

struct Module : public NodeRef {
Expand Down
15 changes: 8 additions & 7 deletions python/tvm/relay/backend/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,17 +114,18 @@ def __init__(self, value):
_make.RefValue, value)


def _arg_to_ast(arg):
def _arg_to_ast(mod, arg):
if isinstance(arg, TensorValue):
return Constant(arg.data.copyto(nd.cpu(0)))
elif isinstance(arg, TupleValue):
return Tuple([_arg_to_ast(field) for field in arg.fields])
return Tuple([_arg_to_ast(mod, field) for field in arg.fields])
elif isinstance(arg, tuple):
return Tuple([_arg_to_ast(field) for field in arg])
return Tuple([_arg_to_ast(mod, field) for field in arg])
elif isinstance(arg, RefValue):
return RefCreate(_arg_to_ast(arg.value))
return RefCreate(_arg_to_ast(mod, arg.value))
elif isinstance(arg, ConstructorValue):
return Call(arg.constructor, [_arg_to_ast(field) for field in arg.fields])
return Call(mod.get_constructor(arg.tag),
[_arg_to_ast(mod, field) for field in arg.fields])
elif isinstance(arg, np.ndarray):
return Constant(nd.array(arg))
elif isinstance(arg, Constant):
Expand Down Expand Up @@ -231,7 +232,7 @@ def evaluate(self, expr=None, binds=None):
if binds:
scope_builder = ScopeBuilder()
for key, value in binds.items():
scope_builder.let(key, _arg_to_ast(value))
scope_builder.let(key, _arg_to_ast(self.mod, value))
scope_builder.ret(expr)
expr = scope_builder.get()

Expand Down Expand Up @@ -294,7 +295,7 @@ def _interp_wrapper(*args, **kwargs):

relay_args = []
for arg in args:
relay_args.append(_arg_to_ast(arg))
relay_args.append(_arg_to_ast(self.mod, arg))

# Set the entry function for the module.
if expr is None:
Expand Down
19 changes: 19 additions & 0 deletions python/tvm/relay/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,25 @@ def get_global_type_var(self, name):
"""
return _module.Module_GetGlobalTypeVar(self, name)

def get_constructor(self, tag):
"""Look up an ADT constructor by tag.
Parameters
----------
tag: int
The tag for a constructor.
Returns
-------
constructor: Constructor
The constructor associated with the given tag,
Raises
------
tvm.TVMError if the corresponding constructor cannot be found.
"""
return _module.Module_LookupTag(self, tag)

@staticmethod
def from_expr(expr):
return _module.Module_FromExpr(expr)
2 changes: 1 addition & 1 deletion src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "RefValueNode(" << node->value << ")";
});

ConstructorValue ConstructorValueNode::make(int tag,
ConstructorValue ConstructorValueNode::make(int32_t tag,
tvm::Array<Value> fields,
Constructor constructor) {
NodePtr<ConstructorValueNode> n = make_node<ConstructorValueNode>();
Expand Down
29 changes: 26 additions & 3 deletions src/relay/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs,
CHECK(!n->global_type_var_map_.count(kv.first->var->name_hint))
<< "Duplicate global type definition name " << kv.first->var->name_hint;
n->global_type_var_map_.Set(kv.first->var->name_hint, kv.first);
n->RegisterConstructors(kv.first, kv.second);
}

return Module(n);
Expand Down Expand Up @@ -108,15 +109,25 @@ void ModuleNode::Add(const GlobalVar& var,
AddUnchecked(var, checked_func);
}

void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& type) {
// We hash the global type var name to use as a globally unique prefix for tags.
// The hash will be used as the most significant byte of the tag, with the index of
// the constructor in the less significant bytes
size_t hash = std::hash<std::string>()(var->var->name_hint);
int32_t prefix = static_cast<int32_t>(hash & 0xff) << 24;
for (size_t i = 0; i < type->constructors.size(); ++i) {
type->constructors[i]->tag = prefix | static_cast<int32_t>(i);
constructor_tag_map_[type->constructors[i]->tag] = type->constructors[i];
}
}

void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type) {
this->type_definitions.Set(var, type);
// set global type var map
CHECK(!global_type_var_map_.count(var->var->name_hint))
<< "Duplicate global type definition name " << var->var->name_hint;
global_type_var_map_.Set(var->var->name_hint, var);
for (size_t i = 0; i < type->constructors.size(); ++i) {
type->constructors[i]->tag = i;
}
RegisterConstructors(var, type);

// need to kind check at the end because the check can look up
// a definition potentially
Expand Down Expand Up @@ -159,6 +170,13 @@ TypeData ModuleNode::LookupDef(const std::string& name) const {
return this->LookupDef(id);
}

Constructor ModuleNode::LookupTag(const int32_t tag) {
auto it = constructor_tag_map_.find(tag);
CHECK(it != constructor_tag_map_.end())
<< "There is no constructor with the tag " << tag;
return (*it).second;
}

void ModuleNode::Update(const Module& mod) {
for (auto pair : mod->functions) {
this->Update(pair.first, pair.second);
Expand Down Expand Up @@ -236,6 +254,11 @@ TVM_REGISTER_API("relay._module.Module_LookupDef_str")
return mod->LookupDef(var);
});

TVM_REGISTER_API("relay._module.Module_LookupTag")
.set_body_typed<Constructor(Module, int32_t)>([](Module mod, int32_t tag) {
return mod->LookupTag(tag);
});

TVM_REGISTER_API("relay._module.Module_FromExpr")
.set_body_typed<Module(Expr)>([](Expr e) {
return ModuleNode::FromExpr(e);
Expand Down
68 changes: 68 additions & 0 deletions tests/python/relay/test_ir_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests for module functionality."""
import tvm
from tvm import relay
from tvm.relay import Module
from tvm.relay.prelude import Prelude
from tvm.relay.testing import add_nat_definitions

def constructor_list(p):
return [p.nil, p.cons, p.rose, p.some, p.none, p.z, p.s]


def adt_list(p):
return [p.nat, p.l, p.optional, p.tree]


def test_constructor_tag_round_trip():
mod1 = Module()
p1 = Prelude(mod1)
add_nat_definitions(p1)
mod2 = Module()
p2 = Prelude(mod2)
add_nat_definitions(p2)

# ensure hashes match across modules
ctors1 = constructor_list(p1)
ctors2 = constructor_list(p2)

for i in range(len(ctors1)):
tag = ctors1[i].tag
ctor = mod2.get_constructor(tag)
assert ctor == ctors2[i]
assert ctor.name_hint == ctors1[i].name_hint


def test_constructor_tag_differences():
# ensure that if we have the type data for a given ADT, the tags
# for the constructors of the *same ADT* are simple offsets from
# each other
mod = Module()
p = Prelude(mod)
add_nat_definitions(p)

adts = adt_list(p)
for adt in adts:
data = mod[adt]
for i in range(len(data.constructors) - 1):
ctor1 = data.constructors[i]
ctor2 = data.constructors[i + 1]
assert ctor2.tag - ctor1.tag == 1
# make sure there is something present at the MSB
assert ctor1.tag - i != 0
assert ctor2.tag - (i + 1) != 0

0 comments on commit 6d69e63

Please sign in to comment.