diff --git a/tests/python/contrib/test_onnx.py b/tests/python/contrib/test_onnx.py index d99946d19d66..b6723e9b5729 100644 --- a/tests/python/contrib/test_onnx.py +++ b/tests/python/contrib/test_onnx.py @@ -89,7 +89,7 @@ def test_bias_add(): bshape = (2,) rtol = 1e-2 if dtype == "float16" else 1e-5 x = relay.var("x", shape=xshape, dtype=dtype) - bias = relay.var("bias", dtype=dtype) + bias = relay.var("bias", shape=bshape, dtype=dtype) z = relay.nn.bias_add(x, bias) func = relay.Function([x, bias], z)