Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save the ResultCollection in the loops state dict #8641

Merged
merged 35 commits into from
Aug 2, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
f2f6858
wip
tchaton Jul 29, 2021
4307a05
resolve some issues
tchaton Jul 29, 2021
e63c560
add ResultCollection
tchaton Jul 30, 2021
bd91665
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 30, 2021
ac9e9f1
add comments
tchaton Jul 30, 2021
6e609f1
Merge branch 'add_support_for_logging' of https://github.com/PyTorchL…
tchaton Jul 30, 2021
46078e9
update changelog
tchaton Jul 30, 2021
f1d50d2
wip
tchaton Jul 29, 2021
aeaeee6
resolve some issues
tchaton Jul 29, 2021
a9368e9
add ResultCollection
tchaton Jul 30, 2021
c174917
add comments
tchaton Jul 30, 2021
3b7370a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 30, 2021
a412825
update changelog
tchaton Jul 30, 2021
41a5156
Reuse key definition
carmocca Jul 30, 2021
df2beae
updates on comments
tchaton Jul 30, 2021
8386cfc
Merge branch 'add_support_for_logging' of https://github.com/PyTorchL…
tchaton Jul 30, 2021
d978b2c
update
tchaton Jul 30, 2021
a493d1f
Indentation and comments
carmocca Jul 30, 2021
0aa5659
apply comments
tchaton Jul 30, 2021
7f00a8b
update on comments
tchaton Aug 1, 2021
768e4ea
resolve tests
tchaton Aug 1, 2021
dfbb051
Merge branch 'master' into add_support_for_logging
tchaton Aug 1, 2021
7416272
typo
tchaton Aug 1, 2021
07888e2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 1, 2021
7349ad1
update
tchaton Aug 1, 2021
b99905b
Merge branch 'add_support_for_logging' of https://github.com/PyTorchL…
tchaton Aug 1, 2021
e451594
Update pytorch_lightning/trainer/connectors/checkpoint_connector.py
tchaton Aug 2, 2021
ec61374
Refactor test
carmocca Aug 2, 2021
d7c72de
add comments
tchaton Aug 2, 2021
5d18496
nit
carmocca Aug 2, 2021
a579b99
update
tchaton Aug 2, 2021
0e377dd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 2, 2021
5d36135
Merge branch 'master' into add_support_for_logging
tchaton Aug 2, 2021
b7ce5ad
Merge branch 'master' into add_support_for_logging
tchaton Aug 2, 2021
923c74b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 2, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `state_id` property to the `Callback` base class ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886))


- Added `ResultCollection` state_dict to Loop `state_dict` and support for distributed reload. ([#8641](https://github.com/PyTorchLightning/pytorch-lightning/pull/8641))


-


Expand All @@ -28,7 +31,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Load ckpt path when model provided in validate/test/predict ([#8352](https://github.com/PyTorchLightning/pytorch-lightning/pull/8352)))



- Saved checkpoints will no longer use the type of a `Callback` as the key to avoid issues with unpickling ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886))


Expand Down
42 changes: 39 additions & 3 deletions pytorch_lightning/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
from typing import Any, Dict, Optional

from deprecate import void
from torchmetrics import Metric

import pytorch_lightning as pl
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import BaseProgress, Progress
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -177,21 +179,55 @@ def state_dict(self, destination: Optional[Dict] = None, prefix: Optional[str] =
destination[prefix + k] = v.state_dict()
elif isinstance(v, Loop):
v.state_dict(destination, prefix + k + ".")
elif isinstance(v, ResultCollection):
# sync / unsync metrics
v.sync()
destination[prefix + k] = v.state_dict()
v.unsync()

return destination

def load_state_dict(self, state_dict: Dict, prefix: str = "", restart_progress: bool = True) -> None:
def load_state_dict(
self,
state_dict: Dict,
prefix: str = "",
restart_progress: bool = True,
metrics: Optional[Dict[str, Metric]] = None,
) -> None:
"""Loads the state of this loop and all its children."""
self._load_from_state_dict(state_dict.copy(), prefix, restart_progress)
self._load_from_state_dict(state_dict.copy(), prefix, restart_progress, metrics)
for k, v in self.__dict__.items():
if isinstance(v, Loop):
v.load_state_dict(state_dict.copy(), prefix + k + ".", restart_progress)

def _load_from_state_dict(self, state_dict: Dict, prefix: str, restart_progress: bool) -> None:
def _load_from_state_dict(
self, state_dict: Dict, prefix: str, restart_progress: bool, metrics: Optional[Dict[str, Metric]] = None
) -> None:
for k, v in self.__dict__.items():
if isinstance(v, BaseProgress):
v.load_state_dict(state_dict[prefix + k])
if restart_progress:
apply_to_collection(v, Progress, lambda p: p.current.reset_on_restart())

elif isinstance(v, ResultCollection):
if isinstance(self.trainer, pl.Trainer) and getattr(self.trainer, "lightning_module", None) is not None:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
metric_attributes = {
name: module
for name, module in self.trainer.lightning_module.named_modules()
if isinstance(module, Metric)
}
if metrics:
metric_attributes.update(metrics)
carmocca marked this conversation as resolved.
Show resolved Hide resolved

# re-attach metrics references
v.load_state_dict(
state_dict[prefix + k],
metrics=metric_attributes,
sync_fn=self.trainer.training_type_plugin.reduce,
)

if not self.trainer.is_global_zero:
v.reset(metrics=False)

self.on_load_checkpoint(state_dict[prefix + "state_dict"])
self.restarting = True
3 changes: 2 additions & 1 deletion pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,8 @@ def _share_information_to_prevent_deadlock(self):
self._share_pids()

# remove `PL_DDP_SYNC_TMPDIR` from os.environ
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self._sync_dir = os.environ.pop("PL_DDP_SYNC_TMPDIR", None)
# FIXME: Add better support for deadlock detection. Changed TMPDIR at on every trainer.{call_fn}.
self._sync_dir = os.getenv("PL_DDP_SYNC_TMPDIR")

def _share_pids(self):
"""
Expand Down
32 changes: 28 additions & 4 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
import os
import re
from pathlib import Path
from typing import Optional, Union
from typing import Any, Dict, Optional, Union

import torch
from torchmetrics import Metric

import pytorch_lightning as pl
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info, rank_zero_warn
Expand Down Expand Up @@ -141,6 +142,12 @@ def restore_model(self) -> None:
# restore model state_dict
self.trainer.training_type_plugin.load_model_state_dict(self._loaded_checkpoint)

# reset metrics states on non-rank 0 as the states have been synced on-saving.
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if not self.trainer.is_global_zero:
for module in self.trainer.lightning_module.modules():
if isinstance(module, Metric):
module.reset()

def restore_model_weights(self, checkpoint_path: Optional[Union[str, Path]]) -> None:
"""Restore only the model weights."""
checkpoint = self._loaded_checkpoint
Expand Down Expand Up @@ -212,7 +219,7 @@ def restore_loops(self) -> None:
" consider using an end of epoch checkpoint."
)

state_dict = self._loaded_checkpoint.get("loops")
state_dict = self._loaded_checkpoint.pop("loops", None)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if state_dict:
self.trainer.fit_loop.load_state_dict(state_dict["fit_loop"])
self.trainer.validate_loop.load_state_dict(state_dict["validate_loop"])
Expand Down Expand Up @@ -338,7 +345,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
"epoch": current_epoch,
"global_step": global_step,
"pytorch-lightning_version": pl.__version__,
"state_dict": self.trainer.accelerator.lightning_module_state_dict(),
"state_dict": self._get_lightning_module_state_dict(),
}
if _fault_tolerant_enabled():
checkpoint["loops"] = self._get_loops_state_dict()
Expand Down Expand Up @@ -440,7 +447,24 @@ def save_checkpoint(self, filepath, weights_only: bool = False) -> None:
_checkpoint = self.dump_checkpoint(weights_only)
self.trainer.accelerator.save_checkpoint(_checkpoint, filepath)

def _get_loops_state_dict(self):
def _get_lightning_module_state_dict(self) -> Dict[str, torch.Tensor]:
if _fault_tolerant_enabled():
metrics = [m for m in self.trainer.lightning_module.modules() if isinstance(m, Metric)]
for metric in metrics:
metric.persistent(True)
if not metric._is_synced:
metric.sync()
tchaton marked this conversation as resolved.
Show resolved Hide resolved

state_dict = self.trainer.accelerator.lightning_module_state_dict()

if _fault_tolerant_enabled():
for metric in metrics:
if metric._is_synced:
metric.unsync()
carmocca marked this conversation as resolved.
Show resolved Hide resolved

return state_dict

def _get_loops_state_dict(self) -> Dict[str, Any]:
return {
"fit_loop": self.trainer.fit_loop.state_dict(),
"validate_loop": self.trainer.validate_loop.state_dict(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,13 +251,14 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}({state})"

def __getstate__(self, drop_value: bool = False) -> dict:
skip = ["update", "compute", "_update_signature"]
skip = ["update", "compute", "_update_signature", "_cache"]
if not self.is_tensor and drop_value:
# Avoid serializing ResultMetrics which are passed Metrics
skip.append("value")
d = {k: v for k, v in self.__dict__.items() if k not in skip}
d["meta"] = d["meta"].__getstate__()
d["_class"] = self.__class__.__name__
d["_is_synced"] = False # don't consider the state as synced on reload
return d

def __setstate__(self, state: dict, sync_fn: Optional[Callable] = None) -> None:
Expand Down Expand Up @@ -604,6 +605,16 @@ def cpu(self) -> "ResultCollection":
"""Move all data to CPU."""
return self.to(device="cpu")

def sync(self) -> None:
for result_metric in self.result_metrics:
if result_metric.is_tensor and not result_metric._is_synced:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
result_metric.sync()

def unsync(self) -> None:
for result_metric in self.result_metrics:
if result_metric.is_tensor and result_metric._is_synced:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
result_metric.unsync()

def __str__(self) -> str:
return f"{self.__class__.__name__}({self.training}, {self.device}, {repr(self)})"

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1310,7 +1310,7 @@ def _log_device_info(self) -> None:
)

def _on_expection(self):
if not self.is_global_zero or not _fault_tolerant_enabled():
if not _fault_tolerant_enabled():
return
# save a checkpoint for fault tolerant training. we don't use `log_dir` to minimize the chances of failure.
file_path = os.path.join(self.default_root_dir, ".pl_auto_save.ckpt")
Expand Down
133 changes: 133 additions & 0 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import re
import time
from argparse import Namespace
from contextlib import suppress
from datetime import timedelta
from logging import INFO
from pathlib import Path
Expand All @@ -31,12 +32,14 @@
import yaml
from omegaconf import Container, OmegaConf
from torch import optim
from torchmetrics import Metric

import pytorch_lightning as pl
import tests.helpers.utils as tutils
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.trainer.connectors.logger_connector.result import MetricSource
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
Expand Down Expand Up @@ -1220,3 +1223,133 @@ def test_trainer_checkpoint_callback_bool(tmpdir):
mc = ModelCheckpoint(dirpath=tmpdir)
with pytest.raises(MisconfigurationException, match="Invalid type provided for checkpoint_callback"):
Trainer(checkpoint_callback=mc)


class DummyMetric(Metric):
def __init__(self):
super().__init__()
self.add_state("sum", torch.tensor(0), dist_reduce_fx=torch.sum)
self.add_state("count", torch.tensor(0), dist_reduce_fx=torch.sum)

def update(self, increment):
self.sum += increment
self.count += 1

def compute(self):
return self.sum // self.count

def __repr__(self) -> str:
return f"{self.__class__.__name__}(sum={self.sum}, count={self.count})"


def result_collection_reload(trainer_kwargs):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
num_processes = trainer_kwargs.get("gpus", 1)

class CustomException(Exception):
pass

class ExtendedBoringModel(BoringModel):
def __init__(self):
super().__init__()
self.has_reloaded = False
self.breaking_batch_idx = 3
self.has_validated_sum = False
self.dummy_metric = DummyMetric()

def training_step(self, batch, batch_idx):
assert len(batch) == 1
if self.has_reloaded:
self.log("tracking", batch_idx, on_step=True, on_epoch=True)
self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True)

self.dummy_metric(batch_idx)
self.log("tracking_metric", self.dummy_metric, on_step=True, on_epoch=True)

value = self.trainer.fit_loop._results["training_step.tracking_metric"].value
value_2 = self.trainer.fit_loop._results["training_step.tracking"].value
shift = 0
if num_processes == 2:
shift = 3 if self.trainer.is_global_zero else -3
expected = sum(range(batch_idx + 1)) + shift
assert expected == value == value_2
else:
if batch_idx == self.breaking_batch_idx:
# simulate failure mid epoch
raise CustomException

self.log("tracking", batch_idx, on_step=True, on_epoch=True)
self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True)

self.dummy_metric(batch_idx)
self.log("tracking_metric", self.dummy_metric, on_step=True, on_epoch=True)

value = self.trainer.fit_loop._results["training_step.tracking"].value
assert value == sum(range(batch_idx + 1))

value = self.trainer.fit_loop._results["training_step.tracking_2"]
assert value == sum(range(batch_idx + 1))

return super().training_step(batch, batch_idx)

def on_epoch_end(self) -> None:
if self.has_reloaded:
total = sum(range(5)) * num_processes
metrics = self.trainer.fit_loop._results.metrics(on_step=False)
assert self.trainer.fit_loop._results["training_step.tracking"].value == total
assert metrics[MetricSource.CALLBACK]["tracking"] == self.dummy_metric.compute() == 2
assert self.trainer.fit_loop._results["training_step.tracking_2"].value == total
assert metrics[MetricSource.CALLBACK]["tracking_2"] == self.dummy_metric.compute() == 2
self.has_validated_sum = True

model = ExtendedBoringModel()

trainer = Trainer(**trainer_kwargs)

with suppress(CustomException):
trainer.fit(model)

checkpoint_path = trainer.accelerator.broadcast(os.path.join(trainer_kwargs["default_root_dir"], "ckpt.pt"))
trainer.save_checkpoint(checkpoint_path)

trainer.accelerator.barrier()

if trainer.is_global_zero:
checkpoint = torch.load(checkpoint_path)
assert checkpoint["state_dict"]["dummy_metric.sum"] == 3 * num_processes

trainer_kwargs["resume_from_checkpoint"] = checkpoint_path
trainer_kwargs["max_epochs"] = 2

trainer = Trainer(**trainer_kwargs)
model.has_reloaded = True
trainer.fit(model)
assert model.has_validated_sum


@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
def test_result_collection_reload(tmpdir):
result_collection_reload(
{
"default_root_dir": tmpdir,
"max_epochs": 1,
"limit_train_batches": 5,
"limit_val_batches": 0,
"accelerator": "ddp",
"gpus": 1,
}
)


@RunIf(min_gpus=2, special=True)
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
def test_result_collection_reload_2_gpus(tmpdir):
result_collection_reload(
{
"default_root_dir": tmpdir,
"max_epochs": 1,
"limit_train_batches": 5,
"limit_val_batches": 0,
"accelerator": "ddp",
"gpus": 2,
}
)
1 change: 0 additions & 1 deletion tests/deprecated_api/test_remove_1-5.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import pytest
import torch
from torch import optim

from pytorch_lightning import Callback, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
Expand Down