Skip to content

Commit

Permalink
[AlterLayout] Strided slice layout transform fix (disallow NCHW4c -> …
Browse files Browse the repository at this point in the history
…NCHW etc properly) (apache#9245)

* prohibit propagating through packed to unpacked layout

* add test
  • Loading branch information
masahi committed Oct 14, 2021
1 parent f891e26 commit 586fa42
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 36 deletions.
16 changes: 4 additions & 12 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2599,24 +2599,19 @@ InferCorrectLayoutOutput StridedSliceInferCorrectLayout(
params->strides = new_strides;
layout = new_layout;
}
} else {
} else if (old_layout_name.size() <
new_layout_name.size()) { // prohibit transforms such as NCHW4c -> NCHW
if (params->axes) {
auto axes = params->axes.value();
Array<Integer> new_axes;

for (size_t i = 0; i < axes.size(); ++i) {
auto old_idx = axes[i];
auto new_idx = new_layout.IndexOf(layout[old_idx]);
new_axes.push_back(new_idx);

const LayoutAxis& axis = layout[old_idx];
if (!axis.IsPrimal()) {
// original layout that contains splitted axes is not supported
return out_default;
}

ICHECK(axis.IsPrimal());
auto factor = new_layout.FactorOf(axis);

if (factor == -1) {
new_begin.push_back(begin[i]);
new_end.push_back(end[i]);
Expand All @@ -2636,10 +2631,7 @@ InferCorrectLayoutOutput StridedSliceInferCorrectLayout(
} else {
for (size_t i = 0; i < begin.size(); i++) {
const LayoutAxis& axis = layout[i];
if (!axis.IsPrimal()) {
// original layout that contains splitted axes is not supported
return out_default;
}
ICHECK(axis.IsPrimal());
auto factor = new_layout.FactorOf(axis);
if (factor == -1) {
new_begin.push_back(IntImm(begin[i]->dtype, begin[i]));
Expand Down
74 changes: 50 additions & 24 deletions tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -1397,28 +1397,54 @@ def expected():
assert tvm.ir.structural_equal(a, b)


def test_conv2d_strided_slice_packed_to_unpacked():
"""We do not support propagating through packed to unpacked layout"""
x_shape = (1, 1, 1, 1, 4)
w_shape = (9, 1, 3, 3, 4, 4)

def before():
x = relay.var("x", shape=x_shape)
weight = relay.var("weight", shape=w_shape)
y = relay.nn.conv2d(
x,
weight,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW4c",
kernel_layout="OIHW4i4o",
)
y = relay.strided_slice(y, begin=[0, 0], end=[1, -1], strides=[1, 8])
return relay.Function([x, weight], y)

def expected():
x = relay.var("x", shape=x_shape)
weight = relay.var("weight", shape=w_shape)
x_nchw = relay.layout_transform(x, src_layout="NCHW4c", dst_layout="NCHW")
weight_oihw = relay.layout_transform(weight, src_layout="OIHW4i4o", dst_layout="OIHW")
y = relay.nn.conv2d(
x_nchw,
weight_oihw,
kernel_size=(3, 3),
padding=(1, 1),
data_layout="NCHW",
kernel_layout="OIHW",
)
y = relay.layout_transform(y, src_layout="NCHW", dst_layout="NCHW4c")
y = relay.strided_slice(y, begin=[0, 0], end=[1, -1], strides=[1, 8])
return relay.Function([x, weight], y)

def alter_conv2d(attrs, inputs, tinfos, out_type):
data, weight = inputs
new_attrs = dict(attrs)
new_attrs["data_layout"] = "NCHW"
new_attrs["kernel_layout"] = "OIHW"
return relay.nn.conv2d(data, weight, **new_attrs)

with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d):
a = run_opt_pass(before(), transform.AlterOpLayout())
b = run_opt_pass(expected(), transform.InferType())
assert tvm.ir.structural_equal(a, b)


if __name__ == "__main__":
test_alter_op()
test_alter_return_none()
test_alter_layout()
test_alter_layout_dual_path()
test_alter_layout_lrn()
test_alter_layout_resnet()
test_alter_layout_broadcast_op()
test_alter_layout_broadcast_scalar_op()
test_alter_layout_scalar()
test_alter_layout_concatenate()
test_alter_layout_nchw_upsamping_op()
test_alter_layout_strided_slice()
test_alter_layout_depthwise_conv2d()
test_alter_layout_prelu()
test_alter_layout_pad()
test_alter_layout_pool()
test_alter_layout_sum()
test_alter_layout_nhwc_arm()
test_alter_layout_nhwc_int8_aarch64()
test_alter_op_with_global_var()
test_alter_op_dense()
test_alter_layout_strided_slice_axes_nhwc()
test_not_inplace_modify()
test_alter_op_dense_packed_data()
pytest.main([__file__])

0 comments on commit 586fa42

Please sign in to comment.