Skip to content

Commit

Permalink
Fix issue onnx#2102: Set reduction axis of mean to height and width f…
Browse files Browse the repository at this point in the history
…or adjust_contrast op

Signed-off-by: cosine <[email protected]>
  • Loading branch information
cosineFish committed Mar 18, 2023
1 parent ec01956 commit 429e813
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3280,13 +3280,13 @@ def func(x, x_new_size_):
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: x_new_size})

def test_adjust_contrast(self):
x_shape = [4, 3, 2]
x_val = np.arange(1, 1 + np.prod(x_shape), dtype=np.float32).reshape(x_shape)
y_val = np.array(2.1, np.float32)
def func(x, y):
x_ = tf.image.adjust_contrast(x, y)
return tf.identity(x_, name=_TFOUTPUT)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
for x_shape in [[4, 3, 2], [2, 3, 4, 5], [3, 4, 2, 4, 3]]:
x_val = np.arange(1, 1 + np.prod(x_shape), dtype=np.float32).reshape(x_shape)
y_val = np.array(2.1, np.float32)
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})

@check_opset_min_version(11, "GatherElements")
def test_adjust_saturation(self):
Expand Down
4 changes: 2 additions & 2 deletions tf2onnx/onnx_opset/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1505,8 +1505,8 @@ def version_1(cls, ctx, node, **kwargs):
contrast_factor = ctx.make_node("Cast", [dtype], attr={'to': dtype}).output[0]
rank = ctx.get_rank(images)
utils.make_sure(rank is not None, "AdjustContrastv2 requires input of known rank")
# Reduce everything except channels
axes_to_reduce = list(range(rank))[:-1]
# Reduce height and width only
axes_to_reduce = list(range(rank))[-3:-1]
mean = ctx.make_node("ReduceMean", [images], attr={'axes': axes_to_reduce, 'keepdims': True},
op_name_scope=node.name).output[0]
diff = ctx.make_node("Sub", [images, mean], op_name_scope=node.name).output[0]
Expand Down

0 comments on commit 429e813

Please sign in to comment.