diff --git a/CHANGELOG.md b/CHANGELOG.md index 3055b15011a2f..f8faf728c2b66 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -81,7 +81,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `LightningModule.model_size` ([#8343](https://github.com/PyTorchLightning/pytorch-lightning/pull/8343)) -- +- Deprecated `DataModule` properties: `train_transforms`, `val_transforms`, `test_transforms`, `size`, `dims` ([#8851](https://github.com/PyTorchLightning/pytorch-lightning/pull/8851)) - diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index b4273d550fd09..60cbaaaf758e9 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -70,6 +70,20 @@ def teardown(self): def __init__(self, train_transforms=None, val_transforms=None, test_transforms=None, dims=None): super().__init__() + if train_transforms is not None: + rank_zero_deprecation( + "DataModule property `train_transforms` was deprecated in v1.5 and will be removed in v1.7." + ) + if val_transforms is not None: + rank_zero_deprecation( + "DataModule property `val_transforms` was deprecated in v1.5 and will be removed in v1.7." + ) + if test_transforms is not None: + rank_zero_deprecation( + "DataModule property `test_transforms` was deprecated in v1.5 and will be removed in v1.7." + ) + if dims is not None: + rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.") self._train_transforms = train_transforms self._val_transforms = val_transforms self._test_transforms = test_transforms @@ -95,55 +109,94 @@ def __init__(self, train_transforms=None, val_transforms=None, test_transforms=N def train_transforms(self): """ Optional transforms (or collection of transforms) you can apply to train dataset + + .. deprecated:: v1.5 + Will be removed in v1.7.0. """ + + rank_zero_deprecation( + "DataModule property `train_transforms` was deprecated in v1.5 and will be removed in v1.7." + ) return self._train_transforms @train_transforms.setter def train_transforms(self, t): + rank_zero_deprecation( + "DataModule property `train_transforms` was deprecated in v1.5 and will be removed in v1.7." + ) self._train_transforms = t @property def val_transforms(self): """ Optional transforms (or collection of transforms) you can apply to validation dataset + + .. deprecated:: v1.5 + Will be removed in v1.7.0. """ + + rank_zero_deprecation( + "DataModule property `val_transforms` was deprecated in v1.5 and will be removed in v1.7." + ) return self._val_transforms @val_transforms.setter def val_transforms(self, t): + rank_zero_deprecation( + "DataModule property `val_transforms` was deprecated in v1.5 and will be removed in v1.7." + ) self._val_transforms = t @property def test_transforms(self): """ Optional transforms (or collection of transforms) you can apply to test dataset + + .. deprecated:: v1.5 + Will be removed in v1.7.0. """ + + rank_zero_deprecation( + "DataModule property `test_transforms` was deprecated in v1.5 and will be removed in v1.7." + ) return self._test_transforms @test_transforms.setter def test_transforms(self, t): + rank_zero_deprecation( + "DataModule property `test_transforms` was deprecated in v1.5 and will be removed in v1.7." + ) self._test_transforms = t @property def dims(self): """ A tuple describing the shape of your data. Extra functionality exposed in ``size``. + + .. deprecated:: v1.5 + Will be removed in v1.7.0. """ + rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.") return self._dims @dims.setter def dims(self, d): + rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.") self._dims = d def size(self, dim=None) -> Union[Tuple, int]: """ Return the dimension of each input either as a tuple or list of tuples. You can index this just as you would with a torch tensor. + + .. deprecated:: v1.5 + Will be removed in v1.7.0. """ if dim is not None: return self.dims[dim] + rank_zero_deprecation("DataModule property `size` was deprecated in v1.5 and will be removed in v1.7.") return self.dims @property diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 38b1e50150063..d836f1427a110 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -15,8 +15,10 @@ import pytest +from pytorch_lightning import LightningDataModule from tests.deprecated_api import _soft_unimport_module from tests.helpers import BoringModel +from tests.helpers.datamodules import MNISTDataModule def test_v1_7_0_deprecated_lightning_module_summarize(tmpdir): @@ -46,3 +48,35 @@ def test_v1_7_0_deprecated_model_size(): match="LightningModule.model_size` property was deprecated in v1.5 and will be removed in v1.7" ): _ = model.model_size + + +def test_v1_7_0_datamodule_transform_properties(tmpdir): + dm = MNISTDataModule() + with pytest.deprecated_call(match=r"DataModule property `train_transforms` was deprecated in v1.5"): + dm.train_transforms = "a" + with pytest.deprecated_call(match=r"DataModule property `val_transforms` was deprecated in v1.5"): + dm.val_transforms = "b" + with pytest.deprecated_call(match=r"DataModule property `test_transforms` was deprecated in v1.5"): + dm.test_transforms = "c" + with pytest.deprecated_call(match=r"DataModule property `train_transforms` was deprecated in v1.5"): + _ = LightningDataModule(train_transforms="a") + with pytest.deprecated_call(match=r"DataModule property `val_transforms` was deprecated in v1.5"): + _ = LightningDataModule(val_transforms="b") + with pytest.deprecated_call(match=r"DataModule property `test_transforms` was deprecated in v1.5"): + _ = LightningDataModule(test_transforms="c") + with pytest.deprecated_call(match=r"DataModule property `test_transforms` was deprecated in v1.5"): + _ = LightningDataModule(test_transforms="c", dims=(1, 1, 1)) + + +def test_v1_7_0_datamodule_size_property(tmpdir): + dm = MNISTDataModule() + with pytest.deprecated_call(match=r"DataModule property `size` was deprecated in v1.5"): + dm.size() + + +def test_v1_7_0_datamodule_dims_property(tmpdir): + dm = MNISTDataModule() + with pytest.deprecated_call(match=r"DataModule property `dims` was deprecated in v1.5"): + _ = dm.dims + with pytest.deprecated_call(match=r"DataModule property `dims` was deprecated in v1.5"): + _ = LightningDataModule(dims=(1, 1, 1))