Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Release 0.6.0] Batch cherry-pick #1057

Merged
merged 5 commits into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,12 @@ string(FIND "${CMAKE_CXX_FLAGS}" "-std=c++" env_cxx_standard)
if(env_cxx_standard GREATER -1)
message(
WARNING "C++ standard version definition detected in environment variable."
"PyTorch requires -std=c++14. Please remove -std=c++ settings in your environment.")
"PyTorch requires -std=c++17. Please remove -std=c++ settings in your environment.")
endif()

set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_C_STANDARD 11)

# https://developercommunity.visualstudio.com/t/VS-16100-isnt-compatible-with-CUDA-11/1433342
if(MSVC)
if(USE_CUDA)
set(CMAKE_CXX_STANDARD 17)
endif()
endif()


set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
Expand Down
8 changes: 4 additions & 4 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ the checks automatically before every `git commit`, you can install them with `p

When adding a new DataPipe, there are few things that need to be done to ensure it is working and documented properly.

1. Naming - please following the naming convention as
[described here](https://github.com/pytorch/data/blob/main/docs/source/tutorial.rst#naming).
1. Naming - please follow the naming convention as
[described here](https://pytorch.org/data/main/dp_tutorial.html#naming).
2. Testing - please add unit tests to ensure that the DataPipe is functioning properly. Here are the
[test requirements](https://github.com/pytorch/data/issues/106) that we have.
- One test that is commonly missed is the serialization test. Please add the new DataPipe to
Expand All @@ -98,14 +98,14 @@ When adding a new DataPipe, there are few things that need to be done to ensure
3. Documentation - ensure that the DataPipe has docstring, usage example, and that it is added to the right category of
the right RST file to be rendered.
- If your DataPipe has a functional form (i.e. `@functional_datapipe(...)`), include at the
[end of the first sentence](https://github.com/pytorch/data/blob/main/torchdata/datapipes/iter/util/combining.py#L119)
[end of the first sentence](https://github.com/pytorch/data/blob/main/torchdata/datapipes/iter/util/combining.py#L25)
of your docstring. This will make sure it correctly shows up in the
[summary table](https://pytorch.org/data/main/torchdata.datapipes.iter.html#archive-datapipes) of our
documentation.
4. Import - import the DataPipe in the correct `__init__.py` file.
5. Interface - if the DataPipe has a functional form, make sure that is generated properly by `gen_pyi.py` into the
relevant interface file.
- You can re-generate the pyi files by re-running `python setup.py develop`, then you can examine the new outputs.
- You can re-generate the pyi files by re-running `pip install -e .`, then you can examine the new outputs.

## Contributor License Agreement ("CLA")

Expand Down
2 changes: 1 addition & 1 deletion docs/source/dlv2_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Here is an example of a ``DataPipe`` graph:
datapipe = IterableWrapper(["./train1.csv", "./train2.csv"])
datapipe = datapipe.open_files(encoding="utf-8").parse_csv()
datapipe = datapipe.shuffle().sharding_filter()
datapipe = datapiep.map(fn).batch(8)
datapipe = datapipe.map(fn).batch(8)

Multiprocessing
----------------
Expand Down
10 changes: 5 additions & 5 deletions docs/source/dp_tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ Working with DataLoader
---------------------------------------------

In this section, we will demonstrate how you can use ``DataPipe`` with ``DataLoader``.
For the most part, you should be able to use it just by passing ``dataset=datapipe`` as an input arugment
For the most part, you should be able to use it just by passing ``dataset=datapipe`` as an input argument
into the ``DataLoader``. For detailed documentation related to ``DataLoader``,
please visit `this page <https://pytorch.org/docs/stable/data.html#single-and-multi-process-data-loading>`_.
please visit `this PyTorch Core page <https://pytorch.org/docs/stable/data.html#single-and-multi-process-data-loading>`_.


Please refer to `this page <dlv2_tutorial.html>`_ about using ``DataPipe`` with ``DataLoader2``.
Expand Down Expand Up @@ -102,7 +102,7 @@ pass defined functions to DataPipes rather than lambda functions because the for
def filter_for_data(filename):
return "sample_data" in filename and filename.endswith(".csv")

def row_processer(row):
def row_processor(row):
return {"label": np.array(row[0], np.int32), "data": np.array(row[1:], dtype=np.float64)}

def build_datapipes(root_dir="."):
Expand All @@ -112,7 +112,7 @@ pass defined functions to DataPipes rather than lambda functions because the for
datapipe = datapipe.parse_csv(delimiter=",", skip_lines=1)
# Shuffle will happen as long as you do NOT set `shuffle=False` later in the DataLoader
datapipe = datapipe.shuffle()
datapipe = datapipe.map(row_processer)
datapipe = datapipe.map(row_processor)
return datapipe

Lastly, we will put everything together in ``'__main__'`` and pass the DataPipe into the DataLoader. Note that
Expand Down Expand Up @@ -180,7 +180,7 @@ Note:
- Place ``ShardingFilter`` (``datapipe.sharding_filter``) as early as possible in the pipeline, especially before expensive
operations such as decoding, in order to avoid repeating these expensive operations across worker/distributed processes.
- For the data source that needs to be sharded, it is crucial to add ``Shuffler`` before ``ShardingFilter``
to ensure data are globally shuffled before splitted into shards. Otherwise, each worker process would
to ensure data are globally shuffled before being split into shards. Otherwise, each worker process would
always process the same shard of data for all epochs. And, it means each batch would only consist of data
from the same shard, which leads to low accuracy during training. However, it doesn't apply to the data
source that has already been sharded for each multi-/distributed process, since ``ShardingFilter`` is no
Expand Down
12 changes: 7 additions & 5 deletions docs/source/reading_service.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
:tocdepth: 3

.. currentmodule:: torchdata.datapipes.iter

ReadingService
===============

Expand All @@ -13,9 +15,9 @@ Dynamic Sharding

Dynamic sharding is achieved by ``MultiProcessingReadingService`` and ``DistributedReadingService`` to shard the pipeline based on the information of corresponding multiprocessing and distributed workers. And, TorchData offers two types of ``DataPipe`` letting users to define the sharding place within the pipeline.

- ``sharding_filter``: When the pipeline is replicable, each distributed/multiprocessing worker loads data from one replica of the ``DataPipe`` graph, and skip the data not blonged to the corresponding worker at the place of ``sharding_filter``.
- ``sharding_filter`` (:class:`ShardingFilter`): When the pipeline is replicable, each distributed/multiprocessing worker loads data from its own replica of the ``DataPipe`` graph, while skipping samples that do not belong to the corresponding worker at the point where ``sharding_filter`` is placed.

- ``sharding_round_robin_dispatch``: When there is any ``sharding_round_robin_dispatch`` ``DataPipe`` in the pipeline, that branch will be treated as a non-replicable branch. Then, a single dispatching process will be created to load data from the non-repliable branch and distributed data to the subsequent worker processes.
- ``sharding_round_robin_dispatch`` (:class:`ShardingRoundRobinDispatcher`): When there is any ``sharding_round_robin_dispatch`` ``DataPipe`` in the pipeline, that branch (i.e. all DataPipes prior to ``sharding_round_robin_dispatch``) will be treated as a non-replicable branch (in the context of multiprocessing). A single dispatching process will be created to load data from the non-replicable branch and distributed data to the subsequent worker processes.

The following is an example of having two types of sharding strategies in the pipeline.

Expand Down Expand Up @@ -116,12 +118,14 @@ When multiprocessing takes place, the graph becomes:

``Client`` in the graph is a ``DataPipe`` that send request and receive response from multiprocessing queues.

.. module:: torchdata.dataloader2

Determinism
^^^^^^^^^^^^

In ``DataLoader2``, a ``SeedGenerator`` becomes a single source of randomness and each ``ReadingService`` would access to it via ``initialize_iteration()`` and generate corresponding random seeds for random ``DataPipe`` operations.

In order to make sure that the Dataset shards are mutually exclusive and collectively exhaunsitve on multiprocessing processes and distributed nodes, ``MultiProcessingReadingService`` and ``DistributedReadingService`` would help ``DataLoader2`` to synchronize random states for any random ``DataPipe`` operation prior to ``sharding_filter`` or ``sharding_round_robin_dispatch``. For the remaining ``DataPipe`` operations after sharding, unique random states are generated based on the distributed rank and worker process id by each ``ReadingService``, in order to perform different random transformations.
In order to make sure that the Dataset shards are mutually exclusive and collectively exhaustive on multiprocessing processes and distributed nodes, ``MultiProcessingReadingService`` and ``DistributedReadingService`` would help :class:`DataLoader2` to synchronize random states for any random ``DataPipe`` operation prior to ``sharding_filter`` or ``sharding_round_robin_dispatch``. For the remaining ``DataPipe`` operations after sharding, unique random states are generated based on the distributed rank and worker process id by each ``ReadingService``, in order to perform different random transformations.

Graph Mode
^^^^^^^^^^^
Expand All @@ -131,8 +135,6 @@ This also allows easier transition of data-preprocessing pipeline from research
Extend ReadingService
----------------------

.. module:: torchdata.dataloader2

The followings are interfaces for custom ``ReadingService``.

.. autoclass:: ReadingServiceInterface
Expand Down
1 change: 1 addition & 0 deletions docs/source/torchdata.datapipes.iter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ A miscellaneous set of DataPipes with different functionalities.
Prefetcher
RandomSplitter
ShardingFilter
ShardingRoundRobinDispatcher

Selecting DataPipes
-------------------------
Expand Down
5 changes: 5 additions & 0 deletions test/test_iterdatapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,11 @@ def odd_even_bug(i: int) -> int:
with self.assertRaisesRegex(KeyError, "is not a valid key in the given MapDataPipe"):
next(it)

# Functional test: ensure that keep_key option works
result_dp = source_dp.zip_with_map(map_dp, odd_even, keep_key=True)
expected_res_keep_key = [(key, (i, odd_even_string(i))) for i, key in zip(range(10), [0, 1] * 5)]
self.assertEqual(expected_res_keep_key, list(result_dp))

# Reset Test:
n_elements_before_reset = 4
result_dp = source_dp.zip_with_map(map_dp, odd_even)
Expand Down
23 changes: 11 additions & 12 deletions torchdata/dataloader2/dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,6 @@ def resume(self) -> None:
Restarts the threads within ``DataLoader2`` and allows it to yield additional batches.
"""
self.dataloader._resume()
if self.dataloader._datapipe_iter and hasattr(self.dataloader._datapipe_iter, "resume"):
self.dataloader._datapipe_iter.resume() # type: ignore[attr-defined]

def limit(self, num_batches: Optional[int]) -> None:
"""
Expand All @@ -120,8 +118,7 @@ def limit(self, num_batches: Optional[int]) -> None:
"""
self.limit_counter = 0
self.limit_threshold = num_batches
if self.dataloader._datapipe_iter and hasattr(self.dataloader._datapipe_iter, "limit"):
self.dataloader._datapipe_iter.limit(num_batches) # type: ignore[attr-defined]
self.dataloader._limit(num_batches)

def __getattr__(self, name):
"""
Expand Down Expand Up @@ -339,11 +336,8 @@ def _pause(self):
if hasattr(self.reading_service, "_pause"):
self._is_paused = True
self.reading_service._pause()
# TODO: the condition should be `else` once `self._datapipe_iter.pause/limit()` is no longer used
elif self._datapipe_iter is None or not (
hasattr(self._datapipe_iter, "limit") or hasattr(self._datapipe_iter, "pause")
):
warnings.warn("ReadingService doesn't support pause.")
else:
warnings.warn("ReadingService doesn't support `pause`.")

def _resume(self):
if hasattr(self.reading_service, "_resume"):
Expand All @@ -352,6 +346,11 @@ def _resume(self):
else:
self.reading_service._resume()
self._is_paused = False
# TODO: the condition should be `else` once `self._datapipe_iter.resume()` is no longer used
elif self._datapipe_iter is None or not hasattr(self._datapipe_iter, "resume"):
warnings.warn("ReadingService doesn't support resume.")
else:
warnings.warn("ReadingService doesn't support `resume`.")

def _limit(self, num_batches: Optional[int]) -> None:
if hasattr(self.reading_service, "_limit"):
self.reading_service._limit(num_batches)
else:
warnings.warn("ReadingService doesn't support `limit`.")
7 changes: 7 additions & 0 deletions torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,13 @@ def _resume(self):
if self.main_prefetch_cnt > 0 and self.num_workers > 0:
self._main_prefetch_datapipe.resume() # type: ignore[union-attr]

def _limit(self, num_batches: Optional[int]) -> None:
"""
For this ReadingService, `DataLoader2Iterator` and `DataLoader2` should sufficiently handle
the limit operation, such that nothing needs to be done here.
"""
pass


class DistributedReadingService(ReadingServiceInterface):
r"""
Expand Down
39 changes: 29 additions & 10 deletions torchdata/datapipes/iter/util/combining.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,20 +174,33 @@ class MapKeyZipperIterDataPipe(IterDataPipe[T_co]):
from ``map_datapipe``
map_datapipe: MapDataPipe that takes a key from ``key_fn``, and returns an item
key_fn: Function that maps each item from ``source_iterdatapipe`` to a key that exists in ``map_datapipe``
keep_key: Option to yield the matching key along with the items in a tuple,
resulting in ``(key, merge_fn(item1, item2))``.
merge_fn: Function that combines the item from ``source_iterdatapipe`` and the matching item
from ``map_datapipe``, by default a tuple is created

Example:
>>> from torchdata.datapipes.iter import IterableWrapper
>>> from torchdata.datapipes.map import SequenceWrapper
>>> from operator import itemgetter
>>> def merge_fn(tuple_from_iter, value_from_map):
>>> return tuple_from_iter[0], tuple_from_iter[1] + value_from_map
>>> dp1 = IterableWrapper([('a', 1), ('b', 2), ('c', 3)])
>>> mapdp = SequenceWrapper({'a': 100, 'b': 200, 'c': 300, 'd': 400})
>>> res_dp = dp1.zip_with_map(map_datapipe=mapdp, key_fn=itemgetter(0), merge_fn=merge_fn)
>>> list(res_dp)

.. testsetup::

from operator import itemgetter

.. testcode::

from torchdata.datapipes.iter import IterableWrapper
from torchdata.datapipes.map import SequenceWrapper

def merge_fn(tuple_from_iter, value_from_map):
return tuple_from_iter[0], tuple_from_iter[1] + value_from_map
dp1 = IterableWrapper([('a', 1), ('b', 2), ('c', 3)])
mapdp = SequenceWrapper({'a': 100, 'b': 200, 'c': 300, 'd': 400})
res_dp = dp1.zip_with_map(map_datapipe=mapdp, key_fn=itemgetter(0), merge_fn=merge_fn)
print(list(res_dp))

.. testoutput::

[('a', 101), ('b', 202), ('c', 303)]

"""

def __init__(
Expand All @@ -196,6 +209,7 @@ def __init__(
map_datapipe: MapDataPipe,
key_fn: Callable,
merge_fn: Optional[Callable] = None,
keep_key: bool = False,
):
if not isinstance(map_datapipe, MapDataPipe):
raise TypeError(f"map_datapipe must be a MapDataPipe, but its type is {type(map_datapipe)} instead.")
Expand All @@ -206,6 +220,7 @@ def __init__(
if merge_fn is not None:
_check_unpickable_fn(merge_fn)
self.merge_fn: Optional[Callable] = merge_fn
self.keep_key = keep_key

def __iter__(self) -> Iterator:
for item in self.source_iterdatapipe:
Expand All @@ -214,7 +229,11 @@ def __iter__(self) -> Iterator:
map_item = self.map_datapipe[key]
except (KeyError, IndexError):
raise KeyError(f"key_fn maps {item} to {key}, which is not a valid key in the given MapDataPipe.")
yield self.merge_fn(item, map_item) if self.merge_fn else (item, map_item)
res = self.merge_fn(item, map_item) if self.merge_fn else (item, map_item)
if self.keep_key:
yield key, res
else:
yield res

def __len__(self) -> int:
return len(self.source_iterdatapipe)
Expand Down
26 changes: 22 additions & 4 deletions torchdata/datapipes/iter/util/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,35 @@
@functional_datapipe("sharding_round_robin_dispatch")
class ShardingRoundRobinDispatcherIterDataPipe(IterDataPipe):
r"""
Wrapper that indicates the prior ``DataPipe`` graph is non-replicable and will be
iterated in a separate dispatching process to coordinate data to worker processes
in a round-robin manner, when multiprocessing takes place
Wrapper that indicates the prior section of ``DataPipe`` graph is non-replicable and will be
iterated in a separate, single dispatching process to distribute data to worker processes
in a round-robin manner when multiprocessing is being used.
(functional name: ``sharding_round_robin_dispatch``).

Args:
source_datapipe: Iterable DataPipe that will be sharded
sharding_group_filter: Optional ``SHARDING_PRIORITIES`` value

Note:
- ``sharding_group_filter`` only accepts ``SHARDING_PRIORITIES.MULTIPROCESSING`` for now
- ``sharding_group_filter`` only accepts ``SHARDING_PRIORITIES.MULTIPROCESSING`` for now
- When using distributed training, you can add a ``sharding_filter()`` prior to this DataPipe
to distribute samples among worker nodes.

Examples:
>>> # xdoctest: +SKIP
>>> from torchdata.datapipes.iter import IterableWrapper
>>> from torch.utils.data.datapipes.iter.sharding import SHARDING_PRIORITIES
>>> dp = IterableWrapper(range(10))
>>> # `.shuffle()` will be executed in a single dispatching processing, then the samples are distributed
>>> # to worker processes
>>> dp = dp.shuffle().sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING)
>>> # `.map()` will be executed within each worker process
>>> dp = dp.map(lambda x: x + 1)
>>> # Distributed case: the 10 samples will be distributed among the nodes
>>> dp = IterableWrapper(range(10)).sharding_filter()
>>> # `.map()` will be executed in a single dispatching processing in each node
>>> # You may apply further transformation after within each worker process
>>> dp = dp.map(lambda x: x + 1).sharding_round_robin_dispatch(SHARDING_PRIORITIES.MULTIPROCESSING)
"""

def __init__(self, source_datapipe: IterDataPipe, sharding_group_filter: Optional[SHARDING_PRIORITIES] = None):
Expand Down