This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
/
Copy pathdataloader.py
816 lines (719 loc) · 33.6 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
# pylint: disable=ungrouped-imports
"""Dataset generator."""
__all__ = ['DataLoader']
import pickle
import logging
import io
import sys
import signal
import multiprocessing
import multiprocessing.queues
from multiprocessing.reduction import ForkingPickler
from multiprocessing.pool import ThreadPool
import threading
import numpy as np
try:
import multiprocessing.resource_sharer
except ImportError:
pass
from . import sampler as _sampler
from . import batchify as _batchify
from ... import ndarray as nd, context
from ...util import is_np_shape, is_np_array, set_np
from ... import numpy as _mx_np # pylint: disable=reimported
if sys.platform == 'darwin' or sys.platform == 'win32':
def rebuild_ndarray(*args):
"""Rebuild ndarray from pickled shared memory"""
# pylint: disable=no-value-for-parameter
return nd.NDArray(nd.ndarray._new_from_shared_mem(*args))
def reduce_ndarray(data):
"""Reduce ndarray to shared memory handle"""
return rebuild_ndarray, data._to_shared_mem()
else:
def rebuild_ndarray(pid, fd, shape, dtype):
"""Rebuild ndarray from pickled shared memory"""
# pylint: disable=no-value-for-parameter
fd = fd.detach()
return nd.NDArray(nd.ndarray._new_from_shared_mem(pid, fd, shape, dtype))
def reduce_ndarray(data):
"""Reduce ndarray to shared memory handle"""
# keep a local ref before duplicating fd
data = data.as_in_context(context.Context('cpu_shared', 0))
pid, fd, shape, dtype = data._to_shared_mem()
fd = multiprocessing.reduction.DupFd(fd)
return rebuild_ndarray, (pid, fd, shape, dtype)
ForkingPickler.register(nd.NDArray, reduce_ndarray)
if sys.platform == 'darwin' or sys.platform == 'win32':
def rebuild_np_ndarray(*args):
"""Rebuild ndarray from pickled shared memory"""
# pylint: disable=no-value-for-parameter
return _mx_np.ndarray(nd.ndarray._new_from_shared_mem(*args))
def reduce_np_ndarray(data):
"""Reduce ndarray to shared memory handle"""
return rebuild_np_ndarray, data._to_shared_mem()
else:
def rebuild_np_ndarray(pid, fd, shape, dtype):
"""Rebuild ndarray from pickled shared memory"""
# pylint: disable=no-value-for-parameter
fd = fd.detach()
return _mx_np.ndarray(nd.ndarray._new_from_shared_mem(pid, fd, shape, dtype))
def reduce_np_ndarray(data):
"""Reduce ndarray to shared memory handle"""
# keep a local ref before duplicating fd
data = data.as_in_context(context.Context('cpu_shared', 0))
pid, fd, shape, dtype = data._to_shared_mem()
fd = multiprocessing.reduction.DupFd(fd)
return rebuild_np_ndarray, (pid, fd, shape, dtype)
ForkingPickler.register(_mx_np.ndarray, reduce_np_ndarray)
class ConnectionWrapper(object):
"""Connection wrapper for multiprocessing that supports sending
NDArray via shared memory."""
def __init__(self, conn):
self._conn = conn
def send(self, obj):
"""Send object"""
buf = io.BytesIO()
ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(obj)
self.send_bytes(buf.getvalue())
def recv(self):
"""Receive object"""
buf = self.recv_bytes()
return pickle.loads(buf)
def __getattr__(self, name):
"""Emmulate conn"""
attr = self.__dict__.get('_conn', None)
return getattr(attr, name)
class Queue(multiprocessing.queues.Queue):
"""Wrapper for multiprocessing queue that dumps NDArray with shared memory."""
def __init__(self, *args, **kwargs):
super().__init__(*args, ctx=multiprocessing.get_context(), **kwargs)
self._reader = ConnectionWrapper(self._reader)
self._writer = ConnectionWrapper(self._writer)
self._send = self._writer.send
self._recv = self._reader.recv
class SimpleQueue(multiprocessing.queues.SimpleQueue):
"""Wrapper for multiprocessing SimpleQueue that dumps NDArray with shared memory.
SimpleQueue don't use threading internally.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, ctx=multiprocessing.get_context(), **kwargs)
self._reader = ConnectionWrapper(self._reader)
self._writer = ConnectionWrapper(self._writer)
self._send = self._writer.send
self._recv = self._reader.recv
def default_batchify_fn(data):
"""Collate data into batch."""
if isinstance(data[0], nd.NDArray):
return _mx_np.stack(data) if is_np_array() else nd.stack(*data)
elif isinstance(data[0], tuple):
data = zip(*data)
return [default_batchify_fn(i) for i in data]
else:
data = np.asarray(data)
array_fn = _mx_np.array if is_np_array() else nd.array
return array_fn(data, dtype=data.dtype)
def default_mp_batchify_fn(data):
"""Collate data into batch. Use shared memory for stacking."""
if isinstance(data[0], nd.NDArray):
empty_fn = _mx_np.empty if is_np_array() else nd.empty
out = empty_fn((len(data),) + data[0].shape, dtype=data[0].dtype,
ctx=context.Context('cpu_shared', 0))
if is_np_array():
return _mx_np.stack(data, out=out)
else:
return nd.stack(*data, out=out)
elif isinstance(data[0], tuple):
data = zip(*data)
return [default_mp_batchify_fn(i) for i in data]
else:
data = np.asarray(data)
array_fn = _mx_np.array if is_np_array() else nd.array
return array_fn(data, dtype=data.dtype,
ctx=context.Context('cpu_shared', 0))
def _as_in_context(data, ctx):
"""Move data into new context."""
if isinstance(data, nd.NDArray):
return data.as_in_context(ctx)
elif isinstance(data, (list, tuple)):
return [_as_in_context(d, ctx) for d in data]
return data
def worker_loop_v1(dataset, key_queue, data_queue, batchify_fn):
"""Worker loop for multiprocessing DataLoader."""
while True:
idx, samples = key_queue.get()
if idx is None:
break
batch = batchify_fn([dataset[i] for i in samples])
data_queue.put((idx, batch))
def fetcher_loop_v1(data_queue, data_buffer, pin_memory=False,
pin_device_id=0, data_buffer_lock=None):
"""Fetcher loop for fetching data from queue and put in reorder dict."""
while True:
idx, batch = data_queue.get()
if idx is None:
break
if pin_memory:
batch = _as_in_context(batch, context.cpu_pinned(pin_device_id))
else:
batch = _as_in_context(batch, context.cpu())
if data_buffer_lock is not None:
with data_buffer_lock:
data_buffer[idx] = batch
else:
data_buffer[idx] = batch
class _MultiWorkerIterV1(object):
"""Internal multi-worker iterator for DataLoader."""
def __init__(self, num_workers, dataset, batchify_fn, batch_sampler,
pin_memory=False, pin_device_id=0, worker_fn=worker_loop_v1):
assert num_workers > 0, "_MultiWorkerIter is not for {} workers".format(num_workers)
self._num_workers = num_workers
self._dataset = dataset
self._batchify_fn = batchify_fn
self._batch_sampler = batch_sampler
self._key_queue = Queue()
self._data_queue = SimpleQueue()
self._data_buffer = {}
self._data_buffer_lock = threading.Lock()
self._rcvd_idx = 0
self._sent_idx = 0
self._iter = iter(self._batch_sampler)
self._shutdown = False
workers = []
for _ in range(self._num_workers):
worker = multiprocessing.Process(
target=worker_fn,
args=(self._dataset, self._key_queue, self._data_queue, self._batchify_fn))
worker.daemon = True
worker.start()
workers.append(worker)
self._workers = workers
self._fetcher = threading.Thread(
target=fetcher_loop_v1,
args=(self._data_queue, self._data_buffer, pin_memory,
pin_device_id, self._data_buffer_lock))
self._fetcher.daemon = True
self._fetcher.start()
# pre-fetch
for _ in range(2 * self._num_workers):
self._push_next()
def __len__(self):
return len(self._batch_sampler)
def __del__(self):
self.shutdown()
def _push_next(self):
"""Assign next batch workload to workers."""
r = next(self._iter, None)
if r is None:
return
self._key_queue.put((self._sent_idx, r))
self._sent_idx += 1
def __next__(self):
assert not self._shutdown, "call __next__ after shutdown is forbidden"
if self._rcvd_idx == self._sent_idx:
assert not self._data_buffer, "Data buffer should be empty at this moment"
self.shutdown()
raise StopIteration
while True:
if self._rcvd_idx in self._data_buffer:
with self._data_buffer_lock:
batch = self._data_buffer.pop(self._rcvd_idx)
self._rcvd_idx += 1
self._push_next()
return batch
def next(self):
return self.__next__()
def __iter__(self):
return self
def shutdown(self):
"""Shutdown internal workers by pushing terminate signals."""
if not self._shutdown:
# send shutdown signal to the fetcher and join data queue first
# Remark: loop_fetcher need to be joined prior to the workers.
# otherwise, the fetcher may fail at getting data
self._data_queue.put((None, None))
self._fetcher.join()
# send shutdown signal to all worker processes
for _ in range(self._num_workers):
self._key_queue.put((None, None))
# force shut down any alive worker processes
for w in self._workers:
if w.is_alive():
w.terminate()
self._shutdown = True
class DataLoaderV1(object):
"""Loads data from a dataset and returns mini-batches of data.
Parameters
----------
dataset : Dataset
Source dataset. Note that numpy and mxnet arrays can be directly used
as a Dataset.
batch_size : int
Size of mini-batch.
shuffle : bool
Whether to shuffle the samples.
sampler : Sampler
The sampler to use. Either specify sampler or shuffle, not both.
last_batch : {'keep', 'discard', 'rollover'}
How to handle the last batch if batch_size does not evenly divide
`len(dataset)`:
- ``keep`` - A batch with less samples than previous batches is returned.
- ``discard`` - The last batch is discarded if its incomplete.
- ``rollover`` - The remaining samples are rolled over to the next epoch.
batch_sampler : Sampler
A sampler that returns mini-batches. Do not specify batch_size,
shuffle, sampler, and last_batch if batch_sampler is specified.
batchify_fn : callable
Callback function to allow users to specify how to merge samples
into a batch. Defaults to ``default_batchify_fn``.
.. code-block:: python
def default_batchify_fn(data):
if isinstance(data[0], nd.NDArray):
return nd.stack(*data)
elif isinstance(data[0], tuple):
data = zip(*data)
return [default_batchify_fn(i) for i in data]
else:
data = np.asarray(data)
return nd.array(data, dtype=data.dtype)
num_workers : int, default 0
The number of multiprocessing workers to use for data preprocessing.
pin_memory : boolean, default False
If ``True``, the dataloader will copy NDArrays into pinned memory
before returning them. Copying from CPU pinned memory to GPU is faster
than from normal CPU memory.
pin_device_id : int, default 0
The device id to use for allocating pinned memory if pin_memory is ``True``
"""
def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
last_batch=None, batch_sampler=None, batchify_fn=None,
num_workers=0, pin_memory=False, pin_device_id=0):
self._dataset = dataset
self._pin_memory = pin_memory
self._pin_device_id = pin_device_id
if batch_sampler is None:
if batch_size is None:
raise ValueError("batch_size must be specified unless " \
"batch_sampler is specified")
if sampler is None:
if shuffle:
sampler = _sampler.RandomSampler(len(dataset))
else:
sampler = _sampler.SequentialSampler(len(dataset))
elif shuffle:
raise ValueError("shuffle must not be specified if sampler is specified")
batch_sampler = _sampler.BatchSampler(
sampler, batch_size, last_batch if last_batch else 'keep')
elif batch_size is not None or shuffle or sampler is not None or \
last_batch is not None:
raise ValueError("batch_size, shuffle, sampler and last_batch must " \
"not be specified if batch_sampler is specified.")
self._batch_sampler = batch_sampler
self._num_workers = num_workers if num_workers >= 0 else 0
if batchify_fn is None:
if num_workers > 0:
self._batchify_fn = _batchify.Stack(use_shared_mem=True)
else:
self._batchify_fn = _batchify.Stack()
else:
self._batchify_fn = batchify_fn
def __iter__(self):
if self._num_workers == 0:
def same_process_iter():
for batch in self._batch_sampler:
ret = self._batchify_fn([self._dataset[idx] for idx in batch])
if self._pin_memory:
ret = _as_in_context(ret, context.cpu_pinned(self._pin_device_id))
yield ret
return same_process_iter()
# multi-worker
return _MultiWorkerIterV1(self._num_workers, self._dataset,
self._batchify_fn, self._batch_sampler,
self._pin_memory, self._pin_device_id)
def __len__(self):
return len(self._batch_sampler)
def _thread_worker_initializer(active_shape, active_array):
"""Initializer for ThreadPool."""
set_np(shape=active_shape, array=active_array)
_worker_dataset = None
def _worker_initializer(dataset, active_shape, active_array):
"""Initialier for processing pool."""
# global dataset is per-process based and only available in worker processes
# this is only necessary to handle MXIndexedRecordIO because otherwise dataset
# can be passed as argument
global _worker_dataset
_worker_dataset = dataset
set_np(shape=active_shape, array=active_array)
def _worker_fn(samples, batchify_fn, dataset=None):
"""Function for processing data in worker process."""
# pylint: disable=unused-argument
# it is required that each worker process has to fork a new MXIndexedRecordIO handle
# preserving dataset as global variable can save tons of overhead and is safe in new process
global _worker_dataset
batch = batchify_fn([_worker_dataset[i] for i in samples])
buf = io.BytesIO()
ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(batch)
return buf.getvalue()
def _thread_worker_fn(samples, batchify_fn, dataset):
"""Threadpool worker function for processing data."""
return batchify_fn([dataset[i] for i in samples])
class _MultiWorkerIter(object):
"""Internal multi-worker iterator for DataLoader."""
def __init__(self, worker_pool, batchify_fn, batch_sampler, pin_memory=False,
pin_device_id=0, worker_fn=_worker_fn, prefetch=0, dataset=None,
data_loader=None, timeout=120):
self._worker_pool = worker_pool
self._batchify_fn = batchify_fn
self._batch_sampler = batch_sampler
self._data_buffer = {}
self._rcvd_idx = 0
self._sent_idx = 0
self._iter = iter(self._batch_sampler)
self._worker_fn = worker_fn
self._pin_memory = pin_memory
self._pin_device_id = pin_device_id
self._dataset = dataset
self._data_loader = data_loader
self._timeout = timeout
# pre-fetch
for _ in range(prefetch):
self._push_next()
def __len__(self):
return len(self._batch_sampler)
def _push_next(self):
"""Assign next batch workload to workers."""
r = next(self._iter, None)
if r is None:
return
async_ret = self._worker_pool.apply_async(
self._worker_fn, (r, self._batchify_fn, self._dataset))
self._data_buffer[self._sent_idx] = async_ret
self._sent_idx += 1
def __next__(self):
self._push_next()
if self._rcvd_idx == self._sent_idx:
assert not self._data_buffer, "Data buffer should be empty at this moment"
raise StopIteration
assert self._rcvd_idx < self._sent_idx, "rcvd_idx must be smaller than sent_idx"
assert self._rcvd_idx in self._data_buffer, "fatal error with _push_next, rcvd_idx missing"
ret = self._data_buffer.pop(self._rcvd_idx)
try:
if self._dataset is None:
batch = pickle.loads(ret.get(self._timeout))
else:
batch = ret.get(self._timeout)
if self._pin_memory:
batch = _as_in_context(batch, context.cpu_pinned(self._pin_device_id))
self._rcvd_idx += 1
return batch
except multiprocessing.context.TimeoutError:
msg = '''Worker timed out after {} seconds. This might be caused by \n
- Slow transform. Please increase timeout to allow slower data loading in each worker.
'''.format(self._timeout)
if not isinstance(self._worker_pool, multiprocessing.pool.ThreadPool):
msg += '''- Insufficient shared_memory if `timeout` is large enough.
Please consider reduce `num_workers` or increase shared_memory in system.
'''
print(msg)
raise
except Exception:
self._worker_pool.terminate()
raise
def next(self):
return self.__next__()
def __iter__(self):
return self
class DataLoader(object):
"""Loads data from a dataset and returns mini-batches of data.
Parameters
----------
dataset : Dataset
Source dataset. Note that numpy and mxnet arrays can be directly used
as a Dataset.
batch_size : int
Size of mini-batch.
shuffle : bool
Whether to shuffle the samples.
sampler : Sampler
The sampler to use. Either specify sampler or shuffle, not both.
last_batch : {'keep', 'discard', 'rollover'}
How to handle the last batch if batch_size does not evenly divide
``len(dataset)``.
keep - A batch with less samples than previous batches is returned.
discard - The last batch is discarded if its incomplete.
rollover - The remaining samples are rolled over to the next epoch.
batch_sampler : Sampler
A sampler that returns mini-batches. Do not specify batch_size,
shuffle, sampler, and last_batch if batch_sampler is specified.
batchify_fn : callable
Callback function to allow users to specify how to merge samples
into a batch. Defaults to `gluon.data.batchify.Stack()`.
.. code-block:: python
def default_batchify_fn(data):
if isinstance(data[0], nd.NDArray):
return nd.stack(*data)
elif isinstance(data[0], np.ndarray):
return np.stack(*data)
elif isinstance(data[0], tuple):
data = zip(*data)
return [default_batchify_fn(i) for i in data]
else:
data = np.asarray(data)
return np.ndarray(data, dtype=data.dtype)
num_workers : int, default 0
The number of multiprocessing workers to use for data preprocessing.
pin_memory : boolean, default False
If ``True``, the dataloader will copy NDArrays into pinned memory
before returning them. Copying from CPU pinned memory to GPU is faster
than from normal CPU memory.
pin_device_id : int, default 0
The device id to use for allocating pinned memory if pin_memory is ``True``
prefetch : int, default is `num_workers * 2`
The number of prefetching batches only works if `num_workers` > 0.
If `prefetch` > 0, it allow worker process to prefetch certain batches before
acquiring data from iterators.
Note that using large prefetching batch will provide smoother bootstrapping performance,
but will consume more shared_memory. Using smaller number may forfeit the purpose of using
multiple worker processes, try reduce `num_workers` in this case.
By default it defaults to `num_workers * 2`.
thread_pool : bool, default False
If ``True``, use threading pool instead of multiprocessing pool. Using threadpool
can avoid shared memory usage. If `DataLoader` is more IO bounded or GIL is not a killing
problem, threadpool version may achieve better performance than multiprocessing.
timeout : int, default is 120
The timeout in seconds for each worker to fetch a batch data. Only modify this number
unless you are experiencing timeout and you know it's due to slow data loading.
Sometimes full `shared_memory` will cause all workers to hang and causes timeout. In these
cases please reduce `num_workers` or increase system `shared_memory` size instead.
try_nopython : bool or None, default is None
Try compile python dataloading pipeline into pure MXNet c++ implementation. The benefit is
potentially faster iteration, no `shared_memory` usage, and less processes managed by python.
The compilation is not gauranteed to support all use cases, but it will fallback to python in
case of failure. You can set `try_nopython` to `False` to disable auto-detection of the
compilation feature or leave it to `None` to allow MXNet to determine it automatically.
If you request `try_nopython` to `True` and the compilation fails, it will raise a
RuntimeError with the failure reason.
"""
def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
last_batch=None, batch_sampler=None, batchify_fn=None,
num_workers=0, pin_memory=False, pin_device_id=0,
prefetch=None, thread_pool=False, timeout=120, try_nopython=None):
self._dataset = dataset
self._pin_memory = pin_memory
self._pin_device_id = pin_device_id
self._thread_pool = thread_pool
self._timeout = timeout
self._mx_iter = None
assert timeout > 0, "timeout must be positive, given {}".format(timeout)
if batch_sampler is None:
if batch_size is None:
raise ValueError("batch_size must be specified unless " \
"batch_sampler is specified")
if sampler is None:
if shuffle:
sampler = _sampler.RandomSampler(len(dataset))
else:
sampler = _sampler.SequentialSampler(len(dataset))
elif shuffle:
raise ValueError("shuffle must not be specified if sampler is specified")
batch_sampler = _sampler.BatchSampler(
sampler, batch_size, last_batch if last_batch else 'keep')
elif batch_size is not None or shuffle or sampler is not None or \
last_batch is not None:
raise ValueError("batch_size, shuffle, sampler and last_batch must " \
"not be specified if batch_sampler is specified.")
self._batch_sampler = batch_sampler
self._num_workers = num_workers if num_workers >= 0 else 0
self._worker_pool = None
self._prefetch = max(0, int(prefetch) if prefetch is not None else 2 * self._num_workers)
if batchify_fn is None:
if num_workers > 0:
self._batchify_fn = _batchify.Stack(use_shared_mem=True)
else:
self._batchify_fn = _batchify.Stack()
else:
self._batchify_fn = batchify_fn
if num_workers > 0 and (try_nopython or try_nopython is None):
# check for capability to use mx backend threadedLoader
use_mx_iter, mx_iter_args = _check_mx_loader_capability(
self._dataset, self._batch_sampler, self._batchify_fn)
if not use_mx_iter:
if try_nopython:
raise RuntimeError(mx_iter_args)
else:
use_mx_iter = False
if use_mx_iter:
logging.info("Using MXNet backend ThreadedDataLoader with %s workers "
"instead of python dataloader.", self._num_workers)
self._mx_iter = _MXThreadedDataLoader(
num_workers=self._num_workers,
pin_memory=self._pin_memory,
pin_device_id=self._pin_device_id,
prefetch=self._prefetch, **mx_iter_args)
else:
nd.waitall()
import gc
gc.collect()
nd.waitall()
if self._num_workers > 0:
if self._thread_pool:
self._worker_pool = ThreadPool(self._num_workers,
initializer=_thread_worker_initializer,
initargs=(is_np_shape(), is_np_array()))
else:
# set ignore keyboard interupt signal before forking processes
original_sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN)
self._worker_pool = multiprocessing.Pool(
self._num_workers, initializer=_worker_initializer,
initargs=[self._dataset, is_np_shape(), is_np_array()])
# resume keyboard interupt signal in main process
signal.signal(signal.SIGINT, original_sigint_handler)
def __iter__(self):
if self._mx_iter is not None:
return iter(self._mx_iter)
if self._num_workers == 0:
def same_process_iter():
for batch in self._batch_sampler:
ret = self._batchify_fn([self._dataset[idx] for idx in batch])
if self._pin_memory:
ret = _as_in_context(ret, context.cpu_pinned(self._pin_device_id))
yield ret
return same_process_iter()
# multi-worker
return _MultiWorkerIter(self._worker_pool, self._batchify_fn, self._batch_sampler,
pin_memory=self._pin_memory, pin_device_id=self._pin_device_id,
worker_fn=_thread_worker_fn if self._thread_pool else _worker_fn,
prefetch=self._prefetch,
dataset=self._dataset if self._thread_pool else None,
data_loader=self, timeout=self._timeout)
def __len__(self):
return len(self._batch_sampler)
def __del__(self):
if self._worker_pool:
# manually terminate due to a bug that pool is not automatically terminated
# https://bugs.python.org/issue34172
assert isinstance(self._worker_pool, multiprocessing.pool.Pool)
self._worker_pool.terminate()
def _check_mx_loader_capability(dataset, batch_sampler, batchify_fn):
from ._internal import MXDataset, MXSampler
from ._internal import MXBatchifyFunction
mx_loader_args = {}
error_template = "MXNet backend loader compatibility: " \
"[dataset - {}][batchify_fn - {}][batch sampler - {}]"
# supported dataset
if isinstance(dataset, MXDataset):
mx_loader_args['dataset'] = dataset
elif hasattr(dataset, '__mx_handle__'):
try:
mx_loader_args['dataset'] = dataset.__mx_handle__()
except NotImplementedError:
return False, error_template.format('fail', 'unknown', 'unknown')
else:
return False, error_template.format('fail', 'unknown', 'unknown')
# supported batchify functions
if hasattr(batchify_fn, '__mx_handle__'):
mx_loader_args['batchify_fn'] = batchify_fn.__mx_handle__()
elif isinstance(batchify_fn, MXBatchifyFunction):
mx_loader_args['batchify_fn'] = batchify_fn
else:
return False, error_template.format('pass', 'fail', 'unknown')
# supported sampler
if isinstance(batch_sampler, _sampler.BatchSampler):
if isinstance(batch_sampler._sampler, _sampler.SequentialSampler):
mx_loader_args['batch_sampler'] = MXSampler(
'SequentialSampler', length=batch_sampler._sampler._length,
start=batch_sampler._sampler._start,
batch_size=batch_sampler._batch_size,
last_batch=batch_sampler._last_batch)
elif isinstance(batch_sampler._sampler, _sampler.RandomSampler):
mx_loader_args['batch_sampler'] = MXSampler(
'RandomSampler', length=batch_sampler._sampler._length,
batch_size=batch_sampler._batch_size,
last_batch=batch_sampler._last_batch)
else:
return False, error_template.format('pass', 'pass', 'fail')
elif isinstance(batch_sampler, MXSampler):
mx_loader_args['batch_sampler'] = batch_sampler
else:
return False, error_template.format('pass', 'pass', 'fail')
# all good
return True, mx_loader_args
class _MXThreadedDataLoader(object):
"""MXNet internal C++ threaded Data Iterator in form of DataLoader
parameters
----------
dataset : Dataset
Source dataset. Note that numpy and mxnet arrays can be directly used
as a Dataset.
batch_sampler : Sampler
A sampler that returns mini-batches.
batchify_fn : callable
Callback function to allow users to specify how to merge samples
into a batch. Defaults to `gluon.data.batchify.Stack()`::
num_workers : int, default 0
The number of multiprocessing workers to use for data preprocessing.
pin_memory : boolean, default False
If ``True``, the dataloader will copy NDArrays into pinned memory
before returning them. Copying from CPU pinned memory to GPU is faster
than from normal CPU memory.
pin_device_id : int, default 0
The device id to use for allocating pinned memory if pin_memory is ``True``
prefetch : int, default is `num_workers * 2`
The number of prefetching batches only works if `num_workers` > 0.
If `prefetch` > 0, it allow worker process to prefetch certain batches before
acquiring data from iterators.
Note that using large prefetching batch will provide smoother bootstrapping performance,
but will consume more shared_memory. Using smaller number may forfeit the purpose of using
multiple worker processes, try reduce `num_workers` in this case.
By default it defaults to `num_workers * 2`, maximum prefetch size is `16`.
"""
def __init__(self, dataset, batch_sampler, batchify_fn,
num_workers=0, pin_memory=False, pin_device_id=0,
prefetch=4):
from ._internal import MXDataset, MXSampler, MXBatchifyFunction
from ...io.io import ThreadedDataLoader
assert isinstance(dataset, MXDataset)
assert isinstance(batch_sampler, MXSampler)
assert isinstance(batchify_fn, MXBatchifyFunction)
self._dataset = dataset
self._batch_sampler = batch_sampler
self._batchify_fn = batchify_fn
if num_workers == 0:
num_workers = 1 # different convention for single thread
if prefetch == 0:
prefetch = 1 # at least one buffer required
pin_device_id = pin_device_id if pin_memory else -1
ctx = 'cpu_pinned' if pin_memory else 'cpu'
self._iter = ThreadedDataLoader(num_workers=num_workers, dataset=dataset,
sampler=batch_sampler, batchify_fn=batchify_fn,
prefetch_buffer=prefetch, ctx=ctx,
device_id=pin_device_id)
def __iter__(self):
while self._iter.iter_next():
self._iter.first_batch = None
items = self._iter.getitems()
pad = self._iter.getpad()
if pad > 0:
items = tuple([x[:-pad] for x in items])
if len(items) < 2:
items = items[0]
yield items
self._iter.reset()
def __len__(self):
return len(self._iter)