Skip to content

Commit

Permalink
add converter for MXNet slice in Relay
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Feb 24, 2019
1 parent 411c973 commit 1939bec
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
17 changes: 17 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,22 @@ def _mx_batch_norm(inputs, attrs):
return _op.nn.batch_norm(*inputs, **new_attrs)


def _mx_slice(inputs, attrs):
new_attrs = {}
begin = attrs.get_int_tuple('begin', None)
end = attrs.get_int_tuple('end', None)
stride = attrs.get_int_tuple('step', None)
print(begin, end, stride)
if begin is None or end is None:
raise RuntimeError("begin and end are required parameters.")
if None in begin or None in end:
raise RuntimeError("None in begin or end is not supported yet.")
new_attrs = {'begin': begin, 'end': end}
if stride is not None:
new_attrs['stride'] = stride
return _op.strided_slice(inputs[0], **new_attrs)


def _mx_split(inputs, attrs):
axis = attrs.get_int("axis", 1)
new_attrs = {}
Expand Down Expand Up @@ -368,6 +384,7 @@ def _mx_roi_align(inputs, attrs):
"BatchNorm" : _mx_batch_norm,
"BatchNorm_v1" : _mx_batch_norm,
"LRN" : _mx_lrn,
"slice" : _mx_slice,
"SliceChannel" : _mx_split,
"split" : _mx_split,
"expand_dims" : _mx_expand_dims,
Expand Down
7 changes: 7 additions & 0 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,13 @@ def test_forward_argmin():
mx_sym = mx.sym.argmin(data, axis=0)
verify_mxnet_frontend_impl(mx_sym, (5, 4), (4,))

def test_forward_slice():
data = mx.sym.var('data')
mx_sym = mx.sym.slice(data, begin=(0, 1), end=(2, 4))
verify_mxnet_frontend_impl(mx_sym, (3, 4), (2, 3))
mx_sym = mx.sym.slice(data, begin=(-1, 1), end=(-3, 4), step=(-1, 2))
verify_mxnet_frontend_impl(mx_sym, (3, 4), (2, 2))

def test_forward_where():
cond = mx.sym.var('cond')
x = mx.sym.var('x')
Expand Down

0 comments on commit 1939bec

Please sign in to comment.