Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deprecate hpc_load() and integrate it with restore() #7955

Merged
merged 15 commits into from
Jun 14, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated default value of `monitor` argument in EarlyStopping callback to enforce `monitor` as a required argument ([#7907](https://github.com/PyTorchLightning/pytorch-lightning/pull/7907))


- Deprecated the use of `CheckpointConnector.hpc_load()` in favor of `CheckpointConnector.restore()` ([#7652](https://github.com/PyTorchLightning/pytorch-lightning/pull/7652))


### Removed

- Removed `ProfilerConnector` ([#7654](https://github.com/PyTorchLightning/pytorch-lightning/pull/7654))
Expand Down
47 changes: 23 additions & 24 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,14 @@

import pytorch_lightning
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, DeviceType, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities import (
_OMEGACONF_AVAILABLE,
DeviceType,
rank_zero_deprecation,
rank_zero_info,
rank_zero_warn,
)
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS

Expand All @@ -45,7 +50,7 @@ def hpc_resume_path(self) -> Optional[str]:
dir_path_hpc = str(self.trainer.weights_save_path)
max_version = self.max_ckpt_version_in_folder(dir_path_hpc, "hpc_ckpt_")
if max_version is not None:
return f"{dir_path_hpc}/hpc_ckpt_{max_version}.ckpt"
return os.path.join(dir_path_hpc, f"hpc_ckpt_{max_version}.ckpt")

def resume_start(self) -> None:
"""
Expand Down Expand Up @@ -129,6 +134,10 @@ def restore_model_state(self, model: LightningModule, checkpoint) -> None:
# hook: give user access to checkpoint if needed.
model.on_load_checkpoint(checkpoint)

# call hpc specific hook
if self.hpc_resume_path is not None:
model.on_hpc_load(self._loaded_checkpoint)

# restore model state_dict
self.trainer.training_type_plugin.load_model_state_dict(checkpoint)

Expand Down Expand Up @@ -248,6 +257,7 @@ def restore_lr_schedulers(self) -> None:
# ----------------------------------
# PRIVATE OPS
# ----------------------------------

def hpc_save(self, folderpath: str, logger):
# make sure the checkpoint folder exists
folderpath = str(folderpath) # because the tests pass a path object
Expand Down Expand Up @@ -365,29 +375,18 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:

return checkpoint

def hpc_load(self, checkpoint_path: str, on_gpu: bool):
"""
Load model/training states from a 'PyTorch-Lightning checkpoint' file for hpc.
All restored states are listed in return value description of `dump_checkpoint`.
def hpc_load(self, checkpoint_path: str) -> None:
"""
Attempts to restore the full training and model state from a HPC checkpoint file.

# read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path`
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)

# acquire the model
model = self.trainer.lightning_module

# restore model and datamodule state
self.restore_model_state(model, checkpoint)

if self.trainer.root_gpu is not None:
model.cuda(self.trainer.root_gpu)

# restore training state
self.restore_training_state(checkpoint)

# call hpc specific hook
model.on_hpc_load(checkpoint)
.. deprecated::v1.4
Will be removed in v1.6. Use :meth:`restore` instead.
"""
rank_zero_deprecation(
"`CheckpointConnector.hpc_load()` was deprecated in v1.4 and will be removed in v1.6."
" Use `CheckpointConnector.restore()` instead."
)
self.restore(checkpoint_path)

def max_ckpt_version_in_folder(self, dir_path: Union[str, Path], name_key: str = 'ckpt_') -> Optional[int]:
"""List up files in `dir_path` with `name_key`, then yield maximum suffix number.
Expand Down
13 changes: 13 additions & 0 deletions tests/deprecated_api/test_remove_1-4.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,16 @@ def training_step(self, batch, batch_idx):

with pytest.deprecated_call(match=r"Relying on.*is deprecated in v1.2 and will be removed in v1.4"):
trainer.fit(TestModel())


def test_v1_4_0_deprecated_hpc_load(tmpdir):
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=1,
)
trainer.fit(model)
trainer.checkpoint_connector.hpc_save(tmpdir, trainer.logger)
checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(str(tmpdir))
with pytest.deprecated_call(match=r"`CheckpointConnector.hpc_load\(\)` was deprecated in v1.4"):
trainer.checkpoint_connector.hpc_load(checkpoint_path)
2 changes: 1 addition & 1 deletion tests/helpers/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def run_model_test(
trainer.checkpoint_connector.hpc_save(save_dir, logger)
# test HPC loading
checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(save_dir)
trainer.checkpoint_connector.hpc_load(checkpoint_path, on_gpu=on_gpu)
trainer.checkpoint_connector.restore(checkpoint_path)


@torch.no_grad()
Expand Down
2 changes: 1 addition & 1 deletion tests/models/data/horovod/train_default_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def training_epoch_end(self, outputs) -> None:
trainer.checkpoint_connector.hpc_save(ckpt_path, trainer.logger)
# test HPC loading
checkpoint_path = trainer.checkpoint_connector.get_max_ckpt_path_from_folder(ckpt_path)
trainer.checkpoint_connector.hpc_load(checkpoint_path, on_gpu=on_gpu)
trainer.checkpoint_connector.restore(checkpoint_path)

if on_gpu:
trainer = Trainer(gpus=1, accelerator='horovod', max_epochs=1)
Expand Down
13 changes: 13 additions & 0 deletions tests/trainer/connectors/test_callback_connector.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from unittest.mock import Mock

Expand Down
155 changes: 155 additions & 0 deletions tests/trainer/connectors/test_checkpoint_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest.mock import Mock

import torch

from pytorch_lightning import Trainer
from tests.helpers import BoringModel


class HPCHookdedModel(BoringModel):

def __init__(self):
super().__init__()
self.hpc_save_called = 0
self.hpc_load_called = 0

def on_hpc_save(self, checkpoint):
assert "state_dict" in checkpoint
self.hpc_save_called += 1

def on_hpc_load(self, checkpoint):
assert "state_dict" in checkpoint
self.hpc_load_called += 1


def test_hpc_hook_calls(tmpdir):
model = HPCHookdedModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=1,
checkpoint_callback=False,
logger=False,
)
trainer.fit(model)
connector = trainer.checkpoint_connector
connector.hpc_save(tmpdir, logger=Mock())
assert model.hpc_save_called == 1
assert model.hpc_load_called == 0

# new training run, restore from hpc checkpoint file automatically
assert set(os.listdir(tmpdir)) == {"hpc_ckpt_1.ckpt"}
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=1,
checkpoint_callback=False,
logger=False,
)
trainer.fit(model)
assert model.hpc_save_called == 1
assert model.hpc_load_called == 1


def test_preloaded_checkpoint_lifecycle(tmpdir):
""" Tests that the preloaded checkpoint contents gets cleared from memory when it is not required anymore. """
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=1,
)
trainer.fit(model)

connector = trainer.checkpoint_connector

assert not trainer.resume_from_checkpoint
assert not connector.resume_checkpoint_path
assert not connector._loaded_checkpoint

connector.resume_start()
assert not connector.resume_checkpoint_path
assert not connector._loaded_checkpoint
connector.resume_end()
assert not connector.resume_checkpoint_path
assert not connector._loaded_checkpoint

ckpt_path = trainer.checkpoint_callback.best_model_path
trainer = Trainer(default_root_dir=tmpdir, max_steps=2, resume_from_checkpoint=ckpt_path)
connector = trainer.checkpoint_connector
connector.resume_start()
assert connector.resume_checkpoint_path == ckpt_path
assert connector._loaded_checkpoint
assert isinstance(connector._loaded_checkpoint, dict)
connector.resume_end()
assert not connector.resume_checkpoint_path
assert not connector._loaded_checkpoint


def test_hpc_restore_attempt(tmpdir):
""" Test that restore() attempts to restore the hpc_ckpt with highest priority. """
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=1,
checkpoint_callback=False,
logger=False,
)
trainer.fit(model)

hpc_ckpt_path = tmpdir / "hpc_ckpt_3.ckpt"
trainer.save_checkpoint(hpc_ckpt_path)
assert os.listdir(tmpdir) == ["hpc_ckpt_3.ckpt"]

# set weights to zero
for param in model.parameters():
torch.nn.init.constant_(param, 0)

# case 1: restore hpc first, no explicit resume path provided
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=2,
checkpoint_callback=False,
logger=False,
)
trainer.fit(model)

for param in model.parameters():
assert param.abs().sum() > 0
torch.nn.init.constant_(param, 0)

# case 2: explicit resume path provided, restore hpc anyway
trainer = Trainer(default_root_dir=tmpdir, max_steps=3, resume_from_checkpoint="not existing")
trainer.fit(model)

for param in model.parameters():
assert param.abs().sum() > 0


def test_hpc_max_ckpt_version(tmpdir):
""" Test that the CheckpointConnector is able to find the hpc checkpoint file with the highest version. """
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=1,
)
trainer.fit(model)
trainer.save_checkpoint(tmpdir / "hpc_ckpt.ckpt")
trainer.save_checkpoint(tmpdir / "hpc_ckpt_0.ckpt")
trainer.save_checkpoint(tmpdir / "hpc_ckpt_3.ckpt")
trainer.save_checkpoint(tmpdir / "hpc_ckpt_33.ckpt")

assert trainer.checkpoint_connector.hpc_resume_path == str(tmpdir / "hpc_ckpt_33.ckpt")
assert trainer.checkpoint_connector.max_ckpt_version_in_folder(tmpdir) == 33
assert trainer.checkpoint_connector.max_ckpt_version_in_folder(tmpdir / "not" / "existing") is None