Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding lock mechanism to prevent on_disk_cache downloading twice #409

Closed
Closed
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
03511ed
Adding lock mechanism to prevent on_disk_cache downloading twice
VitalyFedyunin May 16, 2022
10ef06b
Update on "Adding lock mechanism to prevent on_disk_cache downloading…
VitalyFedyunin May 16, 2022
b914f1d
Update on "Adding lock mechanism to prevent on_disk_cache downloading…
VitalyFedyunin May 16, 2022
c9aa2c5
Update on "Adding lock mechanism to prevent on_disk_cache downloading…
VitalyFedyunin May 16, 2022
911afe8
Update on "Adding lock mechanism to prevent on_disk_cache downloading…
VitalyFedyunin May 16, 2022
1ca1931
Update on "Adding lock mechanism to prevent on_disk_cache downloading…
VitalyFedyunin May 16, 2022
7cbc3ff
Update on "Adding lock mechanism to prevent on_disk_cache downloading…
VitalyFedyunin May 16, 2022
1e588b8
Update on "Adding lock mechanism to prevent on_disk_cache downloading…
VitalyFedyunin May 17, 2022
25df48b
Update on "Adding lock mechanism to prevent on_disk_cache downloading…
VitalyFedyunin May 17, 2022
bab6bff
Update on "Adding lock mechanism to prevent on_disk_cache downloading…
VitalyFedyunin May 17, 2022
3fc00e6
Update on "Adding lock mechanism to prevent on_disk_cache downloading…
VitalyFedyunin May 17, 2022
2e09f36
Update on "Adding lock mechanism to prevent on_disk_cache downloading…
VitalyFedyunin May 17, 2022
76c230d
Update on "Adding lock mechanism to prevent on_disk_cache downloading…
VitalyFedyunin May 17, 2022
9e216cc
Update on "Adding lock mechanism to prevent on_disk_cache downloading…
VitalyFedyunin May 17, 2022
5d8565a
Update on "Adding lock mechanism to prevent on_disk_cache downloading…
VitalyFedyunin May 17, 2022
7b80773
Update on "Adding lock mechanism to prevent on_disk_cache downloading…
VitalyFedyunin May 18, 2022
988cd3c
Update on "Adding lock mechanism to prevent on_disk_cache downloading…
VitalyFedyunin May 18, 2022
b8f619f
Update on "Adding lock mechanism to prevent on_disk_cache downloading…
VitalyFedyunin May 18, 2022
0a80ea4
Update on "Adding lock mechanism to prevent on_disk_cache downloading…
VitalyFedyunin May 18, 2022
c6a06d2
Update on "Adding lock mechanism to prevent on_disk_cache downloading…
VitalyFedyunin May 18, 2022
a604884
Update on "Adding lock mechanism to prevent on_disk_cache downloading…
VitalyFedyunin May 18, 2022
9002589
Update on "Adding lock mechanism to prevent on_disk_cache downloading…
VitalyFedyunin May 18, 2022
b115a71
Update on "Adding lock mechanism to prevent on_disk_cache downloading…
VitalyFedyunin May 18, 2022
755e841
Update on "Adding lock mechanism to prevent on_disk_cache downloading…
VitalyFedyunin May 18, 2022
b02f56f
Update on "Adding lock mechanism to prevent on_disk_cache downloading…
VitalyFedyunin May 18, 2022
f4c18b6
Update on "Adding lock mechanism to prevent on_disk_cache downloading…
VitalyFedyunin May 18, 2022
748d4fc
Update on "Adding lock mechanism to prevent on_disk_cache downloading…
VitalyFedyunin May 18, 2022
ffa0fa3
Update on "Adding lock mechanism to prevent on_disk_cache downloading…
VitalyFedyunin May 18, 2022
58c25aa
Update on "Adding lock mechanism to prevent on_disk_cache downloading…
VitalyFedyunin May 18, 2022
ae05b84
Update on "Adding lock mechanism to prevent on_disk_cache downloading…
VitalyFedyunin May 18, 2022
3a9fb53
Update on "Adding lock mechanism to prevent on_disk_cache downloading…
VitalyFedyunin May 18, 2022
070e292
Update on "Adding lock mechanism to prevent on_disk_cache downloading…
VitalyFedyunin May 18, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
urllib3 >= 1.25
requests
portalocker >= 2.0.0
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def _export_version(version, sha):
"urllib3 >= 1.25",
"requests",
pytorch_package_dep,
"portalocker >= 2.0.0",
NivekT marked this conversation as resolved.
Show resolved Hide resolved
]


Expand Down
40 changes: 40 additions & 0 deletions test/test_local_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
# LICENSE file in the root directory of this source tree.

import bz2
import functools
import hashlib
import io
import itertools
import lzma
import os
import subprocess
import tarfile
import tempfile
import time
import unittest
import warnings
import zipfile
Expand All @@ -21,6 +24,8 @@
import expecttest

from _utils._common_utils_for_test import create_temp_dir, create_temp_files, get_name, reset_after_n_next_calls

from torch.utils.data import DataLoader
from torchdata.datapipes.iter import (
Bz2FileLoader,
CSVDictParser,
Expand All @@ -33,9 +38,11 @@
IoPathFileOpener,
IoPathSaver,
IterableWrapper,
IterDataPipe,
JsonParser,
RarArchiveLoader,
Saver,
StreamReader,
TarArchiveLoader,
WebDataset,
XzFileLoader,
Expand Down Expand Up @@ -64,6 +71,14 @@
skipIfNoRarTools = unittest.skipIf(not HAS_RAR_TOOLS, "no rar tools")


def _unbatch(x):
return x[0]


def _noop(x):
return x


class TestDataPipeLocalIO(expecttest.TestCase):
def setUp(self):
self.temp_dir = create_temp_dir()
Expand Down Expand Up @@ -590,6 +605,31 @@ def filepath_fn(name: str) -> str:
saver_dp = source_dp.save_to_disk(filepath_fn=filepath_fn, mode="wb")
list(saver_dp)

@staticmethod
def _slow_fn(tmpdirname, x):
with open(os.path.join(tmpdirname, str(os.getpid())), "w") as pid_fh:
pid_fh.write("anything")
time.sleep(2)
return (x, "str")

def test_disk_cache_locks(self):
with tempfile.TemporaryDirectory() as tmpdirname:
file_name = os.path.join(tmpdirname, "test.bin")
dp = IterableWrapper([file_name])
dp = dp.on_disk_cache(filepath_fn=_noop)
dp = dp.map(functools.partial(self._slow_fn, tmpdirname))
dp = dp.end_caching(mode="t", filepath_fn=_noop, timeout=120)
dp = FileOpener(dp)
dp = StreamReader(dp)
dl = DataLoader(dp, num_workers=10, multiprocessing_context="spawn", batch_size=1, collate_fn=_unbatch)
result = list(dl)
all_files = []
for (_, _, filenames) in os.walk(tmpdirname):
all_files += filenames
# We expect only two files, one with pid and 'downloaded' one
self.assertEqual(2, len(all_files))
self.assertEqual("str", result[0][1])
Comment on lines +639 to +640
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just want to verify, len(result) should be 1, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope. I'm creating one additional file inside of _slow_fn. So it would be 'downloaded' file and 'pid' file.


# TODO(120): this test currently only covers reading from local
# filesystem. It needs to be modified once test data can be stored on
# gdrive/s3/onedrive
Expand Down
17 changes: 13 additions & 4 deletions test/test_remote_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

import torchdata

from _utils._common_utils_for_test import check_hash_fn, create_temp_dir
from _utils._common_utils_for_test import check_hash_fn, create_temp_dir, IS_WINDOWS
from torch.utils.data import DataLoader

from torchdata.datapipes.iter import (
EndOnDiskCacheHolder,
Expand Down Expand Up @@ -143,8 +144,9 @@ def _read_and_decode(x):

cached_it = iter(file_cache_dp)
for expected_csv_path in _gen_filepath_fn(expected_file_name):
# File doesn't exist on disk
self.assertFalse(os.path.exists(expected_csv_path))

# Check disabled due to some elements of prefetching inside of on_disck_cache
# self.assertFalse(os.path.exists(expected_csv_path))

csv_path = next(cached_it)

Expand All @@ -167,15 +169,22 @@ def _read_and_decode(x):
cached_it = iter(file_cache_dp)
for i in range(3):
expected_csv_path = os.path.join(self.temp_dir.name, root_dir, f"{i}.csv")

# File doesn't exist on disk
self.assertFalse(os.path.exists(expected_csv_path))
# Check disabled due to some elements of prefetching inside of on_disck_cache
# self.assertFalse(os.path.exists(expected_csv_path))

csv_path = next(cached_it)

# File is cached to disk
self.assertTrue(os.path.exists(expected_csv_path))
self.assertEqual(expected_csv_path, csv_path)

if not IS_WINDOWS:
dl = DataLoader(file_cache_dp, num_workers=3, multiprocessing_context="fork", batch_size=1)
expected = [[os.path.join(self.temp_dir.name, root_dir, f"{i}.csv")] for i in range(3)] * 3
self.assertEqual(sorted(expected), sorted(list(dl)))

def test_s3_io_iterdatapipe(self):
# sanity test
file_urls = ["s3://ai2-public-datasets"]
Expand Down
4 changes: 3 additions & 1 deletion torchdata/dataloader2/dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,9 @@ def load_state_dict(self, state: Dict[str, Any]) -> None:
# edge case checking
# iterator has already been created: 1) iterator is just created 2) iterator is created and iter is exhausted
if self._datapipe_iter is not None:
raise RuntimeError("DataLoaderV2 iterator has already been created, `load_state_dict()` can’t be called. Please create a new dataloader in order to use load state dict.")
raise RuntimeError(
"DataLoaderV2 iterator has already been created, `load_state_dict()` can’t be called. Please create a new dataloader in order to use load state dict."
)

serialized_datapipe = state[SERIALIZED_DATAPIPE_KEY_NAME]
reading_service_state = state[READING_SERVICE_STATE_KEY_NAME]
Expand Down
152 changes: 138 additions & 14 deletions torchdata/datapipes/iter/util/cacheholder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,19 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import functools
import hashlib
import inspect
import os.path
import sys
import time
import warnings

from collections import deque
from functools import partial
from typing import Callable, Deque, Dict, Iterator, Optional, TypeVar
from typing import Any, Callable, Deque, Dict, Iterator, List, Optional, TypeVar

import portalocker

from torch.utils.data.datapipes.utils.common import _check_lambda_fn, DILL_AVAILABLE

Expand All @@ -26,6 +31,9 @@

T_co = TypeVar("T_co", covariant=True)

PROMISE_FILE_DELETE_TIMEOUT = 30
PROMISE_FILE_DELETE_RETRY_INTERVAL = 0.005


@functional_datapipe("in_memory_cache")
class InMemoryCacheHolderIterDataPipe(IterDataPipe[T_co]):
Expand Down Expand Up @@ -106,7 +114,7 @@ def _hash_check(filepath, hash_dict, hash_type):
else:
hash_func = hashlib.md5()

with open(filepath, "rb") as f:
with portalocker.Lock(filepath, "rb", flags=portalocker.LockFlags.EXCLUSIVE) as f:
chunk = f.read(1024 ** 2)
while chunk:
hash_func.update(chunk)
Expand Down Expand Up @@ -145,7 +153,7 @@ class OnDiskCacheHolderIterDataPipe(IterDataPipe):
>>> hash_dict = {"expected_filepath": expected_MD5_hash}
>>> cache_dp = url.on_disk_cache(filepath_fn=_filepath_fn, hash_dict=_hash_dict, hash_type="md5")
>>> # You must call ``.end_caching`` at a later point to stop tracing and save the results to local files.
>>> cache_dp = HttpReader(cache_dp).end_caching(mode="wb". filepath_fn=_filepath_fn)
>>> cache_dp = HttpReader(cache_dp).end_caching(mode="wb", filepath_fn=_filepath_fn)
"""

_temp_dict: Dict = {}
Expand Down Expand Up @@ -184,22 +192,42 @@ def __add__(self, other_datapipe):
@staticmethod
def _cache_check_fn(data, filepath_fn, hash_dict, hash_type, extra_check_fn):
filepaths = data if filepath_fn is None else filepath_fn(data)
result = True
if not isinstance(filepaths, (list, tuple)):
filepaths = [
filepaths,
]

for filepath in filepaths:
cached_file_exists = True
if not os.path.exists(filepath):
return False

if hash_dict is not None and not _hash_check(filepath, hash_dict, hash_type):
return False

if extra_check_fn is not None and not extra_check_fn(filepath):
return False

return True
cached_file_exists = False
elif hash_dict is not None and not _hash_check(filepath, hash_dict, hash_type):
cached_file_exists = False
elif extra_check_fn is not None and not extra_check_fn(filepath):
cached_file_exists = False

if not cached_file_exists:
promise_filepath = filepath + ".promise"
dirname = os.path.dirname(promise_filepath)
if not os.path.exists(dirname):
os.makedirs(dirname)

with portalocker.Lock(promise_filepath, "a+", flags=portalocker.LockFlags.EXCLUSIVE) as promise_fh:
promise_fh.seek(0)
data = promise_fh.read()
# TODO(VitalyFedyunin): Potentially there is old .promise file from previous failed run, we
# need to somehow propagate uniq session id for dataloader, save and compare it here,
# raising error
file_exists = len(data) > 0
if not file_exists:
result = False
promise_fh.seek(0)
promise_fh.write("[dataloader session uid]")
promise_fh.truncate()
promise_fh.flush()

return result
Comment on lines +229 to +237
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the cached op is 1-to-n like decompression from archive, if any decompressed file is missing or has incorrect hash, we can directly return False and no need to check other files IMHO.
There can be a chance that multiple processes are locking different decompressed files for an archive. Then, both processes will run decompression -> racing condition again.

So, I think we should lock over data but not filepaths (data represents the compressed archive in this case). For the process that observes promise file over data, they can directly return True.

WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately data could be url or something else, it is hard to lock on it.

But this situation is covered.

Imagine data generates two file namesL file1 file2.

Initial pass (empty FS) will add two locks file1.promise and file2.promise and will go 'False' route.

Now second (and every next) pass will see that files are missing, but will fail to create promise and go into the 'file exists' route, which will led them to the situation when they are waiting for file1.promise and file2.promise to disappear.

Copy link
Contributor

@NivekT NivekT May 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it is an URL, is it possible to create and lock root/URL.promise in the file system?

I think we should have a similar lock for HttpReader to prevent multiple processes from downloading the same file? Nevermind

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. file_exists flag is used for processes to recognize this file or parent archives are going to be processed by another process.


def _end_caching(self):
filepath_fn, hash_dict, hash_type, extra_check_fn = OnDiskCacheHolderIterDataPipe._temp_dict.pop(self)
Expand Down Expand Up @@ -232,6 +260,99 @@ def _read_str(fd):
return "".join(fd)


def _find_promise_file(filename):
promise_filename = filename + ".promise"
while not os.path.exists(promise_filename):
dirname = os.path.dirname(promise_filename)
if dirname == os.path.dirname(dirname):
promise_filename = filename + ".promise"
break
promise_filename = dirname + ".promise"
return promise_filename


def _is_promise_pending(promise_filename):
return os.path.exists(promise_filename)
# try:
# with portalocker.Lock(promise_filename, "r") as promise_fh:
# data = promise_fh.read()
# file_exists = len(data) > 0
# except FileNotFoundError:
# return False
# except PermissionError:
# return True
# return file_exists


def _wait_promise_fn(timeout, filename):
promise_filename = _find_promise_file(filename)
start = time.time()
while _is_promise_pending(promise_filename):
time.sleep(0.01)
if time.time() - start > timeout:
raise Exception(
f"OnDiskCache Exception: {filename} expected to be written by different process, "
+ f"but file is not ready in {timeout} seconds."
)
return filename


class _FulfilledPromisesIterDataPipe(IterDataPipe):
def __init__(self, source_datapipe):
self.source_datapipe = source_datapipe

@staticmethod
def _del_promise_file(promise_filename, filename):
if os.path.exists(promise_filename):
retry = True
start = time.time()
while retry:
retry = False
try:
# print()
# print()
os.unlink(promise_filename)
# except:
except (PermissionError, Exception) as e:
# Workaround about Windows not letting to delete file, while it is open by another process
retry = True
if time.time() - start > PROMISE_FILE_DELETE_TIMEOUT:
# raise Exception("Timeout while trying to recover from the ", type(e), e)
raise Exception("Timeout while trying to recover from the exception ", type(e))
time.sleep(PROMISE_FILE_DELETE_RETRY_INTERVAL)
# except Exception as e:
# raise Exception("Something else happened while trying to delete promise file ", type(e), e)
# except:
# raise Exception("Unclassified situation")
else:
warnings.warn(
f"Attempt to mark {promise_filename} promise (base of file {filename}) as fulfilled failed. Potentially missmatching filename functions of on_disk_cache and end_cache."
)

def __iter__(self):
old_promise_filename = None
old_filename = None
first_entry = True
# TODO(VitalyFedyunin): Limit buffer size here. It is only contains file names from archive,
# but better be save than sorry.
buffer: List[Any] = []
for filename in self.source_datapipe:
promise_filename = _find_promise_file(filename)
if not first_entry:
buffer.append(old_filename)
if old_promise_filename != promise_filename:
self._del_promise_file(old_promise_filename, old_filename)
yield from buffer
buffer = []
old_promise_filename = promise_filename
old_filename = filename
first_entry = False
if not first_entry:
buffer.append(old_filename)
self._del_promise_file(old_promise_filename, old_filename)
yield from buffer


@functional_datapipe("end_caching")
class EndOnDiskCacheHolderIterDataPipe(IterDataPipe):
"""
Expand All @@ -248,6 +369,7 @@ class EndOnDiskCacheHolderIterDataPipe(IterDataPipe):
same_filepath_fn: Set to ``True`` to use same ``filepath_fn`` from the ``OnDiskCacheHolder``.
skip_read: Boolean value to skip reading the file handle from ``datapipe``.
By default, reading is enabled and reading function is created based on the ``mode``.
timeout: Integer value of seconds to wait for uncached item to be written to disk

Example:
>>> from torchdata.datapipes.iter import IterableWrapper, HttpReader
Expand All @@ -259,10 +381,10 @@ class EndOnDiskCacheHolderIterDataPipe(IterDataPipe):
>>> # You must call ``.on_disk_cache`` at some point before ``.end_caching``
>>> cache_dp = url.on_disk_cache(filepath_fn=_filepath_fn, hash_dict=_hash_dict, hash_type="md5")
>>> # You must call ``.end_caching`` at a later point to stop tracing and save the results to local files.
>>> cache_dp = HttpReader(cache_dp).end_caching(mode="wb". filepath_fn=_filepath_fn)
>>> cache_dp = HttpReader(cache_dp).end_caching(mode="wb", filepath_fn=_filepath_fn)
"""

def __new__(cls, datapipe, mode="wb", filepath_fn=None, *, same_filepath_fn=False, skip_read=False):
def __new__(cls, datapipe, mode="wb", filepath_fn=None, *, same_filepath_fn=False, skip_read=False, timeout=300):
if filepath_fn is not None and same_filepath_fn:
raise ValueError("`filepath_fn` is mutually exclusive with `same_filepath_fn`")

Expand All @@ -276,6 +398,7 @@ def __new__(cls, datapipe, mode="wb", filepath_fn=None, *, same_filepath_fn=Fals

_filepath_fn, _hash_dict, _hash_type, _ = OnDiskCacheHolderIterDataPipe._temp_dict[cache_holder]
cached_dp = cache_holder._end_caching()
cached_dp = cached_dp.map(functools.partial(_wait_promise_fn, timeout))
cached_dp = FileLister(cached_dp, recursive=True)

if same_filepath_fn:
Expand All @@ -297,6 +420,7 @@ def __new__(cls, datapipe, mode="wb", filepath_fn=None, *, same_filepath_fn=Fals
todo_dp = todo_dp.check_hash(_hash_dict, _hash_type)

todo_dp = todo_dp.save_to_disk(mode=mode)
todo_dp = _FulfilledPromisesIterDataPipe(todo_dp)

return cached_dp.concat(todo_dp)

Expand Down
Loading