diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index 9e4e00ca47ed..2a6507b62a33 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -114,7 +114,7 @@ class ConstructorNode : public ExprNode { /*! \brief The datatype the constructor will construct. */ GlobalTypeVar belong_to; /*! \brief Index in the table of constructors (set when the type is registered). */ - mutable int tag = -1; + mutable int32_t tag = -1; ConstructorNode() {} diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index 68b7ccab99c7..d05099f781ac 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -182,7 +182,7 @@ RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value); class ConstructorValue; struct ConstructorValueNode : ValueNode { - int tag; + int32_t tag; tvm::Array fields; @@ -195,7 +195,7 @@ struct ConstructorValueNode : ValueNode { v->Visit("constructor", &constructor); } - TVM_DLL static ConstructorValue make(int tag, + TVM_DLL static ConstructorValue make(int32_t tag, tvm::Array fields, Constructor construtor = {}); diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index 3966a6258a20..e5b20987c293 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -32,6 +32,7 @@ #include #include #include +#include namespace tvm { namespace relay { @@ -133,33 +134,40 @@ class ModuleNode : public RelayNode { TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str); /*! - * \brief Lookup a global function by its variable. + * \brief Look up a global function by its variable. * \param var The global var to lookup. * \returns The function named by the variable argument. */ TVM_DLL Function Lookup(const GlobalVar& var); /*! - * \brief Lookup a global function by its string name + * \brief Look up a global function by its string name * \param name The name of the function. * \returns The function named by the argument. */ TVM_DLL Function Lookup(const std::string& name); /*! - * \brief Lookup a global type definition by its variable. + * \brief Look up a global type definition by its variable. * \param var The var of the global type definition. * \return The type definition. */ TVM_DLL TypeData LookupDef(const GlobalTypeVar& var); /*! - * \brief Lookup a global type definition by its name. + * \brief Look up a global type definition by its name. * \param var The name of the global type definition. * \return The type definition. */ TVM_DLL TypeData LookupDef(const std::string& var); + /*! + * \brief Look up a constructor by its tag. + * \param tag The tag for the constructor. + * \return The constructor object. + */ + TVM_DLL Constructor LookupTag(const int32_t tag); + /*! * \brief Update the functions inside this environment by * functions in another environment. @@ -185,6 +193,9 @@ class ModuleNode : public RelayNode { TVM_DECLARE_NODE_TYPE_INFO(ModuleNode, Node); private: + /*! \brief Helper function for registering a typedef's constructors */ + void RegisterConstructors(const GlobalTypeVar& var, const TypeData& type); + /*! \brief A map from string names to global variables that * ensures global uniqueness. */ @@ -194,6 +205,11 @@ class ModuleNode : public RelayNode { * that ensures global uniqueness. */ tvm::Map global_type_var_map_; + + /*! \brief A map from constructor tags to constructor objects + * for convenient access + */ + std::unordered_map constructor_tag_map_; }; struct Module : public NodeRef { diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index ea25b970f87f..23d0a8481f6c 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -113,17 +113,18 @@ def __init__(self, value): _make.RefValue, value) -def _arg_to_ast(arg): +def _arg_to_ast(mod, arg): if isinstance(arg, TensorValue): return Constant(arg.data.copyto(nd.cpu(0))) elif isinstance(arg, TupleValue): - return Tuple([_arg_to_ast(field) for field in arg.fields]) + return Tuple([_arg_to_ast(mod, field) for field in arg.fields]) elif isinstance(arg, tuple): - return Tuple([_arg_to_ast(field) for field in arg]) + return Tuple([_arg_to_ast(mod, field) for field in arg]) elif isinstance(arg, RefValue): - return RefCreate(_arg_to_ast(arg.value)) + return RefCreate(_arg_to_ast(mod, arg.value)) elif isinstance(arg, ConstructorValue): - return Call(arg.constructor, [_arg_to_ast(field) for field in arg.fields]) + return Call(mod.get_constructor(arg.tag), + [_arg_to_ast(mod, field) for field in arg.fields]) elif isinstance(arg, np.ndarray): return Constant(nd.array(arg)) elif isinstance(arg, Constant): @@ -228,7 +229,7 @@ def evaluate(self, expr, binds=None): if binds: scope_builder = ScopeBuilder() for key, value in binds.items(): - scope_builder.let(key, _arg_to_ast(value)) + scope_builder.let(key, _arg_to_ast(self.mod, value)) scope_builder.ret(expr) expr = scope_builder.get() @@ -294,7 +295,7 @@ def _interp_wrapper(*args, **kwargs): relay_args = [] for arg in args: - relay_args.append(_arg_to_ast(arg)) + relay_args.append(_arg_to_ast(self.mod, arg)) if isinstance(expr, GlobalVar): func = self.mod[expr] diff --git a/python/tvm/relay/module.py b/python/tvm/relay/module.py index 138dfa882215..f8bebc9e358c 100644 --- a/python/tvm/relay/module.py +++ b/python/tvm/relay/module.py @@ -165,6 +165,25 @@ def get_global_type_var(self, name): """ return _module.Module_GetGlobalTypeVar(self, name) + def get_constructor(self, tag): + """Look up an ADT constructor by tag. + + Parameters + ---------- + tag: int + The tag for a constructor. + + Returns + ------- + constructor: Constructor + The constructor associated with the given tag, + + Raises + ------ + tvm.TVMError if the corresponding constructor cannot be found. + """ + return _module.Module_LookupTag(self, tag) + @staticmethod def from_expr(expr): return _module.Module_FromExpr(expr) diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 1cc81d5174a5..12a14a5e7bf4 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -103,7 +103,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "RefValueNode(" << node->value << ")"; }); -ConstructorValue ConstructorValueNode::make(int tag, +ConstructorValue ConstructorValueNode::make(int32_t tag, tvm::Array fields, Constructor constructor) { NodePtr n = make_node(); diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index 6b5fee82af89..44912f9752cd 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -52,6 +52,7 @@ Module ModuleNode::make(tvm::Map global_funcs, CHECK(!n->global_type_var_map_.count(kv.first->var->name_hint)) << "Duplicate global type definition name " << kv.first->var->name_hint; n->global_type_var_map_.Set(kv.first->var->name_hint, kv.first); + n->RegisterConstructors(kv.first, kv.second); } return Module(n); @@ -110,15 +111,25 @@ void ModuleNode::Add(const GlobalVar& var, AddUnchecked(var, checked_func); } +void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& type) { + // We hash the global type var name to use as a globally unique prefix for tags. + // The hash will be used as the most significant byte of the tag, with the index of + // the constructor in the less significant bytes + size_t hash = std::hash()(var->var->name_hint); + int32_t prefix = static_cast(hash & 0xff) << 24; + for (size_t i = 0; i < type->constructors.size(); ++i) { + type->constructors[i]->tag = prefix | static_cast(i); + constructor_tag_map_[type->constructors[i]->tag] = type->constructors[i]; + } +} + void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type) { this->type_definitions.Set(var, type); // set global type var map CHECK(!global_type_var_map_.count(var->var->name_hint)) << "Duplicate global type definition name " << var->var->name_hint; global_type_var_map_.Set(var->var->name_hint, var); - for (size_t i = 0; i < type->constructors.size(); ++i) { - type->constructors[i]->tag = i; - } + RegisterConstructors(var, type); // need to kind check at the end because the check can look up // a definition potentially @@ -161,6 +172,13 @@ TypeData ModuleNode::LookupDef(const std::string& name) { return this->LookupDef(id); } +Constructor ModuleNode::LookupTag(const int32_t tag) { + auto it = constructor_tag_map_.find(tag); + CHECK(it != constructor_tag_map_.end()) + << "There is no constructor with the tag " << tag; + return (*it).second; +} + void ModuleNode::Update(const Module& mod) { for (auto pair : mod->functions) { this->Update(pair.first, pair.second); @@ -219,6 +237,11 @@ TVM_REGISTER_API("relay._module.Module_LookupDef_str") return mod->LookupDef(var); }); +TVM_REGISTER_API("relay._module.Module_LookupTag") +.set_body_typed([](Module mod, int32_t tag) { + return mod->LookupTag(tag); + }); + TVM_REGISTER_API("relay._module.Module_FromExpr") .set_body_typed([](Expr e) { return ModuleNode::FromExpr(e); diff --git a/tests/python/relay/test_ir_module.py b/tests/python/relay/test_ir_module.py new file mode 100644 index 000000000000..72a92c8697fc --- /dev/null +++ b/tests/python/relay/test_ir_module.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tests for module functionality.""" +import tvm +from tvm import relay +from tvm.relay import Module +from tvm.relay.prelude import Prelude +from tvm.relay.testing import add_nat_definitions + +def constructor_list(p): + return [p.nil, p.cons, p.rose, p.some, p.none, p.z, p.s] + + +def adt_list(p): + return [p.nat, p.l, p.optional, p.tree] + + +def test_constructor_tag_round_trip(): + mod1 = Module() + p1 = Prelude(mod1) + add_nat_definitions(p1) + mod2 = Module() + p2 = Prelude(mod2) + add_nat_definitions(p2) + + # ensure hashes match across modules + ctors1 = constructor_list(p1) + ctors2 = constructor_list(p2) + + for i in range(len(ctors1)): + tag = ctors1[i].tag + ctor = mod2.get_constructor(tag) + assert ctor == ctors2[i] + assert ctor.name_hint == ctors1[i].name_hint + + +def test_constructor_tag_differences(): + # ensure that if we have the type data for a given ADT, the tags + # for the constructors of the *same ADT* are simple offsets from + # each other + mod = Module() + p = Prelude(mod) + add_nat_definitions(p) + + adts = adt_list(p) + for adt in adts: + data = mod[adt] + for i in range(len(data.constructors) - 1): + ctor1 = data.constructors[i] + ctor2 = data.constructors[i + 1] + assert ctor2.tag - ctor1.tag == 1 + # make sure there is something present at the MSB + assert ctor1.tag - i != 0 + assert ctor2.tag - (i + 1) != 0