Skip to content

Commit

Permalink
[Bugfix] Shape inference of weight for grouped nn.conv3d (#11681)
Browse files Browse the repository at this point in the history
* Fix `nn.conv3d` weight shape inference.

* Add test for conv3d type inference with groups.
  • Loading branch information
wzh99 authored Jun 12, 2022
1 parent 8f6543e commit 8341e33
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 12 deletions.
14 changes: 2 additions & 12 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -438,18 +438,8 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
if (param->kernel_size.defined() && param->channels.defined()) {
ICHECK_EQ(param->kernel_size.size(), 3);
ICHECK_EQ(param->dilation.size(), 3);
Array<IndexExpr> wshape;
tvm::tir::ExprDeepEqual expr_equal;

if (expr_equal(param->channels, param->groups) && !expr_equal(param->channels, 1)) {
// infer weight's shape for depthwise convolution
wshape = {{dshape_ncdhw[1], indexdiv(param->groups, dshape_ncdhw[1]), param->kernel_size[0],
param->kernel_size[1], param->kernel_size[2]}};
} else {
wshape = {{param->channels, indexdiv(dshape_ncdhw[1], param->groups), param->kernel_size[0],
param->kernel_size[1], param->kernel_size[2]}};
}

Array<IndexExpr> wshape({param->channels, indexdiv(dshape_ncdhw[1], param->groups),
param->kernel_size[0], param->kernel_size[1], param->kernel_size[2]});
wshape = trans_kernel_layout.BackwardShape(wshape);
channels = param->channels;
dilated_ksize_z = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
Expand Down
7 changes: 7 additions & 0 deletions tests/python/relay/test_op_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,13 @@ def test_conv3d_infer_type():
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((n, d, h, w, 16), "int32")

# Infer with groups
x = relay.var("x", relay.TensorType((1, 16, 224, 224, 224), "float32"))
w = relay.var("w", relay.TensorType((4, 4, 1, 1, 1), "float32"))
y = relay.nn.conv3d(x, w, groups=4, kernel_size=(1, 1, 1), channels=4)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType((1, 4, 224, 224, 224), "float32")


@tvm.testing.uses_gpu
def test_conv3d_run():
Expand Down

0 comments on commit 8341e33

Please sign in to comment.