-
Notifications
You must be signed in to change notification settings - Fork 157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding lock mechanism to prevent on_disk_cache downloading twice #409
Changes from all commits
03511ed
10ef06b
b914f1d
c9aa2c5
911afe8
1ca1931
7cbc3ff
1e588b8
25df48b
bab6bff
3fc00e6
2e09f36
76c230d
9e216cc
5d8565a
7b80773
988cd3c
b8f619f
0a80ea4
c6a06d2
a604884
9002589
b115a71
755e841
b02f56f
f4c18b6
748d4fc
ffa0fa3
58c25aa
ae05b84
3a9fb53
070e292
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
urllib3 >= 1.25 | ||
requests | ||
portalocker >= 2.0.0 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,14 +4,19 @@ | |
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import functools | ||
import hashlib | ||
import inspect | ||
import os.path | ||
import sys | ||
import time | ||
import warnings | ||
|
||
from collections import deque | ||
from functools import partial | ||
from typing import Callable, Deque, Dict, Iterator, Optional, TypeVar | ||
from typing import Any, Callable, Deque, Dict, Iterator, List, Optional, TypeVar | ||
|
||
import portalocker | ||
|
||
from torch.utils.data.datapipes.utils.common import _check_lambda_fn, DILL_AVAILABLE | ||
|
||
|
@@ -26,6 +31,9 @@ | |
|
||
T_co = TypeVar("T_co", covariant=True) | ||
|
||
PROMISE_FILE_DELETE_TIMEOUT = 30 | ||
PROMISE_FILE_DELETE_RETRY_INTERVAL = 0.005 | ||
|
||
|
||
@functional_datapipe("in_memory_cache") | ||
class InMemoryCacheHolderIterDataPipe(IterDataPipe[T_co]): | ||
|
@@ -106,6 +114,9 @@ def _hash_check(filepath, hash_dict, hash_type): | |
else: | ||
hash_func = hashlib.md5() | ||
|
||
# with portalocker.Lock(filepath, "rb", flags=portalocker.LockFlags.SHARED) as f: | ||
# TODO(VitalyFedyunin): Line above will require all readers (Win) to obtain proper locks, | ||
# I'm putting it on hold as we need to modify PyTorch core codebase heavily. | ||
with open(filepath, "rb") as f: | ||
chunk = f.read(1024 ** 2) | ||
while chunk: | ||
|
@@ -115,6 +126,10 @@ def _hash_check(filepath, hash_dict, hash_type): | |
return hash_func.hexdigest() == hash_dict[filepath] | ||
|
||
|
||
def _promise_filename(filename): | ||
return filename + ".promise" | ||
|
||
|
||
@functional_datapipe("on_disk_cache") | ||
class OnDiskCacheHolderIterDataPipe(IterDataPipe): | ||
""" | ||
|
@@ -145,7 +160,7 @@ class OnDiskCacheHolderIterDataPipe(IterDataPipe): | |
>>> hash_dict = {"expected_filepath": expected_MD5_hash} | ||
>>> cache_dp = url.on_disk_cache(filepath_fn=_filepath_fn, hash_dict=_hash_dict, hash_type="md5") | ||
>>> # You must call ``.end_caching`` at a later point to stop tracing and save the results to local files. | ||
>>> cache_dp = HttpReader(cache_dp).end_caching(mode="wb". filepath_fn=_filepath_fn) | ||
>>> cache_dp = HttpReader(cache_dp).end_caching(mode="wb", filepath_fn=_filepath_fn) | ||
""" | ||
|
||
_temp_dict: Dict = {} | ||
|
@@ -184,22 +199,42 @@ def __add__(self, other_datapipe): | |
@staticmethod | ||
def _cache_check_fn(data, filepath_fn, hash_dict, hash_type, extra_check_fn): | ||
filepaths = data if filepath_fn is None else filepath_fn(data) | ||
result = True | ||
if not isinstance(filepaths, (list, tuple)): | ||
filepaths = [ | ||
filepaths, | ||
] | ||
|
||
for filepath in filepaths: | ||
cached_file_exists = True | ||
if not os.path.exists(filepath): | ||
return False | ||
|
||
if hash_dict is not None and not _hash_check(filepath, hash_dict, hash_type): | ||
return False | ||
|
||
if extra_check_fn is not None and not extra_check_fn(filepath): | ||
return False | ||
|
||
return True | ||
cached_file_exists = False | ||
elif hash_dict is not None and not _hash_check(filepath, hash_dict, hash_type): | ||
cached_file_exists = False | ||
elif extra_check_fn is not None and not extra_check_fn(filepath): | ||
cached_file_exists = False | ||
|
||
if not cached_file_exists: | ||
promise_filepath = _promise_filename(filepath) | ||
dirname = os.path.dirname(promise_filepath) | ||
if not os.path.exists(dirname): | ||
os.makedirs(dirname) | ||
|
||
with portalocker.Lock(promise_filepath, "a+", flags=portalocker.LockFlags.EXCLUSIVE) as promise_fh: | ||
promise_fh.seek(0) | ||
data = promise_fh.read() | ||
# TODO(VitalyFedyunin): Potentially there is old .promise file from previous failed run, we | ||
# need to somehow propagate uniq session id for dataloader, save and compare it here, | ||
# raising error | ||
file_exists = len(data) > 0 | ||
if not file_exists: | ||
result = False | ||
promise_fh.seek(0) | ||
promise_fh.write("[dataloader session uid]") | ||
promise_fh.truncate() | ||
promise_fh.flush() | ||
|
||
return result | ||
Comment on lines
+229
to
+237
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When the cached op is 1-to-n like decompression from archive, if any decompressed file is missing or has incorrect hash, we can directly So, I think we should lock over WDYT? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unfortunately data could be url or something else, it is hard to lock on it. But this situation is covered. Imagine data generates two file namesL file1 file2. Initial pass (empty FS) will add two locks file1.promise and file2.promise and will go 'False' route. Now second (and every next) pass will see that files are missing, but will fail to create promise and go into the 'file exists' route, which will led them to the situation when they are waiting for file1.promise and file2.promise to disappear. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If it is an URL, is it possible to create and lock
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, I see. |
||
|
||
def _end_caching(self): | ||
filepath_fn, hash_dict, hash_type, extra_check_fn = OnDiskCacheHolderIterDataPipe._temp_dict.pop(self) | ||
|
@@ -232,6 +267,82 @@ def _read_str(fd): | |
return "".join(fd) | ||
|
||
|
||
def _find_promise_file(filename): | ||
promise_filename = _promise_filename(filename) | ||
while not os.path.exists(promise_filename): | ||
dirname = os.path.dirname(promise_filename) | ||
if dirname == os.path.dirname(dirname): | ||
promise_filename = _promise_filename(filename) | ||
break | ||
promise_filename = _promise_filename(dirname) | ||
return promise_filename | ||
|
||
|
||
def _is_promise_pending(promise_filename): | ||
return os.path.exists(promise_filename) | ||
|
||
|
||
def _wait_promise_fn(timeout, filename): | ||
promise_filename = _find_promise_file(filename) | ||
start = time.time() | ||
while _is_promise_pending(promise_filename): | ||
time.sleep(0.01) | ||
if time.time() - start > timeout: | ||
raise Exception( | ||
f"OnDiskCache Exception: {filename} expected to be written by different process, " | ||
+ f"but file is not ready in {timeout} seconds." | ||
) | ||
return filename | ||
|
||
|
||
class _FulfilledPromisesIterDataPipe(IterDataPipe): | ||
def __init__(self, source_datapipe): | ||
self.source_datapipe = source_datapipe | ||
|
||
@staticmethod | ||
def _del_promise_file(promise_filename, filename): | ||
if os.path.exists(promise_filename): | ||
retry = True | ||
start = time.time() | ||
while retry: | ||
retry = False | ||
try: | ||
os.unlink(promise_filename) | ||
except Exception as e: | ||
# Workaround about Windows not letting to delete file, while it is open by another process | ||
retry = True | ||
if time.time() - start > PROMISE_FILE_DELETE_TIMEOUT: | ||
raise Exception("Timeout while trying to recover from the ", type(e), e) | ||
time.sleep(PROMISE_FILE_DELETE_RETRY_INTERVAL) | ||
else: | ||
warnings.warn( | ||
f"Attempt to mark {promise_filename} promise (base of file {filename}) as fulfilled failed. Potentially missmatching filename functions of on_disk_cache and end_cache." | ||
) | ||
|
||
def __iter__(self): | ||
old_promise_filename = None | ||
old_filename = None | ||
first_entry = True | ||
# TODO(VitalyFedyunin): Limit buffer size here. It is only contains file names from archive, | ||
# but better be save than sorry. | ||
buffer: List[Any] = [] | ||
for filename in self.source_datapipe: | ||
promise_filename = _find_promise_file(filename) | ||
if not first_entry: | ||
buffer.append(old_filename) | ||
if old_promise_filename != promise_filename: | ||
self._del_promise_file(old_promise_filename, old_filename) | ||
yield from buffer | ||
buffer = [] | ||
old_promise_filename = promise_filename | ||
old_filename = filename | ||
first_entry = False | ||
if not first_entry: | ||
buffer.append(old_filename) | ||
self._del_promise_file(old_promise_filename, old_filename) | ||
yield from buffer | ||
|
||
|
||
@functional_datapipe("end_caching") | ||
class EndOnDiskCacheHolderIterDataPipe(IterDataPipe): | ||
""" | ||
|
@@ -248,6 +359,7 @@ class EndOnDiskCacheHolderIterDataPipe(IterDataPipe): | |
same_filepath_fn: Set to ``True`` to use same ``filepath_fn`` from the ``OnDiskCacheHolder``. | ||
skip_read: Boolean value to skip reading the file handle from ``datapipe``. | ||
By default, reading is enabled and reading function is created based on the ``mode``. | ||
timeout: Integer value of seconds to wait for uncached item to be written to disk | ||
|
||
Example: | ||
>>> from torchdata.datapipes.iter import IterableWrapper, HttpReader | ||
|
@@ -259,10 +371,10 @@ class EndOnDiskCacheHolderIterDataPipe(IterDataPipe): | |
>>> # You must call ``.on_disk_cache`` at some point before ``.end_caching`` | ||
>>> cache_dp = url.on_disk_cache(filepath_fn=_filepath_fn, hash_dict=_hash_dict, hash_type="md5") | ||
>>> # You must call ``.end_caching`` at a later point to stop tracing and save the results to local files. | ||
>>> cache_dp = HttpReader(cache_dp).end_caching(mode="wb". filepath_fn=_filepath_fn) | ||
>>> cache_dp = HttpReader(cache_dp).end_caching(mode="wb", filepath_fn=_filepath_fn) | ||
""" | ||
|
||
def __new__(cls, datapipe, mode="wb", filepath_fn=None, *, same_filepath_fn=False, skip_read=False): | ||
def __new__(cls, datapipe, mode="wb", filepath_fn=None, *, same_filepath_fn=False, skip_read=False, timeout=300): | ||
if filepath_fn is not None and same_filepath_fn: | ||
raise ValueError("`filepath_fn` is mutually exclusive with `same_filepath_fn`") | ||
|
||
|
@@ -276,6 +388,7 @@ def __new__(cls, datapipe, mode="wb", filepath_fn=None, *, same_filepath_fn=Fals | |
|
||
_filepath_fn, _hash_dict, _hash_type, _ = OnDiskCacheHolderIterDataPipe._temp_dict[cache_holder] | ||
cached_dp = cache_holder._end_caching() | ||
cached_dp = cached_dp.map(functools.partial(_wait_promise_fn, timeout)) | ||
cached_dp = FileLister(cached_dp, recursive=True) | ||
|
||
if same_filepath_fn: | ||
|
@@ -297,6 +410,7 @@ def __new__(cls, datapipe, mode="wb", filepath_fn=None, *, same_filepath_fn=Fals | |
todo_dp = todo_dp.check_hash(_hash_dict, _hash_type) | ||
|
||
todo_dp = todo_dp.save_to_disk(mode=mode) | ||
todo_dp = _FulfilledPromisesIterDataPipe(todo_dp) | ||
|
||
return cached_dp.concat(todo_dp) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just want to verify,
len(result)
should be 1, right?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope. I'm creating one additional file inside of _slow_fn. So it would be 'downloaded' file and 'pid' file.