diff --git a/tests/test_backend.py b/tests/test_backend.py index bc31646c3..ae7606b94 100755 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -3280,13 +3280,13 @@ def func(x, x_new_size_): self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: x_new_size}) def test_adjust_contrast(self): - x_shape = [4, 3, 2] - x_val = np.arange(1, 1 + np.prod(x_shape), dtype=np.float32).reshape(x_shape) - y_val = np.array(2.1, np.float32) def func(x, y): x_ = tf.image.adjust_contrast(x, y) return tf.identity(x_, name=_TFOUTPUT) - self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val}) + for x_shape in [[4, 3, 2], [2, 3, 4, 5], [3, 4, 2, 4, 3]]: + x_val = np.arange(1, 1 + np.prod(x_shape), dtype=np.float32).reshape(x_shape) + y_val = np.array(2.1, np.float32) + self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val}) @check_opset_min_version(11, "GatherElements") def test_adjust_saturation(self): diff --git a/tf2onnx/onnx_opset/nn.py b/tf2onnx/onnx_opset/nn.py index 174780db4..d2940b466 100644 --- a/tf2onnx/onnx_opset/nn.py +++ b/tf2onnx/onnx_opset/nn.py @@ -1505,8 +1505,8 @@ def version_1(cls, ctx, node, **kwargs): contrast_factor = ctx.make_node("Cast", [dtype], attr={'to': dtype}).output[0] rank = ctx.get_rank(images) utils.make_sure(rank is not None, "AdjustContrastv2 requires input of known rank") - # Reduce everything except channels - axes_to_reduce = list(range(rank))[:-1] + # Reduce height and width only + axes_to_reduce = list(range(rank))[-3:-1] mean = ctx.make_node("ReduceMean", [images], attr={'axes': axes_to_reduce, 'keepdims': True}, op_name_scope=node.name).output[0] diff = ctx.make_node("Sub", [images, mean], op_name_scope=node.name).output[0]