diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 1fa909e748a0..23b0a5956d8f 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -1107,6 +1107,7 @@ def legalize_conv2d_backward_weight(attrs, inputs, types): padding=attrs.padding, dilation=attrs.strides, groups=in_channel * batch, + channels=attrs.channels, ) # infer shape of backward_weight diff --git a/python/tvm/topi/testing/conv2d_backcward_weight_python.py b/python/tvm/topi/testing/conv2d_backcward_weight_python.py index 36a6b0616053..305fb138795c 100644 --- a/python/tvm/topi/testing/conv2d_backcward_weight_python.py +++ b/python/tvm/topi/testing/conv2d_backcward_weight_python.py @@ -20,7 +20,9 @@ # Reference: cutlass/tools/util/include/cutlass/util/reference/host/convolution.h -def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding): +def conv2d_backward_weight_nchw_python( + dy_np, x_np, kernel_size, stride, padding, groups=1, channels=None +): """Gradient of the conv2d op with respect to weight, in NCHW layout. Parameters @@ -51,17 +53,34 @@ def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding R, S = kernel_size pad_h, pad_w = padding stride_h, stride_w = stride - dw = np.zeros((K, C, R, S)).astype(dy_np.dtype) + is_depth_wise = C == K and C == groups + + if is_depth_wise: + channel_mult = channels // groups + dw = np.zeros((K, channel_mult, R, S)).astype(dy_np.dtype) + else: + dw = np.zeros((K, C // groups, R, S)).astype(dy_np.dtype) + channel_mult = 1 for k in range(K): for r in range(R): for s in range(S): - for c in range(C): + for c in range(dw.shape[1]): acc = 0 for n in range(N): for p in range(P): for q in range(Q): - coord = (n, c, p * stride_h - pad_h + r, q * stride_w - pad_w + s) + if not is_depth_wise: + in_c = c + else: + in_c = k // channel_mult + + coord = ( + n, + in_c, + p * stride_h - pad_h + r, + q * stride_w - pad_w + s, + ) if ( coord[2] < H @@ -76,7 +95,9 @@ def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding return dw -def conv2d_backward_weight_python(dy_np, x_np, kernel_size, stride, padding, layout="NCHW"): +def conv2d_backward_weight_python( + dy_np, x_np, kernel_size, stride, padding, layout="NCHW", groups=1, channels=None +): """Gradient of the conv2d op with respect to weight, in NCHW or NHWC layout. Parameters @@ -99,6 +120,12 @@ def conv2d_backward_weight_python(dy_np, x_np, kernel_size, stride, padding, lay layout: string Layout of dy_np and x_np + groups: int + Number of groups for grouped convolution. + + channels : int + Number of output channels of this convolution. + Returns ------- dw_np : np.ndarray @@ -106,7 +133,9 @@ def conv2d_backward_weight_python(dy_np, x_np, kernel_size, stride, padding, lay [num_filter, filter_height, filter_width, in_channel] for NHWC layout. """ if layout == "NCHW": - return conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding) + return conv2d_backward_weight_nchw_python( + dy_np, x_np, kernel_size, stride, padding, groups, channels + ) dw_np_oihw = conv2d_backward_weight_nchw_python( np.transpose(dy_np, [0, 3, 1, 2]), @@ -114,5 +143,7 @@ def conv2d_backward_weight_python(dy_np, x_np, kernel_size, stride, padding, lay kernel_size, stride, padding, + groups, + channels, ) return np.transpose(dw_np_oihw, [0, 2, 3, 1]) diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 30386bbf4415..4dedd85eb242 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -638,10 +638,23 @@ bool Conv2DBackwardWeightRel(const Array& types, int num_inputs, const Att auto in_channels = dshape_nchw[1]; auto out_channels = grad_shape_nchw[1]; - - Array wshape_oihw( - {out_channels, in_channels, param->kernel_size[0], param->kernel_size[1]}); - + auto in_channels_intimm = in_channels.as(); + auto out_channels_intimm = out_channels.as(); + ICHECK(in_channels_intimm); + ICHECK(out_channels_intimm); + + IndexExpr weight_dim_i; + if (in_channels_intimm->value == out_channels_intimm->value && + in_channels_intimm->value == param->groups) { + // depthwise + ICHECK(param->channels.defined()) << "out_channels attribute not specified for depth wise conv2d."; + weight_dim_i = indexdiv(param->channels, param->groups); + } else { + weight_dim_i = indexdiv(in_channels, param->groups); + } + + Array wshape_oihw{out_channels, weight_dim_i, param->kernel_size[0], + param->kernel_size[1]}; auto wshape = trans_kernel_layout.BackwardShape(wshape_oihw); reporter->Assign(types[2], TensorType(wshape, data->dtype)); return true; diff --git a/tests/python/relay/test_op_grad_level2.py b/tests/python/relay/test_op_grad_level2.py index a5fc630f61dc..ff053a1c0ad8 100644 --- a/tests/python/relay/test_op_grad_level2.py +++ b/tests/python/relay/test_op_grad_level2.py @@ -229,17 +229,26 @@ def test_batch_flatten_grad(): verify_batch_flatten_grad((1, 8)) -def verify_conv2d_backward_weight(dy_shape, x_shape, kernel_size, stride, padding): +def verify_conv2d_backward_weight( + dy_shape, x_shape, kernel_size, stride, padding, groups=1, out_channels=None +): dtype = "float32" dy = relay.var("dy", shape=dy_shape, dtype=dtype) x = relay.var("x", shape=x_shape, dtype=dtype) dw_func = relay.Function( [dy, x], relay.nn.conv2d_backward_weight( - dy, x, strides=stride, padding=padding, kernel_size=kernel_size + dy, + x, + strides=stride, + padding=padding, + kernel_size=kernel_size, + groups=groups, + channels=out_channels, ), ) dw_func_legalized = run_opt_pass(dw_func, relay.transform.Legalize()) + print(run_infer_type(dw_func_legalized)) for dw, target in [(dw_func_legalized, "llvm"), (dw_func, "cuda -libs=cudnn")]: if "cudnn" in target and not tvm.contrib.cudnn.exists(): @@ -251,16 +260,22 @@ def verify_conv2d_backward_weight(dy_shape, x_shape, kernel_size, stride, paddin dw_np = relay.create_executor(device=dev, target=target).evaluate(dw)(dy_np, x_np).numpy() ref_dw_np = tvm.topi.testing.conv2d_backward_weight_python( - dy_np, x_np, kernel_size, stride, padding + dy_np, x_np, kernel_size, stride, padding, groups=groups, channels=out_channels ) np.testing.assert_allclose(dw_np, ref_dw_np, rtol=1e-4, atol=1e-4) def test_conv2d_backward_weight(): - verify_conv2d_backward_weight((2, 8, 32, 32), (2, 4, 32, 32), (3, 3), (1, 1), (1, 1)) - verify_conv2d_backward_weight((2, 16, 15, 15), (2, 3, 32, 32), (3, 3), (2, 2), (0, 0)) + # verify_conv2d_backward_weight((2, 8, 32, 32), (2, 4, 32, 32), (3, 3), (1, 1), (1, 1)) + # verify_conv2d_backward_weight((2, 16, 15, 15), (2, 3, 32, 32), (3, 3), (2, 2), (0, 0)) + verify_conv2d_backward_weight((1, 16, 32, 32), (1, 16, 32, 32), (3, 3), (1, 1), (1, 1), groups=16, out_channels=16) + # verify_conv2d_backward_weight( + # (1, 32, 32, 32), (1, 16, 32, 32), (3, 3), (1, 1), (1, 1), groups=8 + # ) + # verify_conv2d_backward_weight((1, 32, 32, 32), (1, 16, 32, 32), (3, 3), (1, 1), (1, 1), groups=16, out_channels=32) if __name__ == "__main__": - pytest.main([__file__]) + # pytest.main([__file__]) + test_conv2d_backward_weight()