Skip to content

Commit

Permalink
[Relay][VM] Relay VM serialization (#3647)
Browse files Browse the repository at this point in the history
* relay vm serialization

* fix lint

* load params, fix stream

* lint

* fix typo
  • Loading branch information
zhiics authored and jroesch committed Jul 31, 2019
1 parent 0365e50 commit 9045512
Show file tree
Hide file tree
Showing 14 changed files with 2,023 additions and 43 deletions.
2 changes: 1 addition & 1 deletion docs/dev/virtual_machine.rst
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ InvokeClosure
**Arguments**:
::
RegName closure
size_t closure_args_num
size_t num_closure_args
RegName* closure_args

Invokes `closure`, consuming the number of arguments declared in the closure's VMFunction.
Expand Down
27 changes: 23 additions & 4 deletions include/tvm/runtime/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ namespace tvm {
namespace runtime {
namespace vm {

/*! \brief Magic number for NDArray list file */
constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7;

/*! \brief A register name. */
using RegName = int64_t;

Expand Down Expand Up @@ -103,7 +106,7 @@ struct Instruction {
/*! \brief The register containing the closure. */
RegName closure;
/*! \brief The number of arguments to the closure. */
Index closure_args_num;
Index num_closure_args;
/*! \brief The closure arguments as an array. */
RegName* closure_args;
};
Expand All @@ -115,7 +118,7 @@ struct Instruction {
/*! \brief The source register for a move operation. */
RegName from;
};
struct /* Packed Operands */ {
struct /* InvokePacked Operands */ {
/*! \brief The index into the packed function table. */
Index packed_index;
/*! \brief The arity of the packed function. */
Expand Down Expand Up @@ -149,7 +152,7 @@ struct Instruction {
};
struct /* LoadConsti Operands */ {
/* \brief The index into the constant pool. */
size_t val;
Index val;
} load_consti;
struct /* Jump Operands */ {
/*! \brief The jump offset. */
Expand Down Expand Up @@ -284,7 +287,7 @@ struct Instruction {
* \param dst The destination register.
* \return The load_constanti instruction.
*/
static Instruction LoadConsti(size_t val, RegName dst);
static Instruction LoadConsti(Index val, RegName dst);
/*! \brief Construct a move instruction.
* \param src The source register.
* \param dst The destination register.
Expand Down Expand Up @@ -379,6 +382,8 @@ class VirtualMachine : public runtime::ModuleNode {
return "VirtualMachine";
}

/*! \brief The runtime module/library that contains generated code. */
runtime::Module lib;
/*! \brief The virtual machine's packed function table. */
std::vector<PackedFunc> packed_funcs;
/*! \brief The virtual machine's function table. */
Expand Down Expand Up @@ -448,16 +453,30 @@ class VirtualMachine : public runtime::ModuleNode {
void Init(const std::vector<TVMContext>& contexts);
void Run();

/*!
* \brief Load parameters from the parameter bytearray.
* \param params The binary file that contains parameters.
*/
void LoadParams(const std::string& params);

/*! \brief A map from globals (as strings) to their index in the function map.
*/
std::unordered_map<std::string, Index> global_map;

/*! \brief A mapping from the packed function (as string) to the index that
* corresponds to the position of the `packed_funcs` list.
*/
std::unordered_map<std::string, Index> primitive_map;

private:
/*! \brief Invoke a global setting up the VM state to execute.
*
* This does not begin execution of the VM.
*/
void InvokeGlobal(const VMFunction& func, const std::vector<Object>& args);

/*! \brief The parameter name to data mapping. */
std::unordered_map<std::string, Object> params_;
};

} // namespace vm
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from . import param_dict
from . import feature
from .backend import vm
from .backend import serializer
from .backend import deserializer
from .backend import vmobj

# Root operators
Expand Down
81 changes: 81 additions & 0 deletions python/tvm/relay/backend/deserializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# License .to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""
The Relay Virtual Machine deserializer.
Python interface for deserializing a Relay VM.
"""
from tvm import module
from tvm._ffi.runtime_ctypes import TVMByteArray
from . import _vm
from . import vm as rly_vm

def _create_deserializer(code, lib):
"""Create a deserializer object.
Parameters
----------
code : bytearray
The serialized virtual machine code.
lib : :py:class:`~tvm.module.Module`
The serialized runtime module/library that contains the hardware
dependent binary code.
Returns
-------
ret : Deserializer
The created virtual machine deserializer.
"""
if isinstance(code, (bytes, str)):
code = bytearray(code)
elif not isinstance(code, (bytearray, TVMByteArray)):
raise TypeError("vm is expected to be the type of bytearray or " +
"TVMByteArray, but received {}".format(type(code)))

if not isinstance(lib, module.Module):
raise TypeError("lib is expected to be the type of tvm.module.Module" +
", but received {}".format(type(lib)))
return _vm._Deserializer(code, lib)


class Deserializer:
"""Relay VM deserializer.
Parameters
----------
code : bytearray
The serialized virtual machine code.
lib : :py:class:`~tvm.module.Module`
The serialized runtime module/library that contains the hardware
dependent binary code.
"""
def __init__(self, code, lib):
self.mod = _create_deserializer(code, lib)
self._deserialize = self.mod["deserialize"]

def deserialize(self):
"""Deserialize the serialized bytecode into a Relay VM.
Returns
-------
ret : VirtualMachine
The deserialized Relay VM.
"""
return rly_vm.VirtualMachine(self._deserialize())
191 changes: 191 additions & 0 deletions python/tvm/relay/backend/serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# License .to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""
The Relay Virtual Machine serializer.
Python interface for serializing a Relay VM.
"""
import tvm
from . import _vm
from . import vm as rly_vm

def _create_serializer(vm):
"""Create a VM serializer.
Parameters
----------
vm : Union[VirtualMachine, :py:class:`~tvm.module.Module`]
The virtual machine to be serialized.
Returns
-------
ret : Serializer
The created virtual machine serializer.
"""
if isinstance(vm, rly_vm.VirtualMachine):
vm = vm.module
elif not isinstance(vm, tvm.module.Module):
raise TypeError("vm is expected to be the type of VirtualMachine or " +
"tvm.Module, but received {}".format(type(vm)))

return _vm._Serializer(vm)


class Serializer:
"""Relay VM serializer."""
def __init__(self, vm):
self.mod = _create_serializer(vm)
self._get_lib = self.mod["get_lib"]
self._get_bytecode = self.mod["get_bytecode"]
self._get_globals = self.mod["get_globals"]
self._get_stats = self.mod["get_stats"]
self._get_primitive_ops = self.mod["get_primitive_ops"]
self._serialize = self.mod["serialize"]

@property
def stats(self):
"""Get the statistics of the Relay VM.
Returns
-------
ret : String
The serialized statistic information.
"""
return self._get_stats()

@property
def primitive_ops(self):
"""Get the name of the primitive ops that are executed in the VM.
Returns
-------
ret : List[:py:class:`~tvm.expr.StringImm`]
The list of primitive ops.
"""
return [prim_op.value for prim_op in self._get_primitive_ops()]

@property
def bytecode(self):
"""Get the bytecode of the Relay VM.
Returns
-------
ret : String
The serialized bytecode.
Notes
-----
The bytecode is in the following format:
func_name reg_file_size num_instructions
param1 param2 ... paramM
instruction1
instruction2
...
instructionN
Each instruction is printed in the following format:
hash opcode field1 ... fieldX # The text format.
The part starting from # is only used for visualization and debugging.
The real serialized code doesn't contain it, therefore the deserializer
doesn't need to deal with it as well.
"""
return self._get_bytecode()

@property
def globals(self):
"""Get the globals used by the Relay VM.
Returns
-------
ret : List[:py:class:`~tvm.expr.StringImm`]
The serialized globals.
"""
return [glb.value for glb in self._get_globals()]

def serialize(self):
"""Serialize the Relay VM.
Returns
-------
code : bytearray
The binary blob representing a serialized Relay VM. It can then be
saved to disk and later deserialized into a new VM.
lib : :py:class:`~tvm.module.Module`
The runtime module that contains the generated code. It is
basically a library that is composed of hardware dependent code.
Notes
-----
The returned code is organized with the following sections in order.
- Global section. This section contains the globals used by the
virtual machine.
- Constant section. This section is used to store the constant pool of
a virtual machine.
- Primitive name section. This section is introduced to accommodate
the list of primitive operator names that will be invoked by the
virtual machine.
- Code section. The VM functions, including bytecode, are sitting in
this section.
Examples
--------
.. code-block:: python
import numpy as np
import tvm
from tvm import relay
# define a simple network.
x = relay.var('x', shape=(10, 10))
f = relay.Function([x], x + x)
mod = relay.Module({"main": f})
# create a Relay VM.
ctx = tvm.cpu()
target = "llvm"
compiler = relay.vm.VMCompiler()
vm = compiler.compile(mod, target)
vm.init(ctx)
# serialize.
ser = relay.serializer.Serializer(vm)
code, lib = ser.serialize()
# save and load the code and lib file.
tmp = tvm.contrib.util.tempdir()
path_lib = tmp.relpath("lib.so")
lib.export_library(path_lib)
with open(tmp.relpath("code.bc"), "wb") as fo:
fo.write(code)
loaded_lib = tvm.module.load(path_lib)
loaded_code = bytearray(open(tmp.relpath("code.bc"), "rb").read())
# deserialize.
deser = relay.deserializer.Deserializer(loaded_code, loaded_lib)
des_vm = deser.deserialize()
# execute the deserialized vm.
des_vm.init(ctx)
x_data = np.random.rand(10, 10).astype('float32')
res = des_vm.run(x_data)
print(res.asnumpy())
"""
return self._serialize(), self._get_lib()
Loading

0 comments on commit 9045512

Please sign in to comment.