diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py index eb1eb419cd02..06e4b00d43fc 100644 --- a/python/mxnet/gluon/data/dataloader.py +++ b/python/mxnet/gluon/data/dataloader.py @@ -277,7 +277,6 @@ def default_batchify_fn(data): num_workers : int, default 0 The number of multiprocessing workers to use for data preprocessing. - `num_workers > 0` is not supported on Windows yet. """ def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None, last_batch=None, batch_sampler=None, batchify_fn=None, @@ -315,9 +314,10 @@ def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None, def __iter__(self): if self._num_workers == 0: - generator = lambda: [(yield self._batchify_fn([self._dataset[idx] for idx in batch])) - for batch in self._batch_sampler] - return generator() + def single_worker_iter(): + for batch in self._batch_sampler: + yield self._batchify_fn([self._dataset[idx] for idx in batch]) + return single_worker_iter() # multi-worker return _MultiWorkerIter(self._num_workers, self._dataset, diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py index 043804487b5e..23d57eef3c8f 100644 --- a/tests/python/unittest/test_gluon_data.py +++ b/tests/python/unittest/test_gluon_data.py @@ -123,106 +123,99 @@ def __len__(self): def __getitem__(self, key): return mx.nd.full((10,), key) -@unittest.skip("Somehow fails with MKL. Cannot reproduce locally") def test_multi_worker(): data = Dataset() loader = gluon.data.DataLoader(data, batch_size=1, num_workers=5) for i, batch in enumerate(loader): - assert (batch.asnumpy() == i).all() + assert (batch.asnumpy().astype('int32') == i).all() + +class _Dummy(Dataset): + """Dummpy dataset for randomized shape arrays.""" + def __init__(self, random_shape): + self.random_shape = random_shape + + def __getitem__(self, idx): + key = idx + if self.random_shape: + out = np.random.uniform(size=(random.randint(1000, 1100), 40)) + labels = np.random.uniform(size=(random.randint(10, 15))) + else: + out = np.random.uniform(size=(1000, 40)) + labels = np.random.uniform(size=(10)) + return key, out, labels + + def __len__(self): + return 50 + +def _batchify_list(data): + """ + return list of ndarray without stack/concat/pad + """ + if isinstance(data, (tuple, list)): + return list(data) + if isinstance(data, mx.nd.NDArray): + return [data] + return data + +def _batchify(data): + """ + Collate data into batch. Use shared memory for stacking. + + :param data: a list of array, with layout of 'NTC'. + :return either x and x's unpadded lengths, or x, x's unpadded lengths, y and y's unpadded lengths + if labels are not supplied. + """ + + # input layout is NTC + keys, inputs, labels = [item[0] for item in data], [item[1] for item in data], \ + [item[2] for item in data] + + if len(data) > 1: + max_data_len = max([seq.shape[0] for seq in inputs]) + max_labels_len = 0 if not labels else max([seq.shape[0] for seq in labels]) + else: + max_data_len = inputs[0].shape[0] + max_labels_len = 0 if not labels else labels[0].shape[0] + + x_lens = [item.shape[0] for item in inputs] + y_lens = [item.shape[0] for item in labels] + + for i, seq in enumerate(inputs): + pad_len = max_data_len - seq.shape[0] + inputs[i] = np.pad(seq, ((0, pad_len), (0, 0)), 'constant', constant_values=0) + labels[i] = np.pad(labels[i], (0, max_labels_len - labels[i].shape[0]), + 'constant', constant_values=-1) + + inputs = np.asarray(inputs, dtype=np.float32) + if labels is not None: + labels = np.asarray(labels, dtype=np.float32) + inputs = inputs.transpose((1, 0, 2)) + labels = labels.transpose((1, 0)) + + return (nd.array(inputs, dtype=inputs.dtype, ctx=context.Context('cpu_shared', 0)), + nd.array(x_lens, ctx=context.Context('cpu_shared', 0))) \ + if labels is None else ( + nd.array(inputs, dtype=inputs.dtype, ctx=context.Context('cpu_shared', 0)), + nd.array(x_lens, ctx=context.Context('cpu_shared', 0)), + nd.array(labels, dtype=labels.dtype, ctx=context.Context('cpu_shared', 0)), + nd.array(y_lens, ctx=context.Context('cpu_shared', 0))) @with_seed() def test_multi_worker_forked_data_loader(): """ Test should successfully run its course of multi-process/forked data loader without errors """ - class Dummy(Dataset): - def __init__(self, random_shape): - self.random_shape = random_shape - - def __getitem__(self, idx): - key = idx - if self.random_shape: - out = np.random.uniform(size=(random.randint(1000, 1100), 40)) - labels = np.random.uniform(size=(random.randint(10, 15))) - else: - out = np.random.uniform(size=(1000, 40)) - labels = np.random.uniform(size=(10)) - return key, out, labels - - def __len__(self): - return 50 - - def batchify_list(self, data): - """ - return list of ndarray without stack/concat/pad - """ - if isinstance(data, (tuple, list)): - return list(data) - if isinstance(data, mx.nd.NDArray): - return [data] - return data - - def batchify(self, data): - """ - Collate data into batch. Use shared memory for stacking. - - :param data: a list of array, with layout of 'NTC'. - :return either x and x's unpadded lengths, or x, x's unpadded lengths, y and y's unpadded lengths - if labels are not supplied. - """ - - # input layout is NTC - keys, inputs, labels = [item[0] for item in data], [item[1] for item in data], \ - [item[2] for item in data] - - if len(data) > 1: - max_data_len = max([seq.shape[0] for seq in inputs]) - max_labels_len = 0 if not labels else max([seq.shape[0] for seq in labels]) - else: - max_data_len = inputs[0].shape[0] - max_labels_len = 0 if not labels else labels[0].shape[0] - - x_lens = [item.shape[0] for item in inputs] - y_lens = [item.shape[0] for item in labels] - - for i, seq in enumerate(inputs): - pad_len = max_data_len - seq.shape[0] - inputs[i] = np.pad(seq, ((0, pad_len), (0, 0)), 'constant', constant_values=0) - labels[i] = np.pad(labels[i], (0, max_labels_len - labels[i].shape[0]), - 'constant', constant_values=-1) - - inputs = np.asarray(inputs, dtype=np.float32) - if labels is not None: - labels = np.asarray(labels, dtype=np.float32) - inputs = inputs.transpose((1, 0, 2)) - labels = labels.transpose((1, 0)) - - return (nd.array(inputs, dtype=inputs.dtype, ctx=context.Context('cpu_shared', 0)), - nd.array(x_lens, ctx=context.Context('cpu_shared', 0))) \ - if labels is None else ( - nd.array(inputs, dtype=inputs.dtype, ctx=context.Context('cpu_shared', 0)), - nd.array(x_lens, ctx=context.Context('cpu_shared', 0)), - nd.array(labels, dtype=labels.dtype, ctx=context.Context('cpu_shared', 0)), - nd.array(y_lens, ctx=context.Context('cpu_shared', 0))) - - - # This test is pointless on Windows because Windows doesn't fork - if platform.system() != 'Windows': - data = Dummy(True) - loader = DataLoader(data, batch_size=40, batchify_fn=data.batchify, num_workers=2) - for epoch in range(1): - for i, data in enumerate(loader): - if i % 100 == 0: - print(data) - print('{}:{}'.format(epoch, i)) - - data = Dummy(True) - loader = DataLoader(data, batch_size=40, batchify_fn=data.batchify_list, num_workers=2) - for epoch in range(1): - for i, data in enumerate(loader): - if i % 100 == 0: - print(data) - print('{}:{}'.format(epoch, i)) + data = _Dummy(False) + loader = DataLoader(data, batch_size=40, batchify_fn=_batchify, num_workers=2) + for epoch in range(1): + for i, data in enumerate(loader): + pass + + data = _Dummy(True) + loader = DataLoader(data, batch_size=40, batchify_fn=_batchify_list, num_workers=2) + for epoch in range(1): + for i, data in enumerate(loader): + pass if __name__ == '__main__': import nose