From b99fdc109fba9af4fcb5604f628cb3bdf91253d4 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Thu, 28 Feb 2019 15:27:04 -0800 Subject: [PATCH 1/2] Add slice axis op in mxnet converter --- python/tvm/relay/frontend/mxnet.py | 29 +++++++++++++++++++++ tests/python/frontend/mxnet/test_forward.py | 18 +++++++++++++ 2 files changed, 47 insertions(+) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 9ef5f626393a..a3cfe4f09426 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -187,6 +187,34 @@ def _mx_slice(inputs, attrs): return _op.strided_slice(inputs[0], **new_attrs) +def _mx_slice_axis(inputs, attrs): + assert len(inputs) == 1 + shape = ir_pass.infer_type(inputs[0]).checked_type.shape + axis = attrs.get_int("axis") + ax_beg = attrs.get_int("begin") + ax_end = attrs.get_str("end") + if ax_end == "None": + ax_end = int(shape[axis]) + else: + ax_end = int(ax_end) + if ax_beg < 0: + ax_beg += int(shape[axis]) + if ax_end < 0: + ax_end += int(shape[axis]) + assert ax_beg >= 0 and ax_beg < int(shape[axis]) + assert ax_end > ax_beg and ax_end <= int(shape[axis]) + begin = [] + end = [] + for i in range(len(shape)): + if i != axis: + begin.append(0) + end.append(shape[i]) + else: + begin.append(ax_beg) + end.append(ax_end) + return _op.strided_slice(inputs[0], begin, end) + + def _mx_split(inputs, attrs): axis = attrs.get_int("axis", 1) new_attrs = {} @@ -384,6 +412,7 @@ def _mx_roi_align(inputs, attrs): "BatchNorm_v1" : _mx_batch_norm, "LRN" : _mx_lrn, "slice" : _mx_slice, + "slice_axis" : _mx_slice_axis, "SliceChannel" : _mx_split, "split" : _mx_split, "expand_dims" : _mx_expand_dims, diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 671316079308..c3c0ffd450db 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -257,6 +257,23 @@ def verify(start, stop, step): verify(20, 1, -1.5) +def test_forward_slice_axis(): + def verify(shape, axis, begin, end): + data_np = np.random.uniform(size=shape).astype("float32") + ref_res = mx.nd.slice_axis(mx.nd.array(data_np), axis, begin, end) + mx_sym = mx.sym.slice_axis(mx.sym.var("data"), axis, begin, end) + new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"data": shape}) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(new_sym)(data_np) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) + verify((3, 4), 0, 1, 2) + verify((3, 4), 0, 1, None) + verify((3, 4), 1, 0, 2) + verify((3, 4), 1, -3, -1) + + if __name__ == '__main__': test_forward_mlp() test_forward_vgg() @@ -280,3 +297,4 @@ def verify(start, stop, step): test_forward_argmin() test_forward_where() test_forward_arange() + test_forward_slice_axis() From c0c6bde60b519bb43aad3f9a9effdd4516e66a9e Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Thu, 28 Feb 2019 16:00:47 -0800 Subject: [PATCH 2/2] Fix lint --- python/tvm/relay/frontend/mxnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index a3cfe4f09426..6c1d6a62dbce 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -205,10 +205,10 @@ def _mx_slice_axis(inputs, attrs): assert ax_end > ax_beg and ax_end <= int(shape[axis]) begin = [] end = [] - for i in range(len(shape)): + for i, dim in enumerate(shape): if i != axis: begin.append(0) - end.append(shape[i]) + end.append(dim) else: begin.append(ax_beg) end.append(ax_end)