diff --git a/torchdata/datapipes/iter/util/combining.py b/torchdata/datapipes/iter/util/combining.py index 8d1b435b3..7c094bc57 100644 --- a/torchdata/datapipes/iter/util/combining.py +++ b/torchdata/datapipes/iter/util/combining.py @@ -76,14 +76,14 @@ def __init__( if buffer_size is not None and buffer_size <= 0: raise ValueError("'buffer_size' is required to be either None or a positive integer.") self.buffer_size: int = buffer_size + self.buffer: OrderedDict = OrderedDict() def __iter__(self) -> Iterator: - buffer: OrderedDict = OrderedDict() ref_it = iter(self.ref_datapipe) warn_once_flag = True for data in self.source_datapipe: key = self.key_fn(data) - while key not in buffer: + while key not in self.buffer: try: ref_data = next(ref_it) except StopIteration: @@ -92,18 +92,18 @@ def __iter__(self) -> Iterator: "Please consider increasing the buffer size." ) ref_key = self.ref_key_fn(ref_data) - if ref_key in buffer: + if ref_key in self.buffer: raise ValueError("Duplicate key is found in reference DataPipe") - if self.buffer_size is not None and len(buffer) > self.buffer_size: + if self.buffer_size is not None and len(self.buffer) > self.buffer_size: if warn_once_flag: warn_once_flag = False warnings.warn( "Buffer reaches the upper limit, so reference key-data pair begins to " "be removed from buffer in FIFO order. Please consider increase buffer size." ) - buffer.popitem(last=False) - buffer[ref_key] = ref_data - res = self.merge_fn(data, buffer.pop(key)) if self.merge_fn else (data, buffer.pop(key)) + self.buffer.popitem(last=False) + self.buffer[ref_key] = ref_data + res = self.merge_fn(data, self.buffer.pop(key)) if self.merge_fn else (data, self.buffer.pop(key)) if self.keep_key: yield key, res else: @@ -112,6 +112,38 @@ def __iter__(self) -> Iterator: def __len__(self) -> int: return len(self.source_datapipe) + def reset(self) -> None: + self.buffer = OrderedDict() + + def __getstate__(self): + if IterDataPipe.getstate_hook is not None: + return IterDataPipe.getstate_hook(self) + state = ( + self.source_datapipe, + self.ref_datapipe, + self.key_fn, + self.ref_key_fn, + self.keep_key, + self.merge_fn, + self.buffer_size, + ) + return state + + def __setstate__(self, state): + ( + self.source_datapipe, + self.ref_datapipe, + self.key_fn, + self.ref_key_fn, + self.keep_key, + self.merge_fn, + self.buffer_size, + ) = state + self.buffer = OrderedDict() + + def __del__(self): + self.buffer.clear() + @functional_datapipe("zip_with_map") class MapKeyZipperIterDataPipe(IterDataPipe[T_co]): diff --git a/torchdata/datapipes/iter/util/paragraphaggregator.py b/torchdata/datapipes/iter/util/paragraphaggregator.py index f258c0ee8..696ba33fa 100644 --- a/torchdata/datapipes/iter/util/paragraphaggregator.py +++ b/torchdata/datapipes/iter/util/paragraphaggregator.py @@ -46,22 +46,38 @@ def __init__(self, source_datapipe: IterDataPipe[Tuple[str, T_co]], joiner: Call self.source_datapipe: IterDataPipe[Tuple[str, T_co]] = source_datapipe _check_lambda_fn(joiner) self.joiner: Callable = joiner + self.buffer: List = [] def __iter__(self) -> Iterator[Tuple[str, str]]: - buffer = [] prev_filename = None for filename, line in self.source_datapipe: if prev_filename is None: prev_filename = filename if line and prev_filename == filename: - buffer.append(line) + self.buffer.append(line) else: - if buffer: - yield prev_filename, self.joiner(buffer) # type: ignore[misc] + if self.buffer: + yield prev_filename, self.joiner(self.buffer) # type: ignore[misc] if line: - buffer = [line] + self.buffer = [line] else: - buffer = [] + self.buffer = [] prev_filename = filename - if buffer: - yield prev_filename, self.joiner(buffer) # type: ignore[misc] + if self.buffer: + yield prev_filename, self.joiner(self.buffer) # type: ignore[misc] + + def reset(self) -> None: + self.buffer = [] + + def __getstate__(self): + if IterDataPipe.getstate_hook is not None: + return IterDataPipe.getstate_hook(self) + state = (self.source_datapipe, self.joiner) + return state + + def __setstate__(self, state): + (self.source_datapipe, self.joiner) = state + self.buffer = [] + + def __del__(self): + self.buffer.clear()