From e3207a36d369f6a8fdbd4328ba0a98391e4da11b Mon Sep 17 00:00:00 2001 From: Vitaly Fedyunin Date: Mon, 3 Oct 2022 13:17:50 -0400 Subject: [PATCH] Implementing thread based PrefetcherIterDataPipe ghstack-source-id: 6b9c54f3ba2786e6c3fe7a47cfd0fe3387451635 Pull Request resolved: https://github.com/pytorch/data/pull/770 --- test/test_iterdatapipe.py | 10 ++ torchdata/datapipes/iter/__init__.py | 2 + torchdata/datapipes/iter/util/prefetcher.py | 105 ++++++++++++++++++++ 3 files changed, 117 insertions(+) create mode 100644 torchdata/datapipes/iter/util/prefetcher.py diff --git a/test/test_iterdatapipe.py b/test/test_iterdatapipe.py index 6982a87c4..d5cd7db1f 100644 --- a/test/test_iterdatapipe.py +++ b/test/test_iterdatapipe.py @@ -261,6 +261,16 @@ def odd_even_bug(i: int) -> int: result_dp = source_dp.zip_with_map(map_dp, odd_even) self.assertEqual(len(source_dp), len(result_dp)) + def test_prefetcher_iterdatapipe(self) -> None: + source_dp = IterableWrapper(range(50000)) + prefetched_dp = source_dp.prefetch(10) + # check if early termination resets child thread properly + for _, _ in zip(range(100), prefetched_dp): + pass + expected = list(source_dp) + actual = list(prefetched_dp) + self.assertEqual(expected, actual) + def test_repeater_iterdatapipe(self) -> None: import itertools diff --git a/torchdata/datapipes/iter/__init__.py b/torchdata/datapipes/iter/__init__.py index 4a2265d65..09fa0ae9a 100644 --- a/torchdata/datapipes/iter/__init__.py +++ b/torchdata/datapipes/iter/__init__.py @@ -106,6 +106,7 @@ LineReaderIterDataPipe as LineReader, ) from torchdata.datapipes.iter.util.prefetch import FullSyncIterDataPipe as FullSync +from torchdata.datapipes.iter.util.prefetcher import PrefetcherIterDataPipe as Prefetcher from torchdata.datapipes.iter.util.randomsplitter import RandomSplitterIterDataPipe as RandomSplitter from torchdata.datapipes.iter.util.rararchiveloader import RarArchiveLoaderIterDataPipe as RarArchiveLoader from torchdata.datapipes.iter.util.rows2columnar import Rows2ColumnarIterDataPipe as Rows2Columnar @@ -190,6 +191,7 @@ "OnlineReader", "ParagraphAggregator", "ParquetDataFrameLoader", + "Prefetcher", "RandomSplitter", "RarArchiveLoader", "Repeater", diff --git a/torchdata/datapipes/iter/util/prefetcher.py b/torchdata/datapipes/iter/util/prefetcher.py new file mode 100644 index 000000000..64818914e --- /dev/null +++ b/torchdata/datapipes/iter/util/prefetcher.py @@ -0,0 +1,105 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import threading +import time + +from typing import Optional + +from torchdata.dataloader2 import communication + +from torchdata.datapipes import functional_datapipe +from torchdata.datapipes.iter import IterDataPipe + +PRODUCER_SLEEP_INTERVAL = 0.0001 # Interval between buffer fullfilment checks +CONSUMER_SLEEP_INTERVAL = 0.0001 # Interval between checking items availablitity in buffer + + +class _PrefetchData: + def __init__(self, source_datapipe, buffer_size): + self.run_prefetcher = True + # TODO: Potential optimization is changing buffer from list to dequeue + self.prefetch_buffer = [] + self.buffer_size = buffer_size + self.source_datapipe = source_datapipe + + +@functional_datapipe("prefetch") +class PrefetcherIterDataPipe(IterDataPipe): + def __init__(self, source_datapipe, buffer_size: int = 10): + self.source_datapipe = source_datapipe + if buffer_size <= 0: + raise ValueError("'buffer_size' is required to be a positive integer.") + self.buffer_size = buffer_size + self.thread: Optional[threading.Thread] = None + + @staticmethod + def thread_worker(prefetch_data): + itr = iter(prefetch_data.source_datapipe) + stop_iteration = False + while prefetch_data.run_prefetcher: + if len(prefetch_data.prefetch_buffer) < prefetch_data.buffer_size and not stop_iteration: + try: + item = next(itr) + prefetch_data.prefetch_buffer.append(item) + except StopIteration: + stop_iteration = True + except communication.iter.InvalidStateResetRequired: + stop_iteration = True + except communication.iter.TerminateRequired: + prefetch_data.run_prefetcher = False + elif stop_iteration and len(prefetch_data.prefetch_buffer) == 0: + prefetch_data.run_prefetcher = False + else: # Buffer is full, waiting for main thread to consume items + # TODO: Calculate sleep interval based on previous consumption speed + time.sleep(PRODUCER_SLEEP_INTERVAL) + + def __iter__(self): + if self.buffer_size < 1: + yield from self.source_datapipe + else: + try: + prefetch_data = _PrefetchData(self.source_datapipe, self.buffer_size) + self.prefetch_data = prefetch_data + self.thread = threading.Thread( + target=PrefetcherIterDataPipe.thread_worker, args=(prefetch_data,), daemon=True + ) + self.thread.start() + while prefetch_data.run_prefetcher: + if len(prefetch_data.prefetch_buffer) > 0: + yield prefetch_data.prefetch_buffer[0] + prefetch_data.prefetch_buffer = prefetch_data.prefetch_buffer[1:] + else: + # TODO: Calculate sleep interval based on previous availability speed + time.sleep(CONSUMER_SLEEP_INTERVAL) + finally: + prefetch_data.run_prefetcher = False + if self.thread is not None: + self.thread.join() + self.thread = None + + def __getstate__(self): + """ + Getting state in threading enviroment requires next operations: + 1) Stopping of the producer thread. + 2) Saving buffer. + 3) Adding lazy restart of producer thread when __next__ is called again + (this will guarantee that you only change state of the source_datapipe + after entire state of the graph is saved). + """ + # TODO: Update __getstate__ and __setstate__ to support snapshotting and restoration + return dict(source_datapipe=self.source_datapipe) + + def __setstate__(self, state): + self.source_datapipe = state["source_datapipe"] + + def reset(self): + if self.thread is not None: + self.prefetch_data.run_prefetcher = False + self.thread.join() + + def reset_iterator(self): + self.reset()