Skip to content

Commit

Permalink
Properly set LightningModule.device after model replacement (#7188)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
tchaton and carmocca authored Apr 23, 2021
1 parent 8439aea commit f58865a
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 8 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed parsing for pre-release package versions ([#6999](https://github.com/PyTorchLightning/pytorch-lightning/pull/6999))


- Fixed resetting device after `fitting/evaluating/predicting` ([#7188](https://github.com/PyTorchLightning/pytorch-lightning/pull/7188))


## [1.2.7] - 2021-04-06

### Fixed
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ def on_train_start(self) -> None:
with torch.cuda.device(self.root_device):
torch.cuda.empty_cache()

def on_train_end(self) -> None:
def teardown(self) -> None:
self.lightning_module.cpu()

# clean up memory
self.model.cpu()
with torch.cuda.device(self.root_device):
torch.cuda.empty_cache()

Expand Down
6 changes: 0 additions & 6 deletions pytorch_lightning/utilities/device_dtype_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

from typing import Optional, Union

import torch
from torch.nn import Module

Expand Down Expand Up @@ -47,11 +46,6 @@ def device(self) -> Union[str, torch.device]:

return device

@device.setter
def device(self, new_device: Union[str, torch.device]):
# Necessary to avoid infinite recursion
raise RuntimeError('Cannot set the device explicitly. Please use module.to(new_device).')

@parameter_validation
def to(self, *args, **kwargs) -> Module:
"""Moves and/or casts the parameters and buffers.
Expand Down
23 changes: 23 additions & 0 deletions tests/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
from unittest.mock import Mock

import pytest
import torch
from torch import nn
from torch.optim import Adam, SGD

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf


def test_property_current_epoch():
Expand Down Expand Up @@ -358,3 +360,24 @@ def configure_optimizers(self):
)

trainer.fit(model)


@RunIf(min_gpus=1)
def test_device_placement(tmpdir):

model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=1)
trainer.fit(model)

def assert_device(device: torch.device) -> None:
assert model.device == device
for p in model.parameters():
assert p.device == device

assert_device(torch.device("cpu"))
model.to(torch.device("cuda:0"))
assert_device(torch.device("cuda:0"))
trainer.test(model)
assert_device(torch.device("cpu"))
trainer.predict(model, dataloaders=model.train_dataloader())
assert_device(torch.device("cpu"))

0 comments on commit f58865a

Please sign in to comment.