Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tom/onnx2pyverify #165

Merged
merged 2 commits into from
Jan 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 52 additions & 21 deletions onnxconverter_common/onnx2py.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@
const_dir = None
const_counter = None

np_traced = TracingObject("np")
helper_traced = TracingObject("helper")
numpy_helper_traced = TracingObject("numpy_helper")
TensorProtoTraced = TracingObject("TensorProto")
os_traced = TracingObject("os")
np_traced = TracingObject("np", np)
helper_traced = TracingObject("helper", helper)
numpy_helper_traced = TracingObject("numpy_helper", numpy_helper)
TensorProtoTraced = TracingObject("TensorProto", TensorProto)
os_traced = TracingObject("os", os)


# <Helpers> These can be inlined into the output script #
Expand All @@ -41,6 +41,11 @@ def clear_field(proto, field):
return proto


def order_repeated_field(repeated_proto, key_name, order):
order = list(order)
repeated_proto.sort(key=lambda x: order.index(getattr(x, key_name)))


def make_external_tensor(name, data_type, dims, raw_data=None, **kwargs):
tensor = TensorProto()
tensor.data_type = data_type
Expand All @@ -49,13 +54,24 @@ def make_external_tensor(name, data_type, dims, raw_data=None, **kwargs):
if raw_data is not None:
tensor.raw_data = raw_data
external_data_helper.set_external_data(tensor, **kwargs)
order_repeated_field(tensor.external_data, 'key', kwargs.keys())
return tensor


def make_node(op_type, inputs, outputs, name=None, doc_string=None, domain=None, **kwargs):
node = helper.make_node(op_type, inputs, outputs, name, doc_string, domain, **kwargs)
if doc_string == '':
node.doc_string = ''
order_repeated_field(node.attribute, 'name', kwargs.keys())
return node

# </Helpers> #


clear_field_traced = TracingObject("clear_field")
make_external_tensor_traced = TracingObject("make_external_tensor")
clear_field_traced = TracingObject("clear_field", clear_field)
make_external_tensor_traced = TracingObject("make_external_tensor", make_external_tensor)
make_node_traced = TracingObject("make_node", make_node)
DATA_DIR_TRACED = None


def convert_tensor_type(i):
Expand All @@ -67,24 +83,27 @@ def convert_field(field):
if isinstance(field, (int, str, float, bytes)):
return field
elif isinstance(field, onnx.GraphProto):
return convert_graph(field)
converted = convert_graph(field)
elif isinstance(field, onnx.ModelProto):
return convert_model(field)
converted = convert_model(field)
elif isinstance(field, onnx.NodeProto):
return convert_node(field)
converted = convert_node(field)
elif isinstance(field, onnx.TensorProto):
return convert_tensor(field)
converted = convert_tensor(field)
elif isinstance(field, onnx.ValueInfoProto):
return convert_value_info(field)
converted = convert_value_info(field)
elif isinstance(field, onnx.OperatorSetIdProto):
return convert_operatorsetid(field)
converted = convert_operatorsetid(field)
elif isinstance(field, collections.abc.Iterable):
return list(convert_field(x) for x in field)
else:
# Missing handler needs to be added
t = str(type(field))
needed_types.add(t)
return field
# Verify that resulting protobuf is identical to original
# assert TracingObject.get_py_obj(converted) == field
return converted


def convert_value_info(val_info):
Expand All @@ -104,13 +123,17 @@ def convert_shape_denotation(d):
return d.denotation
return None

kwargs["shape"] = [convert_shape_dim(d) for d in val_info.type.tensor_type.shape.dim]
if val_info.type.tensor_type.HasField("shape"):
kwargs["shape"] = [convert_shape_dim(d) for d in val_info.type.tensor_type.shape.dim]
else:
kwargs["shape"] = None
if any(d.HasField("denotation") for d in val_info.type.tensor_type.shape.dim):
kwargs["shape_denotation"] = [convert_shape_denotation(d) for d in val_info.type.tensor_type.shape.dim]

if val_info.HasField("doc_string"):
kwargs["doc_string"].doc_string

helper.make_tensor_value_info
return helper_traced.make_tensor_value_info(name, elem_type, **kwargs)


Expand Down Expand Up @@ -148,11 +171,12 @@ def convert_tensor(tensor):
name = name.replace(c, '_')
const_path = "%s/%s.npy" % (const_dir, name)
np.save(const_path, np_data)
rel_path = TracingObject("os.path.join(DATA_DIR, '%s.npy')" % name)
data_path = os_traced.path.join(DATA_DIR_TRACED, name + '.npy')
const_counter += 1
np_dtype = getattr(np_traced, str(np_data.dtype))
np_shape = list(np_data.shape)
return numpy_helper_traced.from_array(np_traced.load(rel_path).astype(np_dtype).reshape(np_shape), name=tensor.name)
np_array = np_traced.load(data_path).astype(np_dtype).reshape(np_shape)
return numpy_helper_traced.from_array(np_array, name=tensor.name)


def convert_node(node):
Expand All @@ -165,7 +189,7 @@ def convert_node(node):
attrs["to"] = convert_tensor_type(attrs["to"])
inputs = fields.pop("input", [])
outputs = fields.pop("output", [])
return helper_traced.make_node(op_type, inputs=inputs, outputs=outputs, **fields, **attrs)
return make_node_traced(op_type, inputs=inputs, outputs=outputs, **fields, **attrs)


def convert_graph(graph):
Expand Down Expand Up @@ -200,7 +224,7 @@ class MissingHandlerException(Exception):


def convert(model, out_path):
global needed_types, const_dir, const_counter
global needed_types, const_dir, const_counter, DATA_DIR_TRACED
needed_types = set()
if out_path.endswith(".py"):
out_path = out_path[:-3]
Expand All @@ -211,8 +235,9 @@ def convert(model, out_path):
const_counter = 0
TracingObject.reset_cnt(clear_field_traced)
TracingObject.reset_cnt(make_external_tensor_traced)
DATA_DIR_TRACED = TracingObject("DATA_DIR", const_dir)

model_trace = convert_model(model)
model_trace = convert_field(model)
code = "from onnx import helper, numpy_helper, TensorProto"
if TracingObject.get_cnt(make_external_tensor_traced):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Suggest to add one more comment in the code:
"""
Running <out_path+".py"> file to recreate the original onnx model.
Example usage:
python <out_path+".py"> out_model_path.onnx
"""

code += ", external_data_helper"
Expand All @@ -225,12 +250,18 @@ def convert(model, out_path):
code += "\nDATA_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), %r)\n" % const_dir_name
if TracingObject.get_cnt(clear_field_traced):
code += "\n" + inspect.getsource(clear_field)
code += "\n" + inspect.getsource(order_repeated_field)
if TracingObject.get_cnt(make_external_tensor_traced):
code += "\n" + inspect.getsource(make_external_tensor)
code += "\n" + inspect.getsource(make_node)
code += "\n" + "model = " + repr(model_trace) + "\n"
code += "\nif __name__ == '__main__' and len(sys.argv) == 2:\n"
code += " _, out_path = sys.argv\n"
code += " onnx.save(model, out_path)\n"
if TracingObject.get_cnt(make_external_tensor_traced):
code += " with open(out_path, 'wb') as f:\n"
code += " f.write(model.SerializeToString())\n"
else:
code += " onnx.save(model, out_path)\n"
with open(out_path + ".py", "wt") as file:
file.write(code)
if needed_types:
Expand All @@ -249,7 +280,7 @@ def main():
print("ERROR:", e)

print("Model saved to", out_path)
print("Run '%s output.onnx' to generate ONNX file" % out_path)
print("Run 'python %s output.onnx' to generate ONNX file" % out_path)
print("Import the model with 'from %s import model'" % os.path.basename(out_path[:-3]))


Expand Down
38 changes: 32 additions & 6 deletions onnxconverter_common/pytracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@
# license information.
###########################################################################

from collections import OrderedDict
import numpy as np


def indent(s):
return "\n".join(" " + line for line in s.split("\n"))


class NoPyObjException(Exception):
def __init__(self):
super().__init__("Tracing object has no associated python object")


class TracingObject:
"""
Used by onnx2py to mock a module like numpy or onnx.helper and record calls on that module
Expand All @@ -18,8 +24,9 @@ class TracingObject:
x = np.array(np.product([1, 2, 3]), np.int32)
assert repr(x) == "np.array(np.product([1, 2, 3]), np.int32)"
"""
def __init__(self, trace):
def __init__(self, trace, py_obj=NoPyObjException):
self._trace = trace
self._py_obj = py_obj
self._cnt = 0

@staticmethod
Expand All @@ -32,7 +39,7 @@ def get_cnt(o):

@staticmethod
def from_repr(o):
return TracingObject(TracingObject.get_repr(o))
return TracingObject(TracingObject.get_repr(o), o)

@staticmethod
def get_repr(x):
Expand All @@ -46,18 +53,37 @@ def get_repr(x):
return code
return "[\n" + "".join(indent(s) + ",\n" for s in ls) + "]"

@staticmethod
def get_py_obj(o):
if isinstance(o, list):
return [TracingObject.get_py_obj(x) for x in o]
if isinstance(o, TracingObject):
if o._py_obj is NoPyObjException:
raise NoPyObjException()
return o._py_obj
return o

def __getattr__(self, attr):
self._cnt += 1
return TracingObject(self._trace + "." + attr)
trace = self._trace + "." + attr
if self._py_obj is NoPyObjException:
return TracingObject(trace)
return TracingObject(trace, getattr(self._py_obj, attr))

def __call__(self, *args, **kwargs):
self._cnt += 1
arg_s = [TracingObject.get_repr(o) for o in args]
arg_s += [k + "=" + TracingObject.get_repr(o) for k, o in kwargs.items()]
trace = self._trace + "(" + ", ".join(arg_s) + ")"
if len(trace) <= 200:
return TracingObject(trace)
return TracingObject(self._trace + "(\n" + "".join(indent(s) + ",\n" for s in arg_s) + ")")
if len(trace) > 200:
trace = self._trace + "(\n" + "".join(indent(s) + ",\n" for s in arg_s) + ")"
try:
arg_o = [TracingObject.get_py_obj(a) for a in args]
kwarg_o = OrderedDict((k, TracingObject.get_py_obj(v)) for k, v in kwargs.items())
py_obj = TracingObject.get_py_obj(self)(*arg_o, **kwarg_o)
except NoPyObjException:
py_obj = NoPyObjException
return TracingObject(trace, py_obj)

def __repr__(self):
return self._trace