Skip to content

Commit

Permalink
Remove __new__
Browse files Browse the repository at this point in the history
  • Loading branch information
ejguan committed Apr 24, 2023
1 parent 8d796f1 commit 303a947
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 26 deletions.
7 changes: 4 additions & 3 deletions test/dataloader2/test_dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
communication,
DataLoader2,
DistributedReadingService,
InProcessReadingService,
MultiProcessingReadingService,
ReadingServiceInterface,
SequentialReadingService,
Expand Down Expand Up @@ -223,8 +224,8 @@ def _get_mp_reading_service():
return MultiProcessingReadingService(num_workers=2)

@staticmethod
def _get_mp_reading_service_zero_workers():
return MultiProcessingReadingService(num_workers=0)
def _get_in_process_reading_service():
return InProcessReadingService()

def _collect_data(self, datapipe, reading_service_gen):
dl: DataLoader2 = DataLoader2(datapipe, reading_service=reading_service_gen())
Expand All @@ -247,7 +248,7 @@ def test_dataloader2_batch_collate(self) -> None:

reading_service_generators = (
self._get_mp_reading_service,
self._get_mp_reading_service_zero_workers,
self._get_in_process_reading_service,
)
for reading_service_gen in reading_service_generators:
actual = self._collect_data(dp, reading_service_gen=reading_service_gen)
Expand Down
6 changes: 0 additions & 6 deletions test/dataloader2/test_mprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,12 +225,6 @@ class TestMultiProcessingReadingService(TestCase):
`pause`, `resume`, `snapshot`.
"""

def test_zero_worker(self) -> None:
rs = MultiProcessingReadingService(
num_workers=0,
)
self.assertTrue(isinstance(rs, InProcessReadingService))

@mp_ctx_parametrize
@parametrize("dp_fn", [subtest(_non_dispatching_dp, "non_dispatch"), subtest(_dispatching_dp, "dispatch")])
@parametrize("main_prefetch", [0, 10])
Expand Down
20 changes: 3 additions & 17 deletions torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ class MultiProcessingReadingService(ReadingServiceInterface):
process and eventually return the result to the main process.
Args:
num_workers (int, optional): How many subprocesses to use for data loading.
num_workers (int): How many subprocesses to use for data loading.
multiprocessing_context (str, optional): Multiprocessing starting method.
If method is None then the default context is returned.
Otherwise, method should be 'fork', 'spawn'.
Expand Down Expand Up @@ -256,30 +256,16 @@ class MultiProcessingReadingService(ReadingServiceInterface):
_mp: bool
_finalized: bool = False

def __new__(
cls,
num_workers: int = 0,
multiprocessing_context: Optional[str] = None,
worker_prefetch_cnt: int = 10,
main_prefetch_cnt: int = 10,
worker_init_fn: Optional[Callable[[DataPipe, WorkerInfo], DataPipe]] = None,
worker_reset_fn: Optional[Callable[[DataPipe, WorkerInfo, SeedGenerator], DataPipe]] = None,
):
if num_workers == 0:
warnings.warn(f"`InProcessReadingService` is used when {num_workers=}")
return InProcessReadingService(worker_init_fn, worker_reset_fn)
return super().__new__(cls)

def __init__(
self,
num_workers: int = 0,
num_workers: int,
multiprocessing_context: Optional[str] = None,
worker_prefetch_cnt: int = 10,
main_prefetch_cnt: int = 10,
worker_init_fn: Optional[Callable[[DataPipe, WorkerInfo], DataPipe]] = None,
worker_reset_fn: Optional[Callable[[DataPipe, WorkerInfo, SeedGenerator], DataPipe]] = None,
) -> None:
assert num_workers > 0
assert num_workers > 0, "Please use `InProcessReadingService` for num_workers=0"
self.num_workers = num_workers
if multiprocessing_context is not None:
_all_start_methods = mp.get_all_start_methods()
Expand Down

0 comments on commit 303a947

Please sign in to comment.