From f99b024ac24bcac861a772058bf857917ccffc88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 4 Sep 2019 18:11:01 +0200 Subject: [PATCH 1/2] Fixes #17, update clip operator (#18) * Fixes #17, update clip operator (updated in ONNX 11) * handle min or max = None --- onnxconverter_common/__init__.py | 2 +- onnxconverter_common/case_insensitive_dict.py | 2 +- onnxconverter_common/container.py | 1 - onnxconverter_common/data_types.py | 6 +- onnxconverter_common/float16.py | 1 + onnxconverter_common/metadata_props.py | 1 - onnxconverter_common/onnx_ops.py | 85 +++++++++++++++---- onnxconverter_common/optimizer.py | 4 +- onnxconverter_common/shape_calculator.py | 1 - onnxconverter_common/tree_ensemble.py | 3 - 10 files changed, 75 insertions(+), 31 deletions(-) diff --git a/onnxconverter_common/__init__.py b/onnxconverter_common/__init__.py index c436451..ca159ef 100644 --- a/onnxconverter_common/__init__.py +++ b/onnxconverter_common/__init__.py @@ -9,7 +9,7 @@ This framework performs optimization for ONNX models and includes common utilities for ONNX converters. """ -__version__ = "1.5.3" +__version__ = "1.5.4" __author__ = "Microsoft" __producer__ = "OnnxMLTools" __producer_version__ = __version__ diff --git a/onnxconverter_common/case_insensitive_dict.py b/onnxconverter_common/case_insensitive_dict.py index ecc231b..f1706c9 100644 --- a/onnxconverter_common/case_insensitive_dict.py +++ b/onnxconverter_common/case_insensitive_dict.py @@ -42,7 +42,7 @@ def __eq__(self, other): return dict(self.lower_key_iteritems()) == dict(other.lower_key_iteritems()) def copy(self): - return CaseInsensitiveDict(self._dict.values()) + return CaseInsensitiveDict(self._dict.values()) def __repr__(self): return str(dict(self.items())) diff --git a/onnxconverter_common/container.py b/onnxconverter_common/container.py index 74d908c..99053ad 100644 --- a/onnxconverter_common/container.py +++ b/onnxconverter_common/container.py @@ -5,7 +5,6 @@ # -------------------------------------------------------------------------- import six -import onnx from onnx import helper from .interface import ModelContainer diff --git a/onnxconverter_common/data_types.py b/onnxconverter_common/data_types.py index 094d5e3..b41efd5 100644 --- a/onnxconverter_common/data_types.py +++ b/onnxconverter_common/data_types.py @@ -182,8 +182,7 @@ def to_onnx_type(self): onnx_type.map_type.key_type = onnx_proto.TensorProto.STRING onnx_type.map_type.value_type.CopyFrom( self.value_type.to_onnx_type()) - except AttributeError as e: - import onnx + except AttributeError: msg = "ONNX was not compiled with flag ONNX-ML.\n{0}\n{1}" msg = msg.format(str(self), str(self.value_type.to_onnx_type())) info = [onnx.__version__, str(onnx_type)] @@ -207,8 +206,7 @@ def to_onnx_type(self): try: onnx_type.sequence_type.elem_type.CopyFrom( self.element_type.to_onnx_type()) - except AttributeError as e: - import onnx + except AttributeError: msg = "ONNX was not compiled with flag ONNX-ML.\n{0}\n{1}" msg = msg.format(str(self), str(self.element_type.to_onnx_type())) info = [onnx.__version__, str(onnx_type)] diff --git a/onnxconverter_common/float16.py b/onnxconverter_common/float16.py index fb1a447..73c6abb 100644 --- a/onnxconverter_common/float16.py +++ b/onnxconverter_common/float16.py @@ -10,6 +10,7 @@ from onnx import helper from onnx import onnx_pb as onnx_proto + def _npfloat16_to_int(np_list): ''' Convert numpy float16 to python int. diff --git a/onnxconverter_common/metadata_props.py b/onnxconverter_common/metadata_props.py index 530de83..c7f6eb2 100644 --- a/onnxconverter_common/metadata_props.py +++ b/onnxconverter_common/metadata_props.py @@ -1,6 +1,5 @@ import warnings from .case_insensitive_dict import CaseInsensitiveDict -import onnx from onnx import onnx_pb as onnx_proto KNOWN_METADATA_PROPS = CaseInsensitiveDict({ diff --git a/onnxconverter_common/onnx_ops.py b/onnxconverter_common/onnx_ops.py index 6244402..fc5afec 100644 --- a/onnxconverter_common/onnx_ops.py +++ b/onnxconverter_common/onnx_ops.py @@ -112,8 +112,10 @@ def apply_batch_norm(scope, input_names, output_names, container, operator_name= name = _create_name_or_use_existing_one(scope, 'BatchNormalization', operator_name) attrs = {'name': name, 'epsilon': epsilon, 'momentum': momentum} - if container.target_opset < 9: attrs['spatial'] = spatial - if container.target_opset < 7: attrs['is_test'] = is_test + if container.target_opset < 9: + attrs['spatial'] = spatial + if container.target_opset < 7: + attrs['is_test'] = is_test if container.target_opset < 6: attrs['consumed_inputs'] = [0] * len(input_names) @@ -169,20 +171,69 @@ def apply_cast(scope, input_name, output_name, container, operator_name=None, to def apply_clip(scope, input_name, output_name, container, operator_name=None, max=None, min=None): name = _create_name_or_use_existing_one(scope, 'Clip', operator_name) - attrs = {'name': name} - if max is not None: - attrs['max'] = float(max) - if min is not None: - attrs['min'] = float(min) - if container.target_opset < 6: - attrs['consumed_inputs'] = [0] - op_version = 1 + if container.target_opset < 11: + if max is not None: + attrs['max'] = float(max) + if min is not None: + attrs['min'] = float(min) + + if container.target_opset < 6: + attrs['consumed_inputs'] = [0] + op_version = 1 + else: + op_version = 6 + + container.add_node('Clip', input_name, output_name, op_version=op_version, **attrs) else: - op_version = 6 + op_version = 11 + if min is None and max is not None: + raise RuntimeError("Operator 'Clip': min must be specified if max is.") + inputs = [input_name] + + if min is not None: + if isinstance(min, (np.ndarray, float, int)): + # add initializer + if isinstance(min, np.ndarray): + if min.shape != (1, ): + raise RuntimeError("min must an array of one element.") + else: + # container in sklearn-onnx stores the computation type in + # container.dtype. + min = np.array([min], dtype=getattr( + container, 'dtype', np.float32)) + min_name = scope.get_unique_variable_name('clip_min') + container.add_initializer(min_name, getattr(container, 'proto_dtype', + onnx_proto.TensorProto.FLOAT), [1], [min[0]]) + min = min_name + if isinstance(min, str): + inputs.append(min) + else: + raise RuntimeError("Parameter 'min' must be a string or a float.") + + if max is not None: + if min is None: + raise RuntimeError("Parameter 'min' must be specified if 'max' is.") + if isinstance(max, (np.ndarray, float, int)): + # add initializer + if isinstance(max, np.ndarray): + if max.shape != (1, ): + raise RuntimeError("max must an array of one element.") + else: + max = np.array([max], dtype=getattr( + container, 'dtype', np.float32)) + max_name = scope.get_unique_variable_name('clip_max') + container.add_initializer(max_name, getattr(container, 'proto_dtype', + onnx_proto.TensorProto.FLOAT), [1], [max[0]]) + max = max_name + if isinstance(max, str): + inputs.append(max) + else: + raise RuntimeError("Parameter 'max' must be a string or a float.") - container.add_node('Clip', input_name, output_name, op_version=op_version, **attrs) + container.add_node('Clip', input_name, output_name, op_version=op_version, + **attrs) def apply_concat(scope, input_names, output_name, container, operator_name=None, axis=0): @@ -374,9 +425,9 @@ def apply_pad(scope, input_name, output_name, container, operator_name=None, mod def apply_parametric_softplus(scope, input_name, output_name, container, operator_name=None, alpha=None, beta=None): - if alpha == None: + if alpha is None: alpha = [1.0] - if beta == None: + if beta is None: beta = [0.] name = _create_name_or_use_existing_one(scope, 'ParametricSoftplus', operator_name) @@ -515,9 +566,9 @@ def apply_softmax(scope, input_name, output_name, container, operator_name=None, def apply_scaled_tanh(scope, input_name, output_name, container, operator_name=None, alpha=None, beta=None): - if alpha == None: + if alpha is None: alpha = [1.0] - if beta == None: + if beta is None: beta = [1.0] if len(alpha) != 1 or len(beta) != 1: raise ValueError('alpha and beta must be 1-element lists') @@ -621,7 +672,7 @@ def apply_tanh(scope, input_name, output_name, container, operator_name=None): def apply_thresholded_relu(scope, input_name, output_name, container, operator_name=None, alpha=None): - if alpha == None: + if alpha is None: alpha = [1.0] name = _create_name_or_use_existing_one(scope, 'ThresholdedRelu', operator_name) diff --git a/onnxconverter_common/optimizer.py b/onnxconverter_common/optimizer.py index de73712..30515e0 100644 --- a/onnxconverter_common/optimizer.py +++ b/onnxconverter_common/optimizer.py @@ -211,7 +211,7 @@ def generate(self): onode.doc_string = self.origin.doc_string onode.domain = self.origin.domain onode.attribute.extend( - attr for attr in self.origin.attribute if not attr.name in self.attributes) + attr for attr in self.origin.attribute if attr.name not in self.attributes) onode.attribute.extend( helper.make_attribute(attr.name, self.attributes[attr.name]) for attr in self.attributes) @@ -415,7 +415,7 @@ def apply(self, node_list): perm_f = [perm0[idx] for idx in perm1] if self.is_useless_transpose(perm_f): node = self.begin # type: LinkedNode - while node != self.end and len(node.successor) >=1: + while node != self.end and len(node.successor) >= 1: #if node.broadcast: # node.reshape_input_for_broadcast(perm0) node = node.successor[0] diff --git a/onnxconverter_common/shape_calculator.py b/onnxconverter_common/shape_calculator.py index 868135e..195f7ec 100644 --- a/onnxconverter_common/shape_calculator.py +++ b/onnxconverter_common/shape_calculator.py @@ -9,7 +9,6 @@ import numpy as np import numbers import six -from .registration import register_shape_calculator from .data_types import Int64TensorType, FloatTensorType, StringTensorType, DictionaryType, SequenceType from .utils import check_input_and_output_numbers, check_input_and_output_types diff --git a/onnxconverter_common/tree_ensemble.py b/onnxconverter_common/tree_ensemble.py index 5bd8f37..d5487f4 100644 --- a/onnxconverter_common/tree_ensemble.py +++ b/onnxconverter_common/tree_ensemble.py @@ -6,9 +6,6 @@ """ Common functions to convert any learner based on trees. """ - -import numpy as np -import numbers, six from .registration import register_converter From 778a6b22a53779d4186414c2a5de2d4922ec65bb Mon Sep 17 00:00:00 2001 From: David Fan Date: Wed, 15 Jan 2020 12:03:29 -0800 Subject: [PATCH 2/2] Support dynamic pads for Pad opset 11 --- onnxconverter_common/onnx_ops.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/onnxconverter_common/onnx_ops.py b/onnxconverter_common/onnx_ops.py index a0d1942..bab150b 100644 --- a/onnxconverter_common/onnx_ops.py +++ b/onnxconverter_common/onnx_ops.py @@ -504,6 +504,8 @@ def apply_pad(scope, input_name, output_name, container, operator_name=None, mod attrs['mode'] = mode if container.target_opset < 11: + if isinstance(pads, str): + raise ValueError("Dynamic pad is not supported for opset < 11.") if value is not None: attrs['value'] = value if container.target_opset < 2: @@ -514,9 +516,12 @@ def apply_pad(scope, input_name, output_name, container, operator_name=None, mod op_version = 2 else: op_version = 11 - pads_name = scope.get_unique_variable_name(name + '_pads') - container.add_initializer(pads_name, onnx_proto.TensorProto.INT64, [len(pads)], pads) - inputs.append(pads_name) + if isinstance(pads, str): + inputs.append(pads) + else: + pads_name = scope.get_unique_variable_name(name + '_pads') + container.add_initializer(pads_name, onnx_proto.TensorProto.INT64, [len(pads)], pads) + inputs.append(pads_name) if value is not None: value_name = scope.get_unique_variable_name(name + '_value') container.add_initializer(value_name, onnx_type, [], [value]) @@ -743,7 +748,7 @@ def apply_slice(scope, input_name, output_name, container, starts, ends, op_version = 10 else: op_version = 11 - inputs = [input_name] + inputs = input_name if isinstance(input_name, list) else [input_name] starts_name = scope.get_unique_variable_name('starts') ends_name = scope.get_unique_variable_name('ends') container.add_initializer(starts_name, onnx_proto.TensorProto.INT64,