Skip to content

Commit

Permalink
Moving DataPipe buffers from __iter__ to instance (self)
Browse files Browse the repository at this point in the history
ghstack-source-id: c4a55ad83139e8bb7d32881d2bc5675b0392197b
Pull Request resolved: #388
  • Loading branch information
NivekT committed May 19, 2022
1 parent 5e0d1e4 commit f3f533e
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 15 deletions.
46 changes: 39 additions & 7 deletions torchdata/datapipes/iter/util/combining.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,14 @@ def __init__(
if buffer_size is not None and buffer_size <= 0:
raise ValueError("'buffer_size' is required to be either None or a positive integer.")
self.buffer_size: int = buffer_size
self.buffer: OrderedDict = OrderedDict()

def __iter__(self) -> Iterator:
buffer: OrderedDict = OrderedDict()
ref_it = iter(self.ref_datapipe)
warn_once_flag = True
for data in self.source_datapipe:
key = self.key_fn(data)
while key not in buffer:
while key not in self.buffer:
try:
ref_data = next(ref_it)
except StopIteration:
Expand All @@ -92,18 +92,18 @@ def __iter__(self) -> Iterator:
"Please consider increasing the buffer size."
)
ref_key = self.ref_key_fn(ref_data)
if ref_key in buffer:
if ref_key in self.buffer:
raise ValueError("Duplicate key is found in reference DataPipe")
if self.buffer_size is not None and len(buffer) > self.buffer_size:
if self.buffer_size is not None and len(self.buffer) > self.buffer_size:
if warn_once_flag:
warn_once_flag = False
warnings.warn(
"Buffer reaches the upper limit, so reference key-data pair begins to "
"be removed from buffer in FIFO order. Please consider increase buffer size."
)
buffer.popitem(last=False)
buffer[ref_key] = ref_data
res = self.merge_fn(data, buffer.pop(key)) if self.merge_fn else (data, buffer.pop(key))
self.buffer.popitem(last=False)
self.buffer[ref_key] = ref_data
res = self.merge_fn(data, self.buffer.pop(key)) if self.merge_fn else (data, self.buffer.pop(key))
if self.keep_key:
yield key, res
else:
Expand All @@ -112,6 +112,38 @@ def __iter__(self) -> Iterator:
def __len__(self) -> int:
return len(self.source_datapipe)

def reset(self) -> None:
self.buffer = OrderedDict()

def __getstate__(self):
if IterDataPipe.getstate_hook is not None:
return IterDataPipe.getstate_hook(self)
state = (
self.source_datapipe,
self.ref_datapipe,
self.key_fn,
self.ref_key_fn,
self.keep_key,
self.merge_fn,
self.buffer_size,
)
return state

def __setstate__(self, state):
(
self.source_datapipe,
self.ref_datapipe,
self.key_fn,
self.ref_key_fn,
self.keep_key,
self.merge_fn,
self.buffer_size,
) = state
self.buffer = OrderedDict()

def __del__(self):
self.buffer.clear()


@functional_datapipe("zip_with_map")
class MapKeyZipperIterDataPipe(IterDataPipe[T_co]):
Expand Down
31 changes: 23 additions & 8 deletions torchdata/datapipes/iter/util/paragraphaggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,37 @@ def __init__(self, source_datapipe: IterDataPipe[Tuple[str, T_co]], joiner: Call
self.source_datapipe: IterDataPipe[Tuple[str, T_co]] = source_datapipe
_check_lambda_fn(joiner)
self.joiner: Callable = joiner
self.buffer: List = []

def __iter__(self) -> Iterator[Tuple[str, str]]:
buffer = []
prev_filename = None
for filename, line in self.source_datapipe:
if prev_filename is None:
prev_filename = filename
if line and prev_filename == filename:
buffer.append(line)
self.buffer.append(line)
else:
if buffer:
yield prev_filename, self.joiner(buffer) # type: ignore[misc]
if self.buffer:
yield prev_filename, self.joiner(self.buffer) # type: ignore[misc]
if line:
buffer = [line]
self.buffer = [line]
else:
buffer = []
self.buffer = []
prev_filename = filename
if buffer:
yield prev_filename, self.joiner(buffer) # type: ignore[misc]
if self.buffer:
yield prev_filename, self.joiner(self.buffer) # type: ignore[misc]

def reset(self) -> None:
self.buffer = []

def __getstate__(self):
if IterDataPipe.getstate_hook is not None:
return IterDataPipe.getstate_hook(self)
state = (self.source_datapipe, self.joiner)
return state

def __setstate__(self, state):
(self.source_datapipe, self.joiner) = state

def __del__(self):
self.buffer.clear()

0 comments on commit f3f533e

Please sign in to comment.