From 4220cbebe6c1b6d766712883ef06eda42c8c3b6f Mon Sep 17 00:00:00 2001 From: masahi Date: Tue, 12 Oct 2021 02:46:15 +0900 Subject: [PATCH] [AlterLayout] Strided slice layout transform fix (disallow NCHW4c -> NCHW etc properly) (#9245) * prohibit propagating through packed to unpacked layout * add test --- src/relay/op/tensor/transform.cc | 16 +--- .../python/relay/test_pass_alter_op_layout.py | 74 +++++++++++++------ 2 files changed, 54 insertions(+), 36 deletions(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 9dc520a65406..1340ddfbb058 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2651,24 +2651,19 @@ InferCorrectLayoutOutput StridedSliceInferCorrectLayout( params->strides = new_strides; layout = new_layout; } - } else { + } else if (old_layout_name.size() < + new_layout_name.size()) { // prohibit transforms such as NCHW4c -> NCHW if (params->axes) { auto axes = params->axes.value(); Array new_axes; - for (size_t i = 0; i < axes.size(); ++i) { auto old_idx = axes[i]; auto new_idx = new_layout.IndexOf(layout[old_idx]); new_axes.push_back(new_idx); const LayoutAxis& axis = layout[old_idx]; - if (!axis.IsPrimal()) { - // original layout that contains splitted axes is not supported - return out_default; - } - + ICHECK(axis.IsPrimal()); auto factor = new_layout.FactorOf(axis); - if (factor == -1) { new_begin.push_back(begin[i]); new_end.push_back(end[i]); @@ -2688,10 +2683,7 @@ InferCorrectLayoutOutput StridedSliceInferCorrectLayout( } else { for (size_t i = 0; i < begin.size(); i++) { const LayoutAxis& axis = layout[i]; - if (!axis.IsPrimal()) { - // original layout that contains splitted axes is not supported - return out_default; - } + ICHECK(axis.IsPrimal()); auto factor = new_layout.FactorOf(axis); if (factor == -1) { new_begin.push_back(IntImm(begin[i]->dtype, begin[i])); diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 3310b6b2ed69..19685b127d86 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -1397,28 +1397,54 @@ def expected(): assert tvm.ir.structural_equal(a, b) +def test_conv2d_strided_slice_packed_to_unpacked(): + """We do not support propagating through packed to unpacked layout""" + x_shape = (1, 1, 1, 1, 4) + w_shape = (9, 1, 3, 3, 4, 4) + + def before(): + x = relay.var("x", shape=x_shape) + weight = relay.var("weight", shape=w_shape) + y = relay.nn.conv2d( + x, + weight, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NCHW4c", + kernel_layout="OIHW4i4o", + ) + y = relay.strided_slice(y, begin=[0, 0], end=[1, -1], strides=[1, 8]) + return relay.Function([x, weight], y) + + def expected(): + x = relay.var("x", shape=x_shape) + weight = relay.var("weight", shape=w_shape) + x_nchw = relay.layout_transform(x, src_layout="NCHW4c", dst_layout="NCHW") + weight_oihw = relay.layout_transform(weight, src_layout="OIHW4i4o", dst_layout="OIHW") + y = relay.nn.conv2d( + x_nchw, + weight_oihw, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + ) + y = relay.layout_transform(y, src_layout="NCHW", dst_layout="NCHW4c") + y = relay.strided_slice(y, begin=[0, 0], end=[1, -1], strides=[1, 8]) + return relay.Function([x, weight], y) + + def alter_conv2d(attrs, inputs, tinfos, out_type): + data, weight = inputs + new_attrs = dict(attrs) + new_attrs["data_layout"] = "NCHW" + new_attrs["kernel_layout"] = "OIHW" + return relay.nn.conv2d(data, weight, **new_attrs) + + with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d): + a = run_opt_pass(before(), transform.AlterOpLayout()) + b = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(a, b) + + if __name__ == "__main__": - test_alter_op() - test_alter_return_none() - test_alter_layout() - test_alter_layout_dual_path() - test_alter_layout_lrn() - test_alter_layout_resnet() - test_alter_layout_broadcast_op() - test_alter_layout_broadcast_scalar_op() - test_alter_layout_scalar() - test_alter_layout_concatenate() - test_alter_layout_nchw_upsamping_op() - test_alter_layout_strided_slice() - test_alter_layout_depthwise_conv2d() - test_alter_layout_prelu() - test_alter_layout_pad() - test_alter_layout_pool() - test_alter_layout_sum() - test_alter_layout_nhwc_arm() - test_alter_layout_nhwc_int8_aarch64() - test_alter_op_with_global_var() - test_alter_op_dense() - test_alter_layout_strided_slice_axes_nhwc() - test_not_inplace_modify() - test_alter_op_dense_packed_data() + pytest.main([__file__])