From a081fe61ec2c60eb2baf3aa170907cbe2cbd4323 Mon Sep 17 00:00:00 2001 From: stu1130 Date: Fri, 6 Dec 2019 16:39:01 -0800 Subject: [PATCH] Fix NDArrayIter cant pad when size is large --- python/mxnet/io/io.py | 44 +++++++++++++++++--------------- tests/python/unittest/test_io.py | 6 ++--- 2 files changed, 26 insertions(+), 24 deletions(-) diff --git a/python/mxnet/io/io.py b/python/mxnet/io/io.py index dcf964df976a..e36665e61c42 100644 --- a/python/mxnet/io/io.py +++ b/python/mxnet/io/io.py @@ -36,7 +36,7 @@ from ..ndarray.sparse import CSRNDArray from ..ndarray import _ndarray_cls from ..ndarray import array -from ..ndarray import concat +from ..ndarray import concat, tile from .utils import _init_data, _has_instance, _getdata_by_idx @@ -709,23 +709,27 @@ def _getdata(self, data_source, start=None, end=None): def _concat(self, first_data, second_data): """Helper function to concat two NDArrays.""" + if (not first_data) or (not second_data): + return first_data if first_data else second_data assert len(first_data) == len( second_data), 'data source should contain the same size' - if first_data and second_data: - return [ - concat( - first_data[x], - second_data[x], - dim=0 - ) for x in range(len(first_data)) - ] - elif (not first_data) and (not second_data): + return [ + concat( + first_data[i], + second_data[i], + dim=0 + ) for i in range(len(first_data)) + ] + + def _tile(self, data, repeats): + if not data: return [] - else: - return [ - first_data[0] if first_data else second_data[0] - for x in range(len(first_data)) - ] + res = [] + for datum in data: + reps = [1] * len(datum.shape) + reps[0] = repeats + res.append(tile(datum, reps)) + return res def _batchify(self, data_source): """Load data from underlying arrays, internal use only.""" @@ -749,12 +753,10 @@ def _batchify(self, data_source): pad = self.batch_size - self.num_data + self.cursor first_data = self._getdata(data_source, start=self.cursor) if pad > self.num_data: - while True: - if pad <= self.num_data: - break - second_data = self._getdata(data_source, end=self.num_data) - pad -= self.num_data - second_data = self._concat(second_data, self._getdata(data_source, end=pad)) + repeats = pad // self.num_data + second_data = self._tile(self._getdata(data_source, end=self.num_data), repeats) + if pad % self.num_data != 0: + second_data = self._concat(second_data, self._getdata(data_source, end=pad % self.num_data)) else: second_data = self._getdata(data_source, end=pad) return self._concat(first_data, second_data) diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py index 2a806efc9034..a13addb0adca 100644 --- a/tests/python/unittest/test_io.py +++ b/tests/python/unittest/test_io.py @@ -198,11 +198,11 @@ def _test_shuffle(data, labels=None): assert np.array_equal(batch.data[0].asnumpy(), batch_list[idx_list[i]]) i += 1 -# fixes the issue https://github.com/apache/incubator-mxnet/issues/15535 + def _test_corner_case(): data = np.arange(10) - data_iter = mx.io.NDArrayIter(data=data, batch_size=25, shuffle=False, last_batch_handle='pad') - expect = np.concatenate((np.tile(data, 2), np.arange(5))) + data_iter = mx.io.NDArrayIter(data=data, batch_size=205, shuffle=False, last_batch_handle='pad') + expect = np.concatenate((np.tile(data, 20), np.arange(5))) assert np.array_equal(data_iter.next().data[0].asnumpy(), expect)