From 9263a56431b153284a35777c3ec6f693f1fb4273 Mon Sep 17 00:00:00 2001 From: stu1130 Date: Fri, 10 Aug 2018 14:41:30 -0700 Subject: [PATCH 01/21] Add last_batch_handle on ImageIter --- python/mxnet/image/image.py | 108 +++++++++++++++++++++++++++++------- 1 file changed, 88 insertions(+), 20 deletions(-) diff --git a/python/mxnet/image/image.py b/python/mxnet/image/image.py index c2a1906646fe..754ccb42e7fb 100644 --- a/python/mxnet/image/image.py +++ b/python/mxnet/image/image.py @@ -1047,7 +1047,7 @@ class ImageIter(io.DataIter): path_imgidx : str Path to image index file. Needed for partition and shuffling when using .rec source. shuffle : bool - Whether to shuffle all images at the start of each iteration or not. + Whether to shuffle all images or not. Can be slow for HDD. part_index : int Partition index. @@ -1059,6 +1059,9 @@ class ImageIter(io.DataIter): Label name for provided symbols. dtype : str Label data type. Default: float32. Other options: int32, int64, float64 + last_batch_hanle : str, optional + How to handle the last batch. This parameter can be ‘pad’, ‘discard’ or ‘roll_over’. + 'discard' is not support when reading from record file(.rec) withouting shuffle(=False) kwargs : ... More arguments for creating augmenter. See mx.image.CreateAugmenter. """ @@ -1066,9 +1069,11 @@ class ImageIter(io.DataIter): def __init__(self, batch_size, data_shape, label_width=1, path_imgrec=None, path_imglist=None, path_root=None, path_imgidx=None, shuffle=False, part_index=0, num_parts=1, aug_list=None, imglist=None, - data_name='data', label_name='softmax_label', dtype='float32', **kwargs): + data_name='data', label_name='softmax_label', dtype='float32', last_batch_handle='pad', **kwargs): super(ImageIter, self).__init__() assert path_imgrec or path_imglist or (isinstance(imglist, list)) + # Reading from record file without shuffle didn't support 'discard' for last_batch_handle + assert not path_imgrec or shuffle or last_batch_handle != 'discard' assert dtype in ['int32', 'float32', 'int64', 'float64'], dtype + ' label not supported' num_threads = os.environ.get('MXNET_CPU_WORKER_NTHREADS', 1) logging.info('Using %s threads for decoding...', str(num_threads)) @@ -1084,6 +1089,7 @@ def __init__(self, batch_size, data_shape, label_width=1, else: self.imgrec = recordio.MXRecordIO(path_imgrec, 'r') # pylint: disable=redefined-variable-type self.imgidx = None + self.path_imgrec = path_imgrec else: self.imgrec = None @@ -1130,9 +1136,11 @@ def __init__(self, batch_size, data_shape, label_width=1, self.data_shape = data_shape self.label_width = label_width - self.shuffle = shuffle if self.imgrec is None: self.seq = imgkeys + # shuffle + if shuffle: + random.shuffle(self.seq) elif shuffle or num_parts > 1: assert self.imgidx is not None self.seq = self.imgidx @@ -1149,22 +1157,53 @@ def __init__(self, batch_size, data_shape, label_width=1, else: self.auglist = aug_list self.cur = 0 + self.is_iterated_over = False + self._imgrec = None + # handle the last batch + if self.seq and last_batch_handle == 'discard': + new_seq_n = len(self.seq) - len(self.seq) % batch_size + self.seq = self.seq[:new_seq_n] + + self.last_batch_handle = last_batch_handle + self.num_image = len(self.seq) if self.seq is not None else None self.reset() def reset(self): """Resets the iterator to the beginning of the data.""" - if self.shuffle: - random.shuffle(self.seq) + if self.last_batch_handle == 'roll_over' and \ + self.seq and \ + self.cur > self.num_image: + assert self.num_image is not None, 'imgrec without shuffle is not supported' + self.cur = (self.cur % self.num_image) % self.batch_size + elif self.last_batch_handle == 'roll_over' and \ + self.is_iterated_over: + + if self._imgrec is not None: + self.imgrec = self._imgrec + self.is_iterated_over = False + else: + self.cur = 0 + self.is_iterated_over = False + if self.imgrec is not None: + self.imgrec.reset() + + def hard_reset(self): + """Resets the iterator and ignore roll over data""" if self.imgrec is not None: self.imgrec.reset() self.cur = 0 + self.is_iterated_over = False - def next_sample(self): + def next_sample(self, imgrec=None): """Helper function for reading in next sample.""" if self.seq is not None: - if self.cur >= len(self.seq): + if self.cur < self.num_image: + idx = self.seq[self.cur] + elif self.num_image % self.batch_size != 0 and \ + self.cur < self.batch_size * ((self.num_image // self.batch_size) + 1): + idx = self.seq[self.cur % self.num_image] + else: raise StopIteration - idx = self.seq[self.cur] self.cur += 1 if self.imgrec is not None: s = self.imgrec.read_idx(idx) @@ -1177,22 +1216,20 @@ def next_sample(self): label, fname = self.imglist[idx] return label, self.read_image(fname) else: - s = self.imgrec.read() + imgrec = imgrec if imgrec is not None else self.imgrec + s = imgrec.read() if s is None: - raise StopIteration + self.is_iterated_over = True + raise StopIteration header, img = recordio.unpack(s) return header.label, img - def next(self): - """Returns the next batch of data.""" + def iterate(self, batch_data, batch_label, start=0, imgrec=None): + i = start batch_size = self.batch_size - c, h, w = self.data_shape - batch_data = nd.empty((batch_size, c, h, w)) - batch_label = nd.empty(self.provide_label[0][1]) - i = 0 try: while i < batch_size: - label, s = self.next_sample() + label, s = self.next_sample(imgrec) data = self.imdecode(s) try: self.check_valid_image(data) @@ -1207,8 +1244,31 @@ def next(self): except StopIteration: if not i: raise StopIteration + return i - return io.DataBatch([batch_data], [batch_label], batch_size - i) + def next(self): + """Returns the next batch of data.""" + batch_size = self.batch_size + c, h, w = self.data_shape + batch_data = nd.empty((batch_size, c, h, w)) + batch_label = nd.empty(self.provide_label[0][1]) + index = 0 + i = self.iterate(batch_data, batch_label) + + # handle padding for sequential read + if self.seq is None: + pad = batch_size - i + # pad the last batch by creating a new MXRecordIO + self._imgrec = recordio.MXRecordIO(self.path_imgrec, 'r') + if pad != 0: + _ = self.iterate(batch_data, batch_label, i, self._imgrec) + + if self.last_batch_handle != 'roll_over': + self._imgrec.close() + else: + pad = self.getpad() + + return io.DataBatch([batch_data], [batch_label], pad=pad) def check_data_shape(self, data_shape): """Checks if the input data shape is valid""" @@ -1228,9 +1288,9 @@ def imdecode(self, s): def locate(): """Locate the image file/index if decode fails.""" if self.seq is not None: - idx = self.seq[self.cur - 1] + idx = self.seq[(self.cur % self.num_image) - 1] else: - idx = self.cur - 1 + idx = (self.cur % self.num_image) - 1 if self.imglist is not None: _, fname = self.imglist[idx] msg = "filename: {}".format(fname) @@ -1263,3 +1323,11 @@ def augmentation_transform(self, data): def postprocess_data(self, datum): """Final postprocessing step before image is loaded into the batch.""" return nd.transpose(datum, axes=(2, 0, 1)) + + def getpad(self): + if self.last_batch_handle == 'pad' and \ + hasattr(self, 'num_image') and \ + self.cur >= self.num_image: + return self.cur - self.num_image + else: + return 0 From c3a7dc5e4d5842e3117465048d9e342a141af880 Mon Sep 17 00:00:00 2001 From: stu1130 Date: Fri, 10 Aug 2018 14:43:44 -0700 Subject: [PATCH 02/21] Test the last_batch_handle for both of imglist and path_imglist --- tests/python/unittest/test_image.py | 63 +++++++++++++++++++++-------- 1 file changed, 46 insertions(+), 17 deletions(-) diff --git a/tests/python/unittest/test_image.py b/tests/python/unittest/test_image.py index 9eec1835c1f9..44f5fd16cbea 100644 --- a/tests/python/unittest/test_image.py +++ b/tests/python/unittest/test_image.py @@ -130,29 +130,59 @@ def test_color_normalize(self): mx.nd.array(mean), mx.nd.array(std)) assert_almost_equal(mx_result.asnumpy(), (src - mean) / std, atol=1e-3) - def test_imageiter(self): def check_imageiter(dtype='float32'): im_list = [[np.random.randint(0, 5), x] for x in TestImage.IMAGES] - test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, imglist=im_list, - path_root='', dtype=dtype) - for _ in range(3): - for batch in test_iter: - pass - test_iter.reset() - - # test with list file fname = './data/test_imageiter.lst' - file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x]) \ - for k, x in enumerate(TestImage.IMAGES)] + file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x]) + for k, x in enumerate(TestImage.IMAGES)] with open(fname, 'w') as f: for line in file_list: f.write(line + '\n') + + test_list = ['imglist', 'path_imglist'] + + for test in test_list: + imglist = im_list if test == 'imglist' else None + path_imglist = fname if test == 'path_imglist' else None + + test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, imglist=imglist, + path_imglist=path_imglist, path_root='', dtype=dtype) + for _ in range(3): + for batch in test_iter: + pass + test_iter.reset() + # test last batch handle(discard) + test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, + path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='discard') + i = 0 + for batch in test_iter: + i += 1 + assert i == 5 - test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, path_imglist=fname, - path_root='', dtype=dtype) - for batch in test_iter: - pass + # test last_batch_handle(pad) + test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, + path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad') + i = 0 + for batch in test_iter: + if i == 0: + first_three_data = batch.data[0][:2] + if i == 5: + last_three_data = batch.data[0][1:] + i += 1 + assert i == 6 + assert np.array_equal(first_three_data.asnumpy(), last_three_data.asnumpy()) + # test last_batch_handle(roll_over) + test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, + path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='roll_over') + i = 0 + for batch in test_iter: + if i == 0: + second_image = batch.data[0][2] + i += 1 + test_iter.reset() + assert np.array_equal( + test_iter.next().data[0][0].asnumpy(), second_image.asnumpy()) for dtype in ['int32', 'float32', 'int64', 'float64']: check_imageiter(dtype) @@ -183,7 +213,6 @@ def test_augmenters(self): for batch in test_iter: pass - def test_image_detiter(self): im_list = [_generate_objects() + [x] for x in TestImage.IMAGES] det_iter = mx.image.ImageDetIter(2, (3, 300, 300), imglist=im_list, path_root='') @@ -196,7 +225,7 @@ def test_image_detiter(self): det_iter = val_iter.sync_label_shape(det_iter) # test file list - fname = './data/test_imagedetiter.lst' + fname = './data/test_imageiter.lst' im_list = [[k] + _generate_objects() + [x] for k, x in enumerate(TestImage.IMAGES)] with open(fname, 'w') as f: for line in im_list: From 79f300224a90ab00cd77bd9239771a3e2d3f4a17 Mon Sep 17 00:00:00 2001 From: stu1130 Date: Fri, 10 Aug 2018 15:46:49 -0700 Subject: [PATCH 03/21] fix the filename typo --- tests/python/unittest/test_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_image.py b/tests/python/unittest/test_image.py index 44f5fd16cbea..6eb845043f04 100644 --- a/tests/python/unittest/test_image.py +++ b/tests/python/unittest/test_image.py @@ -225,7 +225,7 @@ def test_image_detiter(self): det_iter = val_iter.sync_label_shape(det_iter) # test file list - fname = './data/test_imageiter.lst' + fname = './data/test_imagedetiter.lst' im_list = [[k] + _generate_objects() + [x] for k, x in enumerate(TestImage.IMAGES)] with open(fname, 'w') as f: for line in im_list: From 0eb45732ca0b7abe7f4b2a0a411f83d89e493ded Mon Sep 17 00:00:00 2001 From: stu1130 Date: Mon, 13 Aug 2018 17:27:04 -0700 Subject: [PATCH 04/21] change the name of parameter from 'last_batch_handle' to 'last_batch' --- tests/python/unittest/test_image.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/python/unittest/test_image.py b/tests/python/unittest/test_image.py index 6eb845043f04..0f3efce7fad9 100644 --- a/tests/python/unittest/test_image.py +++ b/tests/python/unittest/test_image.py @@ -154,15 +154,15 @@ def check_imageiter(dtype='float32'): test_iter.reset() # test last batch handle(discard) test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, - path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='discard') + path_imglist=path_imglist, path_root='', dtype=dtype, last_batch='discard') i = 0 for batch in test_iter: i += 1 - assert i == 5 + assert i == 5 - # test last_batch_handle(pad) + # test last_batch(pad) test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, - path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad') + path_imglist=path_imglist, path_root='', dtype=dtype, last_batch='pad') i = 0 for batch in test_iter: if i == 0: @@ -172,9 +172,9 @@ def check_imageiter(dtype='float32'): i += 1 assert i == 6 assert np.array_equal(first_three_data.asnumpy(), last_three_data.asnumpy()) - # test last_batch_handle(roll_over) + # test last_batch(roll_over) test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, - path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='roll_over') + path_imglist=path_imglist, path_root='', dtype=dtype, last_batch='roll_over') i = 0 for batch in test_iter: if i == 0: From c214f5bb4b71e630a563c274e8170a5b04ac755a Mon Sep 17 00:00:00 2001 From: stu1130 Date: Mon, 13 Aug 2018 17:29:42 -0700 Subject: [PATCH 05/21] 1. change the name of the parameter from 'last_batch_handle' to 'last_batch' 2. remove the extra RecordIO and refactor the code accordingly 3. support 'discard' for sequentially reading --- python/mxnet/image/image.py | 96 ++++++++++++++++++------------------- 1 file changed, 46 insertions(+), 50 deletions(-) diff --git a/python/mxnet/image/image.py b/python/mxnet/image/image.py index 754ccb42e7fb..f439b7d781d8 100644 --- a/python/mxnet/image/image.py +++ b/python/mxnet/image/image.py @@ -1059,9 +1059,12 @@ class ImageIter(io.DataIter): Label name for provided symbols. dtype : str Label data type. Default: float32. Other options: int32, int64, float64 - last_batch_hanle : str, optional - How to handle the last batch. This parameter can be ‘pad’, ‘discard’ or ‘roll_over’. - 'discard' is not support when reading from record file(.rec) withouting shuffle(=False) + last_batch : str, optional + How to handle the last batch. + This parameter can be ‘pad’(default), ‘discard’ or ‘roll_over’. + If 'pad', the last batch will be padded with data starting from the begining + If 'discard', the last batch will be discarded + If 'roll_over', the remaining elements will be rolled over to the next iteration kwargs : ... More arguments for creating augmenter. See mx.image.CreateAugmenter. """ @@ -1069,11 +1072,10 @@ class ImageIter(io.DataIter): def __init__(self, batch_size, data_shape, label_width=1, path_imgrec=None, path_imglist=None, path_root=None, path_imgidx=None, shuffle=False, part_index=0, num_parts=1, aug_list=None, imglist=None, - data_name='data', label_name='softmax_label', dtype='float32', last_batch_handle='pad', **kwargs): + data_name='data', label_name='softmax_label', dtype='float32', + last_batch='pad', **kwargs): super(ImageIter, self).__init__() assert path_imgrec or path_imglist or (isinstance(imglist, list)) - # Reading from record file without shuffle didn't support 'discard' for last_batch_handle - assert not path_imgrec or shuffle or last_batch_handle != 'discard' assert dtype in ['int32', 'float32', 'int64', 'float64'], dtype + ' label not supported' num_threads = os.environ.get('MXNET_CPU_WORKER_NTHREADS', 1) logging.info('Using %s threads for decoding...', str(num_threads)) @@ -1089,7 +1091,6 @@ def __init__(self, batch_size, data_shape, label_width=1, else: self.imgrec = recordio.MXRecordIO(path_imgrec, 'r') # pylint: disable=redefined-variable-type self.imgidx = None - self.path_imgrec = path_imgrec else: self.imgrec = None @@ -1138,7 +1139,7 @@ def __init__(self, batch_size, data_shape, label_width=1, if self.imgrec is None: self.seq = imgkeys - # shuffle + # shuffle if shuffle: random.shuffle(self.seq) elif shuffle or num_parts > 1: @@ -1157,44 +1158,39 @@ def __init__(self, batch_size, data_shape, label_width=1, else: self.auglist = aug_list self.cur = 0 - self.is_iterated_over = False - self._imgrec = None + self._is_allowed_reading = True + self._cached_data = None # handle the last batch - if self.seq and last_batch_handle == 'discard': + if self.seq and last_batch == 'discard': new_seq_n = len(self.seq) - len(self.seq) % batch_size self.seq = self.seq[:new_seq_n] - self.last_batch_handle = last_batch_handle + self.last_batch = last_batch self.num_image = len(self.seq) if self.seq is not None else None self.reset() def reset(self): """Resets the iterator to the beginning of the data.""" - if self.last_batch_handle == 'roll_over' and \ - self.seq and \ + if self.seq is not None: + if self.last_batch == 'roll_over' and \ self.cur > self.num_image: - assert self.num_image is not None, 'imgrec without shuffle is not supported' - self.cur = (self.cur % self.num_image) % self.batch_size - elif self.last_batch_handle == 'roll_over' and \ - self.is_iterated_over: - - if self._imgrec is not None: - self.imgrec = self._imgrec - self.is_iterated_over = False + self.cur = (self.cur % self.num_image) % self.batch_size + else: + self.cur = 0 else: - self.cur = 0 - self.is_iterated_over = False - if self.imgrec is not None: + if self.last_batch != 'roll_over' or \ + self._is_allowed_reading is True: self.imgrec.reset() + self._is_allowed_reading = True def hard_reset(self): """Resets the iterator and ignore roll over data""" if self.imgrec is not None: self.imgrec.reset() self.cur = 0 - self.is_iterated_over = False + self._is_allowed_reading = True - def next_sample(self, imgrec=None): + def next_sample(self): """Helper function for reading in next sample.""" if self.seq is not None: if self.cur < self.num_image: @@ -1216,20 +1212,23 @@ def next_sample(self, imgrec=None): label, fname = self.imglist[idx] return label, self.read_image(fname) else: - imgrec = imgrec if imgrec is not None else self.imgrec - s = imgrec.read() + if self._is_allowed_reading is False: + raise StopIteration + s = self.imgrec.read() if s is None: - self.is_iterated_over = True - raise StopIteration + if self.last_batch != 'discard': + self.imgrec.reset() + raise StopIteration header, img = recordio.unpack(s) return header.label, img - def iterate(self, batch_data, batch_label, start=0, imgrec=None): + def iterate(self, batch_data, batch_label, start=0): + """Helper function for iterate a batch of data""" i = start batch_size = self.batch_size try: while i < batch_size: - label, s = self.next_sample(imgrec) + label, s = self.next_sample() data = self.imdecode(s) try: self.check_valid_image(data) @@ -1252,22 +1251,19 @@ def next(self): c, h, w = self.data_shape batch_data = nd.empty((batch_size, c, h, w)) batch_label = nd.empty(self.provide_label[0][1]) - index = 0 i = self.iterate(batch_data, batch_label) - - # handle padding for sequential read + # calculate the padding if self.seq is None: pad = batch_size - i - # pad the last batch by creating a new MXRecordIO - self._imgrec = recordio.MXRecordIO(self.path_imgrec, 'r') - if pad != 0: - _ = self.iterate(batch_data, batch_label, i, self._imgrec) - - if self.last_batch_handle != 'roll_over': - self._imgrec.close() else: - pad = self.getpad() - + pad = self._getpad() + # handle padding for sequential read + if self.seq is None and pad != 0: + if self.last_batch == 'discard': + raise StopIteration + else: + _ = self.iterate(batch_data, batch_label, i) + self._is_allowed_reading = False return io.DataBatch([batch_data], [batch_label], pad=pad) def check_data_shape(self, data_shape): @@ -1305,7 +1301,6 @@ def locate(): def read_image(self, fname): """Reads an input image `fname` and returns the decoded raw bytes. - Example usage: ---------- >>> dataIter.read_image('Face.jpg') # returns decoded raw bytes. @@ -1324,10 +1319,11 @@ def postprocess_data(self, datum): """Final postprocessing step before image is loaded into the batch.""" return nd.transpose(datum, axes=(2, 0, 1)) - def getpad(self): - if self.last_batch_handle == 'pad' and \ - hasattr(self, 'num_image') and \ - self.cur >= self.num_image: + def _getpad(self): + """Helpe function for getting padding number""" + if self.last_batch in ['pad', 'roll_over'] and \ + self.num_image is not None and \ + self.cur >= self.num_image: return self.cur - self.num_image else: return 0 From a66452cadea4834c429fa722b67ec834ca3dd654 Mon Sep 17 00:00:00 2001 From: stu1130 Date: Tue, 14 Aug 2018 09:34:17 -0700 Subject: [PATCH 06/21] fix character '\xe2' encoding issue --- python/mxnet/image/image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/image/image.py b/python/mxnet/image/image.py index f439b7d781d8..b1c5ec6f70cf 100644 --- a/python/mxnet/image/image.py +++ b/python/mxnet/image/image.py @@ -1061,7 +1061,7 @@ class ImageIter(io.DataIter): Label data type. Default: float32. Other options: int32, int64, float64 last_batch : str, optional How to handle the last batch. - This parameter can be ‘pad’(default), ‘discard’ or ‘roll_over’. + This parameter can be 'pad'(default), 'discard' or 'roll_over'. If 'pad', the last batch will be padded with data starting from the begining If 'discard', the last batch will be discarded If 'roll_over', the remaining elements will be rolled over to the next iteration From 0f47661d2e60f0158c2cdf37601b347ed3182589 Mon Sep 17 00:00:00 2001 From: stu1130 Date: Tue, 14 Aug 2018 13:28:22 -0700 Subject: [PATCH 07/21] fix roll_over bug happened when calling reset() several times --- tests/python/unittest/test_image.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_image.py b/tests/python/unittest/test_image.py index 0f3efce7fad9..32b2c5182528 100644 --- a/tests/python/unittest/test_image.py +++ b/tests/python/unittest/test_image.py @@ -80,7 +80,6 @@ def test_imread_vs_imdecode(self): image_read = mx.img.image.imread(img) same(image.asnumpy(), image_read.asnumpy()) - def test_imdecode(self): try: import cv2 @@ -128,7 +127,16 @@ def test_color_normalize(self): src = np.random.rand(height, width, 3) * 255. mx_result = mx.image.color_normalize(mx.nd.array(src), mx.nd.array(mean), mx.nd.array(std)) - assert_almost_equal(mx_result.asnumpy(), (src - mean) / std, atol=1e-3) + assert_almost_equal(mx_result.asnumpy(), (src - mean) / std, atol=1e-3) + + def test_test(self): + data_iter = mx.image.ImageIter(batch_size=4, data_shape=( + 3, 227, 227), path_imgrec='/Users/leecheng/data/caltech.rec', shuffle=True, path_imgidx='/Users/leecheng/data/caltech.idx', last_batch='discard') + for batch in data_iter: + pass + data_iter.reset() + for batch in data_iter: + pass def test_imageiter(self): def check_imageiter(dtype='float32'): @@ -182,7 +190,7 @@ def check_imageiter(dtype='float32'): i += 1 test_iter.reset() assert np.array_equal( - test_iter.next().data[0][0].asnumpy(), second_image.asnumpy()) + test_iter.next().data[0][0].asnumpy(), second_image.asnumpy()), 'failed in {}'.format(test) for dtype in ['int32', 'float32', 'int64', 'float64']: check_imageiter(dtype) From 10aed35d52448fb8bb7c7b77e5b54d6ad845082b Mon Sep 17 00:00:00 2001 From: stu1130 Date: Tue, 14 Aug 2018 13:32:10 -0700 Subject: [PATCH 08/21] unify the logic of how to deal with 'discard', 'pad', ''roll_over' for both sequential read and random access --- python/mxnet/image/image.py | 50 +++++++++++++------------------------ 1 file changed, 18 insertions(+), 32 deletions(-) diff --git a/python/mxnet/image/image.py b/python/mxnet/image/image.py index b1c5ec6f70cf..d38686ff66b2 100644 --- a/python/mxnet/image/image.py +++ b/python/mxnet/image/image.py @@ -1171,16 +1171,11 @@ def __init__(self, batch_size, data_shape, label_width=1, def reset(self): """Resets the iterator to the beginning of the data.""" - if self.seq is not None: - if self.last_batch == 'roll_over' and \ - self.cur > self.num_image: - self.cur = (self.cur % self.num_image) % self.batch_size - else: + if self.last_batch != 'roll_over' or \ + self._is_allowed_reading is True: + if self.imgrec is not None: + self.imgrec.reset() self.cur = 0 - else: - if self.last_batch != 'roll_over' or \ - self._is_allowed_reading is True: - self.imgrec.reset() self._is_allowed_reading = True def hard_reset(self): @@ -1192,13 +1187,14 @@ def hard_reset(self): def next_sample(self): """Helper function for reading in next sample.""" + if self._is_allowed_reading is False: + raise StopIteration if self.seq is not None: if self.cur < self.num_image: idx = self.seq[self.cur] - elif self.num_image % self.batch_size != 0 and \ - self.cur < self.batch_size * ((self.num_image // self.batch_size) + 1): - idx = self.seq[self.cur % self.num_image] else: + if self.last_batch != 'discard': + self.cur = 0 raise StopIteration self.cur += 1 if self.imgrec is not None: @@ -1212,8 +1208,6 @@ def next_sample(self): label, fname = self.imglist[idx] return label, self.read_image(fname) else: - if self._is_allowed_reading is False: - raise StopIteration s = self.imgrec.read() if s is None: if self.last_batch != 'discard': @@ -1253,17 +1247,17 @@ def next(self): batch_label = nd.empty(self.provide_label[0][1]) i = self.iterate(batch_data, batch_label) # calculate the padding - if self.seq is None: - pad = batch_size - i - else: - pad = self._getpad() + pad = batch_size - i # handle padding for sequential read - if self.seq is None and pad != 0: - if self.last_batch == 'discard': - raise StopIteration - else: + if pad != 0: + if self.seq is not None: _ = self.iterate(batch_data, batch_label, i) - self._is_allowed_reading = False + else: + if self.last_batch == 'discard': + raise StopIteration + else: + _ = self.iterate(batch_data, batch_label, i) + self._is_allowed_reading = False return io.DataBatch([batch_data], [batch_label], pad=pad) def check_data_shape(self, data_shape): @@ -1318,12 +1312,4 @@ def augmentation_transform(self, data): def postprocess_data(self, datum): """Final postprocessing step before image is loaded into the batch.""" return nd.transpose(datum, axes=(2, 0, 1)) - - def _getpad(self): - """Helpe function for getting padding number""" - if self.last_batch in ['pad', 'roll_over'] and \ - self.num_image is not None and \ - self.cur >= self.num_image: - return self.cur - self.num_image - else: - return 0 + From 94c96a6e119a11be461ac2f965b336bf9f40f0e8 Mon Sep 17 00:00:00 2001 From: stu1130 Date: Tue, 14 Aug 2018 13:50:10 -0700 Subject: [PATCH 09/21] remove the test case used locally --- tests/python/unittest/test_image.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/tests/python/unittest/test_image.py b/tests/python/unittest/test_image.py index 32b2c5182528..a97be5f38c39 100644 --- a/tests/python/unittest/test_image.py +++ b/tests/python/unittest/test_image.py @@ -127,16 +127,7 @@ def test_color_normalize(self): src = np.random.rand(height, width, 3) * 255. mx_result = mx.image.color_normalize(mx.nd.array(src), mx.nd.array(mean), mx.nd.array(std)) - assert_almost_equal(mx_result.asnumpy(), (src - mean) / std, atol=1e-3) - - def test_test(self): - data_iter = mx.image.ImageIter(batch_size=4, data_shape=( - 3, 227, 227), path_imgrec='/Users/leecheng/data/caltech.rec', shuffle=True, path_imgidx='/Users/leecheng/data/caltech.idx', last_batch='discard') - for batch in data_iter: - pass - data_iter.reset() - for batch in data_iter: - pass + assert_almost_equal(mx_result.asnumpy(), (src - mean) / std, atol=1e-3) def test_imageiter(self): def check_imageiter(dtype='float32'): From 3ffbb067112fe36528ad2306e30f119d5d57f85d Mon Sep 17 00:00:00 2001 From: stu1130 Date: Tue, 14 Aug 2018 13:50:44 -0700 Subject: [PATCH 10/21] fix the bad indentation --- python/mxnet/image/image.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/mxnet/image/image.py b/python/mxnet/image/image.py index d38686ff66b2..5bbdd557bbdb 100644 --- a/python/mxnet/image/image.py +++ b/python/mxnet/image/image.py @@ -1173,9 +1173,9 @@ def reset(self): """Resets the iterator to the beginning of the data.""" if self.last_batch != 'roll_over' or \ self._is_allowed_reading is True: - if self.imgrec is not None: - self.imgrec.reset() - self.cur = 0 + if self.imgrec is not None: + self.imgrec.reset() + self.cur = 0 self._is_allowed_reading = True def hard_reset(self): @@ -1312,4 +1312,3 @@ def augmentation_transform(self, data): def postprocess_data(self, datum): """Final postprocessing step before image is loaded into the batch.""" return nd.transpose(datum, axes=(2, 0, 1)) - From fcd7310cefeda4d955ed57a1f5eb0031c242aea6 Mon Sep 17 00:00:00 2001 From: stu1130 Date: Wed, 15 Aug 2018 09:16:38 -0700 Subject: [PATCH 11/21] assert batch data shape --- tests/python/unittest/test_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_image.py b/tests/python/unittest/test_image.py index a97be5f38c39..0a98b3c1c60c 100644 --- a/tests/python/unittest/test_image.py +++ b/tests/python/unittest/test_image.py @@ -149,7 +149,7 @@ def check_imageiter(dtype='float32'): path_imglist=path_imglist, path_root='', dtype=dtype) for _ in range(3): for batch in test_iter: - pass + assert batch.data[0].shape == (2, 3, 224, 224) test_iter.reset() # test last batch handle(discard) test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, From 7dc815068e242c8c525e7a3c100c65cc24b290da Mon Sep 17 00:00:00 2001 From: stu1130 Date: Wed, 15 Aug 2018 10:14:01 -0700 Subject: [PATCH 12/21] 1. delete the piece of code that handle the discard when initialization. Now we check it in next() function 2. move shuffle back to reset function --- python/mxnet/image/image.py | 34 ++++++++++++---------------------- 1 file changed, 12 insertions(+), 22 deletions(-) diff --git a/python/mxnet/image/image.py b/python/mxnet/image/image.py index 5bbdd557bbdb..c834487a0091 100644 --- a/python/mxnet/image/image.py +++ b/python/mxnet/image/image.py @@ -1047,7 +1047,7 @@ class ImageIter(io.DataIter): path_imgidx : str Path to image index file. Needed for partition and shuffling when using .rec source. shuffle : bool - Whether to shuffle all images or not. + Whether to shuffle all images at the start of each iteration or not. Can be slow for HDD. part_index : int Partition index. @@ -1136,12 +1136,9 @@ def __init__(self, batch_size, data_shape, label_width=1, self.batch_size = batch_size self.data_shape = data_shape self.label_width = label_width - + self.shuffle = shuffle if self.imgrec is None: self.seq = imgkeys - # shuffle - if shuffle: - random.shuffle(self.seq) elif shuffle or num_parts > 1: assert self.imgidx is not None self.seq = self.imgidx @@ -1159,18 +1156,14 @@ def __init__(self, batch_size, data_shape, label_width=1, self.auglist = aug_list self.cur = 0 self._is_allowed_reading = True - self._cached_data = None - # handle the last batch - if self.seq and last_batch == 'discard': - new_seq_n = len(self.seq) - len(self.seq) % batch_size - self.seq = self.seq[:new_seq_n] - self.last_batch = last_batch self.num_image = len(self.seq) if self.seq is not None else None self.reset() def reset(self): """Resets the iterator to the beginning of the data.""" + if self.seq is not None and self.shuffle: + random.shuffle(self.seq) if self.last_batch != 'roll_over' or \ self._is_allowed_reading is True: if self.imgrec is not None: @@ -1216,8 +1209,8 @@ def next_sample(self): header, img = recordio.unpack(s) return header.label, img - def iterate(self, batch_data, batch_label, start=0): - """Helper function for iterate a batch of data""" + def _batchify(self, batch_data, batch_label, start=0): + """Helper function for batchifying data""" i = start batch_size = self.batch_size try: @@ -1245,18 +1238,15 @@ def next(self): c, h, w = self.data_shape batch_data = nd.empty((batch_size, c, h, w)) batch_label = nd.empty(self.provide_label[0][1]) - i = self.iterate(batch_data, batch_label) + i = self._batchify(batch_data, batch_label) # calculate the padding pad = batch_size - i - # handle padding for sequential read + # handle padding for 'pad' and 'roll_over' for the last batch if pad != 0: - if self.seq is not None: - _ = self.iterate(batch_data, batch_label, i) - else: - if self.last_batch == 'discard': - raise StopIteration - else: - _ = self.iterate(batch_data, batch_label, i) + if self.last_batch == 'discard': + raise StopIteration + # pad the rest of the data + _ = self._batchify(batch_data, batch_label, i) self._is_allowed_reading = False return io.DataBatch([batch_data], [batch_label], pad=pad) From 0802ae7fe13b722ff7de60506ebca1075a0f8289 Mon Sep 17 00:00:00 2001 From: stu1130 Date: Wed, 15 Aug 2018 11:31:48 -0700 Subject: [PATCH 13/21] update the test case when we call reset several times, the pad and data are not correct --- tests/python/unittest/test_image.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_image.py b/tests/python/unittest/test_image.py index 0a98b3c1c60c..f2fe6e12c50c 100644 --- a/tests/python/unittest/test_image.py +++ b/tests/python/unittest/test_image.py @@ -147,6 +147,7 @@ def check_imageiter(dtype='float32'): test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, imglist=imglist, path_imglist=path_imglist, path_root='', dtype=dtype) + # test batch data shape for _ in range(3): for batch in test_iter: assert batch.data[0].shape == (2, 3, 224, 224) @@ -158,7 +159,6 @@ def check_imageiter(dtype='float32'): for batch in test_iter: i += 1 assert i == 5 - # test last_batch(pad) test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, path_imglist=path_imglist, path_root='', dtype=dtype, last_batch='pad') @@ -177,11 +177,26 @@ def check_imageiter(dtype='float32'): i = 0 for batch in test_iter: if i == 0: - second_image = batch.data[0][2] + first_image = batch.data[0][0] + logging.info(first_image) + third_image = batch.data[0][2] i += 1 test_iter.reset() assert np.array_equal( - test_iter.next().data[0][0].asnumpy(), second_image.asnumpy()), 'failed in {}'.format(test) + test_iter.next().data[0][0].asnumpy(), third_image.asnumpy()), 'failed in {}'.format(test) + # test iteratopr work properly after calling reset when last_batch is roll_over + i = 0 + for batch in test_iter: + # the last one batch + if i == 3: + assert batch.pad == 1 + logging.info(first_image) + assert np.array_equal( + batch.data[0][2].asnumpy(), first_image.asnumpy()) + else: + assert batch.pad == 0 + i += 1 + assert i == 4 for dtype in ['int32', 'float32', 'int64', 'float64']: check_imageiter(dtype) From 4ed67563727b0a03ec08b3197a3f9880fbccb7cb Mon Sep 17 00:00:00 2001 From: stu1130 Date: Wed, 15 Aug 2018 14:19:33 -0700 Subject: [PATCH 14/21] delete logs we don't need --- tests/python/unittest/test_image.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/python/unittest/test_image.py b/tests/python/unittest/test_image.py index f2fe6e12c50c..0349cef7dea1 100644 --- a/tests/python/unittest/test_image.py +++ b/tests/python/unittest/test_image.py @@ -178,19 +178,17 @@ def check_imageiter(dtype='float32'): for batch in test_iter: if i == 0: first_image = batch.data[0][0] - logging.info(first_image) third_image = batch.data[0][2] i += 1 test_iter.reset() assert np.array_equal( - test_iter.next().data[0][0].asnumpy(), third_image.asnumpy()), 'failed in {}'.format(test) + test_iter.next().data[0][0].asnumpy(), third_image.asnumpy()) # test iteratopr work properly after calling reset when last_batch is roll_over i = 0 for batch in test_iter: # the last one batch if i == 3: assert batch.pad == 1 - logging.info(first_image) assert np.array_equal( batch.data[0][2].asnumpy(), first_image.asnumpy()) else: From a4b207d5fbc341f2ec4232e1c49d456de2bba5db Mon Sep 17 00:00:00 2001 From: stu1130 Date: Wed, 15 Aug 2018 23:00:31 -0700 Subject: [PATCH 15/21] refine some code comment --- python/mxnet/image/image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/image/image.py b/python/mxnet/image/image.py index c834487a0091..d836447bb6f7 100644 --- a/python/mxnet/image/image.py +++ b/python/mxnet/image/image.py @@ -1241,7 +1241,7 @@ def next(self): i = self._batchify(batch_data, batch_label) # calculate the padding pad = batch_size - i - # handle padding for 'pad' and 'roll_over' for the last batch + # handle padding of 'pad' and 'roll_over' for the last batch if pad != 0: if self.last_batch == 'discard': raise StopIteration From c384e9aed59c8ab47e57ca80e7ececf6c3a3f4dd Mon Sep 17 00:00:00 2001 From: stu1130 Date: Thu, 16 Aug 2018 13:50:34 -0700 Subject: [PATCH 16/21] change the roll_over behavior --- python/mxnet/image/image.py | 58 +++++++++++++++++++++++++++---------- 1 file changed, 42 insertions(+), 16 deletions(-) diff --git a/python/mxnet/image/image.py b/python/mxnet/image/image.py index d836447bb6f7..518881a43a04 100644 --- a/python/mxnet/image/image.py +++ b/python/mxnet/image/image.py @@ -1059,7 +1059,7 @@ class ImageIter(io.DataIter): Label name for provided symbols. dtype : str Label data type. Default: float32. Other options: int32, int64, float64 - last_batch : str, optional + last_batch_handle : str, optional How to handle the last batch. This parameter can be 'pad'(default), 'discard' or 'roll_over'. If 'pad', the last batch will be padded with data starting from the begining @@ -1073,7 +1073,7 @@ def __init__(self, batch_size, data_shape, label_width=1, path_imgrec=None, path_imglist=None, path_root=None, path_imgidx=None, shuffle=False, part_index=0, num_parts=1, aug_list=None, imglist=None, data_name='data', label_name='softmax_label', dtype='float32', - last_batch='pad', **kwargs): + last_batch_handle='pad', **kwargs): super(ImageIter, self).__init__() assert path_imgrec or path_imglist or (isinstance(imglist, list)) assert dtype in ['int32', 'float32', 'int64', 'float64'], dtype + ' label not supported' @@ -1156,20 +1156,24 @@ def __init__(self, batch_size, data_shape, label_width=1, self.auglist = aug_list self.cur = 0 self._is_allowed_reading = True - self.last_batch = last_batch + self.last_batch_handle = last_batch_handle self.num_image = len(self.seq) if self.seq is not None else None + self._cache_data = None + self._cache_label = None + self._cache_idx = None self.reset() def reset(self): """Resets the iterator to the beginning of the data.""" if self.seq is not None and self.shuffle: random.shuffle(self.seq) - if self.last_batch != 'roll_over' or \ - self._is_allowed_reading is True: + if self.last_batch_handle != 'roll_over' or \ + self._cache_data is None: if self.imgrec is not None: self.imgrec.reset() self.cur = 0 - self._is_allowed_reading = True + if self._is_allowed_reading is False: + self._is_allowed_reading = True def hard_reset(self): """Resets the iterator and ignore roll over data""" @@ -1186,7 +1190,7 @@ def next_sample(self): if self.cur < self.num_image: idx = self.seq[self.cur] else: - if self.last_batch != 'discard': + if self.last_batch_handle != 'discard': self.cur = 0 raise StopIteration self.cur += 1 @@ -1203,7 +1207,7 @@ def next_sample(self): else: s = self.imgrec.read() if s is None: - if self.last_batch != 'discard': + if self.last_batch_handle != 'discard': self.imgrec.reset() raise StopIteration header, img = recordio.unpack(s) @@ -1236,18 +1240,40 @@ def next(self): """Returns the next batch of data.""" batch_size = self.batch_size c, h, w = self.data_shape - batch_data = nd.empty((batch_size, c, h, w)) - batch_label = nd.empty(self.provide_label[0][1]) - i = self._batchify(batch_data, batch_label) + # if last batch data is rolled over + if self._cache_data is not None: + # check both the data and label have values + assert self._cache_label is not None, "_cache_label didn't have values" + assert self._cache_idx is not None, "_cache_idx didn't have values" + batch_data = self._cache_data + batch_label = self._cache_label + i = self._cache_idx + # clear the cache data + else: + batch_data = nd.empty((batch_size, c, h, w)) + batch_label = nd.empty(self.provide_label[0][1]) + i = self._batchify(batch_data, batch_label) # calculate the padding pad = batch_size - i - # handle padding of 'pad' and 'roll_over' for the last batch + # handle padding for the last batch if pad != 0: - if self.last_batch == 'discard': + if self.last_batch_handle == 'discard': + raise StopIteration + # if the option is 'roll_over', throw StopIteration and cache the data + elif self.last_batch_handle == 'roll_over' and \ + self._cache_data is None: + self._cache_data = batch_data + self._cache_label = batch_label + self._cache_idx = i raise StopIteration - # pad the rest of the data - _ = self._batchify(batch_data, batch_label, i) - self._is_allowed_reading = False + else: + _ = self._batchify(batch_data, batch_label, i) + if self.last_batch_handle == 'pad': + self._is_allowed_reading = False + else: + self._cache_data = None + self._cache_label = None + self._cache_idx = None return io.DataBatch([batch_data], [batch_label], pad=pad) def check_data_shape(self, data_shape): From bc4883341fd728fffd6d6eb62fb8789585cc0eee Mon Sep 17 00:00:00 2001 From: stu1130 Date: Thu, 16 Aug 2018 13:51:45 -0700 Subject: [PATCH 17/21] change the unit test according to the latest roll_over behavior --- tests/python/unittest/test_image.py | 36 ++++++++++++++--------------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/tests/python/unittest/test_image.py b/tests/python/unittest/test_image.py index 0349cef7dea1..affefca2c895 100644 --- a/tests/python/unittest/test_image.py +++ b/tests/python/unittest/test_image.py @@ -154,14 +154,14 @@ def check_imageiter(dtype='float32'): test_iter.reset() # test last batch handle(discard) test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, - path_imglist=path_imglist, path_root='', dtype=dtype, last_batch='discard') + path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='discard') i = 0 for batch in test_iter: i += 1 assert i == 5 - # test last_batch(pad) + # test last_batch_handle(pad) test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, - path_imglist=path_imglist, path_root='', dtype=dtype, last_batch='pad') + path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad') i = 0 for batch in test_iter: if i == 0: @@ -171,30 +171,28 @@ def check_imageiter(dtype='float32'): i += 1 assert i == 6 assert np.array_equal(first_three_data.asnumpy(), last_three_data.asnumpy()) - # test last_batch(roll_over) + # test last_batch_handle(roll_over) test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, - path_imglist=path_imglist, path_root='', dtype=dtype, last_batch='roll_over') + path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='roll_over') i = 0 for batch in test_iter: if i == 0: first_image = batch.data[0][0] - third_image = batch.data[0][2] i += 1 + assert i == 5 test_iter.reset() + first_batch_roll_over = test_iter.next() assert np.array_equal( - test_iter.next().data[0][0].asnumpy(), third_image.asnumpy()) - # test iteratopr work properly after calling reset when last_batch is roll_over - i = 0 - for batch in test_iter: - # the last one batch - if i == 3: - assert batch.pad == 1 - assert np.array_equal( - batch.data[0][2].asnumpy(), first_image.asnumpy()) - else: - assert batch.pad == 0 - i += 1 - assert i == 4 + first_batch_roll_over.data[0][1].asnumpy(), first_image.asnumpy()) + assert first_batch_roll_over.pad == 2 + # test iteratopr work properly after calling reset several times when last_batch_handle is roll_over + for _ in test_iter: + pass + test_iter.reset() + first_batch_roll_over_twice = test_iter.next() + assert np.array_equal( + first_batch_roll_over_twice.data[0][2].asnumpy(), first_image.asnumpy()) + assert first_batch_roll_over_twice.pad == 1 for dtype in ['int32', 'float32', 'int64', 'float64']: check_imageiter(dtype) From d7adf64b7978c085036f64ec15b56a89d8212d9c Mon Sep 17 00:00:00 2001 From: stu1130 Date: Thu, 16 Aug 2018 14:13:28 -0700 Subject: [PATCH 18/21] fix hard_reset bug which misses to clear the cache data --- python/mxnet/image/image.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/mxnet/image/image.py b/python/mxnet/image/image.py index 518881a43a04..41eb4d5ac016 100644 --- a/python/mxnet/image/image.py +++ b/python/mxnet/image/image.py @@ -1177,10 +1177,15 @@ def reset(self): def hard_reset(self): """Resets the iterator and ignore roll over data""" + if self.seq is not None and self.shuffle: + random.shuffle(self.seq) if self.imgrec is not None: self.imgrec.reset() self.cur = 0 self._is_allowed_reading = True + self._cache_data = None + self._cache_label = None + self._cache_idx = None def next_sample(self): """Helper function for reading in next sample.""" From e4a16db559c0cc727a39604f58b06e77a20fbf42 Mon Sep 17 00:00:00 2001 From: stu1130 Date: Fri, 17 Aug 2018 10:49:32 -0700 Subject: [PATCH 19/21] refine minor variable name --- python/mxnet/image/image.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/mxnet/image/image.py b/python/mxnet/image/image.py index 41eb4d5ac016..24f5309d136b 100644 --- a/python/mxnet/image/image.py +++ b/python/mxnet/image/image.py @@ -1155,7 +1155,7 @@ def __init__(self, batch_size, data_shape, label_width=1, else: self.auglist = aug_list self.cur = 0 - self._is_allowed_reading = True + self._allow_read = True self.last_batch_handle = last_batch_handle self.num_image = len(self.seq) if self.seq is not None else None self._cache_data = None @@ -1172,8 +1172,8 @@ def reset(self): if self.imgrec is not None: self.imgrec.reset() self.cur = 0 - if self._is_allowed_reading is False: - self._is_allowed_reading = True + if self._allow_read is False: + self._allow_read = True def hard_reset(self): """Resets the iterator and ignore roll over data""" @@ -1182,14 +1182,14 @@ def hard_reset(self): if self.imgrec is not None: self.imgrec.reset() self.cur = 0 - self._is_allowed_reading = True + self._allow_read = True self._cache_data = None self._cache_label = None self._cache_idx = None def next_sample(self): """Helper function for reading in next sample.""" - if self._is_allowed_reading is False: + if self._allow_read is False: raise StopIteration if self.seq is not None: if self.cur < self.num_image: @@ -1274,7 +1274,7 @@ def next(self): else: _ = self._batchify(batch_data, batch_label, i) if self.last_batch_handle == 'pad': - self._is_allowed_reading = False + self._allow_read = False else: self._cache_data = None self._cache_label = None From ab6e6011f0f25a8be358c6a59f315847f62f7bf6 Mon Sep 17 00:00:00 2001 From: stu1130 Date: Fri, 17 Aug 2018 10:51:28 -0700 Subject: [PATCH 20/21] assert second epoch size and add shuffle test case for sanity check --- tests/python/unittest/test_image.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_image.py b/tests/python/unittest/test_image.py index affefca2c895..69a636edc632 100644 --- a/tests/python/unittest/test_image.py +++ b/tests/python/unittest/test_image.py @@ -186,13 +186,22 @@ def check_imageiter(dtype='float32'): first_batch_roll_over.data[0][1].asnumpy(), first_image.asnumpy()) assert first_batch_roll_over.pad == 2 # test iteratopr work properly after calling reset several times when last_batch_handle is roll_over + # we've called next once + i = 1 for _ in test_iter: - pass + i += 1 + # test second epoch with size 5 + assert i == 5 test_iter.reset() first_batch_roll_over_twice = test_iter.next() assert np.array_equal( first_batch_roll_over_twice.data[0][2].asnumpy(), first_image.asnumpy()) assert first_batch_roll_over_twice.pad == 1 + # test shuffle option for sanity test + test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, shuffle=True, + path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad') + for _ in test_iter: + pass for dtype in ['int32', 'float32', 'int64', 'float64']: check_imageiter(dtype) From 5f8a6590031f6a67293730987adc63310198e7a5 Mon Sep 17 00:00:00 2001 From: stu1130 Date: Fri, 17 Aug 2018 11:08:18 -0700 Subject: [PATCH 21/21] check the third epoch instead of second epoch --- tests/python/unittest/test_image.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/python/unittest/test_image.py b/tests/python/unittest/test_image.py index 69a636edc632..0df08af317aa 100644 --- a/tests/python/unittest/test_image.py +++ b/tests/python/unittest/test_image.py @@ -186,17 +186,19 @@ def check_imageiter(dtype='float32'): first_batch_roll_over.data[0][1].asnumpy(), first_image.asnumpy()) assert first_batch_roll_over.pad == 2 # test iteratopr work properly after calling reset several times when last_batch_handle is roll_over - # we've called next once - i = 1 for _ in test_iter: - i += 1 - # test second epoch with size 5 - assert i == 5 + pass test_iter.reset() first_batch_roll_over_twice = test_iter.next() assert np.array_equal( first_batch_roll_over_twice.data[0][2].asnumpy(), first_image.asnumpy()) assert first_batch_roll_over_twice.pad == 1 + # we've called next once + i = 1 + for _ in test_iter: + i += 1 + # test the third epoch with size 6 + assert i == 6 # test shuffle option for sanity test test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, shuffle=True, path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad')