Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moving DataPipe buffers from __iter__ to instance (self) #388

Closed
wants to merge 8 commits into from
48 changes: 41 additions & 7 deletions torchdata/datapipes/iter/util/combining.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,14 @@ def __init__(
if buffer_size is not None and buffer_size <= 0:
raise ValueError("'buffer_size' is required to be either None or a positive integer.")
self.buffer_size: int = buffer_size
self.buffer: OrderedDict = OrderedDict()

def __iter__(self) -> Iterator:
buffer: OrderedDict = OrderedDict()
ref_it = iter(self.ref_datapipe)
warn_once_flag = True
for data in self.source_datapipe:
key = self.key_fn(data)
while key not in buffer:
while key not in self.buffer:
try:
ref_data = next(ref_it)
except StopIteration:
Expand All @@ -92,26 +92,60 @@ def __iter__(self) -> Iterator:
"Please consider increasing the buffer size."
)
ref_key = self.ref_key_fn(ref_data)
if ref_key in buffer:
if ref_key in self.buffer:
raise ValueError("Duplicate key is found in reference DataPipe")
if self.buffer_size is not None and len(buffer) > self.buffer_size:
if self.buffer_size is not None and len(self.buffer) > self.buffer_size:
if warn_once_flag:
warn_once_flag = False
warnings.warn(
"Buffer reaches the upper limit, so reference key-data pair begins to "
"be removed from buffer in FIFO order. Please consider increase buffer size."
)
buffer.popitem(last=False)
buffer[ref_key] = ref_data
res = self.merge_fn(data, buffer.pop(key)) if self.merge_fn else (data, buffer.pop(key))
self.buffer.popitem(last=False)
self.buffer[ref_key] = ref_data
res = self.merge_fn(data, self.buffer.pop(key)) if self.merge_fn else (data, self.buffer.pop(key))
if self.keep_key:
yield key, res
else:
yield res
if self.buffer:
self.buffer.clear()
Copy link
Contributor Author

@NivekT NivekT May 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting enough the domain CI test for Vision started to fail for a previous commit of this PR. It is raising "ResourceWarning" as you can see in the link below.

https://github.com/pytorch/data/runs/6498262553?check_suite_focus=true

Perhaps we need clear the buffer when the iterator is exhausted? I am adding this line to see if the issue is resolved. Let me know if you have any other thoughts.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just tried and clearing the buffer at the end of __iter__ doesn't help.

@pmeier Do you think this is related to pytest or something that is introduced in this PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reference, here is the thread where we discussed a similar issue:
pytorch/vision#5801

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am going to ignore the warning and land this based on the previous investigation.


def __len__(self) -> int:
return len(self.source_datapipe)

def reset(self) -> None:
self.buffer = OrderedDict()

def __getstate__(self):
if IterDataPipe.getstate_hook is not None:
return IterDataPipe.getstate_hook(self)
state = (
self.source_datapipe,
self.ref_datapipe,
self.key_fn,
self.ref_key_fn,
self.keep_key,
self.merge_fn,
self.buffer_size,
)
return state

def __setstate__(self, state):
(
self.source_datapipe,
self.ref_datapipe,
self.key_fn,
self.ref_key_fn,
self.keep_key,
self.merge_fn,
self.buffer_size,
) = state
self.buffer = OrderedDict()

def __del__(self):
self.buffer.clear()


@functional_datapipe("zip_with_map")
class MapKeyZipperIterDataPipe(IterDataPipe[T_co]):
Expand Down
34 changes: 26 additions & 8 deletions torchdata/datapipes/iter/util/paragraphaggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,40 @@ def __init__(self, source_datapipe: IterDataPipe[Tuple[str, T_co]], joiner: Call
self.source_datapipe: IterDataPipe[Tuple[str, T_co]] = source_datapipe
_check_lambda_fn(joiner)
self.joiner: Callable = joiner
self.buffer: List = []

def __iter__(self) -> Iterator[Tuple[str, str]]:
buffer = []
prev_filename = None
for filename, line in self.source_datapipe:
if prev_filename is None:
prev_filename = filename
if line and prev_filename == filename:
buffer.append(line)
self.buffer.append(line)
else:
if buffer:
yield prev_filename, self.joiner(buffer) # type: ignore[misc]
if self.buffer:
yield prev_filename, self.joiner(self.buffer) # type: ignore[misc]
if line:
buffer = [line]
self.buffer = [line]
else:
buffer = []
self.buffer = []
prev_filename = filename
if buffer:
yield prev_filename, self.joiner(buffer) # type: ignore[misc]
if self.buffer:
try:
yield prev_filename, self.joiner(self.buffer) # type: ignore[misc]
finally:
self.buffer.clear()

def reset(self) -> None:
self.buffer = []

def __getstate__(self):
if IterDataPipe.getstate_hook is not None:
return IterDataPipe.getstate_hook(self)
state = (self.source_datapipe, self.joiner)
return state

def __setstate__(self, state):
(self.source_datapipe, self.joiner) = state

def __del__(self):
self.buffer.clear()