Skip to content

Commit

Permalink
Adding lock mechanism to prevent on_disk_cache downloading twice
Browse files Browse the repository at this point in the history
ghstack-source-id: 2fb276289021ba8d47f60bc7480f6bddcc58b0fd
Pull Request resolved: #409
  • Loading branch information
VitalyFedyunin committed May 16, 2022
1 parent 3c77696 commit 1233eac
Show file tree
Hide file tree
Showing 8 changed files with 157 additions and 11 deletions.
34 changes: 34 additions & 0 deletions test/test_local_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import os
import subprocess
import tarfile
import tempfile
import time
import unittest
import warnings
import zipfile
Expand All @@ -33,15 +35,19 @@
IoPathFileOpener,
IoPathSaver,
IterableWrapper,
IterDataPipe,
JsonParser,
RarArchiveLoader,
Saver,
StreamReader,
TarArchiveLoader,
WebDataset,
XzFileLoader,
ZipArchiveLoader,
)

from torch.utils.data import DataLoader

try:
import iopath

Expand All @@ -64,6 +70,10 @@
skipIfNoRarTools = unittest.skipIf(not HAS_RAR_TOOLS, "no rar tools")


def _unbatch(x):
return x[0]


class TestDataPipeLocalIO(expecttest.TestCase):
def setUp(self):
self.temp_dir = create_temp_dir()
Expand Down Expand Up @@ -590,6 +600,30 @@ def filepath_fn(name: str) -> str:
saver_dp = source_dp.save_to_disk(filepath_fn=filepath_fn, mode="wb")
list(saver_dp)

def test_disk_cache_locks(self):
with tempfile.TemporaryDirectory() as tmpdirname:
file_name = os.path.join(tmpdirname, 'test.bin')
dp = IterableWrapper([file_name])
dp = dp.on_disk_cache(filepath_fn=lambda x: x)

def _slow_fn(x):
with open(os.path.join(tmpdirname, str(os.getpid())), 'w') as pid_fh:
pid_fh.write('anything')
time.sleep(2)
return (x, 'str')
dp = dp.map(_slow_fn)
dp = dp.end_caching(mode="t", filepath_fn=lambda x: x)
dp = FileOpener(dp)
dp = StreamReader(dp)
dl = DataLoader(dp, num_workers=10, multiprocessing_context="spawn", batch_size=1, collate_fn=_unbatch)
result = list(dl)
all_files = []
for (_, _, filenames) in os.walk(tmpdirname):
all_files += filenames
# We expect only two files, one with pid and 'downloaded' one
self.assertEquals(2, len(all_files))
self.assertEquals('str', result[0][1])

# TODO(120): this test currently only covers reading from local
# filesystem. It needs to be modified once test data can be stored on
# gdrive/s3/onedrive
Expand Down
Empty file added test_datapipe.py
Empty file.
Empty file added test_fsspec.py
Empty file.
Empty file added test_remote_io.py
Empty file.
Empty file added todo.py
Empty file.
43 changes: 33 additions & 10 deletions torchdata/datapipes/iter/util/cacheholder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import hashlib
import inspect
import os.path
import portalocker
import sys
import time

from collections import deque
from functools import partial
Expand Down Expand Up @@ -106,7 +108,7 @@ def _hash_check(filepath, hash_dict, hash_type):
else:
hash_func = hashlib.md5()

with open(filepath, "rb") as f:
with portalocker.Lock(filepath, "rb") as f:
chunk = f.read(1024 ** 2)
while chunk:
hash_func.update(chunk)
Expand Down Expand Up @@ -145,7 +147,7 @@ class OnDiskCacheHolderIterDataPipe(IterDataPipe):
>>> hash_dict = {"expected_filepath": expected_MD5_hash}
>>> cache_dp = url.on_disk_cache(filepath_fn=_filepath_fn, hash_dict=_hash_dict, hash_type="md5")
>>> # You must call ``.end_caching`` at a later point to stop tracing and save the results to local files.
>>> cache_dp = HttpReader(cache_dp).end_caching(mode="wb". filepath_fn=_filepath_fn)
>>> cache_dp = HttpReader(cache_dp).end_caching(mode="wb", filepath_fn=_filepath_fn)
"""

_temp_dict: Dict = {}
Expand Down Expand Up @@ -184,22 +186,31 @@ def __add__(self, other_datapipe):
@staticmethod
def _cache_check_fn(data, filepath_fn, hash_dict, hash_type, extra_check_fn):
filepaths = data if filepath_fn is None else filepath_fn(data)
result = True
if not isinstance(filepaths, (list, tuple)):
filepaths = [
filepaths,
]

for filepath in filepaths:
if not os.path.exists(filepath):
return False
promise_filepath = filepath + '.promise'
if not os.path.exists(promise_filepath):
if not os.path.exists(filepath):
with portalocker.Lock(promise_filepath, 'w') as fh:
fh.write('!')
result = False

if hash_dict is not None and not _hash_check(filepath, hash_dict, hash_type):
return False
elif hash_dict is not None and not _hash_check(filepath, hash_dict, hash_type):
with portalocker.Lock(promise_filepath, 'w') as fh:
fh.write('!')
result = False

if extra_check_fn is not None and not extra_check_fn(filepath):
return False
elif extra_check_fn is not None and not extra_check_fn(filepath):
with portalocker.Lock(promise_filepath, 'w') as fh:
fh.write('!')
result = False

return True
return result

def _end_caching(self):
filepath_fn, hash_dict, hash_type, extra_check_fn = OnDiskCacheHolderIterDataPipe._temp_dict.pop(self)
Expand Down Expand Up @@ -231,6 +242,16 @@ def _read_bytes(fd):
def _read_str(fd):
return "".join(fd)

def _wait_promise_fn(filename):
promise_filename = filename + '.promise'
while os.path.exists(promise_filename):
time.sleep(0.01)
return filename

def _promise_fulfilled_fn(filename):
promise_filename = filename + '.promise'
os.unlink(promise_filename)
return filename

@functional_datapipe("end_caching")
class EndOnDiskCacheHolderIterDataPipe(IterDataPipe):
Expand Down Expand Up @@ -259,7 +280,7 @@ class EndOnDiskCacheHolderIterDataPipe(IterDataPipe):
>>> # You must call ``.on_disk_cache`` at some point before ``.end_caching``
>>> cache_dp = url.on_disk_cache(filepath_fn=_filepath_fn, hash_dict=_hash_dict, hash_type="md5")
>>> # You must call ``.end_caching`` at a later point to stop tracing and save the results to local files.
>>> cache_dp = HttpReader(cache_dp).end_caching(mode="wb". filepath_fn=_filepath_fn)
>>> cache_dp = HttpReader(cache_dp).end_caching(mode="wb", filepath_fn=_filepath_fn)
"""

def __new__(cls, datapipe, mode="wb", filepath_fn=None, *, same_filepath_fn=False, skip_read=False):
Expand All @@ -276,6 +297,7 @@ def __new__(cls, datapipe, mode="wb", filepath_fn=None, *, same_filepath_fn=Fals

_filepath_fn, _hash_dict, _hash_type, _ = OnDiskCacheHolderIterDataPipe._temp_dict[cache_holder]
cached_dp = cache_holder._end_caching()
cached_dp = cached_dp.map(_wait_promise_fn)
cached_dp = FileLister(cached_dp, recursive=True)

if same_filepath_fn:
Expand All @@ -297,6 +319,7 @@ def __new__(cls, datapipe, mode="wb", filepath_fn=None, *, same_filepath_fn=Fals
todo_dp = todo_dp.check_hash(_hash_dict, _hash_type)

todo_dp = todo_dp.save_to_disk(mode=mode)
todo_dp = todo_dp.map(_promise_fulfilled_fn)

return cached_dp.concat(todo_dp)

Expand Down
3 changes: 2 additions & 1 deletion torchdata/datapipes/iter/util/saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import os
import portalocker

from typing import Any, Callable, Iterator, Optional, Tuple, Union

Expand Down Expand Up @@ -56,7 +57,7 @@ def __iter__(self) -> Iterator[str]:
dirname = os.path.dirname(filepath)
if not os.path.exists(dirname):
os.makedirs(dirname)
with open(filepath, self.mode) as f:
with portalocker.Lock(filepath, self.mode) as f:
f.write(data)
yield filepath

Expand Down
88 changes: 88 additions & 0 deletions util/todo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from github import Github # pip install PyGithub
import sys
import tempfile
import shutil
import os
import re

file_name = sys.argv[1]

GITHUB_KEY = "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"

def get_git_branch_hash():
stream = os.popen("git rev-parse origin/main")
# output =
return stream.read().rstrip()

# def find_owner(file_name, line_number):
# command = "git blame {file_name}".format(file_name=file_name)
# print(command)
# stream = os.popen(command)
# for line_n, line in enumerate(stream.readlines()):
# print(line)
# if line_n == line_number:
# print("I blame". line)

def generate_issue_id(id_or_name, title, file_name, line_number):
git_branch_hash = get_git_branch_hash()
# print(git_branch_hash)
match = re.match(r'\((\d+)\)', id_or_name)
if match:
return int(match.group(1))
match = re.match('\((.*)\)', id_or_name)
if match:
cc = "cc @{}".format(match.group(1))
else:
cc = ""

# find_owner(file_name, line_number)
# name = match.group(1)
g = Github(GITHUB_KEY)
repo = g.get_repo("pytorch/data")

label_todo = repo.get_label("todo")
# label_porting = repo.get_label("topic: porting" )
# label_operators = repo.get_label("module: operators" )
# label_be = repo.get_label("better-engineering" )

labels = [label_todo]

body = """
This issue is generated from the TODO line
https://github.com/pytorch/data/blob/{git_branch_hash}/{file_name}#L{line_number}
{cc}
""".format(cc = cc, git_branch_hash= git_branch_hash, line_number=line_number+1,file_name=file_name)
# print(body)
# print(title)
title = "[TODO] {}".format(title)
issue = repo.create_issue(title=title, body=body, labels = labels)
print(issue)
# die
return issue.number

def update_file(file_name):
try:
f = tempfile.NamedTemporaryFile(delete=False)
shutil.copyfile(file_name, f.name)
with open(f.name, "r") as f_inp:
with open(file_name, "w") as f_out:
for line_number, line in enumerate(f_inp.readlines()):
if not re.search(r'ignore-todo', line, re.IGNORECASE):
match = re.search(r'(.*?)#\s*todo\s*(\([^)]+\)){0,1}:{0,1}\s*(.*)', line, re.IGNORECASE)
# print(line)
if match:
prefix = match.group(1)
text = match.group(3)
issue_id = generate_issue_id(str(match.group(2)),text, file_name, line_number)
line = "{}# TODO({}): {}\n".format(prefix, issue_id, text) # ignore-todo
f_out.write(line)
except Exception as e:
shutil.copyfile(f.name, file_name)
print(e)
finally:
os.unlink(f.name)
file_name = os.path.normpath(file_name)
# print('processing ', file_name)
update_file(file_name)


0 comments on commit 1233eac

Please sign in to comment.