From 4c8cd19272900d57a5da0d0fe70399d387636ac7 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 20 Sep 2020 16:37:21 -0700 Subject: [PATCH 1/8] Split out changes from #3563 to make that PR easier to review. This formats the file according to the Black formatter --- pytorch_lightning/core/hooks.py | 46 ++++++++++----------------------- 1 file changed, 13 insertions(+), 33 deletions(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index b4cfd50819ffc..0694b9923b1fc 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Union +from typing import Any, Union, List import torch -from pytorch_lightning.utilities import AMPType, move_data_to_device, rank_zero_warn from torch import Tensor from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader +from pytorch_lightning.utilities import move_data_to_device, AMPType, rank_zero_warn try: from apex import amp @@ -28,6 +28,7 @@ class ModelHooks: + def setup(self, stage: str): """ Called at the beginning of fit and test. @@ -112,9 +113,7 @@ def on_pretrain_routine_end(self) -> None: """ # do something at the end of the pretrain routine - def on_train_batch_start( - self, batch: Any, batch_idx: int, dataloader_idx: int - ) -> None: + def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: """ Called in the training loop before anything happens for that batch. @@ -127,9 +126,7 @@ def on_train_batch_start( """ # do something when the batch starts - def on_train_batch_end( - self, batch: Any, batch_idx: int, dataloader_idx: int - ) -> None: + def on_train_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: """ Called in the training loop after the batch. @@ -140,9 +137,7 @@ def on_train_batch_end( """ # do something when the batch ends - def on_validation_batch_start( - self, batch: Any, batch_idx: int, dataloader_idx: int - ) -> None: + def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: """ Called in the validation loop before anything happens for that batch. @@ -153,9 +148,7 @@ def on_validation_batch_start( """ # do something when the batch starts - def on_validation_batch_end( - self, batch: Any, batch_idx: int, dataloader_idx: int - ) -> None: + def on_validation_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: """ Called in the validation loop after the batch. @@ -166,9 +159,7 @@ def on_validation_batch_end( """ # do something when the batch ends - def on_test_batch_start( - self, batch: Any, batch_idx: int, dataloader_idx: int - ) -> None: + def on_test_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: """ Called in the test loop before anything happens for that batch. @@ -179,9 +170,7 @@ def on_test_batch_start( """ # do something when the batch starts - def on_test_batch_end( - self, batch: Any, batch_idx: int, dataloader_idx: int - ) -> None: + def on_test_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: """ Called in the test loop after the batch. @@ -299,9 +288,7 @@ def on_after_backward(self): """ - def backward( - self, trainer, loss: Tensor, optimizer: Optimizer, optimizer_idx: int - ) -> None: + def backward(self, trainer, loss: Tensor, optimizer: Optimizer, optimizer_idx: int) -> None: """ Override backward with your own implementation if you need to. @@ -324,13 +311,7 @@ def backward(self, trainer, loss, optimizer, optimizer_idx): """ loss.backward() - def amp_scale_loss( - self, - unscaled_loss: Tensor, - optimizer: Optimizer, - optimizer_idx: int, - amp_backend: AMPType, - ): + def amp_scale_loss(self, unscaled_loss, optimizer, optimizer_idx, amp_backend: AMPType): if amp_backend == AMPType.NATIVE: scaled_loss = self.trainer.scaler.scale(unscaled_loss) else: @@ -340,6 +321,7 @@ def amp_scale_loss( class DataHooks: + def prepare_data(self) -> None: """ Use this to download and prepare data. @@ -430,9 +412,7 @@ def train_dataloader(self): return loader """ - rank_zero_warn( - "`train_dataloader` must be implemented to be used with the Lightning Trainer" - ) + rank_zero_warn('`train_dataloader` must be implemented to be used with the Lightning Trainer') def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]: r""" From cc636ae2b3136d09ef0398c2d5ef44a75b8839df Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 27 Sep 2020 16:01:32 -0700 Subject: [PATCH 2/8] Store a reference to the trainer on the datamodule Fixes #3682 --- pytorch_lightning/core/hooks.py | 46 +++++++++++++------ .../trainer/connectors/data_connector.py | 3 +- tests/core/test_datamodules.py | 25 ++++++++++ 3 files changed, 60 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 0694b9923b1fc..b4cfd50819ffc 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Union, List +from typing import Any, List, Union import torch +from pytorch_lightning.utilities import AMPType, move_data_to_device, rank_zero_warn from torch import Tensor from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader -from pytorch_lightning.utilities import move_data_to_device, AMPType, rank_zero_warn try: from apex import amp @@ -28,7 +28,6 @@ class ModelHooks: - def setup(self, stage: str): """ Called at the beginning of fit and test. @@ -113,7 +112,9 @@ def on_pretrain_routine_end(self) -> None: """ # do something at the end of the pretrain routine - def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + def on_train_batch_start( + self, batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: """ Called in the training loop before anything happens for that batch. @@ -126,7 +127,9 @@ def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) """ # do something when the batch starts - def on_train_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + def on_train_batch_end( + self, batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: """ Called in the training loop after the batch. @@ -137,7 +140,9 @@ def on_train_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) -> """ # do something when the batch ends - def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + def on_validation_batch_start( + self, batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: """ Called in the validation loop before anything happens for that batch. @@ -148,7 +153,9 @@ def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: """ # do something when the batch starts - def on_validation_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + def on_validation_batch_end( + self, batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: """ Called in the validation loop after the batch. @@ -159,7 +166,9 @@ def on_validation_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: in """ # do something when the batch ends - def on_test_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + def on_test_batch_start( + self, batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: """ Called in the test loop before anything happens for that batch. @@ -170,7 +179,9 @@ def on_test_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) - """ # do something when the batch starts - def on_test_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + def on_test_batch_end( + self, batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: """ Called in the test loop after the batch. @@ -288,7 +299,9 @@ def on_after_backward(self): """ - def backward(self, trainer, loss: Tensor, optimizer: Optimizer, optimizer_idx: int) -> None: + def backward( + self, trainer, loss: Tensor, optimizer: Optimizer, optimizer_idx: int + ) -> None: """ Override backward with your own implementation if you need to. @@ -311,7 +324,13 @@ def backward(self, trainer, loss, optimizer, optimizer_idx): """ loss.backward() - def amp_scale_loss(self, unscaled_loss, optimizer, optimizer_idx, amp_backend: AMPType): + def amp_scale_loss( + self, + unscaled_loss: Tensor, + optimizer: Optimizer, + optimizer_idx: int, + amp_backend: AMPType, + ): if amp_backend == AMPType.NATIVE: scaled_loss = self.trainer.scaler.scale(unscaled_loss) else: @@ -321,7 +340,6 @@ def amp_scale_loss(self, unscaled_loss, optimizer, optimizer_idx, amp_backend: A class DataHooks: - def prepare_data(self) -> None: """ Use this to download and prepare data. @@ -412,7 +430,9 @@ def train_dataloader(self): return loader """ - rank_zero_warn('`train_dataloader` must be implemented to be used with the Lightning Trainer') + rank_zero_warn( + "`train_dataloader` must be implemented to be used with the Lightning Trainer" + ) def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]: r""" diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index fb9bdbe821691..c5cea5cfc2380 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -101,7 +101,7 @@ def attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=None, if test_dataloaders is not None: model.test_dataloader = _PatchDataLoader(test_dataloaders) - def attach_datamodule(self, model, datamodule, stage): + def attach_datamodule(self, model, datamodule: Optional[LightningDataModule], stage: str) -> None: # We use datamodule if it's been provided on .fit or .test, otherwise we check model for it datamodule = datamodule or getattr(model, 'datamodule', None) @@ -122,6 +122,7 @@ def attach_datamodule(self, model, datamodule, stage): model.transfer_batch_to_device = datamodule.transfer_batch_to_device self.trainer.datamodule = datamodule + datamodule.trainer = self.trainer class _PatchDataLoader(object): diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 2fa9c5cf764e2..3fe7aaf286c33 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -254,6 +254,31 @@ def test_full_loop(tmpdir): result = result[0] assert result['test_acc'] > 0.8 +def test_trainer_attached_to_dm(tmpdir): + reset_seed() + + dm = TrialMNISTDataModule(tmpdir) + + model = EvalModelTemplate() + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=3, + weights_summary=None, + deterministic=True, + ) + + # fit model + result = trainer.fit(model, dm) + assert result == 1 + assert dm.trainer is not None + + # test + result = trainer.test(datamodule=dm) + result = result[0] + assert dm.trainer is not None + + @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="test requires multi-GPU machine") def test_full_loop_single_gpu(tmpdir): From fd95161a3d11ebe8c1a36ae468b326539f9d9505 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 27 Sep 2020 16:12:45 -0700 Subject: [PATCH 3/8] Update data_connector.py --- pytorch_lightning/trainer/connectors/data_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index c5cea5cfc2380..4836e1352ea35 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -14,7 +14,7 @@ from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.utilities.exceptions import MisconfigurationException -from typing import List, Union +from typing import List, Union, Optional from torch.utils.data import DataLoader from pytorch_lightning.utilities.model_utils import is_overridden From 6513166220f65abf04cc1ca6c1437b120f8b218d Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 27 Sep 2020 16:13:11 -0700 Subject: [PATCH 4/8] Update data_connector.py --- pytorch_lightning/trainer/connectors/data_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 4836e1352ea35..bab052fc32c5b 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -14,7 +14,7 @@ from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.utilities.exceptions import MisconfigurationException -from typing import List, Union, Optional +from typing import List, Optional, Union from torch.utils.data import DataLoader from pytorch_lightning.utilities.model_utils import is_overridden From 0f3c67021c86e7a0c5fc06218dcee7dc678107c5 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 27 Sep 2020 16:17:58 -0700 Subject: [PATCH 5/8] Update test_datamodules.py --- tests/core/test_datamodules.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 3fe7aaf286c33..357412acdde88 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -254,6 +254,7 @@ def test_full_loop(tmpdir): result = result[0] assert result['test_acc'] > 0.8 + def test_trainer_attached_to_dm(tmpdir): reset_seed() From 6e4b7831562c8f9efa334ea4e47ffeebe6732d99 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 27 Sep 2020 21:44:44 -0700 Subject: [PATCH 6/8] Support more storage backends in trainer.test using best weights Similar to #3692 --- pytorch_lightning/trainer/trainer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b7d76f56646cf..f9785337c2588 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -15,10 +15,8 @@ import os import warnings from typing import Dict, Iterable, List, Optional, Union - import torch from torch.utils.data import DataLoader - from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule @@ -53,6 +51,7 @@ from pytorch_lightning.trainer.connectors.precision_connector import PrecisionConnector from pytorch_lightning.trainer.connectors.profiler_connector import ProfilerConnector from pytorch_lightning.trainer.connectors.data_connector import DataConnector +from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.model_utils import is_overridden from pytorch_lightning.trainer import docstrings from pytorch_lightning.trainer.properties import TrainerProperties @@ -569,7 +568,9 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): ) return {} - ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage) + fs = get_filesystem(ckpt_path) + with fs.open(ckpt_path) as f: + ckpt = torch.load(f, map_location=lambda storage, loc: storage) model.load_state_dict(ckpt['state_dict']) # attach dataloaders From 214bfb5d99fc9c8c1fbbd536ccee23b3bdfd783f Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 27 Sep 2020 21:47:38 -0700 Subject: [PATCH 7/8] Update trainer.py --- pytorch_lightning/trainer/trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f9785337c2588..103c8d99503ed 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -15,8 +15,10 @@ import os import warnings from typing import Dict, Iterable, List, Optional, Union + import torch from torch.utils.data import DataLoader + from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.lightning import LightningModule From 3fe2a5ec6a2ec52d78846fee2b08596683f79166 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 27 Sep 2020 21:51:38 -0700 Subject: [PATCH 8/8] Update trainer.py use cloud_io load directly --- pytorch_lightning/trainer/trainer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 103c8d99503ed..fa96550e7bbf3 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -53,7 +53,7 @@ from pytorch_lightning.trainer.connectors.precision_connector import PrecisionConnector from pytorch_lightning.trainer.connectors.profiler_connector import ProfilerConnector from pytorch_lightning.trainer.connectors.data_connector import DataConnector -from pytorch_lightning.utilities.cloud_io import get_filesystem +from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.model_utils import is_overridden from pytorch_lightning.trainer import docstrings from pytorch_lightning.trainer.properties import TrainerProperties @@ -570,9 +570,7 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): ) return {} - fs = get_filesystem(ckpt_path) - with fs.open(ckpt_path) as f: - ckpt = torch.load(f, map_location=lambda storage, loc: storage) + ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage) model.load_state_dict(ckpt['state_dict']) # attach dataloaders