diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 0c882589e9cb..a6f6390b2110 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -438,18 +438,8 @@ bool Conv3DRel(const Array& types, int num_inputs, const Attrs& attrs, if (param->kernel_size.defined() && param->channels.defined()) { ICHECK_EQ(param->kernel_size.size(), 3); ICHECK_EQ(param->dilation.size(), 3); - Array wshape; - tvm::tir::ExprDeepEqual expr_equal; - - if (expr_equal(param->channels, param->groups) && !expr_equal(param->channels, 1)) { - // infer weight's shape for depthwise convolution - wshape = {{dshape_ncdhw[1], indexdiv(param->groups, dshape_ncdhw[1]), param->kernel_size[0], - param->kernel_size[1], param->kernel_size[2]}}; - } else { - wshape = {{param->channels, indexdiv(dshape_ncdhw[1], param->groups), param->kernel_size[0], - param->kernel_size[1], param->kernel_size[2]}}; - } - + Array wshape({param->channels, indexdiv(dshape_ncdhw[1], param->groups), + param->kernel_size[0], param->kernel_size[1], param->kernel_size[2]}); wshape = trans_kernel_layout.BackwardShape(wshape); channels = param->channels; dilated_ksize_z = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index f54756546470..dd6a54b959cc 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -522,6 +522,13 @@ def test_conv3d_infer_type(): yy = run_infer_type(y) assert yy.checked_type == relay.TensorType((n, d, h, w, 16), "int32") + # Infer with groups + x = relay.var("x", relay.TensorType((1, 16, 224, 224, 224), "float32")) + w = relay.var("w", relay.TensorType((4, 4, 1, 1, 1), "float32")) + y = relay.nn.conv3d(x, w, groups=4, kernel_size=(1, 1, 1), channels=4) + yy = run_infer_type(y) + assert yy.checked_type == relay.TensorType((1, 4, 224, 224, 224), "float32") + @tvm.testing.uses_gpu def test_conv3d_run():