Skip to content

Commit

Permalink
Change iterator over multiple Queue wrappers to request all protocols…
Browse files Browse the repository at this point in the history
… simulteniously

ghstack-source-id: ddf5a50b596b93f986e498f48f2952e689495269
Pull Request resolved: #769
  • Loading branch information
VitalyFedyunin committed Sep 21, 2022
1 parent 86df1a0 commit 0c968c6
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 24 deletions.
12 changes: 11 additions & 1 deletion torchdata/dataloader2/communication/iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,17 @@ class NotAvailable(Exception):
class InvalidStateResetRequired(Exception):
"""
Returned by DataPipe when it is expecting to get reset request,
for example RouterDataPipe expecting all workers to request reset'
for example RouterDataPipe expecting all workers to request reset.
"""

pass


class TerminateRequired(Exception):
"""
Returned by DataPipe when it is expecting to get terminate request,
for example it got terminate request from other source and at the process
of stopping.
"""

pass
Expand Down
63 changes: 40 additions & 23 deletions torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,15 @@

import functools
import multiprocessing as mp
import time
from abc import ABC, abstractmethod
from typing import Any, Callable, List, Optional
from typing import Callable, List, Optional

import torch
from torch.utils.data import DataLoader
from torch.utils.data.graph import DataPipe

from torchdata.dataloader2 import communication
from torchdata.datapipes.iter import IterableWrapper
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe


class ReadingServiceInterface(ABC):
Expand Down Expand Up @@ -99,27 +98,46 @@ def _collate_no_op(batch):
return batch[0]


class _IterateQueueDataPipes:
class _IterateQueueDataPipes(IterDataPipe):
def __init__(self, datapipes):
# TODO(VitalyFedyunin): Consider combining _IterateQueueDataPipes and QueueWrapper
# into one class, which supports any number of queues.
self.datapipes = datapipes
for dp in self.datapipes:
if not isinstance(dp, communication.iter.QueueWrapper):
raise Exception("Source datapipes should be an instance of iter.QueueWrapper")

def __iter__(self):
# TODO(612): This is slow as it does not sends data requests ahead.
exclude_datapipes: List[Any] = []
while len(exclude_datapipes) < len(self.datapipes):
for dp in self.datapipes:
if dp not in exclude_datapipes:
forever = True
while forever:
try:
value = dp.nonblocking_next()
yield value
forever = False
except StopIteration:
exclude_datapipes.append(dp)
forever = False
except communication.iter.NotAvailable:
time.sleep(0.001)
total_pipes = len(self.datapipes)
disabled_pipe = [False] * len(self.datapipes)
cnt_disabled_pipes = 0

for idx in range(total_pipes):
self.datapipes[idx].protocol.request_next()

while cnt_disabled_pipes < total_pipes:
for idx in range(total_pipes):
if not disabled_pipe[idx]:
response = self.datapipes[idx].protocol.get_response_next(block=True)
if isinstance(response, communication.messages.StopIterationResponse):
disabled_pipe[idx] = True
cnt_disabled_pipes += 1
continue
if isinstance(response, communication.messages.InvalidStateResponse):
raise communication.iter.InvalidStateResetRequired
if isinstance(response, communication.messages.TerminateResponse):
raise communication.iter.TerminateRequired
self.datapipes[idx].protocol.request_next()
yield response.value

def reset(self):
# Collect all existing requests results to clear queues
for dp in self.datapipes:
if dp.protocol.waiting_for_response():
dp.protocol.get_response_next(block=True)
# NonBlocking DataPipes do not reset automatically, have to do it manually
for dp in self.datapipes:
dp.reset_iterator()


class PrototypeMultiProcessingReadingService(ReadingServiceInterface):
Expand Down Expand Up @@ -163,11 +181,10 @@ def initialize(self, datapipe: DataPipe) -> DataPipe:
)
self.datapipes.append(local_datapipe)

return IterableWrapper(_IterateQueueDataPipes(self.datapipes), deepcopy=False) # type: ignore[return-value]
return _IterateQueueDataPipes(self.datapipes) # type: ignore[return-value]

def initialize_iteration(self) -> None:
for dp in self.datapipes:
dp.reset_iterator()
pass

def __del__(self):
self.finalize()
Expand Down

0 comments on commit 0c968c6

Please sign in to comment.