-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-737]Add last batch handle for imageiter #12131
Changes from all commits
9263a56
c3a7dc5
79f3002
0eb4573
c214f5b
a66452c
0f47661
10aed35
94c96a6
3ffbb06
fcd7310
7dc8150
0802ae7
4ed6756
a4b207d
c384e9a
bc48833
d7adf64
e4a16db
ab6e601
5f8a659
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1059,14 +1059,21 @@ class ImageIter(io.DataIter): | |
Label name for provided symbols. | ||
dtype : str | ||
Label data type. Default: float32. Other options: int32, int64, float64 | ||
last_batch_handle : str, optional | ||
How to handle the last batch. | ||
This parameter can be 'pad'(default), 'discard' or 'roll_over'. | ||
If 'pad', the last batch will be padded with data starting from the begining | ||
If 'discard', the last batch will be discarded | ||
If 'roll_over', the remaining elements will be rolled over to the next iteration | ||
kwargs : ... | ||
More arguments for creating augmenter. See mx.image.CreateAugmenter. | ||
""" | ||
|
||
def __init__(self, batch_size, data_shape, label_width=1, | ||
path_imgrec=None, path_imglist=None, path_root=None, path_imgidx=None, | ||
shuffle=False, part_index=0, num_parts=1, aug_list=None, imglist=None, | ||
data_name='data', label_name='softmax_label', dtype='float32', **kwargs): | ||
data_name='data', label_name='softmax_label', dtype='float32', | ||
last_batch_handle='pad', **kwargs): | ||
super(ImageIter, self).__init__() | ||
assert path_imgrec or path_imglist or (isinstance(imglist, list)) | ||
assert dtype in ['int32', 'float32', 'int64', 'float64'], dtype + ' label not supported' | ||
|
@@ -1129,7 +1136,6 @@ def __init__(self, batch_size, data_shape, label_width=1, | |
self.batch_size = batch_size | ||
self.data_shape = data_shape | ||
self.label_width = label_width | ||
|
||
self.shuffle = shuffle | ||
if self.imgrec is None: | ||
self.seq = imgkeys | ||
|
@@ -1149,22 +1155,49 @@ def __init__(self, batch_size, data_shape, label_width=1, | |
else: | ||
self.auglist = aug_list | ||
self.cur = 0 | ||
self._allow_read = True | ||
self.last_batch_handle = last_batch_handle | ||
self.num_image = len(self.seq) if self.seq is not None else None | ||
self._cache_data = None | ||
self._cache_label = None | ||
self._cache_idx = None | ||
self.reset() | ||
|
||
def reset(self): | ||
"""Resets the iterator to the beginning of the data.""" | ||
if self.shuffle: | ||
if self.seq is not None and self.shuffle: | ||
random.shuffle(self.seq) | ||
if self.last_batch_handle != 'roll_over' or \ | ||
self._cache_data is None: | ||
if self.imgrec is not None: | ||
self.imgrec.reset() | ||
self.cur = 0 | ||
if self._allow_read is False: | ||
self._allow_read = True | ||
|
||
def hard_reset(self): | ||
"""Resets the iterator and ignore roll over data""" | ||
if self.seq is not None and self.shuffle: | ||
random.shuffle(self.seq) | ||
if self.imgrec is not None: | ||
self.imgrec.reset() | ||
self.cur = 0 | ||
self._allow_read = True | ||
self._cache_data = None | ||
self._cache_label = None | ||
self._cache_idx = None | ||
|
||
def next_sample(self): | ||
"""Helper function for reading in next sample.""" | ||
if self._allow_read is False: | ||
raise StopIteration | ||
if self.seq is not None: | ||
if self.cur >= len(self.seq): | ||
if self.cur < self.num_image: | ||
idx = self.seq[self.cur] | ||
else: | ||
if self.last_batch_handle != 'discard': | ||
self.cur = 0 | ||
raise StopIteration | ||
idx = self.seq[self.cur] | ||
self.cur += 1 | ||
if self.imgrec is not None: | ||
s = self.imgrec.read_idx(idx) | ||
|
@@ -1179,17 +1212,16 @@ def next_sample(self): | |
else: | ||
s = self.imgrec.read() | ||
if s is None: | ||
if self.last_batch_handle != 'discard': | ||
self.imgrec.reset() | ||
raise StopIteration | ||
header, img = recordio.unpack(s) | ||
return header.label, img | ||
|
||
def next(self): | ||
"""Returns the next batch of data.""" | ||
def _batchify(self, batch_data, batch_label, start=0): | ||
"""Helper function for batchifying data""" | ||
i = start | ||
batch_size = self.batch_size | ||
c, h, w = self.data_shape | ||
batch_data = nd.empty((batch_size, c, h, w)) | ||
batch_label = nd.empty(self.provide_label[0][1]) | ||
i = 0 | ||
try: | ||
while i < batch_size: | ||
label, s = self.next_sample() | ||
|
@@ -1207,8 +1239,47 @@ def next(self): | |
except StopIteration: | ||
if not i: | ||
raise StopIteration | ||
return i | ||
|
||
return io.DataBatch([batch_data], [batch_label], batch_size - i) | ||
def next(self): | ||
"""Returns the next batch of data.""" | ||
batch_size = self.batch_size | ||
c, h, w = self.data_shape | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we assuming always channels_first? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes we are |
||
# if last batch data is rolled over | ||
if self._cache_data is not None: | ||
# check both the data and label have values | ||
assert self._cache_label is not None, "_cache_label didn't have values" | ||
assert self._cache_idx is not None, "_cache_idx didn't have values" | ||
batch_data = self._cache_data | ||
batch_label = self._cache_label | ||
i = self._cache_idx | ||
# clear the cache data | ||
else: | ||
batch_data = nd.empty((batch_size, c, h, w)) | ||
batch_label = nd.empty(self.provide_label[0][1]) | ||
i = self._batchify(batch_data, batch_label) | ||
# calculate the padding | ||
pad = batch_size - i | ||
# handle padding for the last batch | ||
if pad != 0: | ||
if self.last_batch_handle == 'discard': | ||
raise StopIteration | ||
# if the option is 'roll_over', throw StopIteration and cache the data | ||
elif self.last_batch_handle == 'roll_over' and \ | ||
self._cache_data is None: | ||
self._cache_data = batch_data | ||
self._cache_label = batch_label | ||
self._cache_idx = i | ||
raise StopIteration | ||
else: | ||
_ = self._batchify(batch_data, batch_label, i) | ||
if self.last_batch_handle == 'pad': | ||
self._allow_read = False | ||
else: | ||
self._cache_data = None | ||
self._cache_label = None | ||
self._cache_idx = None | ||
return io.DataBatch([batch_data], [batch_label], pad=pad) | ||
|
||
def check_data_shape(self, data_shape): | ||
"""Checks if the input data shape is valid""" | ||
|
@@ -1228,9 +1299,9 @@ def imdecode(self, s): | |
def locate(): | ||
"""Locate the image file/index if decode fails.""" | ||
if self.seq is not None: | ||
idx = self.seq[self.cur - 1] | ||
idx = self.seq[(self.cur % self.num_image) - 1] | ||
else: | ||
idx = self.cur - 1 | ||
idx = (self.cur % self.num_image) - 1 | ||
if self.imglist is not None: | ||
_, fname = self.imglist[idx] | ||
msg = "filename: {}".format(fname) | ||
|
@@ -1245,7 +1316,6 @@ def locate(): | |
|
||
def read_image(self, fname): | ||
"""Reads an input image `fname` and returns the decoded raw bytes. | ||
|
||
Example usage: | ||
---------- | ||
>>> dataIter.read_image('Face.jpg') # returns decoded raw bytes. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -80,7 +80,6 @@ def test_imread_vs_imdecode(self): | |
image_read = mx.img.image.imread(img) | ||
same(image.asnumpy(), image_read.asnumpy()) | ||
|
||
|
||
def test_imdecode(self): | ||
try: | ||
import cv2 | ||
|
@@ -130,29 +129,81 @@ def test_color_normalize(self): | |
mx.nd.array(mean), mx.nd.array(std)) | ||
assert_almost_equal(mx_result.asnumpy(), (src - mean) / std, atol=1e-3) | ||
|
||
|
||
def test_imageiter(self): | ||
def check_imageiter(dtype='float32'): | ||
im_list = [[np.random.randint(0, 5), x] for x in TestImage.IMAGES] | ||
test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, imglist=im_list, | ||
path_root='', dtype=dtype) | ||
for _ in range(3): | ||
for batch in test_iter: | ||
pass | ||
test_iter.reset() | ||
|
||
# test with list file | ||
fname = './data/test_imageiter.lst' | ||
file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x]) \ | ||
for k, x in enumerate(TestImage.IMAGES)] | ||
file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x]) | ||
for k, x in enumerate(TestImage.IMAGES)] | ||
with open(fname, 'w') as f: | ||
for line in file_list: | ||
f.write(line + '\n') | ||
|
||
test_list = ['imglist', 'path_imglist'] | ||
|
||
test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, path_imglist=fname, | ||
path_root='', dtype=dtype) | ||
for batch in test_iter: | ||
pass | ||
for test in test_list: | ||
imglist = im_list if test == 'imglist' else None | ||
path_imglist = fname if test == 'path_imglist' else None | ||
|
||
test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, imglist=imglist, | ||
path_imglist=path_imglist, path_root='', dtype=dtype) | ||
# test batch data shape | ||
for _ in range(3): | ||
for batch in test_iter: | ||
assert batch.data[0].shape == (2, 3, 224, 224) | ||
test_iter.reset() | ||
# test last batch handle(discard) | ||
test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, | ||
path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='discard') | ||
i = 0 | ||
for batch in test_iter: | ||
i += 1 | ||
assert i == 5 | ||
# test last_batch_handle(pad) | ||
test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, | ||
path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad') | ||
i = 0 | ||
for batch in test_iter: | ||
if i == 0: | ||
first_three_data = batch.data[0][:2] | ||
if i == 5: | ||
last_three_data = batch.data[0][1:] | ||
i += 1 | ||
assert i == 6 | ||
assert np.array_equal(first_three_data.asnumpy(), last_three_data.asnumpy()) | ||
# test last_batch_handle(roll_over) | ||
test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, | ||
path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='roll_over') | ||
i = 0 | ||
for batch in test_iter: | ||
if i == 0: | ||
first_image = batch.data[0][0] | ||
i += 1 | ||
assert i == 5 | ||
test_iter.reset() | ||
first_batch_roll_over = test_iter.next() | ||
assert np.array_equal( | ||
first_batch_roll_over.data[0][1].asnumpy(), first_image.asnumpy()) | ||
assert first_batch_roll_over.pad == 2 | ||
# test iteratopr work properly after calling reset several times when last_batch_handle is roll_over | ||
for _ in test_iter: | ||
pass | ||
test_iter.reset() | ||
first_batch_roll_over_twice = test_iter.next() | ||
assert np.array_equal( | ||
first_batch_roll_over_twice.data[0][2].asnumpy(), first_image.asnumpy()) | ||
assert first_batch_roll_over_twice.pad == 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. assert second epoch with size 6? |
||
# we've called next once | ||
i = 1 | ||
for _ in test_iter: | ||
i += 1 | ||
# test the third epoch with size 6 | ||
assert i == 6 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. test epoch size |
||
# test shuffle option for sanity test | ||
test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, shuffle=True, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add shuffle test case |
||
path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad') | ||
for _ in test_iter: | ||
pass | ||
|
||
for dtype in ['int32', 'float32', 'int64', 'float64']: | ||
check_imageiter(dtype) | ||
|
@@ -183,7 +234,6 @@ def test_augmenters(self): | |
for batch in test_iter: | ||
pass | ||
|
||
|
||
def test_image_detiter(self): | ||
im_list = [_generate_objects() + [x] for x in TestImage.IMAGES] | ||
det_iter = mx.image.ImageDetIter(2, (3, 300, 300), imglist=im_list, path_root='') | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: just self._allow_read = True would work here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, it works as well. Only iter with 'pad' would change this flag though.