From bc177d6083a4e767970f7fcaae03cc90213d402b Mon Sep 17 00:00:00 2001 From: Andrew Audibert Date: Mon, 6 Jan 2025 10:35:13 -0800 Subject: [PATCH] #pygrain Fix segfault when starting multiple mp_prefetches concurrently. Parallel calls to `start_prefetch` result in concurrent calls to `SharedMemoryArray.enable_async_del`, which attempt to concurrently modify class-level state. Before the fix, the added unit test segfaults about 10% of the time, and the repro in experimental/ segfaults consistently. Using a higher number of iterators in the unit test results in forge OOMs. After the fix, the test passes with --runs_per_test=1000 PiperOrigin-RevId: 712578730 --- .../dataset/transformations/prefetch_test.py | 17 +++++++++++++++++ grain/_src/python/shared_memory_array.py | 8 +++++--- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/grain/_src/python/dataset/transformations/prefetch_test.py b/grain/_src/python/dataset/transformations/prefetch_test.py index db76c0a7..9bcd1eff 100644 --- a/grain/_src/python/dataset/transformations/prefetch_test.py +++ b/grain/_src/python/dataset/transformations/prefetch_test.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from concurrent import futures import dataclasses import sys import time @@ -661,6 +662,22 @@ def test_fails_with_negative_prefetch_buffer_size(self): ): prefetch.ThreadPrefetchIterDataset(self.ds, prefetch_buffer_size=-1) + def test_concurrent_start_prefetch(self): + num_iters = 10 # Can't set this much higher without Forge OOMing. + + def make_iter(i): + ds = dataset.MapDataset.source([i]) + ds = ds.to_iter_dataset() + ds = ds.mp_prefetch(options=options.MultiprocessingOptions(num_workers=1)) + return ds.__iter__() + + iters = [make_iter(i) for i in range(num_iters)] + with futures.ThreadPoolExecutor(max_workers=num_iters) as executor: + for it in iters: + executor.submit(it.start_prefetch) + for it in iters: + _ = next(it) + if __name__ == '__main__': absltest.main() diff --git a/grain/_src/python/shared_memory_array.py b/grain/_src/python/shared_memory_array.py index d65927fd..16e7333f 100644 --- a/grain/_src/python/shared_memory_array.py +++ b/grain/_src/python/shared_memory_array.py @@ -58,6 +58,7 @@ class SharedMemoryArray(np.ndarray): the memory will not be freed. """ + _lock: threading.Lock = threading.Lock() _unlink_thread_pool: pool.ThreadPool | None = None _unlink_semaphore: threading.Semaphore | None = None @@ -121,9 +122,10 @@ def __reduce_ex__(self, protocol): @classmethod def enable_async_del(cls, num_threads: int = 1) -> None: - if not SharedMemoryArray._unlink_thread_pool: - SharedMemoryArray._unlink_thread_pool = pool.ThreadPool(num_threads) - SharedMemoryArray._unlink_semaphore = threading.Semaphore(num_threads) + with cls._lock: + if not SharedMemoryArray._unlink_thread_pool: + SharedMemoryArray._unlink_thread_pool = pool.ThreadPool(num_threads) + SharedMemoryArray._unlink_semaphore = threading.Semaphore(num_threads) def unlink_on_del(self) -> None: """Mark this object responsible for unlinking the shared memory."""