diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index fc80c9ed6171..ea8b826588b1 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -101,18 +101,6 @@ def schedule_lrn_cuda(attrs, outs, target): return topi.cuda.schedule_lrn(outs) -def naive_schedule(_, outs, target): - """Return the naive default schedule""" - if "gpu" in target.keys: - # For GPU, we at least need thread binding to make a valid schedule. - # So the naive schedule cannot be compiled. - raise RuntimeError( - "Cannot compile for GPU targets if no tuned schedule is found. " - "Please see the warning messages above for more information about the failed workloads." - ) - return tvm.te.create_schedule(outs[-1].op) - - @conv2d_strategy.register(["cuda", "gpu"]) def conv2d_strategy_cuda(attrs, inputs, out_type, target): """conv2d cuda strategy""" diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 15c7f2f7fa17..e888eb4d037b 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -27,6 +27,18 @@ logger = logging.getLogger("strategy") +def naive_schedule(_, outs, target): + """Return the naive default schedule""" + if "gpu" in target.keys: + # For GPU, we at least need thread binding to make a valid schedule. + # So the naive schedule cannot be compiled. + raise RuntimeError( + "Cannot compile for GPU targets if no tuned schedule is found. " + "Please see the warning messages above for more information about the failed workloads." + ) + return te.create_schedule(outs[-1].op) + + def wrap_topi_schedule(topi_schedule): """Wrap TOPI schedule which doesn't use attrs""" @@ -357,7 +369,6 @@ def wrap_compute_deformable_conv2d(topi_compute): """wrap deformable_conv2d topi compute""" def _compute_deformable_conv2d(attrs, inputs, out_dtype): - assert attrs.data_layout == "NCHW" padding = get_const_tuple(attrs.padding) strides = get_const_tuple(attrs.strides) dilation = get_const_tuple(attrs.dilation) @@ -384,15 +395,24 @@ def _compute_deformable_conv2d(attrs, inputs, out_dtype): @override_native_generic_func("deformable_conv2d_strategy") def deformable_conv2d_strategy(attrs, inputs, out_type, target): """deformable_conv2d generic strategy""" - logger.warning("deformable_conv2d is not optimized for this platform.") layout = attrs.data_layout - assert layout == "NCHW" strategy = _op.OpStrategy() - strategy.add_implementation( - wrap_compute_deformable_conv2d(topi.nn.deformable_conv2d_nchw), - wrap_topi_schedule(topi.generic.schedule_deformable_conv2d_nchw), - name="deformable_conv2d.generic", - ) + + if layout == "NCHW": + strategy.add_implementation( + wrap_compute_deformable_conv2d(topi.nn.deformable_conv2d_nchw), + wrap_topi_schedule(topi.generic.schedule_deformable_conv2d_nchw), + name="deformable_conv2d_nchw.generic", + ) + elif layout == "NHWC": + # This implementation should never be picked by autotvm + strategy.add_implementation( + wrap_compute_deformable_conv2d(topi.nn.deformable_conv2d_nhwc), + wrap_topi_schedule(naive_schedule), + name="deformable_conv2d_nhwc.generic", + ) + else: + raise RuntimeError("Layout %s is not supported in deformable conv2d" % layout) return strategy diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h index 13e87a54b9d8..50e07a604764 100644 --- a/src/relay/op/nn/convolution.h +++ b/src/relay/op/nn/convolution.h @@ -1106,9 +1106,43 @@ bool DeformableConv2DRel(const Array& types, int num_inputs, const Attrs& const auto* weight = types[2].as(); ICHECK(data); + static const Layout kNCHW("NCHW"); + static const Layout kOIHW("OIHW"); + auto* param = attrs.as(); - ICHECK_EQ(param->data_layout, "NCHW") << "data layout not supported."; - ICHECK_EQ(param->kernel_layout, "OIHW") << "kernel_layout not supported."; + ICHECK(param != nullptr); + const Layout in_layout(param->data_layout); + const Layout kernel_layout(param->kernel_layout); + + const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW); + if (!trans_in_layout.defined()) { + reporter->GetDiagCtx().Emit( + Diagnostic::Error(reporter->GetSpan()) + << "deformable_conv2d only support input layouts that are convertible from NCHW." + << " The provided layout is: " << in_layout); + return false; + } + + const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW); + if (!trans_kernel_layout.defined()) { + reporter->GetDiagCtx().Emit( + Diagnostic::Error(reporter->GetSpan()) + << "deformable_conv2d only support kernel layouts that are convertible from OIHW." + << " The provided layout is: " << kernel_layout); + return false; + } + + Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); + const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW); + if (!trans_out_layout.defined()) { + reporter->GetDiagCtx().Emit( + Diagnostic::Error(reporter->GetSpan()) + << "deformable_conv2d only support output layouts that are convertible from NCHW." + << "The provided layout is: " << out_layout); + return false; + } + + Array dshape_nchw = trans_in_layout.ForwardShape(data->shape); IndexExpr channels, dilated_ksize_y, dilated_ksize_x, ksize_y, ksize_x; @@ -1116,8 +1150,10 @@ bool DeformableConv2DRel(const Array& types, int num_inputs, const Attrs& if (param->kernel_size.defined() && param->channels.defined()) { ICHECK_EQ(param->kernel_size.size(), 2); ICHECK_EQ(param->dilation.size(), 2); - Array wshape({param->channels, indexdiv(data->shape[1], param->groups), + Array wshape({param->channels, indexdiv(dshape_nchw[1], param->groups), param->kernel_size[0], param->kernel_size[1]}); + + wshape = trans_kernel_layout.BackwardShape(wshape); channels = param->channels; ksize_y = param->kernel_size[0]; ksize_x = param->kernel_size[1]; @@ -1128,7 +1164,8 @@ bool DeformableConv2DRel(const Array& types, int num_inputs, const Attrs& } else { // use weight to infer the conv shape. if (weight == nullptr) return false; - auto wshape = weight->shape; + auto wshape = trans_kernel_layout.ForwardShape(weight->shape); + if (param->kernel_size.defined()) { ICHECK_EQ(param->kernel_size.size(), 2); // check the size @@ -1142,8 +1179,8 @@ bool DeformableConv2DRel(const Array& types, int num_inputs, const Attrs& << "DeformableConv2D: shape of weight is inconsistent with channels, " << " channels=" << param->channels << " wshape=" << wshape; } - if (!data->shape[1].as() && !wshape[1].as()) { - ICHECK(reporter->AssertEQ(indexdiv(data->shape[1], param->groups), wshape[1])); + if (!dshape_nchw[1].as() && !wshape[1].as()) { + ICHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[1])); } channels = wshape[0]; ksize_y = wshape[2]; @@ -1152,22 +1189,24 @@ bool DeformableConv2DRel(const Array& types, int num_inputs, const Attrs& dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; } // dilation - Array oshape({data->shape[0], channels, 0, 0}); + Array oshape({dshape_nchw[0], channels, 0, 0}); IndexExpr pad_h, pad_w; GetPaddingHeightWidth(param->padding, &pad_h, &pad_w); - oshape.Set(2, indexdiv(data->shape[2] + pad_h - dilated_ksize_y, param->strides[0]) + 1); - oshape.Set(3, indexdiv(data->shape[3] + pad_w - dilated_ksize_x, param->strides[1]) + 1); + oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y, param->strides[0]) + 1); + oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x, param->strides[1]) + 1); DataType out_dtype = param->out_dtype; // infer offset shape Array offset_shape( - {data->shape[0], 2 * ksize_y * ksize_x * param->deformable_groups, oshape[2], oshape[3]}); + {dshape_nchw[0], 2 * ksize_y * ksize_x * param->deformable_groups, oshape[2], oshape[3]}); + offset_shape = trans_in_layout.BackwardShape(offset_shape); reporter->Assign(types[1], TensorType(offset_shape, data->dtype)); if (out_dtype.bits() == 0) { out_dtype = data->dtype; } + oshape = trans_out_layout.BackwardShape(oshape); reporter->Assign(types[3], TensorType(oshape, out_dtype)); return true; } diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index f114957f3cab..b3b65539cf81 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -787,12 +787,34 @@ def verify_yolo_reorg(shape, stride): @tvm.testing.uses_gpu def test_deformable_conv2d(): - def test_infer_type(batch, in_channel, size, out_channel, deformable_groups, groups): - data_shape = (batch, in_channel, size, size) + def test_infer_type(batch, in_channel, size, out_channel, deformable_groups, groups, layout): + kernel_size = (3, 3) + if layout == "NCHW": + kernel_layout = "OIHW" + data_shape = (batch, in_channel, size, size) + weight_shape = (out_channel, in_channel // groups, kernel_size[0], kernel_size[1]) + out_shape = (batch, out_channel, size, size) + offset_shape = ( + batch, + 2 * kernel_size[0] * kernel_size[1] * deformable_groups, + out_shape[2], + out_shape[3], + ) + else: + kernel_layout = "HWIO" + data_shape = (batch, size, size, in_channel) + weight_shape = (kernel_size[0], kernel_size[1], in_channel // groups, out_channel) + out_shape = (batch, size, size, out_channel) + offset_shape = ( + batch, + out_shape[1], + out_shape[2], + 2 * kernel_size[0] * kernel_size[1] * deformable_groups, + ) + data = relay.var("data", shape=data_shape) offset = relay.var("offset") kernel = relay.var("kernel") - kernel_size = (3, 3) y = relay.nn.deformable_conv2d( data, offset, @@ -800,26 +822,22 @@ def test_infer_type(batch, in_channel, size, out_channel, deformable_groups, gro strides=(1, 1), padding=(1, 1), dilation=(1, 1), + data_layout=layout, + kernel_layout=kernel_layout, kernel_size=kernel_size, deformable_groups=deformable_groups, groups=groups, channels=out_channel, ) - weight_shape = (out_channel, in_channel // groups, kernel_size[0], kernel_size[1]) - out_shape = (batch, out_channel, size, size) - offset_shape = ( - batch, - 2 * kernel_size[0] * kernel_size[1] * deformable_groups, - out_shape[2], - out_shape[3], - ) yy = run_infer_type(y) - assert yy.checked_type == relay.TensorType(out_shape) + assert yy.checked_type == relay.TensorType(out_shape), yy.checked_type assert yy.args[1].checked_type == relay.TensorType(offset_shape), yy.args[1].checked_type - assert yy.args[2].checked_type == relay.TensorType(weight_shape) + assert yy.args[2].checked_type == relay.TensorType(weight_shape), yy.args[2].checked_type - test_infer_type(1, 4, 16, 4, 4, 1) - test_infer_type(2, 4, 16, 4, 1, 2) + test_infer_type(1, 4, 16, 4, 4, 1, "NCHW") + test_infer_type(2, 4, 16, 4, 1, 2, "NCHW") + test_infer_type(1, 4, 16, 4, 4, 1, "NHWC") + test_infer_type(2, 4, 16, 4, 1, 2, "NHWC") def test_run(batch, in_channel, size, out_channel, deformable_groups, groups): kernel_size = (3, 3) @@ -1216,4 +1234,3 @@ def verify_batch_to_space_nd(dshape, block_shape, crops): test_affine_grid() test_grid_sample() test_space_to_batch_nd() - test_batch_to_space_nd()