Skip to content

Commit

Permalink
add torchdata version check to avoid "in_order" error (#3344)
Browse files Browse the repository at this point in the history
  • Loading branch information
faaany authored Jan 15, 2025
1 parent f0b0305 commit 828aae4
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
import math
from contextlib import suppress
from typing import Callable, List, Optional, Union

import torch
from packaging import version
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler

from .logging import get_logger
Expand All @@ -25,6 +27,7 @@
RNGType,
broadcast,
broadcast_object_list,
compare_versions,
concatenate,
find_batch_size,
get_data_structure,
Expand Down Expand Up @@ -415,6 +418,9 @@ def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, *
"StatefulDataLoader is not available. Please install torchdata version 0.8.0 or higher to use it."
)
if use_stateful_dataloader:
torchdata_version = version.parse(importlib.metadata.version("torchdata"))
if compare_versions(torchdata_version, "<", "0.11") and is_torch_version(">=", "2.6.0"):
kwargs.pop("in_order")
self.base_dataloader = StatefulDataLoader(dataset, batch_sampler=batch_sampler, **kwargs)
else:
self.base_dataloader = DataLoader(dataset, batch_sampler=batch_sampler, **kwargs)
Expand Down

0 comments on commit 828aae4

Please sign in to comment.