From 8314ee1d86ec7edd4eedc7a5e538ff651e12132b Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 13 Jun 2019 16:44:03 -0700 Subject: [PATCH 1/8] Use hash of ADT name and constructor idx to generate tag, add reverse mapping to module and use where appropriate --- include/tvm/relay/adt.h | 2 +- include/tvm/relay/interpreter.h | 4 +-- include/tvm/relay/module.h | 23 ++++++++++--- python/tvm/relay/backend/interpreter.py | 44 ++++++++++++------------- python/tvm/relay/module.py | 19 +++++++++++ src/relay/backend/interpreter.cc | 2 +- src/relay/ir/module.cc | 25 ++++++++++++-- 7 files changed, 86 insertions(+), 33 deletions(-) diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index 9e4e00ca47ed..61cd64a4fab7 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -114,7 +114,7 @@ class ConstructorNode : public ExprNode { /*! \brief The datatype the constructor will construct. */ GlobalTypeVar belong_to; /*! \brief Index in the table of constructors (set when the type is registered). */ - mutable int tag = -1; + mutable int64_t tag = -1; ConstructorNode() {} diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index 68b7ccab99c7..0cb38ab5b1bf 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -182,7 +182,7 @@ RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value); class ConstructorValue; struct ConstructorValueNode : ValueNode { - int tag; + int64_t tag; tvm::Array fields; @@ -195,7 +195,7 @@ struct ConstructorValueNode : ValueNode { v->Visit("constructor", &constructor); } - TVM_DLL static ConstructorValue make(int tag, + TVM_DLL static ConstructorValue make(int64_t tag, tvm::Array fields, Constructor construtor = {}); diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index 3966a6258a20..bfc32cf557bc 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -133,33 +133,40 @@ class ModuleNode : public RelayNode { TVM_DLL GlobalTypeVar GetGlobalTypeVar(const std::string& str); /*! - * \brief Lookup a global function by its variable. + * \brief Look up a global function by its variable. * \param var The global var to lookup. * \returns The function named by the variable argument. */ TVM_DLL Function Lookup(const GlobalVar& var); /*! - * \brief Lookup a global function by its string name + * \brief Look up a global function by its string name * \param name The name of the function. * \returns The function named by the argument. */ TVM_DLL Function Lookup(const std::string& name); /*! - * \brief Lookup a global type definition by its variable. + * \brief Look up a global type definition by its variable. * \param var The var of the global type definition. * \return The type definition. */ TVM_DLL TypeData LookupDef(const GlobalTypeVar& var); /*! - * \brief Lookup a global type definition by its name. + * \brief Look up a global type definition by its name. * \param var The name of the global type definition. * \return The type definition. */ TVM_DLL TypeData LookupDef(const std::string& var); + /*! + * \brief Look up a constructor by its tag. + * \param tag The tag for the constructor. + * \return The constructor object. + */ + TVM_DLL Constructor LookupTag(const size_t tag); + /*! * \brief Update the functions inside this environment by * functions in another environment. @@ -185,6 +192,9 @@ class ModuleNode : public RelayNode { TVM_DECLARE_NODE_TYPE_INFO(ModuleNode, Node); private: + /*! \brief Helper function for registering a typedef's constructors */ + void RegisterConstructors(const GlobalTypeVar& var, const TypeData& type); + /*! \brief A map from string names to global variables that * ensures global uniqueness. */ @@ -194,6 +204,11 @@ class ModuleNode : public RelayNode { * that ensures global uniqueness. */ tvm::Map global_type_var_map_; + + /*! \brief A map from constructor tags to constructor objects + * for convenient access + */ + std::unordered_map constructor_tag_map_; }; struct Module : public NodeRef { diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index ea25b970f87f..d65392e7efb2 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -112,26 +112,6 @@ def __init__(self, value): self.__init_handle_by_constructor__( _make.RefValue, value) - -def _arg_to_ast(arg): - if isinstance(arg, TensorValue): - return Constant(arg.data.copyto(nd.cpu(0))) - elif isinstance(arg, TupleValue): - return Tuple([_arg_to_ast(field) for field in arg.fields]) - elif isinstance(arg, tuple): - return Tuple([_arg_to_ast(field) for field in arg]) - elif isinstance(arg, RefValue): - return RefCreate(_arg_to_ast(arg.value)) - elif isinstance(arg, ConstructorValue): - return Call(arg.constructor, [_arg_to_ast(field) for field in arg.fields]) - elif isinstance(arg, np.ndarray): - return Constant(nd.array(arg)) - elif isinstance(arg, Constant): - return arg - else: - return const(arg) - - class Executor(object): """An abstract interface for executing Relay programs.""" @@ -228,7 +208,7 @@ def evaluate(self, expr, binds=None): if binds: scope_builder = ScopeBuilder() for key, value in binds.items(): - scope_builder.let(key, _arg_to_ast(value)) + scope_builder.let(key, self._arg_to_ast(value)) scope_builder.ret(expr) expr = scope_builder.get() @@ -264,6 +244,26 @@ def __init__(self, mod, ctx, target): self.target = target self._intrp = _backend.CreateInterpreter(mod, ctx, target) + def _arg_to_ast(self, arg): + if isinstance(arg, TensorValue): + return Constant(arg.data.copyto(nd.cpu(0))) + elif isinstance(arg, TupleValue): + return Tuple([self._arg_to_ast(field) for field in arg.fields]) + elif isinstance(arg, tuple): + return Tuple([self._arg_to_ast(field) for field in arg]) + elif isinstance(arg, RefValue): + return RefCreate(self._arg_to_ast(arg.value)) + elif isinstance(arg, ConstructorValue): + return Call(self.mod.get_constructor(arg.tag), + [self._arg_to_ast(field) for field in arg.fields]) + elif isinstance(arg, np.ndarray): + return Constant(nd.array(arg)) + elif isinstance(arg, Constant): + return arg + else: + return const(arg) + + def optimize(self, expr): """Optimize an expr. @@ -294,7 +294,7 @@ def _interp_wrapper(*args, **kwargs): relay_args = [] for arg in args: - relay_args.append(_arg_to_ast(arg)) + relay_args.append(self._arg_to_ast(arg)) if isinstance(expr, GlobalVar): func = self.mod[expr] diff --git a/python/tvm/relay/module.py b/python/tvm/relay/module.py index 138dfa882215..f8bebc9e358c 100644 --- a/python/tvm/relay/module.py +++ b/python/tvm/relay/module.py @@ -165,6 +165,25 @@ def get_global_type_var(self, name): """ return _module.Module_GetGlobalTypeVar(self, name) + def get_constructor(self, tag): + """Look up an ADT constructor by tag. + + Parameters + ---------- + tag: int + The tag for a constructor. + + Returns + ------- + constructor: Constructor + The constructor associated with the given tag, + + Raises + ------ + tvm.TVMError if the corresponding constructor cannot be found. + """ + return _module.Module_LookupTag(self, tag) + @staticmethod def from_expr(expr): return _module.Module_FromExpr(expr) diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 1cc81d5174a5..810d692b7f30 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -103,7 +103,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "RefValueNode(" << node->value << ")"; }); -ConstructorValue ConstructorValueNode::make(int tag, +ConstructorValue ConstructorValueNode::make(int64_t tag, tvm::Array fields, Constructor constructor) { NodePtr n = make_node(); diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index 6b5fee82af89..82943dff6468 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -52,6 +52,7 @@ Module ModuleNode::make(tvm::Map global_funcs, CHECK(!n->global_type_var_map_.count(kv.first->var->name_hint)) << "Duplicate global type definition name " << kv.first->var->name_hint; n->global_type_var_map_.Set(kv.first->var->name_hint, kv.first); + n->RegisterConstructors(kv.first, kv.second); } return Module(n); @@ -110,15 +111,21 @@ void ModuleNode::Add(const GlobalVar& var, AddUnchecked(var, checked_func); } +void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& type) { + for (size_t i = 0; i < type->constructors.size(); ++i) { + type->constructors[i]->tag = static_cast(dmlc::HashCombine(std::hash()(var->var->name_hint), + std::hash()(i))); + constructor_tag_map_[type->constructors[i]->tag] = type->constructors[i]; + } +} + void ModuleNode::AddDef(const GlobalTypeVar& var, const TypeData& type) { this->type_definitions.Set(var, type); // set global type var map CHECK(!global_type_var_map_.count(var->var->name_hint)) << "Duplicate global type definition name " << var->var->name_hint; global_type_var_map_.Set(var->var->name_hint, var); - for (size_t i = 0; i < type->constructors.size(); ++i) { - type->constructors[i]->tag = i; - } + RegisterConstructors(var, type); // need to kind check at the end because the check can look up // a definition potentially @@ -161,6 +168,13 @@ TypeData ModuleNode::LookupDef(const std::string& name) { return this->LookupDef(id); } +Constructor ModuleNode::LookupTag(const size_t tag) { + auto it = constructor_tag_map_.find(tag); + CHECK(it != constructor_tag_map_.end()) + << "There is no constructor with the tag " << tag; + return (*it).second; +} + void ModuleNode::Update(const Module& mod) { for (auto pair : mod->functions) { this->Update(pair.first, pair.second); @@ -219,6 +233,11 @@ TVM_REGISTER_API("relay._module.Module_LookupDef_str") return mod->LookupDef(var); }); +TVM_REGISTER_API("relay._module.Module_LookupTag") +.set_body_typed([](Module mod, size_t tag) { + return mod->LookupTag(tag); + }); + TVM_REGISTER_API("relay._module.Module_FromExpr") .set_body_typed([](Expr e) { return ModuleNode::FromExpr(e); From 5ce6f635f10e502c0b63487b992ec51969f90b04 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 13 Jun 2019 16:52:23 -0700 Subject: [PATCH 2/8] Lint and build fixes --- include/tvm/relay/module.h | 1 + src/relay/ir/module.cc | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index bfc32cf557bc..b6a3349f5351 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -32,6 +32,7 @@ #include #include #include +#include namespace tvm { namespace relay { diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index 82943dff6468..b92115c8a003 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -113,8 +113,9 @@ void ModuleNode::Add(const GlobalVar& var, void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& type) { for (size_t i = 0; i < type->constructors.size(); ++i) { - type->constructors[i]->tag = static_cast(dmlc::HashCombine(std::hash()(var->var->name_hint), - std::hash()(i))); + size_t hash = dmlc::HashCombine(std::hash()(var->var->name_hint), + std::hash()(i)); + type->constructors[i]->tag = static_cast(hash); constructor_tag_map_[type->constructors[i]->tag] = type->constructors[i]; } } From ac1ed8c6f11d5f1b87a1876fc38b1968f0d6c6ac Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 13 Jun 2019 16:59:13 -0700 Subject: [PATCH 3/8] Add round-tripping test for getting constructors by tag --- tests/python/relay/test_ir_module.py | 43 ++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 tests/python/relay/test_ir_module.py diff --git a/tests/python/relay/test_ir_module.py b/tests/python/relay/test_ir_module.py new file mode 100644 index 000000000000..0d9c5cac59cd --- /dev/null +++ b/tests/python/relay/test_ir_module.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tests for module functionality.""" +import tvm +from tvm import relay +from tvm.relay import Module +from tvm.relay.prelude import Prelude +from tvm.relay.testing import add_nat_definitions + +def constructor_list(p): + return [p.nil, p.cons, p.rose, p.some, p.none, p.z, p.s] + + +def test_constructor_tag_round_trip(): + mod1 = Module() + p1 = Prelude(mod1) + add_nat_definitions(p1) + mod2 = Module() + p2 = Prelude(mod2) + add_nat_definitions(p2) + + # ensure hashes match across modules + ctors1 = constructor_list(p1) + ctors2 = constructor_list(p2) + + for i in range(len(ctors1)): + tag = ctors1[i].tag + ctor = mod2.get_constructor(tag) + assert ctor.name_hint == ctors1[i].name_hint From 484083466b3f5ce06a88aaa43343f659af871090 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 13 Jun 2019 17:33:33 -0700 Subject: [PATCH 4/8] Use int64_t everywhere for tags --- include/tvm/relay/module.h | 2 +- src/relay/ir/module.cc | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index b6a3349f5351..e23e3b99de02 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -166,7 +166,7 @@ class ModuleNode : public RelayNode { * \param tag The tag for the constructor. * \return The constructor object. */ - TVM_DLL Constructor LookupTag(const size_t tag); + TVM_DLL Constructor LookupTag(const int64_t tag); /*! * \brief Update the functions inside this environment by diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index b92115c8a003..7c32a4b0fbbf 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -169,7 +169,7 @@ TypeData ModuleNode::LookupDef(const std::string& name) { return this->LookupDef(id); } -Constructor ModuleNode::LookupTag(const size_t tag) { +Constructor ModuleNode::LookupTag(const int64_t tag) { auto it = constructor_tag_map_.find(tag); CHECK(it != constructor_tag_map_.end()) << "There is no constructor with the tag " << tag; @@ -235,7 +235,7 @@ TVM_REGISTER_API("relay._module.Module_LookupDef_str") }); TVM_REGISTER_API("relay._module.Module_LookupTag") -.set_body_typed([](Module mod, size_t tag) { +.set_body_typed([](Module mod, int64_t tag) { return mod->LookupTag(tag); }); From 9e35c671a0ecd90a7de153cb7f5ec9bde1677f1e Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Thu, 13 Jun 2019 17:57:33 -0700 Subject: [PATCH 5/8] Add additional identity check --- tests/python/relay/test_ir_module.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/relay/test_ir_module.py b/tests/python/relay/test_ir_module.py index 0d9c5cac59cd..60a46224cef0 100644 --- a/tests/python/relay/test_ir_module.py +++ b/tests/python/relay/test_ir_module.py @@ -40,4 +40,5 @@ def test_constructor_tag_round_trip(): for i in range(len(ctors1)): tag = ctors1[i].tag ctor = mod2.get_constructor(tag) + assert ctor == ctors2[i] assert ctor.name_hint == ctors1[i].name_hint From 8737357b37c2ba53b169c333eb702f9aa0871dcc Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Fri, 14 Jun 2019 13:03:31 -0700 Subject: [PATCH 6/8] Bring out _arg_to_ast again --- python/tvm/relay/backend/interpreter.py | 45 +++++++++++++------------ 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index d65392e7efb2..23d0a8481f6c 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -112,6 +112,27 @@ def __init__(self, value): self.__init_handle_by_constructor__( _make.RefValue, value) + +def _arg_to_ast(mod, arg): + if isinstance(arg, TensorValue): + return Constant(arg.data.copyto(nd.cpu(0))) + elif isinstance(arg, TupleValue): + return Tuple([_arg_to_ast(mod, field) for field in arg.fields]) + elif isinstance(arg, tuple): + return Tuple([_arg_to_ast(mod, field) for field in arg]) + elif isinstance(arg, RefValue): + return RefCreate(_arg_to_ast(mod, arg.value)) + elif isinstance(arg, ConstructorValue): + return Call(mod.get_constructor(arg.tag), + [_arg_to_ast(mod, field) for field in arg.fields]) + elif isinstance(arg, np.ndarray): + return Constant(nd.array(arg)) + elif isinstance(arg, Constant): + return arg + else: + return const(arg) + + class Executor(object): """An abstract interface for executing Relay programs.""" @@ -208,7 +229,7 @@ def evaluate(self, expr, binds=None): if binds: scope_builder = ScopeBuilder() for key, value in binds.items(): - scope_builder.let(key, self._arg_to_ast(value)) + scope_builder.let(key, _arg_to_ast(self.mod, value)) scope_builder.ret(expr) expr = scope_builder.get() @@ -244,26 +265,6 @@ def __init__(self, mod, ctx, target): self.target = target self._intrp = _backend.CreateInterpreter(mod, ctx, target) - def _arg_to_ast(self, arg): - if isinstance(arg, TensorValue): - return Constant(arg.data.copyto(nd.cpu(0))) - elif isinstance(arg, TupleValue): - return Tuple([self._arg_to_ast(field) for field in arg.fields]) - elif isinstance(arg, tuple): - return Tuple([self._arg_to_ast(field) for field in arg]) - elif isinstance(arg, RefValue): - return RefCreate(self._arg_to_ast(arg.value)) - elif isinstance(arg, ConstructorValue): - return Call(self.mod.get_constructor(arg.tag), - [self._arg_to_ast(field) for field in arg.fields]) - elif isinstance(arg, np.ndarray): - return Constant(nd.array(arg)) - elif isinstance(arg, Constant): - return arg - else: - return const(arg) - - def optimize(self, expr): """Optimize an expr. @@ -294,7 +295,7 @@ def _interp_wrapper(*args, **kwargs): relay_args = [] for arg in args: - relay_args.append(self._arg_to_ast(arg)) + relay_args.append(_arg_to_ast(self.mod, arg)) if isinstance(expr, GlobalVar): func = self.mod[expr] From 840ec935faa05d43bda995126c09b48f3aacf851 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Mon, 1 Jul 2019 13:15:29 -0700 Subject: [PATCH 7/8] Use 8-bit hash of GTV name as MSB of tag, index as LSB for more readable tags --- src/relay/ir/module.cc | 9 ++++++--- tests/python/relay/test_ir_module.py | 24 ++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index 7c32a4b0fbbf..78d0cce0ab46 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -112,10 +112,13 @@ void ModuleNode::Add(const GlobalVar& var, } void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& type) { + // We hash the global type var name to use as a globally unique prefix for tags. + // The hash will be used as the most significant byte of the tag, with the index of + // the constructor in the less significant bytes + size_t hash = std::hash()(var->var->name_hint); + int64_t prefix = static_cast(hash & 0xff) << 56; for (size_t i = 0; i < type->constructors.size(); ++i) { - size_t hash = dmlc::HashCombine(std::hash()(var->var->name_hint), - std::hash()(i)); - type->constructors[i]->tag = static_cast(hash); + type->constructors[i]->tag = prefix | static_cast(i); constructor_tag_map_[type->constructors[i]->tag] = type->constructors[i]; } } diff --git a/tests/python/relay/test_ir_module.py b/tests/python/relay/test_ir_module.py index 60a46224cef0..72a92c8697fc 100644 --- a/tests/python/relay/test_ir_module.py +++ b/tests/python/relay/test_ir_module.py @@ -25,6 +25,10 @@ def constructor_list(p): return [p.nil, p.cons, p.rose, p.some, p.none, p.z, p.s] +def adt_list(p): + return [p.nat, p.l, p.optional, p.tree] + + def test_constructor_tag_round_trip(): mod1 = Module() p1 = Prelude(mod1) @@ -42,3 +46,23 @@ def test_constructor_tag_round_trip(): ctor = mod2.get_constructor(tag) assert ctor == ctors2[i] assert ctor.name_hint == ctors1[i].name_hint + + +def test_constructor_tag_differences(): + # ensure that if we have the type data for a given ADT, the tags + # for the constructors of the *same ADT* are simple offsets from + # each other + mod = Module() + p = Prelude(mod) + add_nat_definitions(p) + + adts = adt_list(p) + for adt in adts: + data = mod[adt] + for i in range(len(data.constructors) - 1): + ctor1 = data.constructors[i] + ctor2 = data.constructors[i + 1] + assert ctor2.tag - ctor1.tag == 1 + # make sure there is something present at the MSB + assert ctor1.tag - i != 0 + assert ctor2.tag - (i + 1) != 0 From 47ace40d51a9dcc9492c206f78c14dc7fbe26b20 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Mon, 1 Jul 2019 15:35:18 -0700 Subject: [PATCH 8/8] Use int32 instead of int64 for tag --- include/tvm/relay/adt.h | 2 +- include/tvm/relay/interpreter.h | 4 ++-- include/tvm/relay/module.h | 4 ++-- src/relay/backend/interpreter.cc | 2 +- src/relay/ir/module.cc | 8 ++++---- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/include/tvm/relay/adt.h b/include/tvm/relay/adt.h index 61cd64a4fab7..2a6507b62a33 100644 --- a/include/tvm/relay/adt.h +++ b/include/tvm/relay/adt.h @@ -114,7 +114,7 @@ class ConstructorNode : public ExprNode { /*! \brief The datatype the constructor will construct. */ GlobalTypeVar belong_to; /*! \brief Index in the table of constructors (set when the type is registered). */ - mutable int64_t tag = -1; + mutable int32_t tag = -1; ConstructorNode() {} diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index 0cb38ab5b1bf..d05099f781ac 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -182,7 +182,7 @@ RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value); class ConstructorValue; struct ConstructorValueNode : ValueNode { - int64_t tag; + int32_t tag; tvm::Array fields; @@ -195,7 +195,7 @@ struct ConstructorValueNode : ValueNode { v->Visit("constructor", &constructor); } - TVM_DLL static ConstructorValue make(int64_t tag, + TVM_DLL static ConstructorValue make(int32_t tag, tvm::Array fields, Constructor construtor = {}); diff --git a/include/tvm/relay/module.h b/include/tvm/relay/module.h index e23e3b99de02..e5b20987c293 100644 --- a/include/tvm/relay/module.h +++ b/include/tvm/relay/module.h @@ -166,7 +166,7 @@ class ModuleNode : public RelayNode { * \param tag The tag for the constructor. * \return The constructor object. */ - TVM_DLL Constructor LookupTag(const int64_t tag); + TVM_DLL Constructor LookupTag(const int32_t tag); /*! * \brief Update the functions inside this environment by @@ -209,7 +209,7 @@ class ModuleNode : public RelayNode { /*! \brief A map from constructor tags to constructor objects * for convenient access */ - std::unordered_map constructor_tag_map_; + std::unordered_map constructor_tag_map_; }; struct Module : public NodeRef { diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 810d692b7f30..12a14a5e7bf4 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -103,7 +103,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "RefValueNode(" << node->value << ")"; }); -ConstructorValue ConstructorValueNode::make(int64_t tag, +ConstructorValue ConstructorValueNode::make(int32_t tag, tvm::Array fields, Constructor constructor) { NodePtr n = make_node(); diff --git a/src/relay/ir/module.cc b/src/relay/ir/module.cc index 78d0cce0ab46..44912f9752cd 100644 --- a/src/relay/ir/module.cc +++ b/src/relay/ir/module.cc @@ -116,9 +116,9 @@ void ModuleNode::RegisterConstructors(const GlobalTypeVar& var, const TypeData& // The hash will be used as the most significant byte of the tag, with the index of // the constructor in the less significant bytes size_t hash = std::hash()(var->var->name_hint); - int64_t prefix = static_cast(hash & 0xff) << 56; + int32_t prefix = static_cast(hash & 0xff) << 24; for (size_t i = 0; i < type->constructors.size(); ++i) { - type->constructors[i]->tag = prefix | static_cast(i); + type->constructors[i]->tag = prefix | static_cast(i); constructor_tag_map_[type->constructors[i]->tag] = type->constructors[i]; } } @@ -172,7 +172,7 @@ TypeData ModuleNode::LookupDef(const std::string& name) { return this->LookupDef(id); } -Constructor ModuleNode::LookupTag(const int64_t tag) { +Constructor ModuleNode::LookupTag(const int32_t tag) { auto it = constructor_tag_map_.find(tag); CHECK(it != constructor_tag_map_.end()) << "There is no constructor with the tag " << tag; @@ -238,7 +238,7 @@ TVM_REGISTER_API("relay._module.Module_LookupDef_str") }); TVM_REGISTER_API("relay._module.Module_LookupTag") -.set_body_typed([](Module mod, int64_t tag) { +.set_body_typed([](Module mod, int32_t tag) { return mod->LookupTag(tag); });