From 94d9f9396aef992c43ee5bb8b4484bcf6430937a Mon Sep 17 00:00:00 2001 From: Tom Wildenhain Date: Tue, 29 Dec 2020 13:39:13 -0500 Subject: [PATCH 1/2] Fixed support for external data format --- onnxconverter_common/onnx2py.py | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxconverter_common/onnx2py.py b/onnxconverter_common/onnx2py.py index b4630ca..8e987b5 100644 --- a/onnxconverter_common/onnx2py.py +++ b/onnxconverter_common/onnx2py.py @@ -40,7 +40,6 @@ def clear_field(proto, field): proto.ClearField(field) return proto - def make_external_tensor(name, data_type, dims, raw_data=None, **kwargs): tensor = TensorProto() tensor.data_type = data_type From a3f2cddb11a8afa9dc699fdc681c2d69b678ff19 Mon Sep 17 00:00:00 2001 From: Tom Wildenhain Date: Tue, 29 Dec 2020 15:56:03 -0500 Subject: [PATCH 2/2] Fixed bugs in large model support, added verification code for debugging --- onnxconverter_common/onnx2py.py | 74 ++++++++++++++++++++++--------- onnxconverter_common/pytracing.py | 38 +++++++++++++--- 2 files changed, 85 insertions(+), 27 deletions(-) diff --git a/onnxconverter_common/onnx2py.py b/onnxconverter_common/onnx2py.py index 8e987b5..4843850 100644 --- a/onnxconverter_common/onnx2py.py +++ b/onnxconverter_common/onnx2py.py @@ -27,11 +27,11 @@ const_dir = None const_counter = None -np_traced = TracingObject("np") -helper_traced = TracingObject("helper") -numpy_helper_traced = TracingObject("numpy_helper") -TensorProtoTraced = TracingObject("TensorProto") -os_traced = TracingObject("os") +np_traced = TracingObject("np", np) +helper_traced = TracingObject("helper", helper) +numpy_helper_traced = TracingObject("numpy_helper", numpy_helper) +TensorProtoTraced = TracingObject("TensorProto", TensorProto) +os_traced = TracingObject("os", os) # These can be inlined into the output script # @@ -40,6 +40,12 @@ def clear_field(proto, field): proto.ClearField(field) return proto + +def order_repeated_field(repeated_proto, key_name, order): + order = list(order) + repeated_proto.sort(key=lambda x: order.index(getattr(x, key_name))) + + def make_external_tensor(name, data_type, dims, raw_data=None, **kwargs): tensor = TensorProto() tensor.data_type = data_type @@ -48,13 +54,24 @@ def make_external_tensor(name, data_type, dims, raw_data=None, **kwargs): if raw_data is not None: tensor.raw_data = raw_data external_data_helper.set_external_data(tensor, **kwargs) + order_repeated_field(tensor.external_data, 'key', kwargs.keys()) return tensor + +def make_node(op_type, inputs, outputs, name=None, doc_string=None, domain=None, **kwargs): + node = helper.make_node(op_type, inputs, outputs, name, doc_string, domain, **kwargs) + if doc_string == '': + node.doc_string = '' + order_repeated_field(node.attribute, 'name', kwargs.keys()) + return node + # # -clear_field_traced = TracingObject("clear_field") -make_external_tensor_traced = TracingObject("make_external_tensor") +clear_field_traced = TracingObject("clear_field", clear_field) +make_external_tensor_traced = TracingObject("make_external_tensor", make_external_tensor) +make_node_traced = TracingObject("make_node", make_node) +DATA_DIR_TRACED = None def convert_tensor_type(i): @@ -66,17 +83,17 @@ def convert_field(field): if isinstance(field, (int, str, float, bytes)): return field elif isinstance(field, onnx.GraphProto): - return convert_graph(field) + converted = convert_graph(field) elif isinstance(field, onnx.ModelProto): - return convert_model(field) + converted = convert_model(field) elif isinstance(field, onnx.NodeProto): - return convert_node(field) + converted = convert_node(field) elif isinstance(field, onnx.TensorProto): - return convert_tensor(field) + converted = convert_tensor(field) elif isinstance(field, onnx.ValueInfoProto): - return convert_value_info(field) + converted = convert_value_info(field) elif isinstance(field, onnx.OperatorSetIdProto): - return convert_operatorsetid(field) + converted = convert_operatorsetid(field) elif isinstance(field, collections.abc.Iterable): return list(convert_field(x) for x in field) else: @@ -84,6 +101,9 @@ def convert_field(field): t = str(type(field)) needed_types.add(t) return field + # Verify that resulting protobuf is identical to original + # assert TracingObject.get_py_obj(converted) == field + return converted def convert_value_info(val_info): @@ -103,13 +123,17 @@ def convert_shape_denotation(d): return d.denotation return None - kwargs["shape"] = [convert_shape_dim(d) for d in val_info.type.tensor_type.shape.dim] + if val_info.type.tensor_type.HasField("shape"): + kwargs["shape"] = [convert_shape_dim(d) for d in val_info.type.tensor_type.shape.dim] + else: + kwargs["shape"] = None if any(d.HasField("denotation") for d in val_info.type.tensor_type.shape.dim): kwargs["shape_denotation"] = [convert_shape_denotation(d) for d in val_info.type.tensor_type.shape.dim] if val_info.HasField("doc_string"): kwargs["doc_string"].doc_string + helper.make_tensor_value_info return helper_traced.make_tensor_value_info(name, elem_type, **kwargs) @@ -147,11 +171,12 @@ def convert_tensor(tensor): name = name.replace(c, '_') const_path = "%s/%s.npy" % (const_dir, name) np.save(const_path, np_data) - rel_path = TracingObject("os.path.join(DATA_DIR, '%s.npy')" % name) + data_path = os_traced.path.join(DATA_DIR_TRACED, name + '.npy') const_counter += 1 np_dtype = getattr(np_traced, str(np_data.dtype)) np_shape = list(np_data.shape) - return numpy_helper_traced.from_array(np_traced.load(rel_path).astype(np_dtype).reshape(np_shape), name=tensor.name) + np_array = np_traced.load(data_path).astype(np_dtype).reshape(np_shape) + return numpy_helper_traced.from_array(np_array, name=tensor.name) def convert_node(node): @@ -164,7 +189,7 @@ def convert_node(node): attrs["to"] = convert_tensor_type(attrs["to"]) inputs = fields.pop("input", []) outputs = fields.pop("output", []) - return helper_traced.make_node(op_type, inputs=inputs, outputs=outputs, **fields, **attrs) + return make_node_traced(op_type, inputs=inputs, outputs=outputs, **fields, **attrs) def convert_graph(graph): @@ -199,7 +224,7 @@ class MissingHandlerException(Exception): def convert(model, out_path): - global needed_types, const_dir, const_counter + global needed_types, const_dir, const_counter, DATA_DIR_TRACED needed_types = set() if out_path.endswith(".py"): out_path = out_path[:-3] @@ -210,8 +235,9 @@ def convert(model, out_path): const_counter = 0 TracingObject.reset_cnt(clear_field_traced) TracingObject.reset_cnt(make_external_tensor_traced) + DATA_DIR_TRACED = TracingObject("DATA_DIR", const_dir) - model_trace = convert_model(model) + model_trace = convert_field(model) code = "from onnx import helper, numpy_helper, TensorProto" if TracingObject.get_cnt(make_external_tensor_traced): code += ", external_data_helper" @@ -224,12 +250,18 @@ def convert(model, out_path): code += "\nDATA_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), %r)\n" % const_dir_name if TracingObject.get_cnt(clear_field_traced): code += "\n" + inspect.getsource(clear_field) + code += "\n" + inspect.getsource(order_repeated_field) if TracingObject.get_cnt(make_external_tensor_traced): code += "\n" + inspect.getsource(make_external_tensor) + code += "\n" + inspect.getsource(make_node) code += "\n" + "model = " + repr(model_trace) + "\n" code += "\nif __name__ == '__main__' and len(sys.argv) == 2:\n" code += " _, out_path = sys.argv\n" - code += " onnx.save(model, out_path)\n" + if TracingObject.get_cnt(make_external_tensor_traced): + code += " with open(out_path, 'wb') as f:\n" + code += " f.write(model.SerializeToString())\n" + else: + code += " onnx.save(model, out_path)\n" with open(out_path + ".py", "wt") as file: file.write(code) if needed_types: @@ -248,7 +280,7 @@ def main(): print("ERROR:", e) print("Model saved to", out_path) - print("Run '%s output.onnx' to generate ONNX file" % out_path) + print("Run 'python %s output.onnx' to generate ONNX file" % out_path) print("Import the model with 'from %s import model'" % os.path.basename(out_path[:-3])) diff --git a/onnxconverter_common/pytracing.py b/onnxconverter_common/pytracing.py index e5901ba..f7727f4 100644 --- a/onnxconverter_common/pytracing.py +++ b/onnxconverter_common/pytracing.py @@ -3,6 +3,7 @@ # license information. ########################################################################### +from collections import OrderedDict import numpy as np @@ -10,6 +11,11 @@ def indent(s): return "\n".join(" " + line for line in s.split("\n")) +class NoPyObjException(Exception): + def __init__(self): + super().__init__("Tracing object has no associated python object") + + class TracingObject: """ Used by onnx2py to mock a module like numpy or onnx.helper and record calls on that module @@ -18,8 +24,9 @@ class TracingObject: x = np.array(np.product([1, 2, 3]), np.int32) assert repr(x) == "np.array(np.product([1, 2, 3]), np.int32)" """ - def __init__(self, trace): + def __init__(self, trace, py_obj=NoPyObjException): self._trace = trace + self._py_obj = py_obj self._cnt = 0 @staticmethod @@ -32,7 +39,7 @@ def get_cnt(o): @staticmethod def from_repr(o): - return TracingObject(TracingObject.get_repr(o)) + return TracingObject(TracingObject.get_repr(o), o) @staticmethod def get_repr(x): @@ -46,18 +53,37 @@ def get_repr(x): return code return "[\n" + "".join(indent(s) + ",\n" for s in ls) + "]" + @staticmethod + def get_py_obj(o): + if isinstance(o, list): + return [TracingObject.get_py_obj(x) for x in o] + if isinstance(o, TracingObject): + if o._py_obj is NoPyObjException: + raise NoPyObjException() + return o._py_obj + return o + def __getattr__(self, attr): self._cnt += 1 - return TracingObject(self._trace + "." + attr) + trace = self._trace + "." + attr + if self._py_obj is NoPyObjException: + return TracingObject(trace) + return TracingObject(trace, getattr(self._py_obj, attr)) def __call__(self, *args, **kwargs): self._cnt += 1 arg_s = [TracingObject.get_repr(o) for o in args] arg_s += [k + "=" + TracingObject.get_repr(o) for k, o in kwargs.items()] trace = self._trace + "(" + ", ".join(arg_s) + ")" - if len(trace) <= 200: - return TracingObject(trace) - return TracingObject(self._trace + "(\n" + "".join(indent(s) + ",\n" for s in arg_s) + ")") + if len(trace) > 200: + trace = self._trace + "(\n" + "".join(indent(s) + ",\n" for s in arg_s) + ")" + try: + arg_o = [TracingObject.get_py_obj(a) for a in args] + kwarg_o = OrderedDict((k, TracingObject.get_py_obj(v)) for k, v in kwargs.items()) + py_obj = TracingObject.get_py_obj(self)(*arg_o, **kwarg_o) + except NoPyObjException: + py_obj = NoPyObjException + return TracingObject(trace, py_obj) def __repr__(self): return self._trace