Skip to content

Commit

Permalink
[ONNX]fix datatype on Reciprocal op (apache#7519)
Browse files Browse the repository at this point in the history
* fix datatype on Reciprocal op

* clean up test case
  • Loading branch information
Matthew Brookhart authored and trevor-m committed Mar 2, 2021
1 parent 0a9c125 commit d65ecaa
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
3 changes: 2 additions & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,8 @@ class Reciprocal(OnnxOpConverter):

@classmethod
def _impl_v1(cls, inputs, attr, params):
return _expr.const(1.0) / inputs[0]
dtype = infer_type(inputs[0]).checked_type.dtype
return _expr.const(1.0, dtype=dtype) / inputs[0]


class Flatten(OnnxOpConverter):
Expand Down
11 changes: 7 additions & 4 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1830,23 +1830,26 @@ def test_unary_ops():
dtype = "float32"
out_shape = in_shape

def verify_unary_ops(op, x, rtol=1e-5, atol=1e-5):
def verify_unary_ops(op, x, rtol=1e-5, atol=1e-5, dtype="float32"):
x = x.astype(dtype)
ONNX_DTYPE = mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)]
z = helper.make_node(op, ["in1"], ["out"])
graph = helper.make_graph(
[z],
"_test",
inputs=[
helper.make_tensor_value_info("in1", TensorProto.FLOAT, list(in_shape)),
helper.make_tensor_value_info("in1", ONNX_DTYPE, list(in_shape)),
],
outputs=[helper.make_tensor_value_info("out", TensorProto.FLOAT, list(out_shape))],
outputs=[helper.make_tensor_value_info("out", ONNX_DTYPE, list(out_shape))],
)
model = helper.make_model(graph, producer_name="_test")
verify_with_ort_with_inputs(model, [x], rtol=rtol, atol=atol)

x = np.random.uniform(size=in_shape).astype(dtype)
x = np.random.uniform(size=in_shape)
verify_unary_ops("Neg", x)
verify_unary_ops("Abs", x)
verify_unary_ops("Reciprocal", x)
verify_unary_ops("Reciprocal", x, dtype="float16")
verify_unary_ops("Sqrt", x)
verify_unary_ops("Relu", x)
verify_unary_ops("Exp", x)
Expand Down

0 comments on commit d65ecaa

Please sign in to comment.