From 344f3a984aeec39e2ce0a27a82290897307de751 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 1 Feb 2021 15:28:17 +0100 Subject: [PATCH] Refactor access to trainer attributes in LightningModule (#5730) * rank access * tests for property * weekref * logger * changelog * torchscript * changelog * chlog * . * amp * yapf * flake8 Co-authored-by: Jirka Borovec --- CHANGELOG.md | 3 ++ pytorch_lightning/core/lightning.py | 21 ++++++-- .../trainer/connectors/model_connector.py | 6 +-- pytorch_lightning/tuner/tuning.py | 2 - tests/core/test_lightning_module.py | 54 ++++++++++++++++++- tests/trainer/test_lr_finder.py | 2 + 6 files changed, 78 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dc381b3983753..d55a33f960589 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -107,6 +107,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed the default value for the `progress_bar_refresh_rate` Trainer argument in Google COLAB notebooks to 20 ([#5516](https://github.com/PyTorchLightning/pytorch-lightning/pull/5516)) +- Made `LightningModule.global_rank`, `LightningModule.local_rank` and `LightningModule.logger` read-only properties ([#5730](https://github.com/PyTorchLightning/pytorch-lightning/pull/5730)) + + - Refactored Accelerators and Plugins * Added base classes for plugins ([#5715](https://github.com/PyTorchLightning/pytorch-lightning/pull/5715)) * Added parallel plugins for DP, DDP, DDPSpawn, DDP2 and Horovod ([#5714](https://github.com/PyTorchLightning/pytorch-lightning/pull/5714)) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index c453bd5d607d6..965dba8ad3a30 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -67,6 +67,9 @@ class LightningModule( "current_epoch", "global_step", "running_stage", + "global_rank", + "local_rank", + "logger", ] + DeviceDtypeModuleMixin.__jit_unused_properties__ def __init__(self, *args, **kwargs): @@ -83,9 +86,6 @@ def __init__(self, *args, **kwargs): #: Pointer to the trainer object self.trainer = None - #: Pointer to the logger object - self.logger = None - self._distrib_type = None self._device_type = None @@ -132,6 +132,16 @@ def global_step(self) -> int: """Total training batches seen across all epochs""" return self.trainer.global_step if self.trainer else 0 + @property + def global_rank(self) -> int: + """ The index of the current process across all nodes and devices. """ + return self.trainer.global_rank if self.trainer else 0 + + @property + def local_rank(self) -> int: + """ The index of the current process within a single node. """ + return self.trainer.local_rank if self.trainer else 0 + @example_input_array.setter def example_input_array(self, example: Any) -> None: self._example_input_array = example @@ -163,6 +173,11 @@ def automatic_optimization(self) -> bool: def automatic_optimization(self, automatic_optimization: bool) -> None: self._automatic_optimization = automatic_optimization + @property + def logger(self): + """ Reference to the logger object in the Trainer. """ + return self.trainer.logger if self.trainer else None + def print(self, *args, **kwargs) -> None: r""" Prints only from process 0. Use this in any distributed mode to log only once. diff --git a/pytorch_lightning/trainer/connectors/model_connector.py b/pytorch_lightning/trainer/connectors/model_connector.py index a3759d1075ee5..673e8765ed51f 100644 --- a/pytorch_lightning/trainer/connectors/model_connector.py +++ b/pytorch_lightning/trainer/connectors/model_connector.py @@ -17,6 +17,7 @@ Currently supports training on CPU, GPU (dp, ddp, ddp2, horovod) and TPU. """ +from weakref import proxy class ModelConnector: @@ -30,8 +31,7 @@ def copy_trainer_model_properties(self, model): self.trainer.train_loop.automatic_optimization = automatic_optimization for m in [model, ref_model]: - m.trainer = self.trainer - m.logger = self.trainer.logger + m.trainer = proxy(self.trainer) m._device_type = str(self.trainer._device_type) m._distrib_type = str(self.trainer._distrib_type) m.use_amp = self.trainer.amp_backend is not None @@ -39,8 +39,6 @@ def copy_trainer_model_properties(self, model): m.tpu_local_core_rank = self.trainer.tpu_local_core_rank m.tpu_global_core_rank = self.trainer.tpu_global_core_rank m.precision = self.trainer.precision - m.global_rank = self.trainer.global_rank - m.local_rank = self.trainer.local_rank def get_model(self): return self._get_reference_model(self.trainer.model) diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index dae3fed868520..0567399970ae7 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -50,12 +50,10 @@ def tune(self, model, train_dataloader, val_dataloaders, datamodule): val_dataloaders=val_dataloaders, datamodule=datamodule, ) - model.logger = self.trainer.logger # reset logger binding # Run learning rate finder: if self.trainer.auto_lr_find: self.internal_find_lr(model) - model.logger = self.trainer.logger # reset logger binding def scale_batch_size( self, diff --git a/tests/core/test_lightning_module.py b/tests/core/test_lightning_module.py index 9cea8cf28c07f..17d25b6c9b75a 100644 --- a/tests/core/test_lightning_module.py +++ b/tests/core/test_lightning_module.py @@ -11,16 +11,68 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import patch +from unittest.mock import patch, Mock import pytest from torch.optim import Adam, SGD from pytorch_lightning import Trainer +from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import BoringModel +def test_property_current_epoch(): + """ Test that the current_epoch in LightningModule is accessible via the Trainer. """ + model = BoringModel() + assert model.current_epoch == 0 + + trainer = Mock(current_epoch=123) + model.trainer = trainer + assert model.current_epoch == 123 + + +def test_property_global_step(): + """ Test that the global_step in LightningModule is accessible via the Trainer. """ + model = BoringModel() + assert model.global_step == 0 + + trainer = Mock(global_step=123) + model.trainer = trainer + assert model.global_step == 123 + + +def test_property_global_rank(): + """ Test that the global rank in LightningModule is accessible via the Trainer. """ + model = BoringModel() + assert model.global_rank == 0 + + trainer = Mock(global_rank=123) + model.trainer = trainer + assert model.global_rank == 123 + + +def test_property_local_rank(): + """ Test that the local rank in LightningModule is accessible via the Trainer. """ + model = BoringModel() + assert model.local_rank == 0 + + trainer = Mock(local_rank=123) + model.trainer = trainer + assert model.local_rank == 123 + + +def test_property_logger(tmpdir): + """ Test that the logger in LightningModule is accessible via the Trainer. """ + model = BoringModel() + assert model.logger is None + + logger = TensorBoardLogger(tmpdir) + trainer = Mock(logger=logger) + model.trainer = trainer + assert model.logger == logger + + def test_automatic_optimization(tmpdir): class TestModel(BoringModel): def optimizer_step(self, *_, **__): diff --git a/tests/trainer/test_lr_finder.py b/tests/trainer/test_lr_finder.py index 3b59095fcf393..228246fb18e4d 100755 --- a/tests/trainer/test_lr_finder.py +++ b/tests/trainer/test_lr_finder.py @@ -90,6 +90,8 @@ def test_trainer_reset_correctly(tmpdir): assert attributes_before[key] == attributes_after[key], \ f'Attribute {key} was not reset correctly after learning rate finder' + assert model.trainer == trainer + @pytest.mark.parametrize('use_hparams', [False, True]) def test_trainer_arg_bool(tmpdir, use_hparams):