Skip to content

Commit

Permalink
Refactor access to trainer attributes in LightningModule (#5730)
Browse files Browse the repository at this point in the history
* rank access

* tests for property

* weekref

* logger

* changelog

* torchscript

* changelog

* chlog

* .

* amp

* yapf

* flake8

Co-authored-by: Jirka Borovec <[email protected]>
  • Loading branch information
awaelchli and Borda authored Feb 1, 2021
1 parent 9064b83 commit 344f3a9
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 10 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed the default value for the `progress_bar_refresh_rate` Trainer argument in Google COLAB notebooks to 20 ([#5516](https://github.com/PyTorchLightning/pytorch-lightning/pull/5516))


- Made `LightningModule.global_rank`, `LightningModule.local_rank` and `LightningModule.logger` read-only properties ([#5730](https://github.com/PyTorchLightning/pytorch-lightning/pull/5730))


- Refactored Accelerators and Plugins
* Added base classes for plugins ([#5715](https://github.com/PyTorchLightning/pytorch-lightning/pull/5715))
* Added parallel plugins for DP, DDP, DDPSpawn, DDP2 and Horovod ([#5714](https://github.com/PyTorchLightning/pytorch-lightning/pull/5714))
Expand Down
21 changes: 18 additions & 3 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ class LightningModule(
"current_epoch",
"global_step",
"running_stage",
"global_rank",
"local_rank",
"logger",
] + DeviceDtypeModuleMixin.__jit_unused_properties__

def __init__(self, *args, **kwargs):
Expand All @@ -83,9 +86,6 @@ def __init__(self, *args, **kwargs):
#: Pointer to the trainer object
self.trainer = None

#: Pointer to the logger object
self.logger = None

self._distrib_type = None
self._device_type = None

Expand Down Expand Up @@ -132,6 +132,16 @@ def global_step(self) -> int:
"""Total training batches seen across all epochs"""
return self.trainer.global_step if self.trainer else 0

@property
def global_rank(self) -> int:
""" The index of the current process across all nodes and devices. """
return self.trainer.global_rank if self.trainer else 0

@property
def local_rank(self) -> int:
""" The index of the current process within a single node. """
return self.trainer.local_rank if self.trainer else 0

@example_input_array.setter
def example_input_array(self, example: Any) -> None:
self._example_input_array = example
Expand Down Expand Up @@ -163,6 +173,11 @@ def automatic_optimization(self) -> bool:
def automatic_optimization(self, automatic_optimization: bool) -> None:
self._automatic_optimization = automatic_optimization

@property
def logger(self):
""" Reference to the logger object in the Trainer. """
return self.trainer.logger if self.trainer else None

def print(self, *args, **kwargs) -> None:
r"""
Prints only from process 0. Use this in any distributed mode to log only once.
Expand Down
6 changes: 2 additions & 4 deletions pytorch_lightning/trainer/connectors/model_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Currently supports training on CPU, GPU (dp, ddp, ddp2, horovod) and TPU.
"""
from weakref import proxy


class ModelConnector:
Expand All @@ -30,17 +31,14 @@ def copy_trainer_model_properties(self, model):
self.trainer.train_loop.automatic_optimization = automatic_optimization

for m in [model, ref_model]:
m.trainer = self.trainer
m.logger = self.trainer.logger
m.trainer = proxy(self.trainer)
m._device_type = str(self.trainer._device_type)
m._distrib_type = str(self.trainer._distrib_type)
m.use_amp = self.trainer.amp_backend is not None
m.testing = self.trainer.testing
m.tpu_local_core_rank = self.trainer.tpu_local_core_rank
m.tpu_global_core_rank = self.trainer.tpu_global_core_rank
m.precision = self.trainer.precision
m.global_rank = self.trainer.global_rank
m.local_rank = self.trainer.local_rank

def get_model(self):
return self._get_reference_model(self.trainer.model)
Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/tuner/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,10 @@ def tune(self, model, train_dataloader, val_dataloaders, datamodule):
val_dataloaders=val_dataloaders,
datamodule=datamodule,
)
model.logger = self.trainer.logger # reset logger binding

# Run learning rate finder:
if self.trainer.auto_lr_find:
self.internal_find_lr(model)
model.logger = self.trainer.logger # reset logger binding

def scale_batch_size(
self,
Expand Down
54 changes: 53 additions & 1 deletion tests/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,68 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import patch
from unittest.mock import patch, Mock

import pytest
from torch.optim import Adam, SGD

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import BoringModel


def test_property_current_epoch():
""" Test that the current_epoch in LightningModule is accessible via the Trainer. """
model = BoringModel()
assert model.current_epoch == 0

trainer = Mock(current_epoch=123)
model.trainer = trainer
assert model.current_epoch == 123


def test_property_global_step():
""" Test that the global_step in LightningModule is accessible via the Trainer. """
model = BoringModel()
assert model.global_step == 0

trainer = Mock(global_step=123)
model.trainer = trainer
assert model.global_step == 123


def test_property_global_rank():
""" Test that the global rank in LightningModule is accessible via the Trainer. """
model = BoringModel()
assert model.global_rank == 0

trainer = Mock(global_rank=123)
model.trainer = trainer
assert model.global_rank == 123


def test_property_local_rank():
""" Test that the local rank in LightningModule is accessible via the Trainer. """
model = BoringModel()
assert model.local_rank == 0

trainer = Mock(local_rank=123)
model.trainer = trainer
assert model.local_rank == 123


def test_property_logger(tmpdir):
""" Test that the logger in LightningModule is accessible via the Trainer. """
model = BoringModel()
assert model.logger is None

logger = TensorBoardLogger(tmpdir)
trainer = Mock(logger=logger)
model.trainer = trainer
assert model.logger == logger


def test_automatic_optimization(tmpdir):
class TestModel(BoringModel):
def optimizer_step(self, *_, **__):
Expand Down
2 changes: 2 additions & 0 deletions tests/trainer/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def test_trainer_reset_correctly(tmpdir):
assert attributes_before[key] == attributes_after[key], \
f'Attribute {key} was not reset correctly after learning rate finder'

assert model.trainer == trainer


@pytest.mark.parametrize('use_hparams', [False, True])
def test_trainer_arg_bool(tmpdir, use_hparams):
Expand Down

0 comments on commit 344f3a9

Please sign in to comment.