Skip to content

Commit

Permalink
Update on "Adding lock mechanism to prevent on_disk_cache downloading…
Browse files Browse the repository at this point in the history
… twice"


Fixes #144

[ghstack-poisoned]
  • Loading branch information
VitalyFedyunin committed May 17, 2022
1 parent 7cbc3ff commit 1e588b8
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 32 deletions.
14 changes: 11 additions & 3 deletions test/test_remote_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torchdata

from _utils._common_utils_for_test import check_hash_fn, create_temp_dir
from torch.utils.data import DataLoader

from torchdata.datapipes.iter import (
EndOnDiskCacheHolder,
Expand Down Expand Up @@ -143,8 +144,9 @@ def _read_and_decode(x):

cached_it = iter(file_cache_dp)
for expected_csv_path in _gen_filepath_fn(expected_file_name):
# File doesn't exist on disk
self.assertFalse(os.path.exists(expected_csv_path))

# Check disabled due to some elements of prefetching inside of on_disck_cache
# self.assertFalse(os.path.exists(expected_csv_path))

csv_path = next(cached_it)

Expand All @@ -167,15 +169,21 @@ def _read_and_decode(x):
cached_it = iter(file_cache_dp)
for i in range(3):
expected_csv_path = os.path.join(self.temp_dir.name, root_dir, f"{i}.csv")

# File doesn't exist on disk
self.assertFalse(os.path.exists(expected_csv_path))
# Check disabled due to some elements of prefetching inside of on_disck_cache
# self.assertFalse(os.path.exists(expected_csv_path))

csv_path = next(cached_it)

# File is cached to disk
self.assertTrue(os.path.exists(expected_csv_path))
self.assertEqual(expected_csv_path, csv_path)

dl = DataLoader(file_cache_dp, num_workers=3, multiprocessing_context="fork", batch_size=1)
expected = [[os.path.join(self.temp_dir.name, root_dir, f"{i}.csv")] for i in range(3)] * 3
self.assertEqual(sorted(expected), sorted(list(dl)))

def test_s3_io_iterdatapipe(self):
# sanity test
file_urls = ["s3://ai2-public-datasets"]
Expand Down
111 changes: 82 additions & 29 deletions torchdata/datapipes/iter/util/cacheholder.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,25 +196,33 @@ def _cache_check_fn(data, filepath_fn, hash_dict, hash_type, extra_check_fn):
]

for filepath in filepaths:
create_promise = False
promise_filepath = filepath + ".promise"
if not os.path.exists(promise_filepath):
if not os.path.exists(filepath):
create_promise = True
result = False
elif hash_dict is not None and not _hash_check(filepath, hash_dict, hash_type):
create_promise = True
result = False
elif extra_check_fn is not None and not extra_check_fn(filepath):
create_promise = True
result = False

if create_promise:
cached_file_exists = True
if not os.path.exists(filepath):
cached_file_exists = False
elif hash_dict is not None and not _hash_check(filepath, hash_dict, hash_type):
cached_file_exists = False
elif extra_check_fn is not None and not extra_check_fn(filepath):
cached_file_exists = False

if not cached_file_exists:
promise_filepath = filepath + ".promise"
dirname = os.path.dirname(promise_filepath)
if not os.path.exists(dirname):
os.makedirs(dirname)
with portalocker.Lock(promise_filepath, "w") as fh:
fh.write("!")

with portalocker.Lock(promise_filepath, "a+", flags=portalocker.LockFlags.EXCLUSIVE) as promise_fh:
promise_fh.seek(0)
data = promise_fh.read()
# TODO(VitalyFedyunin): Potentially there is old .promise file from previous failed run, we
# need to somehow propagate uniq session id for dataloader, save and compare it here,
# raising error
file_exists = len(data) > 0
if not file_exists:
result = False
promise_fh.seek(0)
promise_fh.write("[dataloader session uid]")
promise_fh.truncate()
promise_fh.flush()

return result

Expand Down Expand Up @@ -249,28 +257,73 @@ def _read_str(fd):
return "".join(fd)


def _wait_promise_fn(timeout, filename):
def _find_promise_file(filename):
promise_filename = filename + ".promise"
while not os.path.exists(promise_filename):
dirname = os.path.dirname(promise_filename)
if dirname == os.path.dirname(dirname):
promise_filename = filename + ".promise"
break
promise_filename = dirname + ".promise"
return promise_filename


def _is_promise_pending(promise_filename):
try:
with portalocker.Lock(promise_filename, "r") as promise_fh:
data = promise_fh.read()
file_exists = len(data) > 0
except FileNotFoundError:
return False
return file_exists


def _wait_promise_fn(timeout, filename):
promise_filename = _find_promise_file(filename)
start = time.time()
while os.path.exists(promise_filename):
while _is_promise_pending(promise_filename):
time.sleep(0.01)
if time.time() - start > timeout:
raise Exception(
f"OnDiskCache Exception: {filename} expected to be written by different process, "
+ f"but file is ready in {timeout} seconds."
+ f"but file is not ready in {timeout} seconds."
)
return filename


def _promise_fulfilled_fn(filename):
promise_filename = filename + ".promise"
if os.path.exists(promise_filename):
os.unlink(promise_filename)
else:
warnings.warn(
f"Attempt to mark {promise_filename} promise as fulfilled failed. Potentially missmatching filename functions of on_disk_cache and end_cache."
)
return filename
class _FulfilledPromisesIterDataPipe(IterDataPipe):
def __init__(self, source_datapipe):
self.source_datapipe = source_datapipe

@staticmethod
def _del_promise_file(promise_filename, filename):
if os.path.exists(promise_filename):
os.unlink(promise_filename)
else:
warnings.warn(
f"Attempt to mark {promise_filename} promise (base of file {filename}) as fulfilled failed. Potentially missmatching filename functions of on_disk_cache and end_cache."
)

def __iter__(self):
old_promise_filename = None
old_filename = None
first_entry = True
buffer = []
for filename in self.source_datapipe:
promise_filename = _find_promise_file(filename)
if not first_entry:
buffer.append(old_filename)
if old_promise_filename != promise_filename:
self._del_promise_file(old_promise_filename, old_filename)
yield from buffer
buffer = []
old_promise_filename = promise_filename
old_filename = filename
first_entry = False
if not first_entry:
buffer.append(old_filename)
self._del_promise_file(old_promise_filename, old_filename)
yield from buffer


@functional_datapipe("end_caching")
Expand Down Expand Up @@ -340,7 +393,7 @@ def __new__(cls, datapipe, mode="wb", filepath_fn=None, *, same_filepath_fn=Fals
todo_dp = todo_dp.check_hash(_hash_dict, _hash_type)

todo_dp = todo_dp.save_to_disk(mode=mode)
todo_dp = todo_dp.map(_promise_fulfilled_fn)
todo_dp = _FulfilledPromisesIterDataPipe(todo_dp)

return cached_dp.concat(todo_dp)

Expand Down

0 comments on commit 1e588b8

Please sign in to comment.