diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 1f1d18e240cd..4d341c76043a 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -194,6 +194,34 @@ def _mx_slice(inputs, attrs): return _op.strided_slice(inputs[0], **new_attrs) +def _mx_slice_axis(inputs, attrs): + assert len(inputs) == 1 + shape = ir_pass.infer_type(inputs[0]).checked_type.shape + axis = attrs.get_int("axis") + ax_beg = attrs.get_int("begin") + ax_end = attrs.get_str("end") + if ax_end == "None": + ax_end = int(shape[axis]) + else: + ax_end = int(ax_end) + if ax_beg < 0: + ax_beg += int(shape[axis]) + if ax_end < 0: + ax_end += int(shape[axis]) + assert ax_beg >= 0 and ax_beg < int(shape[axis]) + assert ax_end > ax_beg and ax_end <= int(shape[axis]) + begin = [] + end = [] + for i, dim in enumerate(shape): + if i != axis: + begin.append(0) + end.append(dim) + else: + begin.append(ax_beg) + end.append(ax_end) + return _op.strided_slice(inputs[0], begin, end) + + def _mx_split(inputs, attrs): axis = attrs.get_int("axis", 1) new_attrs = {} @@ -423,6 +451,7 @@ def _mx_roi_align(inputs, attrs): "BatchNorm_v1" : _mx_batch_norm, "LRN" : _mx_lrn, "slice" : _mx_slice, + "slice_axis" : _mx_slice_axis, "SliceChannel" : _mx_split, "split" : _mx_split, "expand_dims" : _mx_expand_dims, diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index ee47d72046ed..7f53aa8a0155 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -337,6 +337,23 @@ def test_forward_scalar_ops(): tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) +def test_forward_slice_axis(): + def verify(shape, axis, begin, end): + data_np = np.random.uniform(size=shape).astype("float32") + ref_res = mx.nd.slice_axis(mx.nd.array(data_np), axis, begin, end) + mx_sym = mx.sym.slice_axis(mx.sym.var("data"), axis, begin, end) + new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"data": shape}) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(new_sym)(data_np) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) + verify((3, 4), 0, 1, 2) + verify((3, 4), 0, 1, None) + verify((3, 4), 1, 0, 2) + verify((3, 4), 1, -3, -1) + + if __name__ == '__main__': test_forward_mlp() test_forward_vgg() @@ -363,3 +380,4 @@ def test_forward_scalar_ops(): test_forward_broadcast_ops() test_forward_elemwise_ops() test_forward_scalar_ops() + test_forward_slice_axis() \ No newline at end of file