Skip to content

Commit

Permalink
[ONNX] Fix in interpreting auto_pad parameters SAME_UPPER and SAME_LO…
Browse files Browse the repository at this point in the history
…WER in ConvTranspose operator
  • Loading branch information
padreofthegame committed Oct 31, 2023
1 parent d83cd21 commit 89373d8
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 5 deletions.
8 changes: 4 additions & 4 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,9 +872,9 @@ def _impl_v1(cls, inputs, attr, params):
if "output_shape" in attr and "auto_pad" not in attr:
pad = right + left
elif "LOWER" in attr["auto_pad"]:
pad = left + right
else:
pad = right + left
else:
pad = left + right
attr["pads"] = pad
else:
data = autopad(
Expand Down Expand Up @@ -966,9 +966,9 @@ def _impl_v11(cls, inputs, attr, params):
if "output_shape" in attr and "auto_pad" not in attr:
pad = right + left
elif "LOWER" in attr["auto_pad"]:
pad = left + right
else:
pad = right + left
else:
pad = left + right
attr["pads"] = pad
elif attr["auto_pad"] == "VALID":
attr["pads"] = tuple([0 for i in range(ndim - 2)])
Expand Down
31 changes: 30 additions & 1 deletion tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3382,6 +3382,36 @@ def repeat(num, dims):
auto_pad="SAME_LOWER",
)

verify_convtranspose_with_output_shape(
(1, 1) + repeat(32, dims),
(1, 2) + repeat(4, dims),
repeat(num, dims),
repeat(4, dims),
repeat(2, dims),
repeat(1, dims),
auto_pad="SAME_UPPER",
)

verify_convtranspose_with_output_shape(
(1, 1, 3, 3),
(1, 2, 3, 3),
(6, 6),
(3, 3),
(2, 2),
(1, 1),
auto_pad="SAME_UPPER",
)

verify_convtranspose_with_output_shape(
(1, 1, 3, 3),
(1, 2, 3, 3),
(6, 6),
(3, 3),
(2, 2),
(1, 1),
auto_pad="SAME_LOWER",
)


@tvm.testing.parametrize_targets
def test_unsqueeze_constant(target, dev):
Expand Down Expand Up @@ -5550,7 +5580,6 @@ def verify_eyelike(indata, dynamic=False):
"test_cast_DOUBLE_to_FLOAT16",
"test_castlike_DOUBLE_to_FLOAT16",
"test_castlike_DOUBLE_to_FLOAT16_expanded",
"test_convtranspose_autopad_same",
"test_convtranspose_dilations",
"test_cumsum_1d",
"test_cumsum_1d_exclusive",
Expand Down

0 comments on commit 89373d8

Please sign in to comment.