Skip to content

Commit

Permalink
Use coordinate_transformation_mode as default parameter (tf 2.0 will …
Browse files Browse the repository at this point in the history
…change it) (#38)
  • Loading branch information
jiafatom authored Dec 18, 2019
1 parent 85e9bcf commit f684aeb
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions onnxconverter_common/onnx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ def apply_reshape(scope, input_name, output_name, container, operator_name=None,
container.add_node('Reshape', input_name, output_name, op_version=5, name=name)


def apply_resize(scope, input_name, output_name, container, operator_name=None, mode='nearest', scales=None):
def apply_resize(scope, input_name, output_name, container, operator_name=None, mode='nearest', coordinate_transformation_mode='asymmetric', scales=None):
'''
:param mode: "nearest" or "linear"
:param scales: a float tensor for scaling (upsampling or downsampling) all input dimensions
Expand All @@ -590,7 +590,7 @@ def apply_resize(scope, input_name, output_name, container, operator_name=None,
roi = [0.0] * len(scales) + [1.0] * len(scales)
container.add_initializer(roi_tensor_name, onnx_proto.TensorProto.FLOAT, [2 * len(scales)], roi)
inputs.append(roi_tensor_name)
attrs['coordinate_transformation_mode'] = 'asymmetric'
attrs['coordinate_transformation_mode'] = coordinate_transformation_mode
if attrs['mode'] == 'nearest':
attrs['nearest_mode'] = 'floor'

Expand Down Expand Up @@ -806,7 +806,7 @@ def apply_transpose(scope, input_name, output_name, container, operator_name=Non
container.add_node('Transpose', input_name, output_name, name=name, perm=perm)


def apply_upsample(scope, input_name, output_name, container, operator_name=None, mode='nearest', scales=None):
def apply_upsample(scope, input_name, output_name, container, operator_name=None, mode='nearest', coordinate_transformation_mode='asymmetric', scales=None):
'''
:param mode: nearest or linear
:param scales: an integer list of scaling-up rate of all input dimensions
Expand Down Expand Up @@ -838,4 +838,4 @@ def apply_upsample(scope, input_name, output_name, container, operator_name=None
else:
# Upsample op is deprecated in ONNX opset 10
# We implement Upsample through Resize instead
apply_resize(scope, input_name, output_name, container, operator_name, mode, scales)
apply_resize(scope, input_name, output_name, container, operator_name, mode, coordinate_transformation_mode, scales)

0 comments on commit f684aeb

Please sign in to comment.