diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 4a92aa4617a..98c152ab842 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib import math from contextlib import suppress from typing import Callable, List, Optional, Union import torch +from packaging import version from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler from .logging import get_logger @@ -25,6 +27,7 @@ RNGType, broadcast, broadcast_object_list, + compare_versions, concatenate, find_batch_size, get_data_structure, @@ -415,6 +418,9 @@ def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, * "StatefulDataLoader is not available. Please install torchdata version 0.8.0 or higher to use it." ) if use_stateful_dataloader: + torchdata_version = version.parse(importlib.metadata.version("torchdata")) + if compare_versions(torchdata_version, "<", "0.11") and is_torch_version(">=", "2.6.0"): + kwargs.pop("in_order") self.base_dataloader = StatefulDataLoader(dataset, batch_sampler=batch_sampler, **kwargs) else: self.base_dataloader = DataLoader(dataset, batch_sampler=batch_sampler, **kwargs)