Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
Alicia1529 authored and Ubuntu committed Feb 19, 2020
1 parent d5ad306 commit 8ac452c
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
8 changes: 4 additions & 4 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3095,9 +3095,9 @@ def _get_index_range(start, stop, length, step=1):
elif start < 0:
start += length
if start < 0:
raise IndexError('Slicing start %d exceeds limit of %d' % (start-length, length))
start = 0
elif start >= length:
raise IndexError('Slicing start %d exceeds limit of %d' % (start, length))
start = length

if stop is None:
if step > 0:
Expand All @@ -3110,9 +3110,9 @@ def _get_index_range(start, stop, length, step=1):
elif stop < 0:
stop += length
if stop < 0:
raise IndexError('Slicing stop %d exceeds limit of %d' % (stop-length, length))
stop = 0
elif stop > length:
raise IndexError('Slicing stop %d exceeds limit of %d' % (stop, length))
stop = length

return start, stop, step

Expand Down
16 changes: 16 additions & 0 deletions tests/python/unittest/test_numpy_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,21 @@ def test_getitem(np_array, index):
mx_indexed_array = mx_indexed_array.asnumpy()
assert same(np_indexed_array, mx_indexed_array), 'Failed with index = {}'.format(index)

def test_getitem_slice_bound():
mx_array = np.arange(10)
np_array = mx_array.asnumpy()
assert_almost_equal(mx_array[100:], np_array[100:])
assert_almost_equal(mx_array[:100], np_array[:100])
assert_almost_equal(mx_array[-100:], np_array[-100:])
assert_almost_equal(mx_array[:-100], np_array[:-100])

mx_array = np.arange(81).reshape(3, 3, 3, 3)
np_array = mx_array.asnumpy()
assert_almost_equal(mx_array[100:], np_array[100:])
assert_almost_equal(mx_array[:100], np_array[:100])
assert_almost_equal(mx_array[-100:], np_array[-100:])
assert_almost_equal(mx_array[:-100], np_array[:-100])

def test_setitem(np_array, index):
def assert_same(np_array, np_index, mx_array, mx_index, mx_value, np_value=None):
if np_value is not None:
Expand Down Expand Up @@ -986,6 +1001,7 @@ def test_setitem_autograd(np_array, index):
test_setitem(np_array, index)
test_getitem_autograd(np_array, index)
test_setitem_autograd(np_array, index)
test_getitem_slice_bound()


@with_seed()
Expand Down

0 comments on commit 8ac452c

Please sign in to comment.