Skip to content

Commit

Permalink
python/pytorch: Fix WorkerSessionManager returning None when using sa…
Browse files Browse the repository at this point in the history
…mplers with no workers

When using a dynamic batch sampler without a dataloader (e.g to just get the batch indices), we trigger a bug where session manager returns None.

Signed-off-by: Soham Manoli <[email protected]>
  • Loading branch information
msoham123 committed Aug 21, 2024
1 parent 40d6580 commit 84b335a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
4 changes: 3 additions & 1 deletion python/aistore/pytorch/worker_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@ def __init__(self, session_manager: SessionManager):
@property
def session(self):
"""
Returns: Active request session acquired for a specific Pytorch dataloader worker
Returns an active request session acquired for a specific Pytorch dataloader worker.
"""
# sessions are not thread safe, so we must return different sessions for each worker
worker_info = get_worker_info()
if worker_info is None:
if self._session is None:
self._session = self._create_session()
return self._session
# if we only have one session but multiple workers, create more
if worker_info.id not in self._worker_sessions:
Expand Down
9 changes: 9 additions & 0 deletions python/tests/integration/pytorch/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,12 @@ def test_dynamic_sampler(self):

# Test that all objects are included in our batch
self.assertEqual(num_objects, NUM_OBJECTS)

def test_sampler_no_workers(self):
sampler = DynamicBatchSampler(
data_source=self.dataset,
max_batch_size=MAX_BATCH_SIZE,
)

for _ in sampler:
continue

0 comments on commit 84b335a

Please sign in to comment.