Skip to content

Commit

Permalink
Adding lock mechanism to prevent on_disk_cache downloading twice
Browse files Browse the repository at this point in the history
ghstack-source-id: dd94ca8f4abac335e1c7ab72fa512660c333c7e8
Pull Request resolved: #409
  • Loading branch information
VitalyFedyunin committed May 16, 2022
1 parent 3c77696 commit e21c304
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 11 deletions.
34 changes: 34 additions & 0 deletions test/test_local_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import os
import subprocess
import tarfile
import tempfile
import time
import unittest
import warnings
import zipfile
Expand All @@ -33,15 +35,19 @@
IoPathFileOpener,
IoPathSaver,
IterableWrapper,
IterDataPipe,
JsonParser,
RarArchiveLoader,
Saver,
StreamReader,
TarArchiveLoader,
WebDataset,
XzFileLoader,
ZipArchiveLoader,
)

from torch.utils.data import DataLoader

try:
import iopath

Expand All @@ -64,6 +70,10 @@
skipIfNoRarTools = unittest.skipIf(not HAS_RAR_TOOLS, "no rar tools")


def _unbatch(x):
return x[0]


class TestDataPipeLocalIO(expecttest.TestCase):
def setUp(self):
self.temp_dir = create_temp_dir()
Expand Down Expand Up @@ -590,6 +600,30 @@ def filepath_fn(name: str) -> str:
saver_dp = source_dp.save_to_disk(filepath_fn=filepath_fn, mode="wb")
list(saver_dp)

def test_disk_cache_locks(self):
with tempfile.TemporaryDirectory() as tmpdirname:
file_name = os.path.join(tmpdirname, 'test.bin')
dp = IterableWrapper([file_name])
dp = dp.on_disk_cache(filepath_fn=lambda x: x)

def _slow_fn(x):
with open(os.path.join(tmpdirname, str(os.getpid())), 'w') as pid_fh:
pid_fh.write('anything')
time.sleep(2)
return (x, 'str')
dp = dp.map(_slow_fn)
dp = dp.end_caching(mode="t", filepath_fn=lambda x: x)
dp = FileOpener(dp)
dp = StreamReader(dp)
dl = DataLoader(dp, num_workers=10, multiprocessing_context="spawn", batch_size=1, collate_fn=_unbatch)
result = list(dl)
all_files = []
for (_, _, filenames) in os.walk(tmpdirname):
all_files += filenames
# We expect only two files, one with pid and 'downloaded' one
self.assertEquals(2, len(all_files))
self.assertEquals('str', result[0][1])

# TODO(120): this test currently only covers reading from local
# filesystem. It needs to be modified once test data can be stored on
# gdrive/s3/onedrive
Expand Down
Empty file added todo.py
Empty file.
43 changes: 33 additions & 10 deletions torchdata/datapipes/iter/util/cacheholder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import hashlib
import inspect
import os.path
import portalocker
import sys
import time

from collections import deque
from functools import partial
Expand Down Expand Up @@ -106,7 +108,7 @@ def _hash_check(filepath, hash_dict, hash_type):
else:
hash_func = hashlib.md5()

with open(filepath, "rb") as f:
with portalocker.Lock(filepath, "rb") as f:
chunk = f.read(1024 ** 2)
while chunk:
hash_func.update(chunk)
Expand Down Expand Up @@ -145,7 +147,7 @@ class OnDiskCacheHolderIterDataPipe(IterDataPipe):
>>> hash_dict = {"expected_filepath": expected_MD5_hash}
>>> cache_dp = url.on_disk_cache(filepath_fn=_filepath_fn, hash_dict=_hash_dict, hash_type="md5")
>>> # You must call ``.end_caching`` at a later point to stop tracing and save the results to local files.
>>> cache_dp = HttpReader(cache_dp).end_caching(mode="wb". filepath_fn=_filepath_fn)
>>> cache_dp = HttpReader(cache_dp).end_caching(mode="wb", filepath_fn=_filepath_fn)
"""

_temp_dict: Dict = {}
Expand Down Expand Up @@ -184,22 +186,31 @@ def __add__(self, other_datapipe):
@staticmethod
def _cache_check_fn(data, filepath_fn, hash_dict, hash_type, extra_check_fn):
filepaths = data if filepath_fn is None else filepath_fn(data)
result = True
if not isinstance(filepaths, (list, tuple)):
filepaths = [
filepaths,
]

for filepath in filepaths:
if not os.path.exists(filepath):
return False
promise_filepath = filepath + '.promise'
if not os.path.exists(promise_filepath):
if not os.path.exists(filepath):
with portalocker.Lock(promise_filepath, 'w') as fh:
fh.write('!')
result = False

if hash_dict is not None and not _hash_check(filepath, hash_dict, hash_type):
return False
elif hash_dict is not None and not _hash_check(filepath, hash_dict, hash_type):
with portalocker.Lock(promise_filepath, 'w') as fh:
fh.write('!')
result = False

if extra_check_fn is not None and not extra_check_fn(filepath):
return False
elif extra_check_fn is not None and not extra_check_fn(filepath):
with portalocker.Lock(promise_filepath, 'w') as fh:
fh.write('!')
result = False

return True
return result

def _end_caching(self):
filepath_fn, hash_dict, hash_type, extra_check_fn = OnDiskCacheHolderIterDataPipe._temp_dict.pop(self)
Expand Down Expand Up @@ -231,6 +242,16 @@ def _read_bytes(fd):
def _read_str(fd):
return "".join(fd)

def _wait_promise_fn(filename):
promise_filename = filename + '.promise'
while os.path.exists(promise_filename):
time.sleep(0.01)
return filename

def _promise_fulfilled_fn(filename):
promise_filename = filename + '.promise'
os.unlink(promise_filename)
return filename

@functional_datapipe("end_caching")
class EndOnDiskCacheHolderIterDataPipe(IterDataPipe):
Expand Down Expand Up @@ -259,7 +280,7 @@ class EndOnDiskCacheHolderIterDataPipe(IterDataPipe):
>>> # You must call ``.on_disk_cache`` at some point before ``.end_caching``
>>> cache_dp = url.on_disk_cache(filepath_fn=_filepath_fn, hash_dict=_hash_dict, hash_type="md5")
>>> # You must call ``.end_caching`` at a later point to stop tracing and save the results to local files.
>>> cache_dp = HttpReader(cache_dp).end_caching(mode="wb". filepath_fn=_filepath_fn)
>>> cache_dp = HttpReader(cache_dp).end_caching(mode="wb", filepath_fn=_filepath_fn)
"""

def __new__(cls, datapipe, mode="wb", filepath_fn=None, *, same_filepath_fn=False, skip_read=False):
Expand All @@ -276,6 +297,7 @@ def __new__(cls, datapipe, mode="wb", filepath_fn=None, *, same_filepath_fn=Fals

_filepath_fn, _hash_dict, _hash_type, _ = OnDiskCacheHolderIterDataPipe._temp_dict[cache_holder]
cached_dp = cache_holder._end_caching()
cached_dp = cached_dp.map(_wait_promise_fn)
cached_dp = FileLister(cached_dp, recursive=True)

if same_filepath_fn:
Expand All @@ -297,6 +319,7 @@ def __new__(cls, datapipe, mode="wb", filepath_fn=None, *, same_filepath_fn=Fals
todo_dp = todo_dp.check_hash(_hash_dict, _hash_type)

todo_dp = todo_dp.save_to_disk(mode=mode)
todo_dp = todo_dp.map(_promise_fulfilled_fn)

return cached_dp.concat(todo_dp)

Expand Down
3 changes: 2 additions & 1 deletion torchdata/datapipes/iter/util/saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import os
import portalocker

from typing import Any, Callable, Iterator, Optional, Tuple, Union

Expand Down Expand Up @@ -56,7 +57,7 @@ def __iter__(self) -> Iterator[str]:
dirname = os.path.dirname(filepath)
if not os.path.exists(dirname):
os.makedirs(dirname)
with open(filepath, self.mode) as f:
with portalocker.Lock(filepath, self.mode) as f:
f.write(data)
yield filepath

Expand Down

0 comments on commit e21c304

Please sign in to comment.