From 5028134614243820e21566ea2d13cb0a5d769e8f Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Tue, 17 May 2022 02:34:23 -0700 Subject: [PATCH] added `filelock` in `IoPathSaverIterDataPipe` (#413) Summary: Please read through our [contribution guide](https://github.com/pytorch/data/blob/main/CONTRIBUTING.md) prior to creating your pull request. - Note that there is a section on requirements related to adding a new DataPipe. Fixes https://github.com/pytorch/data/issues/397 This is an identical PR to https://github.com/pytorch/data/pull/395 which I accidentally deleted during a bad force push trying to resolve a rebase Pull Request resolved: https://github.com/pytorch/data/pull/413 Reviewed By: ejguan Differential Revision: D36433231 Pulled By: msaroufim fbshipit-source-id: 0f71790eaded445f8e42911db4bea4c468175520 --- test/test_local_io.py | 79 ++++++++++++++++++------- torchdata/datapipes/iter/load/iopath.py | 7 ++- 2 files changed, 62 insertions(+), 24 deletions(-) diff --git a/test/test_local_io.py b/test/test_local_io.py index f5a0cda61..6b62939d4 100644 --- a/test/test_local_io.py +++ b/test/test_local_io.py @@ -15,6 +15,7 @@ import unittest import warnings import zipfile +from functools import partial from json.decoder import JSONDecodeError @@ -44,6 +45,7 @@ try: import iopath + import torch HAS_IOPATH = True except ImportError: @@ -64,6 +66,17 @@ skipIfNoRarTools = unittest.skipIf(not HAS_RAR_TOOLS, "no rar tools") +def filepath_fn(temp_dir_name, name: str) -> str: + return os.path.join(temp_dir_name, os.path.basename(name)) + + +def init_fn(worker_id): + info = torch.utils.data.get_worker_info() + num_workers = info.num_workers + datapipe = info.dataset + torch.utils.data.graph_settings.apply_sharding(datapipe, num_workers, worker_id) + + class TestDataPipeLocalIO(expecttest.TestCase): def setUp(self): self.temp_dir = create_temp_dir() @@ -280,29 +293,28 @@ def is_nonempty_json(path_and_stream): len(json_dp) def test_saver_iterdatapipe(self): - def filepath_fn(name: str) -> str: - return os.path.join(self.temp_dir.name, os.path.basename(name)) - # Functional Test: Saving some data name_to_data = {"1.txt": b"DATA1", "2.txt": b"DATA2", "3.txt": b"DATA3"} source_dp = IterableWrapper(sorted(name_to_data.items())) - saver_dp = source_dp.save_to_disk(filepath_fn=filepath_fn, mode="wb") + saver_dp = source_dp.save_to_disk(filepath_fn=partial(filepath_fn, self.temp_dir.name), mode="wb") res_file_paths = list(saver_dp) - expected_paths = [filepath_fn(name) for name in name_to_data.keys()] + expected_paths = [filepath_fn(self.temp_dir.name, name) for name in name_to_data.keys()] self.assertEqual(expected_paths, res_file_paths) for name in name_to_data.keys(): - p = filepath_fn(name) + p = filepath_fn(self.temp_dir.name, name) with open(p) as f: self.assertEqual(name_to_data[name], f.read().encode()) # Reset Test: - saver_dp = Saver(source_dp, filepath_fn=filepath_fn, mode="wb") + saver_dp = Saver(source_dp, filepath_fn=partial(filepath_fn, self.temp_dir.name), mode="wb") n_elements_before_reset = 2 res_before_reset, res_after_reset = reset_after_n_next_calls(saver_dp, n_elements_before_reset) - self.assertEqual([filepath_fn("1.txt"), filepath_fn("2.txt")], res_before_reset) + self.assertEqual( + [filepath_fn(self.temp_dir.name, "1.txt"), filepath_fn(self.temp_dir.name, "2.txt")], res_before_reset + ) self.assertEqual(expected_paths, res_after_reset) for name in name_to_data.keys(): - p = filepath_fn(name) + p = filepath_fn(self.temp_dir.name, name) with open(p) as f: self.assertEqual(name_to_data[name], f.read().encode()) @@ -582,12 +594,9 @@ def test_decompressor_iterdatapipe(self): len(tar_decompress_dp) def _write_text_files(self): - def filepath_fn(name: str) -> str: - return os.path.join(self.temp_dir.name, os.path.basename(name)) - name_to_data = {"1.text": b"DATA", "2.text": b"DATA", "3.text": b"DATA"} source_dp = IterableWrapper(sorted(name_to_data.items())) - saver_dp = source_dp.save_to_disk(filepath_fn=filepath_fn, mode="wb") + saver_dp = source_dp.save_to_disk(filepath_fn=partial(filepath_fn, self.temp_dir.name), mode="wb") list(saver_dp) # TODO(120): this test currently only covers reading from local @@ -626,35 +635,61 @@ def test_io_path_file_loader_iterdatapipe(self): @skipIfNoIoPath def test_io_path_saver_iterdatapipe(self): - def filepath_fn(name: str) -> str: - return os.path.join(self.temp_dir.name, os.path.basename(name)) - # Functional Test: Saving some data name_to_data = {"1.txt": b"DATA1", "2.txt": b"DATA2", "3.txt": b"DATA3"} source_dp = IterableWrapper(sorted(name_to_data.items())) - saver_dp = source_dp.save_by_iopath(filepath_fn=filepath_fn, mode="wb") + saver_dp = source_dp.save_by_iopath(filepath_fn=partial(filepath_fn, self.temp_dir.name), mode="wb") res_file_paths = list(saver_dp) - expected_paths = [filepath_fn(name) for name in name_to_data.keys()] + expected_paths = [filepath_fn(self.temp_dir.name, name) for name in name_to_data.keys()] self.assertEqual(expected_paths, res_file_paths) for name in name_to_data.keys(): - p = filepath_fn(name) + p = filepath_fn(self.temp_dir.name, name) with open(p) as f: self.assertEqual(name_to_data[name], f.read().encode()) # Reset Test: - saver_dp = IoPathSaver(source_dp, filepath_fn=filepath_fn, mode="wb") + saver_dp = IoPathSaver(source_dp, filepath_fn=partial(filepath_fn, self.temp_dir.name), mode="wb") n_elements_before_reset = 2 res_before_reset, res_after_reset = reset_after_n_next_calls(saver_dp, n_elements_before_reset) - self.assertEqual([filepath_fn("1.txt"), filepath_fn("2.txt")], res_before_reset) + self.assertEqual( + [filepath_fn(self.temp_dir.name, "1.txt"), filepath_fn(self.temp_dir.name, "2.txt")], res_before_reset + ) self.assertEqual(expected_paths, res_after_reset) for name in name_to_data.keys(): - p = filepath_fn(name) + p = filepath_fn(self.temp_dir.name, name) with open(p) as f: self.assertEqual(name_to_data[name], f.read().encode()) # __len__ Test: returns the length of source DataPipe self.assertEqual(3, len(saver_dp)) + @skipIfNoIoPath + def test_io_path_saver_file_lock(self): + # Same filename with different name + name_to_data = {"1.txt": b"DATA1", "1.txt": b"DATA2", "2.txt": b"DATA3", "2.txt": b"DATA4"} # noqa: F601 + + # Add sharding_filter to shard data into 2 + source_dp = IterableWrapper(list(name_to_data.items())).sharding_filter() + + # Use appending as the mode + saver_dp = source_dp.save_by_iopath(filepath_fn=partial(filepath_fn, self.temp_dir.name), mode="ab") + + import torch.utils.data.graph_settings + + from torch.utils.data import DataLoader + + num_workers = 2 + line_lengths = [] + dl = DataLoader(saver_dp, num_workers=num_workers, worker_init_fn=init_fn, multiprocessing_context="spawn") + for filename in dl: + with open(filename[0]) as f: + lines = f.readlines() + x = len(lines) + line_lengths.append(x) + self.assertEqual(x, 1) + + self.assertEqual(num_workers, len(line_lengths)) + def _write_test_rar_files(self): # `rarfile` can only read but not write .rar archives so we use to system utilities rar_archive_name = os.path.join(self.temp_dir.name, "test_rar") diff --git a/torchdata/datapipes/iter/load/iopath.py b/torchdata/datapipes/iter/load/iopath.py index 57feea52a..2ce8b2b19 100644 --- a/torchdata/datapipes/iter/load/iopath.py +++ b/torchdata/datapipes/iter/load/iopath.py @@ -135,6 +135,7 @@ def __len__(self) -> int: @functional_datapipe("save_by_iopath") class IoPathSaverIterDataPipe(IterDataPipe[str]): + r""" Takes in a DataPipe of tuples of metadata and data, saves the data to the target path which is generated by the ``filepath_fn`` and metadata, and yields the resulting path @@ -183,8 +184,10 @@ def __init__( def __iter__(self) -> Iterator[str]: for meta, data in self.source_datapipe: filepath = meta if self.filepath_fn is None else self.filepath_fn(meta) - with self.pathmgr.open(filepath, self.mode) as f: - f.write(data) + with iopath.file_lock(filepath): + if not os.path.exists(filepath): + with self.pathmgr.open(filepath, self.mode) as f: + f.write(data) yield filepath def register_handler(self, handler, allow_override=False):