From 041c094b3646e0f521f5bd2c4f6f6b5b1cff7b97 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 6 Feb 2022 14:19:09 +0900 Subject: [PATCH] dw conv2d properly supported for wgrad --- python/tvm/contrib/cudnn.py | 14 +++++- python/tvm/topi/cuda/conv2d.py | 9 ++-- .../testing/conv2d_backcward_weight_python.py | 43 ++++++++++++++++--- src/relay/op/nn/convolution.cc | 23 ++++++++-- tests/python/relay/test_op_grad_level2.py | 20 +++++++-- 5 files changed, 90 insertions(+), 19 deletions(-) diff --git a/python/tvm/contrib/cudnn.py b/python/tvm/contrib/cudnn.py index c897de74b250c..bfea1ff2e06ef 100644 --- a/python/tvm/contrib/cudnn.py +++ b/python/tvm/contrib/cudnn.py @@ -826,10 +826,20 @@ def conv_backward_filter( x.shape[0], tvm.tir.expr.IntImm ), "Dynamic batch is not supported for cudnn conv2d backwad filter yet." + ic_ind = 1 if tensor_format == 0 else 3 + + if groups > 1: + assert ( + x_shape[ic_ind] == dy.shape[ic_ind] and x_shape[ic_ind] == groups + ), "Only depthwise wgrad supported for groups > 1." + ic = 1 + else: + ic = x_shape[ic_ind] + if tensor_format == 0: - dw_shape = [dy.shape[1], x_shape[1], filter_h, filter_w] + dw_shape = [dy.shape[1], ic, filter_h, filter_w] else: - dw_shape = [dy.shape[3], filter_h, filter_w, x_shape[3]] + dw_shape = [dy.shape[3], filter_h, filter_w, ic] algo = conv_backward_filter_find_algo( tensor_format, diff --git a/python/tvm/topi/cuda/conv2d.py b/python/tvm/topi/cuda/conv2d.py index 5a5d59a6e2182..bce032040dcd9 100644 --- a/python/tvm/topi/cuda/conv2d.py +++ b/python/tvm/topi/cuda/conv2d.py @@ -130,9 +130,12 @@ def conv2d_backward_weight_cudnn( ): """Compute conv2d wgrad using CuDNN library""" assert layout in ["NCHW", "NHWC"] - # cuDNN does not seem to support other combination. - assert output_dtype == "float16", "Only supports fp16 output for cuDNN wgrad." - conv_dtype = "float32" + + if dy.dtype == "float16": + # cuDNN does not seem to support other combination. + assert output_dtype == "float16", "Only supports fp16 output for cuDNN fp16 wgrad." + + conv_dtype = "float32" # Accumulation is always fp32 return cudnn.conv_backward_filter( dy, x, diff --git a/python/tvm/topi/testing/conv2d_backcward_weight_python.py b/python/tvm/topi/testing/conv2d_backcward_weight_python.py index 36a6b06160531..3e4a2d0ca36f4 100644 --- a/python/tvm/topi/testing/conv2d_backcward_weight_python.py +++ b/python/tvm/topi/testing/conv2d_backcward_weight_python.py @@ -20,7 +20,9 @@ # Reference: cutlass/tools/util/include/cutlass/util/reference/host/convolution.h -def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding): +def conv2d_backward_weight_nchw_python( + dy_np, x_np, kernel_size, stride, padding, groups=1, channels=None +): """Gradient of the conv2d op with respect to weight, in NCHW layout. Parameters @@ -51,17 +53,34 @@ def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding R, S = kernel_size pad_h, pad_w = padding stride_h, stride_w = stride - dw = np.zeros((K, C, R, S)).astype(dy_np.dtype) + is_depth_wise = C == K and C == groups + + if is_depth_wise: + assert channels == groups, "Only channel_mult == 1 supported for now." + dw = np.zeros((K, 1, R, S)).astype(dy_np.dtype) + else: + assert groups == 1, "General grouped conv2d not supported for now." + dw = np.zeros((K, C, R, S)).astype(dy_np.dtype) for k in range(K): for r in range(R): for s in range(S): - for c in range(C): + for c in range(dw.shape[1]): acc = 0 for n in range(N): for p in range(P): for q in range(Q): - coord = (n, c, p * stride_h - pad_h + r, q * stride_w - pad_w + s) + if not is_depth_wise: + in_c = c + else: + in_c = k + + coord = ( + n, + in_c, + p * stride_h - pad_h + r, + q * stride_w - pad_w + s, + ) if ( coord[2] < H @@ -76,7 +95,9 @@ def conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding return dw -def conv2d_backward_weight_python(dy_np, x_np, kernel_size, stride, padding, layout="NCHW"): +def conv2d_backward_weight_python( + dy_np, x_np, kernel_size, stride, padding, layout="NCHW", groups=1, channels=None +): """Gradient of the conv2d op with respect to weight, in NCHW or NHWC layout. Parameters @@ -99,6 +120,12 @@ def conv2d_backward_weight_python(dy_np, x_np, kernel_size, stride, padding, lay layout: string Layout of dy_np and x_np + groups: int + Number of groups for grouped convolution. + + channels : int + Number of output channels of this convolution. + Returns ------- dw_np : np.ndarray @@ -106,7 +133,9 @@ def conv2d_backward_weight_python(dy_np, x_np, kernel_size, stride, padding, lay [num_filter, filter_height, filter_width, in_channel] for NHWC layout. """ if layout == "NCHW": - return conv2d_backward_weight_nchw_python(dy_np, x_np, kernel_size, stride, padding) + return conv2d_backward_weight_nchw_python( + dy_np, x_np, kernel_size, stride, padding, groups, channels + ) dw_np_oihw = conv2d_backward_weight_nchw_python( np.transpose(dy_np, [0, 3, 1, 2]), @@ -114,5 +143,7 @@ def conv2d_backward_weight_python(dy_np, x_np, kernel_size, stride, padding, lay kernel_size, stride, padding, + groups, + channels, ) return np.transpose(dw_np_oihw, [0, 2, 3, 1]) diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 30386bbf4415a..3ec96713b2a6d 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -639,11 +639,26 @@ bool Conv2DBackwardWeightRel(const Array& types, int num_inputs, const Att auto in_channels = dshape_nchw[1]; auto out_channels = grad_shape_nchw[1]; - Array wshape_oihw( - {out_channels, in_channels, param->kernel_size[0], param->kernel_size[1]}); - + auto in_channels_intimm = in_channels.as(); + auto out_channels_intimm = out_channels.as(); + ICHECK(in_channels_intimm); + ICHECK(out_channels_intimm); + + IndexExpr weight_dim_i; + if (in_channels_intimm->value == out_channels_intimm->value && + in_channels_intimm->value == param->groups) { + // depthwise + ICHECK(param->channels.defined()) << "out_channels attribute not specified for depth wise conv2d."; + weight_dim_i = indexdiv(param->channels, param->groups); + } else { + weight_dim_i = indexdiv(in_channels, param->groups); + } + + Array wshape_oihw{out_channels, weight_dim_i, param->kernel_size[0], param->kernel_size[1]}; auto wshape = trans_kernel_layout.BackwardShape(wshape_oihw); - reporter->Assign(types[2], TensorType(wshape, data->dtype)); + + const auto dw_dtype = param->out_dtype == DataType() ? grad->dtype : param->out_dtype; + reporter->Assign(types[2], TensorType(wshape, dw_dtype)); return true; } diff --git a/tests/python/relay/test_op_grad_level2.py b/tests/python/relay/test_op_grad_level2.py index a5fc630f61dc8..55189d6525720 100644 --- a/tests/python/relay/test_op_grad_level2.py +++ b/tests/python/relay/test_op_grad_level2.py @@ -229,16 +229,24 @@ def test_batch_flatten_grad(): verify_batch_flatten_grad((1, 8)) -def verify_conv2d_backward_weight(dy_shape, x_shape, kernel_size, stride, padding): +def verify_conv2d_backward_weight(dy_shape, x_shape, kernel_size, stride, padding, groups=1, out_channels=None): dtype = "float32" dy = relay.var("dy", shape=dy_shape, dtype=dtype) x = relay.var("x", shape=x_shape, dtype=dtype) dw_func = relay.Function( [dy, x], relay.nn.conv2d_backward_weight( - dy, x, strides=stride, padding=padding, kernel_size=kernel_size + dy, + x, + strides=stride, + padding=padding, + kernel_size=kernel_size, + groups=groups, + channels=out_channels, + out_dtype=dtype, ), ) + dw_func_legalized = run_opt_pass(dw_func, relay.transform.Legalize()) for dw, target in [(dw_func_legalized, "llvm"), (dw_func, "cuda -libs=cudnn")]: @@ -251,7 +259,7 @@ def verify_conv2d_backward_weight(dy_shape, x_shape, kernel_size, stride, paddin dw_np = relay.create_executor(device=dev, target=target).evaluate(dw)(dy_np, x_np).numpy() ref_dw_np = tvm.topi.testing.conv2d_backward_weight_python( - dy_np, x_np, kernel_size, stride, padding + dy_np, x_np, kernel_size, stride, padding, groups=groups, channels=out_channels ) np.testing.assert_allclose(dw_np, ref_dw_np, rtol=1e-4, atol=1e-4) @@ -260,7 +268,11 @@ def verify_conv2d_backward_weight(dy_shape, x_shape, kernel_size, stride, paddin def test_conv2d_backward_weight(): verify_conv2d_backward_weight((2, 8, 32, 32), (2, 4, 32, 32), (3, 3), (1, 1), (1, 1)) verify_conv2d_backward_weight((2, 16, 15, 15), (2, 3, 32, 32), (3, 3), (2, 2), (0, 0)) + verify_conv2d_backward_weight( + (1, 16, 32, 32), (1, 16, 32, 32), (3, 3), (1, 1), (1, 1), groups=16, out_channels=16 + ) if __name__ == "__main__": - pytest.main([__file__]) + # pytest.main([__file__]) + test_conv2d_backward_weight()