Skip to content

Commit

Permalink
[ONNX] More Unit Tests! (#7956)
Browse files Browse the repository at this point in the history
* support same lower and maxpool in autopad

* fix isinf tests

* lower tolerance on roialign test becuase the onnx result is cropped to 4 decimal places

* slow support for bottom-k

* throw with nullptr in gathernd and scatternd, fix typo

* fix lint

* fix a copy typo
  • Loading branch information
Matthew Brookhart authored May 3, 2021
1 parent dd5379f commit c380a69
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 18 deletions.
98 changes: 88 additions & 10 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ class Pool(OnnxOpConverter):
def _impl_v1(cls, inputs, attr, params):
data = inputs[0]
input_shape = infer_shape(data)
input_dtype = infer_type(data).checked_type.dtype
ndim = len(input_shape)
if "auto_pad" in attr:
attr["auto_pad"] = attr["auto_pad"].decode("utf-8")
Expand All @@ -293,7 +294,19 @@ def _impl_v1(cls, inputs, attr, params):
else:
# Warning: Pool does not yet support dynamic shapes,
# one will need to run dynamic_to_static on this model after import
data = autopad(data, attr["strides"], attr["kernel_shape"], [1] * ndim, ndim)
if "int" in input_dtype:
pad_val = np.iinfo(np.dtype(input_dtype)).min
else:
pad_val = np.finfo(np.dtype(input_dtype)).min
data = autopad(
data,
attr.get("strides", [1] * (ndim - 2)),
attr["kernel_shape"],
[1] * ndim,
ndim,
pad_value=pad_val,
mode=attr["auto_pad"],
)
elif attr["auto_pad"] == "VALID":
attr["pads"] = tuple([0 for i in range(ndim - 2)])
elif attr["auto_pad"] == "NOTSET":
Expand Down Expand Up @@ -356,7 +369,17 @@ def _impl_v1(cls, inputs, attr, params):
return AttrCvt(op_name="instance_norm")(inputs, attr, params)


def autopad(data, strides, kernel_shape, dilations, ndim, pad_type="constant", deconv=False):
def autopad(
data,
strides,
kernel_shape,
dilations,
ndim,
pad_type="constant",
deconv=False,
mode="SAME_UPPER",
pad_value=0.0,
):
"""
Perform autopadding with dynamic input shapes
"""
Expand Down Expand Up @@ -391,14 +414,19 @@ def autopad(data, strides, kernel_shape, dilations, ndim, pad_type="constant", d
pad_after = total_pad - pad_before

# combine
pad = _op.concatenate(
[_op.reshape(pad_before, [-1, 1]), _op.reshape(pad_after, [-1, 1])], axis=1
)
if "LOWER" in mode:
pad = _op.concatenate(
[_op.reshape(pad_after, [-1, 1]), _op.reshape(pad_before, [-1, 1])], axis=1
)
else:
pad = _op.concatenate(
[_op.reshape(pad_before, [-1, 1]), _op.reshape(pad_after, [-1, 1])], axis=1
)

# pad N and C with zeros
pad = _op.concatenate([_op.const(np.zeros([2, 2], dtype="int64"), dtype="int64"), pad], axis=0)

return _op.nn.pad(data, fold_constant(pad), _op.const(0.0), pad_type)
return _op.nn.pad(data, fold_constant(pad), _op.const(pad_value), pad_type)


class Conv(OnnxOpConverter):
Expand Down Expand Up @@ -427,6 +455,7 @@ def _impl_v1(cls, inputs, attr, params):
attr["kernel_shape"],
attr.get("dilations", [1] * (ndim - 2)),
ndim,
mode=attr["auto_pad"],
)
elif attr["auto_pad"] == "VALID":
attr["pads"] = tuple([0 for i in range(ndim - 2)])
Expand Down Expand Up @@ -485,6 +514,7 @@ def _impl_v1(cls, inputs, attr, params):
attr.get("dilations", [1] * (ndim - 2)),
ndim,
deconv=True,
mode=attr["auto_pad"],
)
elif attr["auto_pad"] == "VALID":
attr["pads"] = tuple([0 for i in range(ndim - 2)])
Expand Down Expand Up @@ -757,7 +787,14 @@ def _impl_v1(cls, inputs, attr, params):
if attr["auto_pad"] in ("SAME_UPPER", "SAME_LOWER"):
# Warning: LpPool does not yet support dynamic shapes,
# one will need to run dynamic_to_static on this model after import
data = autopad(data, attr["strides"], attr["kernel_shape"], [1] * ndim, ndim)
data = autopad(
data,
attr["strides"],
attr["kernel_shape"],
[1] * ndim,
ndim,
mode=attr["auto_pad"],
)
elif attr["auto_pad"] == "VALID":
attr["pads"] = tuple([0 for i in range(ndim - 2)])
elif attr["auto_pad"] == "NOTSET":
Expand Down Expand Up @@ -1377,7 +1414,7 @@ def _impl_v1(cls, inputs, attr, params):


class ScatterND(OnnxOpConverter):
"""Operator converter for Scatter."""
"""Operator converter for ScatterND."""

@classmethod
def _impl_v11(cls, inputs, attr, params):
Expand Down Expand Up @@ -2228,7 +2265,32 @@ def _impl_v1(cls, inputs, attr, params):
largest = attr.get("largest", 1)

if largest == 0:
raise NotImplementedError("TVM only supports finding TopK largest elements")
# TODO(mbrookhart): optimize this by adding a smallest attribute to topi if this
# ever becomes a bottleneck
ndim = len(infer_shape(inputs[0]))
if axis < 0:
axis += ndim
sort = _op.sort(inputs[0], axis=axis)
argsort = _op.argsort(inputs[0], axis=axis, dtype="int64")
begin = [0] * ndim
stride = [1] * ndim
end = _op.concatenate(
[
_op.const([np.iinfo(np.int64).max] * axis, dtype="int64"),
inputs[1],
_op.const([np.iinfo(np.int64).max] * (ndim - axis - 1), dtype="int64"),
],
axis=0,
)
return _expr.TupleWrapper(
_expr.Tuple(
[
_op.strided_slice(sort, begin, end, stride),
_op.strided_slice(argsort, begin, end, stride),
]
),
2,
)

return _op.topk(inputs[0], inputs[1], axis=axis, dtype="int64")

Expand All @@ -2246,6 +2308,22 @@ def _impl_v1(cls, inputs, attr, params):
)


class IsInf(OnnxOpConverter):
"""Operator converter for IsInf"""

@classmethod
def _impl_v10(cls, inputs, attr, params):
detect_negative = attr.get("detect_negative", 1)
detect_positive = attr.get("detect_positive", 1)
dtype = infer_type(inputs[0]).checked_type.dtype
isinf = _op.isinf(inputs[0])
if not detect_negative:
isinf = isinf * (inputs[0] > _op.const(0, dtype))
if not detect_positive:
isinf = isinf * (inputs[0] < _op.const(0, dtype))
return isinf


class MaxRoiPool(OnnxOpConverter):
"""Operator converter for MaxRoiPool."""

Expand Down Expand Up @@ -2789,7 +2867,7 @@ def _get_convert_map(opset):
"Floor": Renamer("floor"),
"Ceil": Renamer("ceil"),
"Round": Renamer("round"),
"IsInf": Renamer("isinf"),
"IsInf": IsInf.get_converter(opset),
"IsNaN": Renamer("isnan"),
"Sqrt": Renamer("sqrt"),
"Relu": Renamer("relu"),
Expand Down
4 changes: 4 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,8 @@ bool ScatterNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,

const auto out_shape = data->shape;
const IntImmNode* mdim = indices->shape[0].as<IntImmNode>();
ICHECK(mdim) << "ScatterND needs a static shape for the first axis of indices, got "
<< indices->shape;
const size_t kdim = indices->shape.size() - 1;
const size_t ndim = out_shape.size();
ICHECK_LE(size_t(mdim->value), ndim)
Expand Down Expand Up @@ -3331,6 +3333,8 @@ bool GatherNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
}
const size_t ndim = data->shape.size();
const IntImmNode* mdim = indices->shape[0].as<IntImmNode>();
ICHECK(mdim) << "GatherND needs a static shape for the first axis of indices, got "
<< indices->shape;
const size_t kdim = indices->shape.size() - 1;
ICHECK(size_t(mdim->value) <= ndim) << "GatherND: indices shape does satisfy.";

Expand Down
16 changes: 8 additions & 8 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4209,12 +4209,8 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"):
"test_eyelike_populate_off_main_diagonal/",
"test_eyelike_with_dtype/",
"test_eyelike_without_dtype/",
"test_isinf_negative/",
"test_isinf_positive/",
"test_matmulinteger/",
"test_maxpool_2d_dilations/",
"test_maxpool_2d_same_lower/",
"test_maxpool_2d_same_upper/",
"test_maxpool_with_argmax_2d_precomputed_pads/",
"test_maxpool_with_argmax_2d_precomputed_strides/",
"test_maxunpool_export_with_output_shape/",
Expand All @@ -4233,7 +4229,6 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"):
"test_reversesequence_batch/",
"test_reversesequence_time/",
"test_rnn_seq_length/",
"test_roialign/",
"test_round/",
"test_scan9_sum/",
"test_scan_sum/",
Expand All @@ -4252,7 +4247,6 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"):
"test_tfidfvectorizer_tf_onlybigrams_levelempty/",
"test_tfidfvectorizer_tf_onlybigrams_skip5/",
"test_tfidfvectorizer_tf_uniandbigrams_skip5/",
"test_top_k_smallest/",
"test_unique_not_sorted_without_axis/",
"test_unique_sorted_with_axis/",
"test_unique_sorted_with_axis_3d/",
Expand All @@ -4268,6 +4262,12 @@ def test_onnx_nodes(test):
if failure in test:
pytest.skip()
break
atol = 1e-5
rtol = 1e-5
if "roialign" in test:
# for some reason the ONNX test crops the
# roialign results to 4 decimal places
atol = 1e-4
onnx_model = onnx.load(test + "/model.onnx")
inputs = []
outputs = []
Expand All @@ -4285,10 +4285,10 @@ def test_onnx_nodes(test):
raise ImportError(str(tensor) + " not labeled as an import or an output")
tvm_val = get_tvm_output_with_vm(onnx_model, inputs, "llvm", tvm.cpu(0))
if len(outputs) == 1:
tvm.testing.assert_allclose(outputs[0], tvm_val, rtol=1e-5, atol=1e-5)
tvm.testing.assert_allclose(outputs[0], tvm_val, rtol=rtol, atol=atol)
else:
for output, val in zip(outputs, tvm_val):
tvm.testing.assert_allclose(output, val, rtol=1e-5, atol=1e-5)
tvm.testing.assert_allclose(output, val, rtol=rtol, atol=atol)


def test_wrong_input():
Expand Down

0 comments on commit c380a69

Please sign in to comment.