diff --git a/tests/python/relay/test_op_grad_level2.py b/tests/python/relay/test_op_grad_level2.py index 3ebb7eef848be..42404ab2cabf3 100644 --- a/tests/python/relay/test_op_grad_level2.py +++ b/tests/python/relay/test_op_grad_level2.py @@ -175,7 +175,13 @@ def verify_conv2d_grad(dshape, wshape, strides, padding, dilation, groups=1, mod data = relay.var("data", shape=dshape, dtype=dtype) weight = relay.var("weight", shape=wshape, dtype=dtype) conv = relay.nn.conv2d( - data, weight, strides=strides, padding=padding, dilation=dilation, groups=groups + data, + weight, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + out_dtype=dtype, ) fwd_func = relay.Function([data, weight], conv) check_grad(fwd_func, mode=mode)