Skip to content

Commit

Permalink
[MXNET-737]Add last batch handle for imageiter (apache#12131)
Browse files Browse the repository at this point in the history
[MXNET-737]Add last batch handle for imageiter
  • Loading branch information
stu1130 authored and sandeep-krishnamurthy committed Aug 18, 2018
1 parent 726bd60 commit 356be33
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 32 deletions.
100 changes: 85 additions & 15 deletions python/mxnet/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,14 +1059,21 @@ class ImageIter(io.DataIter):
Label name for provided symbols.
dtype : str
Label data type. Default: float32. Other options: int32, int64, float64
last_batch_handle : str, optional
How to handle the last batch.
This parameter can be 'pad'(default), 'discard' or 'roll_over'.
If 'pad', the last batch will be padded with data starting from the begining
If 'discard', the last batch will be discarded
If 'roll_over', the remaining elements will be rolled over to the next iteration
kwargs : ...
More arguments for creating augmenter. See mx.image.CreateAugmenter.
"""

def __init__(self, batch_size, data_shape, label_width=1,
path_imgrec=None, path_imglist=None, path_root=None, path_imgidx=None,
shuffle=False, part_index=0, num_parts=1, aug_list=None, imglist=None,
data_name='data', label_name='softmax_label', dtype='float32', **kwargs):
data_name='data', label_name='softmax_label', dtype='float32',
last_batch_handle='pad', **kwargs):
super(ImageIter, self).__init__()
assert path_imgrec or path_imglist or (isinstance(imglist, list))
assert dtype in ['int32', 'float32', 'int64', 'float64'], dtype + ' label not supported'
Expand Down Expand Up @@ -1129,7 +1136,6 @@ def __init__(self, batch_size, data_shape, label_width=1,
self.batch_size = batch_size
self.data_shape = data_shape
self.label_width = label_width

self.shuffle = shuffle
if self.imgrec is None:
self.seq = imgkeys
Expand All @@ -1149,22 +1155,49 @@ def __init__(self, batch_size, data_shape, label_width=1,
else:
self.auglist = aug_list
self.cur = 0
self._allow_read = True
self.last_batch_handle = last_batch_handle
self.num_image = len(self.seq) if self.seq is not None else None
self._cache_data = None
self._cache_label = None
self._cache_idx = None
self.reset()

def reset(self):
"""Resets the iterator to the beginning of the data."""
if self.shuffle:
if self.seq is not None and self.shuffle:
random.shuffle(self.seq)
if self.last_batch_handle != 'roll_over' or \
self._cache_data is None:
if self.imgrec is not None:
self.imgrec.reset()
self.cur = 0
if self._allow_read is False:
self._allow_read = True

def hard_reset(self):
"""Resets the iterator and ignore roll over data"""
if self.seq is not None and self.shuffle:
random.shuffle(self.seq)
if self.imgrec is not None:
self.imgrec.reset()
self.cur = 0
self._allow_read = True
self._cache_data = None
self._cache_label = None
self._cache_idx = None

def next_sample(self):
"""Helper function for reading in next sample."""
if self._allow_read is False:
raise StopIteration
if self.seq is not None:
if self.cur >= len(self.seq):
if self.cur < self.num_image:
idx = self.seq[self.cur]
else:
if self.last_batch_handle != 'discard':
self.cur = 0
raise StopIteration
idx = self.seq[self.cur]
self.cur += 1
if self.imgrec is not None:
s = self.imgrec.read_idx(idx)
Expand All @@ -1179,17 +1212,16 @@ def next_sample(self):
else:
s = self.imgrec.read()
if s is None:
if self.last_batch_handle != 'discard':
self.imgrec.reset()
raise StopIteration
header, img = recordio.unpack(s)
return header.label, img

def next(self):
"""Returns the next batch of data."""
def _batchify(self, batch_data, batch_label, start=0):
"""Helper function for batchifying data"""
i = start
batch_size = self.batch_size
c, h, w = self.data_shape
batch_data = nd.empty((batch_size, c, h, w))
batch_label = nd.empty(self.provide_label[0][1])
i = 0
try:
while i < batch_size:
label, s = self.next_sample()
Expand All @@ -1207,8 +1239,47 @@ def next(self):
except StopIteration:
if not i:
raise StopIteration
return i

return io.DataBatch([batch_data], [batch_label], batch_size - i)
def next(self):
"""Returns the next batch of data."""
batch_size = self.batch_size
c, h, w = self.data_shape
# if last batch data is rolled over
if self._cache_data is not None:
# check both the data and label have values
assert self._cache_label is not None, "_cache_label didn't have values"
assert self._cache_idx is not None, "_cache_idx didn't have values"
batch_data = self._cache_data
batch_label = self._cache_label
i = self._cache_idx
# clear the cache data
else:
batch_data = nd.empty((batch_size, c, h, w))
batch_label = nd.empty(self.provide_label[0][1])
i = self._batchify(batch_data, batch_label)
# calculate the padding
pad = batch_size - i
# handle padding for the last batch
if pad != 0:
if self.last_batch_handle == 'discard':
raise StopIteration
# if the option is 'roll_over', throw StopIteration and cache the data
elif self.last_batch_handle == 'roll_over' and \
self._cache_data is None:
self._cache_data = batch_data
self._cache_label = batch_label
self._cache_idx = i
raise StopIteration
else:
_ = self._batchify(batch_data, batch_label, i)
if self.last_batch_handle == 'pad':
self._allow_read = False
else:
self._cache_data = None
self._cache_label = None
self._cache_idx = None
return io.DataBatch([batch_data], [batch_label], pad=pad)

def check_data_shape(self, data_shape):
"""Checks if the input data shape is valid"""
Expand All @@ -1228,9 +1299,9 @@ def imdecode(self, s):
def locate():
"""Locate the image file/index if decode fails."""
if self.seq is not None:
idx = self.seq[self.cur - 1]
idx = self.seq[(self.cur % self.num_image) - 1]
else:
idx = self.cur - 1
idx = (self.cur % self.num_image) - 1
if self.imglist is not None:
_, fname = self.imglist[idx]
msg = "filename: {}".format(fname)
Expand All @@ -1245,7 +1316,6 @@ def locate():

def read_image(self, fname):
"""Reads an input image `fname` and returns the decoded raw bytes.
Example usage:
----------
>>> dataIter.read_image('Face.jpg') # returns decoded raw bytes.
Expand Down
84 changes: 67 additions & 17 deletions tests/python/unittest/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def test_imread_vs_imdecode(self):
image_read = mx.img.image.imread(img)
same(image.asnumpy(), image_read.asnumpy())


def test_imdecode(self):
try:
import cv2
Expand Down Expand Up @@ -130,29 +129,81 @@ def test_color_normalize(self):
mx.nd.array(mean), mx.nd.array(std))
assert_almost_equal(mx_result.asnumpy(), (src - mean) / std, atol=1e-3)


def test_imageiter(self):
def check_imageiter(dtype='float32'):
im_list = [[np.random.randint(0, 5), x] for x in TestImage.IMAGES]
test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, imglist=im_list,
path_root='', dtype=dtype)
for _ in range(3):
for batch in test_iter:
pass
test_iter.reset()

# test with list file
fname = './data/test_imageiter.lst'
file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x]) \
for k, x in enumerate(TestImage.IMAGES)]
file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x])
for k, x in enumerate(TestImage.IMAGES)]
with open(fname, 'w') as f:
for line in file_list:
f.write(line + '\n')

test_list = ['imglist', 'path_imglist']

test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, path_imglist=fname,
path_root='', dtype=dtype)
for batch in test_iter:
pass
for test in test_list:
imglist = im_list if test == 'imglist' else None
path_imglist = fname if test == 'path_imglist' else None

test_iter = mx.image.ImageIter(2, (3, 224, 224), label_width=1, imglist=imglist,
path_imglist=path_imglist, path_root='', dtype=dtype)
# test batch data shape
for _ in range(3):
for batch in test_iter:
assert batch.data[0].shape == (2, 3, 224, 224)
test_iter.reset()
# test last batch handle(discard)
test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='discard')
i = 0
for batch in test_iter:
i += 1
assert i == 5
# test last_batch_handle(pad)
test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad')
i = 0
for batch in test_iter:
if i == 0:
first_three_data = batch.data[0][:2]
if i == 5:
last_three_data = batch.data[0][1:]
i += 1
assert i == 6
assert np.array_equal(first_three_data.asnumpy(), last_three_data.asnumpy())
# test last_batch_handle(roll_over)
test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist,
path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='roll_over')
i = 0
for batch in test_iter:
if i == 0:
first_image = batch.data[0][0]
i += 1
assert i == 5
test_iter.reset()
first_batch_roll_over = test_iter.next()
assert np.array_equal(
first_batch_roll_over.data[0][1].asnumpy(), first_image.asnumpy())
assert first_batch_roll_over.pad == 2
# test iteratopr work properly after calling reset several times when last_batch_handle is roll_over
for _ in test_iter:
pass
test_iter.reset()
first_batch_roll_over_twice = test_iter.next()
assert np.array_equal(
first_batch_roll_over_twice.data[0][2].asnumpy(), first_image.asnumpy())
assert first_batch_roll_over_twice.pad == 1
# we've called next once
i = 1
for _ in test_iter:
i += 1
# test the third epoch with size 6
assert i == 6
# test shuffle option for sanity test
test_iter = mx.image.ImageIter(3, (3, 224, 224), label_width=1, imglist=imglist, shuffle=True,
path_imglist=path_imglist, path_root='', dtype=dtype, last_batch_handle='pad')
for _ in test_iter:
pass

for dtype in ['int32', 'float32', 'int64', 'float64']:
check_imageiter(dtype)
Expand Down Expand Up @@ -183,7 +234,6 @@ def test_augmenters(self):
for batch in test_iter:
pass


def test_image_detiter(self):
im_list = [_generate_objects() + [x] for x in TestImage.IMAGES]
det_iter = mx.image.ImageDetIter(2, (3, 300, 300), imglist=im_list, path_root='')
Expand Down

0 comments on commit 356be33

Please sign in to comment.