Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix NDArrayIter cant pad when size is large
Browse files Browse the repository at this point in the history
  • Loading branch information
stu1130 committed Dec 7, 2019
1 parent 9c94fdb commit a081fe6
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 24 deletions.
44 changes: 23 additions & 21 deletions python/mxnet/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from ..ndarray.sparse import CSRNDArray
from ..ndarray import _ndarray_cls
from ..ndarray import array
from ..ndarray import concat
from ..ndarray import concat, tile

from .utils import _init_data, _has_instance, _getdata_by_idx

Expand Down Expand Up @@ -709,23 +709,27 @@ def _getdata(self, data_source, start=None, end=None):

def _concat(self, first_data, second_data):
"""Helper function to concat two NDArrays."""
if (not first_data) or (not second_data):
return first_data if first_data else second_data
assert len(first_data) == len(
second_data), 'data source should contain the same size'
if first_data and second_data:
return [
concat(
first_data[x],
second_data[x],
dim=0
) for x in range(len(first_data))
]
elif (not first_data) and (not second_data):
return [
concat(
first_data[i],
second_data[i],
dim=0
) for i in range(len(first_data))
]

def _tile(self, data, repeats):
if not data:
return []
else:
return [
first_data[0] if first_data else second_data[0]
for x in range(len(first_data))
]
res = []
for datum in data:
reps = [1] * len(datum.shape)
reps[0] = repeats
res.append(tile(datum, reps))
return res

def _batchify(self, data_source):
"""Load data from underlying arrays, internal use only."""
Expand All @@ -749,12 +753,10 @@ def _batchify(self, data_source):
pad = self.batch_size - self.num_data + self.cursor
first_data = self._getdata(data_source, start=self.cursor)
if pad > self.num_data:
while True:
if pad <= self.num_data:
break
second_data = self._getdata(data_source, end=self.num_data)
pad -= self.num_data
second_data = self._concat(second_data, self._getdata(data_source, end=pad))
repeats = pad // self.num_data
second_data = self._tile(self._getdata(data_source, end=self.num_data), repeats)
if pad % self.num_data != 0:
second_data = self._concat(second_data, self._getdata(data_source, end=pad % self.num_data))
else:
second_data = self._getdata(data_source, end=pad)
return self._concat(first_data, second_data)
Expand Down
6 changes: 3 additions & 3 deletions tests/python/unittest/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,11 @@ def _test_shuffle(data, labels=None):
assert np.array_equal(batch.data[0].asnumpy(), batch_list[idx_list[i]])
i += 1

# fixes the issue https://github.com/apache/incubator-mxnet/issues/15535

def _test_corner_case():
data = np.arange(10)
data_iter = mx.io.NDArrayIter(data=data, batch_size=25, shuffle=False, last_batch_handle='pad')
expect = np.concatenate((np.tile(data, 2), np.arange(5)))
data_iter = mx.io.NDArrayIter(data=data, batch_size=205, shuffle=False, last_batch_handle='pad')
expect = np.concatenate((np.tile(data, 20), np.arange(5)))
assert np.array_equal(data_iter.next().data[0].asnumpy(), expect)


Expand Down

0 comments on commit a081fe6

Please sign in to comment.