diff --git a/CHANGELOG.md b/CHANGELOG.md index bab60910c7a30..8b5f689242d84 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -301,6 +301,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed metric objects passed directly to `self.log` not being reset correctly ([#7055](https://github.com/PyTorchLightning/pytorch-lightning/pull/7055)) +- Fixed `CombinedLoader` in distributed settings for validation / testing ([#7102](https://github.com/PyTorchLightning/pytorch-lightning/pull/7102)) + + ## [1.2.7] - 2021-04-06 ### Fixed diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index a361f6e6203c2..51ed858fa9b22 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -108,12 +108,15 @@ def auto_add_worker_init_fn(self, dataloader: DataLoader) -> None: dataloader.worker_init_fn = partial(pl_worker_init_function, rank=self.global_rank) def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader: - # don't do anything if it's not a dataloader is_dataloader = isinstance(dataloader, DataLoader) # don't manipulate iterable datasets is_iterable_ds = has_iterable_dataset(dataloader) + if isinstance(dataloader, CombinedLoader): + dataloader.loaders = apply_to_collection(dataloader.loaders, DataLoader, self.auto_add_sampler, shuffle) + return dataloader + if not is_dataloader or is_iterable_ds: return dataloader diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index f884306dc09c8..3cb0b0cb1f11a 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -19,6 +19,8 @@ import torch from torch import Tensor from torch.utils.data import Dataset +from torch.utils.data.dataloader import DataLoader +from torch.utils.data.dataset import IterableDataset from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.cloud_io import get_filesystem @@ -352,7 +354,7 @@ def __init__(self, loaders: Any, mode: str = 'min_size'): @property def sampler(self) -> Union[Iterable, Sequence, Mapping]: """Return a collections of samplers extracting from loaders.""" - return apply_to_collection(self.loaders, Iterable, getattr, 'sampler', None, wrong_dtype=(Sequence, Mapping)) + return apply_to_collection(self.loaders, (DataLoader, IterableDataset), getattr, 'sampler', None) def _wrap_loaders_max_size_cycle(self) -> Any: """ diff --git a/tests/trainer/test_supporters.py b/tests/trainer/test_supporters.py index 30b984dc896be..6da2436b5eafc 100644 --- a/tests/trainer/test_supporters.py +++ b/tests/trainer/test_supporters.py @@ -11,12 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from collections import Sequence +from unittest import mock import pytest import torch -from torch.utils.data import TensorDataset +from torch.utils.data import DataLoader, TensorDataset +from torch.utils.data.dataset import Dataset +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data.sampler import Sampler +from pytorch_lightning import Trainer from pytorch_lightning.trainer.supporters import ( _nested_calc_num_data, CombinedDataset, @@ -25,6 +31,7 @@ CycleIterator, TensorRunningAccum, ) +from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -237,3 +244,46 @@ def test_nested_calc_num_data(input_data, compute_func, expected_length): calculated_length = _nested_calc_num_data(input_data, compute_func) assert calculated_length == expected_length + + +@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "PL_TRAINER_GPUS": "2"}) +@mock.patch('torch.cuda.device_count', return_value=2) +@mock.patch('torch.cuda.is_available', return_value=True) +def test_combined_data_loader_validation_test(cuda_available_mock, device_count_mock, tmpdir): + """ + This test makes sure distributed sampler has been properly injected in dataloaders + when using CombinedLoader + """ + + class CustomDataset(Dataset): + + def __init__(self, data): + self.data = data + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + return self.data[index] + + dataloader = CombinedLoader({ + "a": DataLoader(CustomDataset(range(10))), + "b": { + "c": DataLoader(CustomDataset(range(10))), + "d": DataLoader(CustomDataset(range(10))) + }, + "e": [DataLoader(CustomDataset(range(10))), + DataLoader(CustomDataset(range(10)))] + }) + + trainer = Trainer(replace_sampler_ddp=True, accelerator="ddp", gpus=2) + dataloader = trainer.auto_add_sampler(dataloader, shuffle=True) + _count = 0 + + def _assert_distributed_sampler(v): + nonlocal _count + _count += 1 + assert isinstance(v, DistributedSampler) + + apply_to_collection(dataloader.sampler, Sampler, _assert_distributed_sampler) + assert _count == 5