Skip to content

Commit

Permalink
added filelock in IoPathSaverIterDataPipe (#413)
Browse files Browse the repository at this point in the history
Summary:
Please read through our [contribution guide](https://github.com/pytorch/data/blob/main/CONTRIBUTING.md) prior to
creating your pull request.

- Note that there is a section on requirements related to adding a new DataPipe.

Fixes #397

This is an identical PR to #395 which I accidentally deleted during a bad force push trying to resolve a rebase

Pull Request resolved: #413

Reviewed By: ejguan

Differential Revision: D36433231

Pulled By: msaroufim

fbshipit-source-id: 0f71790eaded445f8e42911db4bea4c468175520
  • Loading branch information
msaroufim authored and facebook-github-bot committed May 17, 2022
1 parent 2c7a913 commit 5028134
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 24 deletions.
79 changes: 57 additions & 22 deletions test/test_local_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import unittest
import warnings
import zipfile
from functools import partial

from json.decoder import JSONDecodeError

Expand Down Expand Up @@ -44,6 +45,7 @@

try:
import iopath
import torch

HAS_IOPATH = True
except ImportError:
Expand All @@ -64,6 +66,17 @@
skipIfNoRarTools = unittest.skipIf(not HAS_RAR_TOOLS, "no rar tools")


def filepath_fn(temp_dir_name, name: str) -> str:
return os.path.join(temp_dir_name, os.path.basename(name))


def init_fn(worker_id):
info = torch.utils.data.get_worker_info()
num_workers = info.num_workers
datapipe = info.dataset
torch.utils.data.graph_settings.apply_sharding(datapipe, num_workers, worker_id)


class TestDataPipeLocalIO(expecttest.TestCase):
def setUp(self):
self.temp_dir = create_temp_dir()
Expand Down Expand Up @@ -280,29 +293,28 @@ def is_nonempty_json(path_and_stream):
len(json_dp)

def test_saver_iterdatapipe(self):
def filepath_fn(name: str) -> str:
return os.path.join(self.temp_dir.name, os.path.basename(name))

# Functional Test: Saving some data
name_to_data = {"1.txt": b"DATA1", "2.txt": b"DATA2", "3.txt": b"DATA3"}
source_dp = IterableWrapper(sorted(name_to_data.items()))
saver_dp = source_dp.save_to_disk(filepath_fn=filepath_fn, mode="wb")
saver_dp = source_dp.save_to_disk(filepath_fn=partial(filepath_fn, self.temp_dir.name), mode="wb")
res_file_paths = list(saver_dp)
expected_paths = [filepath_fn(name) for name in name_to_data.keys()]
expected_paths = [filepath_fn(self.temp_dir.name, name) for name in name_to_data.keys()]
self.assertEqual(expected_paths, res_file_paths)
for name in name_to_data.keys():
p = filepath_fn(name)
p = filepath_fn(self.temp_dir.name, name)
with open(p) as f:
self.assertEqual(name_to_data[name], f.read().encode())

# Reset Test:
saver_dp = Saver(source_dp, filepath_fn=filepath_fn, mode="wb")
saver_dp = Saver(source_dp, filepath_fn=partial(filepath_fn, self.temp_dir.name), mode="wb")
n_elements_before_reset = 2
res_before_reset, res_after_reset = reset_after_n_next_calls(saver_dp, n_elements_before_reset)
self.assertEqual([filepath_fn("1.txt"), filepath_fn("2.txt")], res_before_reset)
self.assertEqual(
[filepath_fn(self.temp_dir.name, "1.txt"), filepath_fn(self.temp_dir.name, "2.txt")], res_before_reset
)
self.assertEqual(expected_paths, res_after_reset)
for name in name_to_data.keys():
p = filepath_fn(name)
p = filepath_fn(self.temp_dir.name, name)
with open(p) as f:
self.assertEqual(name_to_data[name], f.read().encode())

Expand Down Expand Up @@ -582,12 +594,9 @@ def test_decompressor_iterdatapipe(self):
len(tar_decompress_dp)

def _write_text_files(self):
def filepath_fn(name: str) -> str:
return os.path.join(self.temp_dir.name, os.path.basename(name))

name_to_data = {"1.text": b"DATA", "2.text": b"DATA", "3.text": b"DATA"}
source_dp = IterableWrapper(sorted(name_to_data.items()))
saver_dp = source_dp.save_to_disk(filepath_fn=filepath_fn, mode="wb")
saver_dp = source_dp.save_to_disk(filepath_fn=partial(filepath_fn, self.temp_dir.name), mode="wb")
list(saver_dp)

# TODO(120): this test currently only covers reading from local
Expand Down Expand Up @@ -626,35 +635,61 @@ def test_io_path_file_loader_iterdatapipe(self):

@skipIfNoIoPath
def test_io_path_saver_iterdatapipe(self):
def filepath_fn(name: str) -> str:
return os.path.join(self.temp_dir.name, os.path.basename(name))

# Functional Test: Saving some data
name_to_data = {"1.txt": b"DATA1", "2.txt": b"DATA2", "3.txt": b"DATA3"}
source_dp = IterableWrapper(sorted(name_to_data.items()))
saver_dp = source_dp.save_by_iopath(filepath_fn=filepath_fn, mode="wb")
saver_dp = source_dp.save_by_iopath(filepath_fn=partial(filepath_fn, self.temp_dir.name), mode="wb")
res_file_paths = list(saver_dp)
expected_paths = [filepath_fn(name) for name in name_to_data.keys()]
expected_paths = [filepath_fn(self.temp_dir.name, name) for name in name_to_data.keys()]
self.assertEqual(expected_paths, res_file_paths)
for name in name_to_data.keys():
p = filepath_fn(name)
p = filepath_fn(self.temp_dir.name, name)
with open(p) as f:
self.assertEqual(name_to_data[name], f.read().encode())

# Reset Test:
saver_dp = IoPathSaver(source_dp, filepath_fn=filepath_fn, mode="wb")
saver_dp = IoPathSaver(source_dp, filepath_fn=partial(filepath_fn, self.temp_dir.name), mode="wb")
n_elements_before_reset = 2
res_before_reset, res_after_reset = reset_after_n_next_calls(saver_dp, n_elements_before_reset)
self.assertEqual([filepath_fn("1.txt"), filepath_fn("2.txt")], res_before_reset)
self.assertEqual(
[filepath_fn(self.temp_dir.name, "1.txt"), filepath_fn(self.temp_dir.name, "2.txt")], res_before_reset
)
self.assertEqual(expected_paths, res_after_reset)
for name in name_to_data.keys():
p = filepath_fn(name)
p = filepath_fn(self.temp_dir.name, name)
with open(p) as f:
self.assertEqual(name_to_data[name], f.read().encode())

# __len__ Test: returns the length of source DataPipe
self.assertEqual(3, len(saver_dp))

@skipIfNoIoPath
def test_io_path_saver_file_lock(self):
# Same filename with different name
name_to_data = {"1.txt": b"DATA1", "1.txt": b"DATA2", "2.txt": b"DATA3", "2.txt": b"DATA4"} # noqa: F601

# Add sharding_filter to shard data into 2
source_dp = IterableWrapper(list(name_to_data.items())).sharding_filter()

# Use appending as the mode
saver_dp = source_dp.save_by_iopath(filepath_fn=partial(filepath_fn, self.temp_dir.name), mode="ab")

import torch.utils.data.graph_settings

from torch.utils.data import DataLoader

num_workers = 2
line_lengths = []
dl = DataLoader(saver_dp, num_workers=num_workers, worker_init_fn=init_fn, multiprocessing_context="spawn")
for filename in dl:
with open(filename[0]) as f:
lines = f.readlines()
x = len(lines)
line_lengths.append(x)
self.assertEqual(x, 1)

self.assertEqual(num_workers, len(line_lengths))

def _write_test_rar_files(self):
# `rarfile` can only read but not write .rar archives so we use to system utilities
rar_archive_name = os.path.join(self.temp_dir.name, "test_rar")
Expand Down
7 changes: 5 additions & 2 deletions torchdata/datapipes/iter/load/iopath.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def __len__(self) -> int:

@functional_datapipe("save_by_iopath")
class IoPathSaverIterDataPipe(IterDataPipe[str]):

r"""
Takes in a DataPipe of tuples of metadata and data, saves the data
to the target path which is generated by the ``filepath_fn`` and metadata, and yields the resulting path
Expand Down Expand Up @@ -183,8 +184,10 @@ def __init__(
def __iter__(self) -> Iterator[str]:
for meta, data in self.source_datapipe:
filepath = meta if self.filepath_fn is None else self.filepath_fn(meta)
with self.pathmgr.open(filepath, self.mode) as f:
f.write(data)
with iopath.file_lock(filepath):
if not os.path.exists(filepath):
with self.pathmgr.open(filepath, self.mode) as f:
f.write(data)
yield filepath

def register_handler(self, handler, allow_override=False):
Expand Down

0 comments on commit 5028134

Please sign in to comment.