Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-737]Add last batch handle for imageiter #12131

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
9263a56
Add last_batch_handle on ImageIter
stu1130 Aug 10, 2018
c3a7dc5
Test the last_batch_handle for both of imglist and path_imglist
stu1130 Aug 10, 2018
79f3002
fix the filename typo
stu1130 Aug 10, 2018
0eb4573
change the name of parameter from 'last_batch_handle' to 'last_batch'
stu1130 Aug 14, 2018
c214f5b
1. change the name of the parameter from 'last_batch_handle' to 'last…
stu1130 Aug 14, 2018
a66452c
fix character '\xe2' encoding issue
stu1130 Aug 14, 2018
0f47661
fix roll_over bug happened when calling reset() several times
stu1130 Aug 14, 2018
10aed35
unify the logic of how to deal with 'discard', 'pad', ''roll_over' fo…
stu1130 Aug 14, 2018
94c96a6
remove the test case used locally
stu1130 Aug 14, 2018
3ffbb06
fix the bad indentation
stu1130 Aug 14, 2018
fcd7310
assert batch data shape
stu1130 Aug 15, 2018
7dc8150
1. delete the piece of code that handle the discard when initializati…
stu1130 Aug 15, 2018
0802ae7
update the test case when we call reset several times, the pad and da…
stu1130 Aug 15, 2018
4ed6756
delete logs we don't need
stu1130 Aug 15, 2018
a4b207d
refine some code comment
stu1130 Aug 16, 2018
c384e9a
change the roll_over behavior
stu1130 Aug 16, 2018
bc48833
change the unit test according to the latest roll_over behavior
stu1130 Aug 16, 2018
d7adf64
fix hard_reset bug which misses to clear the cache data
stu1130 Aug 16, 2018
e4a16db
refine minor variable name
stu1130 Aug 17, 2018
ab6e601
assert second epoch size and add shuffle test case for sanity check
stu1130 Aug 17, 2018
5f8a659
check the third epoch instead of second epoch
stu1130 Aug 17, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 85 additions & 15 deletions python/mxnet/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,14 +1059,21 @@ class ImageIter(io.DataIter):
Label name for provided symbols.
dtype : str
Label data type. Default: float32. Other options: int32, int64, float64
last_batch_handle : str, optional
How to handle the last batch.
This parameter can be 'pad'(default), 'discard' or 'roll_over'.
If 'pad', the last batch will be padded with data starting from the begining
If 'discard', the last batch will be discarded
If 'roll_over', the remaining elements will be rolled over to the next iteration
kwargs : ...
More arguments for creating augmenter. See mx.image.CreateAugmenter.
"""

def __init__(self, batch_size, data_shape, label_width=1,
path_imgrec=None, path_imglist=None, path_root=None, path_imgidx=None,
shuffle=False, part_index=0, num_parts=1, aug_list=None, imglist=None,
data_name='data', label_name='softmax_label', dtype='float32', **kwargs):
data_name='data', label_name='softmax_label', dtype='float32',
last_batch_handle='pad', **kwargs):
super(ImageIter, self).__init__()
assert path_imgrec or path_imglist or (isinstance(imglist, list))
assert dtype in ['int32', 'float32', 'int64', 'float64'], dtype + ' label not supported'
Expand Down Expand Up @@ -1129,7 +1136,6 @@ def __init__(self, batch_size, data_shape, label_width=1,
self.batch_size = batch_size
self.data_shape = data_shape
self.label_width = label_width

self.shuffle = shuffle
if self.imgrec is None:
self.seq = imgkeys
Expand All @@ -1149,22 +1155,49 @@ def __init__(self, batch_size, data_shape, label_width=1,
else:
self.auglist = aug_list
self.cur = 0
self._allow_read = True
self.last_batch_handle = last_batch_handle
self.num_image = len(self.seq) if self.seq is not None else None
self._cache_data = None
self._cache_label = None
self._cache_idx = None
self.reset()

def reset(self):
"""Resets the iterator to the beginning of the data."""
if self.shuffle:
if self.seq is not None and self.shuffle:
random.shuffle(self.seq)
if self.last_batch_handle != 'roll_over' or \
self._cache_data is None:
if self.imgrec is not None:
self.imgrec.reset()
self.cur = 0
if self._allow_read is False:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: just self._allow_read = True would work here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it works as well. Only iter with 'pad' would change this flag though.

self._allow_read = True

def hard_reset(self):
"""Resets the iterator and ignore roll over data"""
if self.seq is not None and self.shuffle:
random.shuffle(self.seq)
if self.imgrec is not None:
self.imgrec.reset()
self.cur = 0
self._allow_read = True
self._cache_data = None
self._cache_label = None
self._cache_idx = None

def next_sample(self):
"""Helper function for reading in next sample."""
if self._allow_read is False:
raise StopIteration
if self.seq is not None:
if self.cur >= len(self.seq):
if self.cur < self.num_image:
idx = self.seq[self.cur]
else:
if self.last_batch_handle != 'discard':
self.cur = 0
raise StopIteration
idx = self.seq[self.cur]
self.cur += 1
if self.imgrec is not None:
s = self.imgrec.read_idx(idx)
Expand All @@ -1179,17 +1212,16 @@ def next_sample(self):
else:
s = self.imgrec.read()
if s is None:
if self.last_batch_handle != 'discard':
self.imgrec.reset()
raise StopIteration
header, img = recordio.unpack(s)
return header.label, img

def next(self):
"""Returns the next batch of data."""
def _batchify(self, batch_data, batch_label, start=0):
"""Helper function for batchifying data"""
i = start
batch_size = self.batch_size
c, h, w = self.data_shape
batch_data = nd.empty((batch_size, c, h, w))
batch_label = nd.empty(self.provide_label[0][1])
i = 0
try:
while i < batch_size:
label, s = self.next_sample()
Expand All @@ -1207,8 +1239,47 @@ def next(self):
except StopIteration:
if not i:
raise StopIteration
return i

return io.DataBatch([batch_data], [batch_label], batch_size - i)
def next(self):
"""Returns the next batch of data."""
batch_size = self.batch_size
c, h, w = self.data_shape
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we assuming always channels_first?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we are

# if last batch data is rolled over
if self._cache_data is not None:
# check both the data and label have values
assert self._cache_label is not None, "_cache_label didn't have values"
assert self._cache_idx is not None, "_cache_idx didn't have values"
batch_data = self._cache_data
batch_label = self._cache_label
i = self._cache_idx
# clear the cache data
else:
batch_data = nd.empty((batch_size, c, h, w))
batch_label = nd.empty(self.provide_label[0][1])
i = self._batchify(batch_data, batch_label)
# calculate the padding
pad = batch_size - i
# handle padding for the last batch
if pad != 0:
if self.last_batch_handle == 'discard':
raise StopIteration
# if the option is 'roll_over', throw StopIteration and cache the data
elif self.last_batch_handle == 'roll_over' and \
self._cache_data is None:
self._cache_data = batch_data
self._cache_label = batch_label
self._cache_idx = i
raise StopIteration
else:
_ = self._batchify(batch_data, batch_label, i)
if self.last_batch_handle == 'pad':
self._allow_read = False
else:
self._cache_data = None
self._cache_label = None
self._cache_idx = None
return io.DataBatch([batch_data], [batch_label], pad=pad)

def check_data_shape(self, data_shape):
"""Checks if the input data shape is valid"""
Expand All @@ -1228,9 +1299,9 @@ def imdecode(self, s):
def locate():
"""Locate the image file/index if decode fails."""
if self.seq is not None:
idx = self.seq[self.cur - 1]
idx = self.seq[(self.cur % self.num_image) - 1]
else:
idx = self.cur - 1
idx = (self.cur % self.num_image) - 1
if self.imglist is not None:
_, fname = self.imglist[idx]
msg = "filename: {}".format(fname)
Expand All @@ -1245,7 +1316,6 @@ def locate():

def read_image(self, fname):
"""Reads an input image `fname` and returns the decoded raw bytes.

Example usage:
----------
>>> dataIter.read_image('Face.jpg') # returns decoded raw bytes.
Expand Down
84 changes: 67 additions & 17 deletions tests/python/unittest/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def test_imread_vs_imdecode(self):
image_read = mx.img.image.imread(img)
same(image.asnumpy(), image_read.asnumpy())


def test_imdecode(self):
try:
import cv2
Expand Down Expand Up @@ -130,29 +129,81 @@ def test_color_normalize(self):
mx.nd.array(mean), mx.nd.array(std))
assert_almost_equal(mx_result.asnumpy(), (src - mean) / std, atol=1e-3)


def test_imageiter(self):
def check_imageiter(dtype='float32'):
im_list = [[np.random.randint(0, 5), x] for x in TestImage.IMAGES]
test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, imglist=im_list,
path_root='', dtype=dtype)
for _ in range(3):
for batch in test_iter:
pass
test_iter.reset()

# test with list file
fname = './data/test_imageiter.lst'
file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x]) \
for k, x in enumerate(TestImage.IMAGES)]
file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x])
for k, x in enumerate(TestImage.IMAGES)]
with open(fname, 'w') as f:
for line in file_list:
f.write(line + '\n')

test_list = ['imglist', 'path_imglist']

test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, path_imglist=fname,
path_root='', dtype=dtype)
for batch in test_iter:
pass
for test in test_list:
imglist = im_list if test == 'imglist' else None
path_imglist = fname if test == 'path_imglist' else None

test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, imglist=imglist,
path_imglist=path_imglist, path_root='', dtype=dtype)
# test batch data shape
for _ in range(3):
for batch in test_iter:
assert batch.data[0].shape == (2, 3, 224, 224)
test_iter.reset()
# test last batch handle(discard)
test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='discard')
i = 0
for batch in test_iter:
i += 1
assert i == 5
# test last_batch_handle(pad)
test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad')
i = 0
for batch in test_iter:
if i == 0:
first_three_data = batch.data[0][:2]
if i == 5:
last_three_data = batch.data[0][1:]
i += 1
assert i == 6
assert np.array_equal(first_three_data.asnumpy(), last_three_data.asnumpy())
# test last_batch_handle(roll_over)
test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='roll_over')
i = 0
for batch in test_iter:
if i == 0:
first_image = batch.data[0][0]
i += 1
assert i == 5
test_iter.reset()
first_batch_roll_over = test_iter.next()
assert np.array_equal(
first_batch_roll_over.data[0][1].asnumpy(), first_image.asnumpy())
assert first_batch_roll_over.pad == 2
# test iteratopr work properly after calling reset several times when last_batch_handle is roll_over
for _ in test_iter:
pass
test_iter.reset()
first_batch_roll_over_twice = test_iter.next()
assert np.array_equal(
first_batch_roll_over_twice.data[0][2].asnumpy(), first_image.asnumpy())
assert first_batch_roll_over_twice.pad == 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert second epoch with size 6?
Also can you add test for shuffle=True, just for sanity test, no value assert is required for shuffle mode.

# we've called next once
i = 1
for _ in test_iter:
i += 1
# test the third epoch with size 6
assert i == 6
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test epoch size

# test shuffle option for sanity test
test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, shuffle=True,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add shuffle test case

path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad')
for _ in test_iter:
pass

for dtype in ['int32', 'float32', 'int64', 'float64']:
check_imageiter(dtype)
Expand Down Expand Up @@ -183,7 +234,6 @@ def test_augmenters(self):
for batch in test_iter:
pass


def test_image_detiter(self):
im_list = [_generate_objects() + [x] for x in TestImage.IMAGES]
det_iter = mx.image.ImageDetIter(2, (3, 300, 300), imglist=im_list, path_root='')
Expand Down