Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Cleanup ModelCheckpoint / EarlyStopping by moving logic to LoggerConnector #5218

Merged
merged 19 commits into from
Jan 7, 2021
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch

from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn, _TPU_AVAILABLE
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn


class EarlyStopping(Callback):
Expand Down Expand Up @@ -199,12 +199,6 @@ def _run_early_stopping_check(self, trainer, pl_module):
# when in dev debugging
trainer.dev_debugger.track_early_stopping_history(self, current)

if not isinstance(current, torch.Tensor):
current = torch.tensor(current, device=pl_module.device)

if trainer.use_tpu and _TPU_AVAILABLE:
current = current.cpu()

if self.monitor_op(current - self.min_delta, self.best_score):
self.best_score = current
self.wait_count = 0
Expand Down
5 changes: 1 addition & 4 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,9 +574,6 @@ def _save_top_k_checkpoints(self, metrics, trainer, pl_module, filepath):
epoch = metrics.get("epoch")
step = metrics.get("step")

if not isinstance(current, torch.Tensor) and current is not None:
current = torch.tensor(current, device=pl_module.device)

if self.check_monitor_top_k(current):
self._update_best_and_save(filepath, current, epoch, step, trainer, pl_module)
elif self.verbose:
Expand Down Expand Up @@ -605,7 +602,7 @@ def _update_best_and_save(
del_list.append(delpath)

# do not save nan, replace with +/- inf
if torch.isnan(current):
if isinstance(current, torch.Tensor) and torch.isnan(current):
current = torch.tensor(float('inf' if self.mode == "min" else '-inf'))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this always be on cpu? or should it be on current.device?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great question ! I am not sure. What do you think ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO it should live on current.device, since all the other tensors (especially current if not nan) also live on this device.


# save the current score
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,10 @@ def get_forked_metrics(self, add_dataloader_idx=False):
dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx)

if options['forked']:
result[dl_key] = self[k]
if isinstance(self[k], Metric):
result[dl_key] = self[k].compute().detach()
tchaton marked this conversation as resolved.
Show resolved Hide resolved
else:
result[dl_key] = self[k]

return result

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def update_logger_connector(self) -> None:

if is_train:
# Only log and add to callback epoch step during evaluation, test.
logger_connector.logged_metrics.update(batch_log_metrics)
logger_connector._logged_metrics.update(batch_log_metrics)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
callback_metrics.update(batch_pbar_metrics)
callback_metrics.update(batch_log_metrics)
else:
Expand All @@ -385,8 +385,8 @@ def update_logger_connector(self) -> None:

# get logged_metrics
epoch_log_metrics = self.get_epoch_log_metrics()
logger_connector.logged_metrics.update(epoch_log_metrics)
logger_connector.logged_metrics.update(epoch=self.trainer.current_epoch)
logger_connector._logged_metrics.update(epoch_log_metrics)
logger_connector._logged_metrics.update({"epoch": self.trainer.current_epoch})

# get forked_metrics
forked_metrics = self.get_forked_metrics()
Expand All @@ -396,11 +396,11 @@ def update_logger_connector(self) -> None:
callback_metrics.update(forked_metrics)

if not is_train:
logger_connector.evaluation_callback_metrics.update(callback_metrics)
logger_connector._evaluation_callback_metrics.update(callback_metrics)

# update callback_metrics
logger_connector.callback_metrics.update(callback_metrics)
logger_connector.callback_metrics.pop("epoch", None)
logger_connector._callback_metrics.update(callback_metrics)
logger_connector._callback_metrics.pop("epoch", None)

batch_pbar_metrics.pop("debug_epoch", None)
return batch_pbar_metrics, batch_log_metrics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import os
from copy import deepcopy
from pprint import pprint
from typing import Iterable, Union
from typing import Any, Iterable, Union, Dict

import torch

Expand All @@ -23,6 +23,7 @@
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator
from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore, LoggerStages
from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder
from pytorch_lightning.utilities import flatten_dict
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand All @@ -31,19 +32,64 @@
class LoggerConnector:
def __init__(self, trainer):
self.trainer = trainer
self.callback_metrics = {}
self.evaluation_callback_metrics = {}
self.logged_metrics = {}
self.progress_bar_metrics = {}
self._callback_metrics = MetricsHolder()
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self._evaluation_callback_metrics = MetricsHolder(to_float=True)
self._logged_metrics = MetricsHolder()
self._progress_bar_metrics = MetricsHolder()
self.eval_loop_results = []
self._cached_results = {stage: EpochResultStore(trainer, stage) for stage in LoggerStages}
self._callback_hook_validator = CallbackHookNameValidator()
self._current_stage = None

@property
def callback_metrics(self) -> Dict:
return self.get_metrics("callback_metrics")
tchaton marked this conversation as resolved.
Show resolved Hide resolved

@callback_metrics.setter
def callback_metrics(self, callback_metrics: Dict) -> None:
self.set_metrics("callback_metrics", callback_metrics)

@property
def evaluation_callback_metrics(self) -> Dict:
return self.get_metrics("evaluation_callback_metrics")

@evaluation_callback_metrics.setter
def evaluation_callback_metrics(self, evaluation_callback_metrics: Dict) -> None:
self.set_metrics("evaluation_callback_metrics", evaluation_callback_metrics)

@property
def logged_metrics(self) -> Dict:
return self.get_metrics("logged_metrics")

@logged_metrics.setter
def logged_metrics(self, logged_metrics: Dict) -> None:
self.set_metrics("logged_metrics", logged_metrics)

@property
def progress_bar_metrics(self) -> Dict:
return self.get_metrics("progress_bar_metrics")

@progress_bar_metrics.setter
def progress_bar_metrics(self, progress_bar_metrics: Dict) -> None:
self.set_metrics("progress_bar_metrics", progress_bar_metrics)

@property
def cached_results(self) -> Union[EpochResultStore, None]:
return self._cached_results.get(self._current_stage) # type: ignore

def get_metrics(self, key: str) -> Dict:
metrics_holder = getattr(self, f"_{key}", None)
model_ref = self.trainer.get_model()
metrics_holder.convert(
self.trainer.use_tpu,
model_ref.device if model_ref is not None else model_ref
)
return metrics_holder.metrics

def set_metrics(self, key: str, val: Any) -> None:
metrics_holder = getattr(self, f"_{key}", None)
metrics_holder.reset(val)

def set_stage(self, stage_or_testing: Union[str, bool], reset: bool = False) -> None:
self._current_stage = LoggerStages.determine_stage(stage_or_testing)
if reset:
Expand Down Expand Up @@ -153,10 +199,10 @@ def cache_training_step_metrics(self, opt_closure_result):
if len(pbar_metrics_tmp) > 0:
self.add_progress_bar_metrics(pbar_metrics_tmp)

self.callback_metrics.update(callback_metrics_tmp)
self._callback_metrics.update(callback_metrics_tmp)

# save legacy log metrics
self.logged_metrics.update(logged_metrics_tmp)
self._logged_metrics.update(logged_metrics_tmp)
self.cached_results.legacy_batch_log_metrics.update(logged_metrics_tmp)

def log_metrics(self, metrics, grad_norm_dic, step=None, log_train_step_metrics=False):
Expand Down Expand Up @@ -209,7 +255,7 @@ def add_progress_bar_metrics(self, metrics):
if isinstance(v, torch.Tensor):
v = v.item()

self.progress_bar_metrics[k] = v
self._progress_bar_metrics.metrics[k] = v

self.trainer.dev_debugger.track_pbar_metrics_history(metrics)

Expand Down Expand Up @@ -275,11 +321,11 @@ def _track_callback_metrics(self, eval_results, using_eval_result):
if using_eval_result:
if isinstance(eval_results, list):
for eval_result in eval_results:
self.trainer.logger_connector.callback_metrics.update(eval_result.callback_metrics)
self.trainer.logger_connector.evaluation_callback_metrics.update(eval_result.callback_metrics)
self.trainer.logger_connector._callback_metrics.update(eval_result.callback_metrics)
self.trainer.logger_connector._evaluation_callback_metrics.update(eval_result.callback_metrics)
else:
self.trainer.logger_connector.callback_metrics.update(eval_results.callback_metrics)
self.trainer.logger_connector.evaluation_callback_metrics.update(eval_results.callback_metrics)
self.trainer.logger_connector._callback_metrics.update(eval_results.callback_metrics)
self.trainer.logger_connector._evaluation_callback_metrics.update(eval_results.callback_metrics)
else:
flat = {}
if isinstance(eval_results, list):
Expand All @@ -294,8 +340,8 @@ def _track_callback_metrics(self, eval_results, using_eval_result):
if 'val_loss' in flat:
flat['checkpoint_on'] = flat['val_loss']
flat['early_stop_on'] = flat['val_loss']
self.trainer.logger_connector.callback_metrics.update(flat)
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)
self.trainer.logger_connector._callback_metrics.update(flat)
self.trainer.logger_connector._evaluation_callback_metrics.update(flat)
else:
# with a scalar return, auto set it to "val_loss" for callbacks
if isinstance(eval_results, torch.Tensor):
Expand All @@ -307,8 +353,8 @@ def _track_callback_metrics(self, eval_results, using_eval_result):
if 'val_loss' in flat:
flat['checkpoint_on'] = flat['val_loss']
flat['early_stop_on'] = flat['val_loss']
self.trainer.logger_connector.callback_metrics.update(flat)
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)
self.trainer.logger_connector._callback_metrics.update(flat)
self.trainer.logger_connector._evaluation_callback_metrics.update(flat)

def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metrics, log_metrics, callback_metrics):
# eval loop returns all metrics
Expand All @@ -324,8 +370,8 @@ def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metric
# track metrics for callbacks (all prog bar, logged and callback metrics)
callback_metrics.update(log_metrics)
callback_metrics.update(prog_bar_metrics)
self.trainer.logger_connector.callback_metrics.update(callback_metrics)
self.trainer.logger_connector.evaluation_callback_metrics.update(callback_metrics)
self.trainer.logger_connector._callback_metrics.update(callback_metrics)
self.trainer.logger_connector._evaluation_callback_metrics.update(callback_metrics)

if len(dataloader_result_metrics) > 0:
self.eval_loop_results.append(dataloader_result_metrics)
Expand Down Expand Up @@ -435,15 +481,15 @@ def log_train_epoch_end_metrics(
# add the metrics to the loggers and callbacks
if epoch_log_metrics and len(epoch_log_metrics) > 0:
self.log_metrics(epoch_log_metrics, {})
self.callback_metrics.update(epoch_log_metrics)
self._callback_metrics.update(epoch_log_metrics)

# add metrics to callbacks
self.callback_metrics.update(epoch_callback_metrics)
self._callback_metrics.update(epoch_callback_metrics)

# add metrics to progress_bar and callbacks
if len(epoch_progress_bar_metrics) > 0:
self.add_progress_bar_metrics(epoch_progress_bar_metrics)
self.callback_metrics.update(epoch_progress_bar_metrics)
self._callback_metrics.update(epoch_progress_bar_metrics)

# reset epoch loop result for next epoch
self.cached_results.reset()
Expand Down Expand Up @@ -599,4 +645,4 @@ def log_train_step_metrics(self, batch_output):
grad_norm_dic = {}
if len(batch_log_metrics) > 0 or len(grad_norm_dic) > 0:
self.log_metrics(batch_log_metrics, grad_norm_dic, log_train_step_metrics=True)
self.callback_metrics.update(batch_log_metrics)
self._callback_metrics.update(batch_log_metrics)
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright The PyTorch Lightning team.
tchaton marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
from typing import Any

import torch

from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.utilities import _TPU_AVAILABLE


class MetricsHolder:
tchaton marked this conversation as resolved.
Show resolved Hide resolved

"""
This class acts as a dictonary holder.
It holds metris and implement convertion functions.
tchaton marked this conversation as resolved.
Show resolved Hide resolved
Those functions will be triggered within LoggerConnector
when the property is being requested from the user.
"""

def __init__(self, to_float: bool = False):
self.metrics = {}
self._to_float = to_float

def update(self, metrics):
self.metrics.update(metrics)

def pop(self, key, default):
return self.metrics.pop(key, default)

def reset(self, metrics):
self.metrics = metrics

def convert(self, use_tpu: bool, device: torch.device):
for key, value in self.metrics.items():
self.metrics[key] = self._convert(value, use_tpu, device)

def _convert(self, current: Any, use_tpu: bool, device: torch.device):
if self._to_float:
return self._convert_to_float(current, use_tpu, device)
return self._convert_to_tensor(current, use_tpu, device)

def _convert_to_float(self, current, use_tpu: bool, device: torch.device):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(current, Metric):
current = current.compute().detach()

if isinstance(current, torch.Tensor):
current = float(current.item())

elif isinstance(current, int):
current = float(current)

return current

def _convert_to_tensor(self, current: Any, use_tpu: bool, device: torch.device):
if current is not None:
if isinstance(current, Metric):
current = current.compute().detach()

elif isinstance(current, numbers.Number):
if device is None:
current = torch.tensor(current, dtype=torch.float)
else:
current = torch.tensor(current, device=device, dtype=torch.float)

if use_tpu and _TPU_AVAILABLE:
current = current.cpu()

return current
1 change: 0 additions & 1 deletion tests/deprecated_api/test_remove_1-3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.profiler.profilers import PassThroughProfiler, SimpleProfiler
from tests.deprecated_api import _soft_unimport_module


def test_v1_3_0_deprecated_arguments(tmpdir):
Expand Down
Loading