From 5d70b008668ba244dad598d1ce3e8a79e2d6cede Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 9 Apr 2019 13:20:56 +0800 Subject: [PATCH] [Relay] InferCorrectLayout for strided_slice & min_num_branches option in CombineParallelConv2D (#2961) * [Relay] InferCorrectLayout for strided_slice * Add min_num_branches option to CombineParallelConv2D * Return undef if original layout contains splitted axes --- python/tvm/relay/ir_pass.py | 9 ++- src/relay/op/tensor/transform.cc | 61 ++++++++++++++++++- src/relay/pass/combine_parallel_conv2d.cc | 15 ++++- .../python/relay/test_pass_alter_op_layout.py | 43 +++++++++++++ .../test_pass_combine_parallel_conv2d.py | 8 +-- 5 files changed, 125 insertions(+), 11 deletions(-) diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 8eb0adc3da1a..b3d323b2aed6 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -722,20 +722,23 @@ def fuse_ops(expr, opt_level=1): return _ir_pass.FuseOps(expr, opt_level) -def combine_parallel_conv2d(expr): - """Fold multiple conv2d into one. +def combine_parallel_conv2d(expr, min_num_branches=3): + """Combine multiple conv2d into one. Parameters ---------- expr : tvm.relay.Expr The input expression. + min_num_branches : int + The minimum number of parallel branches when the transformation should be applied. + Returns ------- transformed_expr : tvm.relay.Expr Transformed expression """ - return _ir_pass.CombineParallelConv2D(expr) + return _ir_pass.CombineParallelConv2D(expr, min_num_branches) def alter_op_layout(expr): diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 15eaceb41a2d..f86156bdbddc 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1722,6 +1722,64 @@ bool StridedSliceRel(const Array& types, } +Array > StridedSliceInferCorrectLayout( + const Attrs& attrs, + const Array& new_in_layouts, + const Array& old_in_layouts, + const Array>& old_in_shapes) { + CHECK(old_in_layouts.defined()); + CHECK_EQ(old_in_layouts.size(), 1); + CHECK(old_in_shapes.defined()); + CHECK_EQ(old_in_shapes.size(), 1); + + auto layout = old_in_layouts[0]; + if (layout.defined() && new_in_layouts.defined()) { + CHECK_EQ(new_in_layouts.size(), 1); + auto new_layout = new_in_layouts[0]; + auto shape = old_in_shapes[0]; + + // NOTE: Discard "const" qualifier here. + auto *params = const_cast(attrs.as()); + + Array new_begin, new_end; + + for (size_t i = 0; i < params->begin.size(); i++) { + const LayoutAxis& axis = layout[i]; + if (!axis.IsPrimal()) { + // original layout that contains splitted axes is not supported + return {{Layout::Undef()}, {Layout::Undef()}}; + } + auto factor = new_layout.FactorOf(axis); + if (factor == -1) { + new_begin.push_back(params->begin[i]); + new_end.push_back(params->end[i]); + } else { + if (params->strides.defined() && i < params->strides.size()) { + auto stride = params->strides[i]; + // arbitrary stride is not supported + if (stride.defined() && stride->value != 1) { + return {{Layout::Undef()}, {Layout::Undef()}}; + } + } + int64_t begin = params->begin[i].defined() ? params->begin[i]->value : 0; + int64_t end = params->end[i].defined() ? params->end[i]->value : + shape[i].as()->value; + if (begin % factor || end % factor) { + // transform to original layout + return {{Layout::Undef()}, {Layout::Undef()}}; + } + new_begin.push_back(tvm::Integer(begin / factor)); + new_end.push_back(tvm::Integer(end / factor)); + } + } + layout = new_layout; + params->begin = new_begin; + params->end = new_end; + } + return {{layout}, {layout}}; +} + + // Positional relay function to create StridedSlice operator used by frontend FFI. Expr MakeStridedSlice(Expr data, Array begin, @@ -1783,7 +1841,8 @@ Examples:: .set_attrs_type_key("relay.attrs.StridedSliceAttrs") .add_type_rel("StridedSlice", StridedSliceRel) .set_attr("FTVMCompute", StridedSliceCompute) -.set_attr("TOpPattern", kInjective); +.set_attr("TOpPattern", kInjective) +.set_attr("FInferCorrectLayout", StridedSliceInferCorrectLayout); // relay.split diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index cb53698762ad..cd7a852bcad7 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -159,10 +159,15 @@ class BranchGroupFinder : private ExprVisitor { class ParallelConv2DCombiner { public: + explicit ParallelConv2DCombiner(uint64_t min_num_branches) : min_num_branches_(min_num_branches) { + } + Expr Combine(const Expr& expr) { auto groups = BranchGroupFinder().Find(expr); for (const Group& group : groups) { - if (group.size() < 2) continue; + if (group.size() < min_num_branches_) { + continue; + } CombineBranches(group); } return ExprSubst(expr, std::move(subst_map_)); @@ -170,6 +175,7 @@ class ParallelConv2DCombiner { private: std::unordered_map subst_map_; + uint64_t min_num_branches_; std::tuple TransformWeight(const Group& branches) { int64_t num_filters = 0; // number of filters of the transformed weight @@ -343,11 +349,14 @@ class ParallelConv2DCombiner { } }; -Expr CombineParallelConv2D(const Expr& expr) { return ParallelConv2DCombiner().Combine(expr); } +/*! \brief Combine parallel conv2d if number of branches >= min_num_branches */ +Expr CombineParallelConv2D(const Expr& expr, uint64_t min_num_branches) { + return ParallelConv2DCombiner(min_num_branches).Combine(expr); +} TVM_REGISTER_API("relay._ir_pass.CombineParallelConv2D") .set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = CombineParallelConv2D(args[0]); + *ret = CombineParallelConv2D(args[0], args[1]); }); } // namespace relay diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 0f21288245d9..f7a1c83ddff1 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -472,6 +472,48 @@ def expected(): assert(alpha_equal(a, b)) +def test_alter_layout_strided_slice(): + """Test rewriting strided_slice during alter_iop_layout""" + def before(): + x = relay.var("x", shape=(1, 32, 28, 28)) + weight = relay.var('weight', shape=(32, 32, 3, 3)) + y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1)) + y = relay.strided_slice(y, begin=[0, 16], end=[None, None]) + y = relay.Function(free_vars(y), y) + return y + + @register_alter_op_layout("nn.conv2d", level=109) + def alter_conv2d(attrs, inputs, tinfos): + data, weight = inputs + new_attrs = dict(attrs) + new_attrs['data_layout'] = 'NCHW4c' + return relay.nn.conv2d(data, weight, **new_attrs) + + def expected(): + x = relay.var("x", shape=(1, 32, 28, 28)) + weight = relay.var("weight") + x = relay.layout_transform(x, "NCHW", "NCHW4c") + y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1), + data_layout="NCHW4c") + y = relay.strided_slice(y, begin=[0, 4], end=[None, 8]) + y = relay.layout_transform(y, "NCHW4c", "NCHW") + y = relay.Function(free_vars(y), y) + return y + + a = before() + a = infer_type(a) + a = canonicalize_ops(a) + a = infer_type(a) + + a = alter_op_layout(a) + a = infer_type(a) + + b = expected() + b = infer_type(b) + + assert(alpha_equal(a, b)) + + if __name__ == "__main__": test_alter_op() test_alter_return_none() @@ -482,3 +524,4 @@ def expected(): test_alter_layout_scalar() test_alter_layout_concatenate() test_alter_layout_nchw_upsamping_op() + test_alter_layout_strided_slice() diff --git a/tests/python/relay/test_pass_combine_parallel_conv2d.py b/tests/python/relay/test_pass_combine_parallel_conv2d.py index 0d6e1e39b509..3bb656b2bda5 100644 --- a/tests/python/relay/test_pass_combine_parallel_conv2d.py +++ b/tests/python/relay/test_pass_combine_parallel_conv2d.py @@ -55,7 +55,7 @@ def check(x_shape, channels1, channels2, channels3, channels4): y_before = before(x, w1, w2, w3, w4) y = relay.ir_pass.infer_type(y_before) - y = relay.ir_pass.combine_parallel_conv2d(y) + y = relay.ir_pass.combine_parallel_conv2d(y, min_num_branches=2) y = relay.ir_pass.infer_type(y) y_expected = expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4) y_expected = relay.ir_pass.infer_type(y_expected) @@ -102,7 +102,7 @@ def check(x_shape, channels1, channels2): bias = relay.var("bias", shape=(channels2, 1, 1)) y_before = before(x, w1, w2, scale1, scale2, bias) y = relay.ir_pass.infer_type(y_before) - y = relay.ir_pass.combine_parallel_conv2d(y) + y = relay.ir_pass.combine_parallel_conv2d(y, min_num_branches=2) y = relay.ir_pass.infer_type(y) y_expected = expected(x, w1, w2, scale1, scale2, bias, channels1, channels2) y_expected = relay.ir_pass.infer_type(y_expected) @@ -142,7 +142,7 @@ def check(x_shape, channels1, channels2): scale2 = relay.var("scale2", shape=(1,)) y_before = before(x, w1, w2, scale1, scale2) y = relay.ir_pass.infer_type(y_before) - y = relay.ir_pass.combine_parallel_conv2d(y) + y = relay.ir_pass.combine_parallel_conv2d(y, min_num_branches=2) y = relay.ir_pass.infer_type(y) y_expected = expected(x, w1, w2, scale1, scale2, channels1, channels2) y_expected = relay.ir_pass.infer_type(y_expected) @@ -179,7 +179,7 @@ def check(x_shape, repeat): w = relay.var("w", shape=(out_c, in_c, 1, 1)) y_before = before(x, w, repeat) y = relay.ir_pass.infer_type(y_before) - y = relay.ir_pass.combine_parallel_conv2d(y) + y = relay.ir_pass.combine_parallel_conv2d(y, min_num_branches=2) y = relay.ir_pass.infer_type(y) y_expected = expected(x, w, out_c, repeat) y_expected = relay.ir_pass.infer_type(y_expected)