Skip to content

Commit

Permalink
[Relay] Improve reduction op layout propagation for packed input (apa…
Browse files Browse the repository at this point in the history
…che#9253)

* wip

* fixed packed dim size logic

* fixed test

* formatting

* fix compile warning
  • Loading branch information
masahi authored and ylc committed Jan 13, 2022
1 parent 95e7a8f commit 345c9b9
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 22 deletions.
42 changes: 30 additions & 12 deletions src/relay/op/tensor/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,23 +149,41 @@ InferCorrectLayoutOutput ReduceInferCorrectLayout(const Attrs& attrs,
tvm::Array<tvm::Integer> new_r_axes;
std::string inferred_in_string = "";
std::string inferred_out_string = "";
int axis_index = 0;
for (auto iter_var : layout->axes) {
const auto& layout_axis = LayoutAxis::Get(iter_var);
auto push_new_axis = [&](const std::string& layout_dim, int axis) {
if ((old_r_dims.count(layout_dim) && !params->exclude) ||
(!old_r_dims.count(layout_dim) && params->exclude)) {
new_r_axes.push_back(tvm::Integer(axis));
return true;
}
return false;
};
for (size_t axis_index = 0; axis_index < layout->axes.size(); ++axis_index) {
const auto& layout_axis = LayoutAxis::Get(layout->axes[axis_index]);
const std::string& layout_dim = layout_axis.name();
// Collect only the primal axis.
if (layout_axis.IsPrimal()) {
if (old_r_dims.count(layout_dim) && !params->exclude) {
new_r_axes.push_back(tvm::Integer(axis_index));
}
if (!old_r_dims.count(layout_dim) && params->exclude) {
new_r_axes.push_back(tvm::Integer(axis_index));
}
push_new_axis(layout_dim, axis_index);
inferred_in_string += layout_dim;
if (!old_r_dims.count(layout_dim) || params->keepdims) {
inferred_out_string += layout_dim;
}
inferred_in_string += layout_dim;
axis_index++;
} else {
// For example, if the original layout is NCHW, the new layout is NCHW8c, and the original
// reduce axes is [1], the new reduce axes become [1, 4].
auto primal_dim = layout_axis.ToPrimal().name();
auto packed_dim = std::to_string(layout.FactorOf(layout_axis)) + layout_dim;
inferred_in_string += packed_dim;
if (push_new_axis(primal_dim, axis_index)) {
if (params->exclude) {
// The primal axis is not reduced, so keep the input packed dim.
inferred_out_string += packed_dim;
} else {
// If the primal axis is part of reduce axes in the original layout, the inner dim
// becomes 1 after reduction.
inferred_out_string += "1" + layout_dim;
}
} else {
inferred_out_string += packed_dim;
}
}
}

Expand Down
22 changes: 12 additions & 10 deletions tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,20 +507,23 @@ def expected():
bias = relay.layout_transform(bias, src_layout="NHWC", dst_layout="NCHW")
bias = relay.layout_transform(bias, src_layout="NCHW", dst_layout="NCHW16c")
add = relay.add(y, bias)
y = relay.layout_transform(add, src_layout="NCHW16c", dst_layout="NCHW")
mean = relay.mean(y, axis=1, exclude=True)
var = relay.variance(y, axis=1, exclude=True)
mean = relay.mean(add, axis=[1, 4], exclude=True)
var = relay.variance(add, axis=[1, 4], exclude=True)
denom = relay.const(1.0) / relay.sqrt(var + relay.const(1e-05))
gamma = relay.var("gamma", shape=(16,))
denom = denom * gamma
denom_c16c = denom * relay.layout_transform(gamma, src_layout="C", dst_layout="C16c")
denom = relay.layout_transform(denom_c16c, src_layout="C16c", dst_layout="C")
denom_expand1 = relay.expand_dims(denom, axis=1, num_newaxis=2)
denom_expand2 = relay.expand_dims(denom_expand1, axis=0)
denom_nchwc16 = relay.layout_transform(
denom_expand2, src_layout="NCHW", dst_layout="NCHW16c"
)
out = add * denom_nchwc16
beta = relay.var("beta", shape=(16,))
numerator = (-mean) * denom + beta
numerator_c16c = (-mean) * denom_c16c + relay.layout_transform(
beta, src_layout="C", dst_layout="C16c"
)
numerator = relay.layout_transform(numerator_c16c, src_layout="C16c", dst_layout="C")
numerator_expand1 = relay.expand_dims(numerator, axis=1, num_newaxis=2)
numerator_expand2 = relay.expand_dims(numerator_expand1, axis=0)
numerator_nchwc16 = relay.layout_transform(
Expand Down Expand Up @@ -1096,8 +1099,8 @@ def expected_nchw():
y = relay.nn.conv2d(
y, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW16c"
)
ret = relay.layout_transform(y, "NCHW16c", "NCHW")
ret = relay.sum(ret, axis=[1], keepdims=True)
ret = relay.sum(y, axis=[1, 4], keepdims=True)
ret = relay.layout_transform(ret, "NCHW1c", "NCHW")
y = relay.Function(analysis.free_vars(ret), ret)
return y

Expand Down Expand Up @@ -1126,9 +1129,8 @@ def expected_nhwc():
y = relay.nn.conv2d(
y, weight1, channels=32, kernel_size=(3, 3), padding=(1, 1), data_layout="NCHW16c"
)
ret = relay.layout_transform(y, "NCHW16c", "NCHW")
ret = relay.sum(ret, axis=[1], keepdims=True)
ret = relay.layout_transform(ret, "NCHW", "NHWC")
ret = relay.sum(y, axis=[1, 4], keepdims=True)
ret = relay.layout_transform(ret, "NCHW1c", "NHWC")
y = relay.Function(analysis.free_vars(ret), ret)
return y

Expand Down

0 comments on commit 345c9b9

Please sign in to comment.