Skip to content

Commit

Permalink
#pygrain Fix segfault when starting multiple mp_prefetches concurrently.
Browse files Browse the repository at this point in the history
Parallel calls to `start_prefetch` result in concurrent calls to `SharedMemoryArray.enable_async_del`, which attempt to concurrently modify class-level state. Before the fix, the added unit test segfaults about 10% of the time, and the repro in experimental/ segfaults consistently. Using a higher number of iterators in the unit test results in forge OOMs. After the fix, the test passes with --runs_per_test=1000

PiperOrigin-RevId: 712578730
  • Loading branch information
aaudiber authored and copybara-github committed Jan 6, 2025
1 parent 3099aec commit bc177d6
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 3 deletions.
17 changes: 17 additions & 0 deletions grain/_src/python/dataset/transformations/prefetch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from concurrent import futures
import dataclasses
import sys
import time
Expand Down Expand Up @@ -661,6 +662,22 @@ def test_fails_with_negative_prefetch_buffer_size(self):
):
prefetch.ThreadPrefetchIterDataset(self.ds, prefetch_buffer_size=-1)

def test_concurrent_start_prefetch(self):
num_iters = 10 # Can't set this much higher without Forge OOMing.

def make_iter(i):
ds = dataset.MapDataset.source([i])
ds = ds.to_iter_dataset()
ds = ds.mp_prefetch(options=options.MultiprocessingOptions(num_workers=1))
return ds.__iter__()

iters = [make_iter(i) for i in range(num_iters)]
with futures.ThreadPoolExecutor(max_workers=num_iters) as executor:
for it in iters:
executor.submit(it.start_prefetch)
for it in iters:
_ = next(it)


if __name__ == '__main__':
absltest.main()
8 changes: 5 additions & 3 deletions grain/_src/python/shared_memory_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class SharedMemoryArray(np.ndarray):
the memory will not be freed.
"""

_lock: threading.Lock = threading.Lock()
_unlink_thread_pool: pool.ThreadPool | None = None
_unlink_semaphore: threading.Semaphore | None = None

Expand Down Expand Up @@ -121,9 +122,10 @@ def __reduce_ex__(self, protocol):

@classmethod
def enable_async_del(cls, num_threads: int = 1) -> None:
if not SharedMemoryArray._unlink_thread_pool:
SharedMemoryArray._unlink_thread_pool = pool.ThreadPool(num_threads)
SharedMemoryArray._unlink_semaphore = threading.Semaphore(num_threads)
with cls._lock:
if not SharedMemoryArray._unlink_thread_pool:
SharedMemoryArray._unlink_thread_pool = pool.ThreadPool(num_threads)
SharedMemoryArray._unlink_semaphore = threading.Semaphore(num_threads)

def unlink_on_del(self) -> None:
"""Mark this object responsible for unlinking the shared memory."""
Expand Down

0 comments on commit bc177d6

Please sign in to comment.