Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
update doc for dataloader w/ windows unittests
Browse files Browse the repository at this point in the history
  • Loading branch information
zhreshold committed Jul 13, 2018
1 parent e912314 commit 2ec3c3b
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 94 deletions.
8 changes: 4 additions & 4 deletions python/mxnet/gluon/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,6 @@ def default_batchify_fn(data):
num_workers : int, default 0
The number of multiprocessing workers to use for data preprocessing.
`num_workers > 0` is not supported on Windows yet.
"""
def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
last_batch=None, batch_sampler=None, batchify_fn=None,
Expand Down Expand Up @@ -315,9 +314,10 @@ def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,

def __iter__(self):
if self._num_workers == 0:
generator = lambda: [(yield self._batchify_fn([self._dataset[idx] for idx in batch]))
for batch in self._batch_sampler]
return generator()
def single_worker_iter():
for batch in self._batch_sampler:
yield self._batchify_fn([self._dataset[idx] for idx in batch])
return single_worker_iter()

# multi-worker
return _MultiWorkerIter(self._num_workers, self._dataset,
Expand Down
173 changes: 83 additions & 90 deletions tests/python/unittest/test_gluon_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,106 +123,99 @@ def __len__(self):
def __getitem__(self, key):
return mx.nd.full((10,), key)

@unittest.skip("Somehow fails with MKL. Cannot reproduce locally")
def test_multi_worker():
data = Dataset()
loader = gluon.data.DataLoader(data, batch_size=1, num_workers=5)
for i, batch in enumerate(loader):
assert (batch.asnumpy() == i).all()
assert (batch.asnumpy().astype('int32') == i).all()

class _Dummy(Dataset):
"""Dummpy dataset for randomized shape arrays."""
def __init__(self, random_shape):
self.random_shape = random_shape

def __getitem__(self, idx):
key = idx
if self.random_shape:
out = np.random.uniform(size=(random.randint(1000, 1100), 40))
labels = np.random.uniform(size=(random.randint(10, 15)))
else:
out = np.random.uniform(size=(1000, 40))
labels = np.random.uniform(size=(10))
return key, out, labels

def __len__(self):
return 50

def _batchify_list(data):
"""
return list of ndarray without stack/concat/pad
"""
if isinstance(data, (tuple, list)):
return list(data)
if isinstance(data, mx.nd.NDArray):
return [data]
return data

def _batchify(data):
"""
Collate data into batch. Use shared memory for stacking.
:param data: a list of array, with layout of 'NTC'.
:return either x and x's unpadded lengths, or x, x's unpadded lengths, y and y's unpadded lengths
if labels are not supplied.
"""

# input layout is NTC
keys, inputs, labels = [item[0] for item in data], [item[1] for item in data], \
[item[2] for item in data]

if len(data) > 1:
max_data_len = max([seq.shape[0] for seq in inputs])
max_labels_len = 0 if not labels else max([seq.shape[0] for seq in labels])
else:
max_data_len = inputs[0].shape[0]
max_labels_len = 0 if not labels else labels[0].shape[0]

x_lens = [item.shape[0] for item in inputs]
y_lens = [item.shape[0] for item in labels]

for i, seq in enumerate(inputs):
pad_len = max_data_len - seq.shape[0]
inputs[i] = np.pad(seq, ((0, pad_len), (0, 0)), 'constant', constant_values=0)
labels[i] = np.pad(labels[i], (0, max_labels_len - labels[i].shape[0]),
'constant', constant_values=-1)

inputs = np.asarray(inputs, dtype=np.float32)
if labels is not None:
labels = np.asarray(labels, dtype=np.float32)
inputs = inputs.transpose((1, 0, 2))
labels = labels.transpose((1, 0))

return (nd.array(inputs, dtype=inputs.dtype, ctx=context.Context('cpu_shared', 0)),
nd.array(x_lens, ctx=context.Context('cpu_shared', 0))) \
if labels is None else (
nd.array(inputs, dtype=inputs.dtype, ctx=context.Context('cpu_shared', 0)),
nd.array(x_lens, ctx=context.Context('cpu_shared', 0)),
nd.array(labels, dtype=labels.dtype, ctx=context.Context('cpu_shared', 0)),
nd.array(y_lens, ctx=context.Context('cpu_shared', 0)))

@with_seed()
def test_multi_worker_forked_data_loader():
"""
Test should successfully run its course of multi-process/forked data loader without errors
"""
class Dummy(Dataset):
def __init__(self, random_shape):
self.random_shape = random_shape

def __getitem__(self, idx):
key = idx
if self.random_shape:
out = np.random.uniform(size=(random.randint(1000, 1100), 40))
labels = np.random.uniform(size=(random.randint(10, 15)))
else:
out = np.random.uniform(size=(1000, 40))
labels = np.random.uniform(size=(10))
return key, out, labels

def __len__(self):
return 50

def batchify_list(self, data):
"""
return list of ndarray without stack/concat/pad
"""
if isinstance(data, (tuple, list)):
return list(data)
if isinstance(data, mx.nd.NDArray):
return [data]
return data

def batchify(self, data):
"""
Collate data into batch. Use shared memory for stacking.
:param data: a list of array, with layout of 'NTC'.
:return either x and x's unpadded lengths, or x, x's unpadded lengths, y and y's unpadded lengths
if labels are not supplied.
"""

# input layout is NTC
keys, inputs, labels = [item[0] for item in data], [item[1] for item in data], \
[item[2] for item in data]

if len(data) > 1:
max_data_len = max([seq.shape[0] for seq in inputs])
max_labels_len = 0 if not labels else max([seq.shape[0] for seq in labels])
else:
max_data_len = inputs[0].shape[0]
max_labels_len = 0 if not labels else labels[0].shape[0]

x_lens = [item.shape[0] for item in inputs]
y_lens = [item.shape[0] for item in labels]

for i, seq in enumerate(inputs):
pad_len = max_data_len - seq.shape[0]
inputs[i] = np.pad(seq, ((0, pad_len), (0, 0)), 'constant', constant_values=0)
labels[i] = np.pad(labels[i], (0, max_labels_len - labels[i].shape[0]),
'constant', constant_values=-1)

inputs = np.asarray(inputs, dtype=np.float32)
if labels is not None:
labels = np.asarray(labels, dtype=np.float32)
inputs = inputs.transpose((1, 0, 2))
labels = labels.transpose((1, 0))

return (nd.array(inputs, dtype=inputs.dtype, ctx=context.Context('cpu_shared', 0)),
nd.array(x_lens, ctx=context.Context('cpu_shared', 0))) \
if labels is None else (
nd.array(inputs, dtype=inputs.dtype, ctx=context.Context('cpu_shared', 0)),
nd.array(x_lens, ctx=context.Context('cpu_shared', 0)),
nd.array(labels, dtype=labels.dtype, ctx=context.Context('cpu_shared', 0)),
nd.array(y_lens, ctx=context.Context('cpu_shared', 0)))


# This test is pointless on Windows because Windows doesn't fork
if platform.system() != 'Windows':
data = Dummy(True)
loader = DataLoader(data, batch_size=40, batchify_fn=data.batchify, num_workers=2)
for epoch in range(1):
for i, data in enumerate(loader):
if i % 100 == 0:
print(data)
print('{}:{}'.format(epoch, i))

data = Dummy(True)
loader = DataLoader(data, batch_size=40, batchify_fn=data.batchify_list, num_workers=2)
for epoch in range(1):
for i, data in enumerate(loader):
if i % 100 == 0:
print(data)
print('{}:{}'.format(epoch, i))
data = _Dummy(False)
loader = DataLoader(data, batch_size=40, batchify_fn=_batchify, num_workers=2)
for epoch in range(1):
for i, data in enumerate(loader):
pass

data = _Dummy(True)
loader = DataLoader(data, batch_size=40, batchify_fn=_batchify_list, num_workers=2)
for epoch in range(1):
for i, data in enumerate(loader):
pass

if __name__ == '__main__':
import nose
Expand Down

0 comments on commit 2ec3c3b

Please sign in to comment.