Skip to content

Commit

Permalink
Deprecate and remove calls to agg_and_log_metrics (#11832)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <[email protected]>
  • Loading branch information
3 people authored Feb 18, 2022
1 parent 73e9ca3 commit d613719
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 35 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `pytorch_lightning.utilities.warnings.LightningDeprecationWarning` in favor of `pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning`


- Deprecated `LightningLoggerBase.agg_and_log_metrics` in favor of `LightningLoggerBase.log_metrics` ([#11832](https://github.com/PyTorchLightning/pytorch-lightning/pull/11832))


### Removed

- Removed deprecated parameter `method` in `pytorch_lightning.utilities.model_helpers.is_overridden` ([#10507](https://github.com/PyTorchLightning/pytorch-lightning/pull/10507))
Expand Down
27 changes: 23 additions & 4 deletions pytorch_lightning/loggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ def _aggregate_metrics(
) -> Tuple[int, Optional[Dict[str, float]]]:
"""Aggregates metrics.
.. deprecated:: v1.6
This method is deprecated in v1.6 and will be removed in v1.8.
Args:
metrics: Dictionary with metric names as keys and measured quantities as values
step: Step number at which the metrics should be recorded
Expand All @@ -126,7 +129,13 @@ def _aggregate_metrics(
return agg_step, agg_mets

def _reduce_agg_metrics(self):
"""Aggregate accumulated metrics."""
"""Aggregate accumulated metrics.
See deprecation warning below.
.. deprecated:: v1.6
This method is deprecated in v1.6 and will be removed in v1.8.
"""
# compute the metrics
if not self._metrics_to_agg:
agg_mets = None
Expand All @@ -137,7 +146,13 @@ def _reduce_agg_metrics(self):
return self._prev_step, agg_mets

def _finalize_agg_metrics(self):
"""This shall be called before save/close."""
"""This shall be called before save/close.
See deprecation warning below.
.. deprecated:: v1.6
This method is deprecated in v1.6 and will be removed in v1.8.
"""
agg_step, metrics_to_log = self._reduce_agg_metrics()
self._metrics_to_agg = []

Expand All @@ -148,6 +163,10 @@ def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = N
"""Aggregates and records metrics. This method doesn't log the passed metrics instantaneously, but instead
it aggregates them and logs only if metrics are ready to be logged.
.. deprecated:: v1.6
This method is deprecated in v1.6 and will be removed in v1.8.
Please use `LightningLoggerBase.log_metrics` instead.
Args:
metrics: Dictionary with metric names as keys and measured quantities as values
step: Step number at which the metrics should be recorded
Expand Down Expand Up @@ -272,11 +291,11 @@ def experiment(self) -> List[Any]:

def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
for logger in self._logger_iterable:
logger.agg_and_log_metrics(metrics, step)
logger.agg_and_log_metrics(metrics=metrics, step=step)

def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
for logger in self._logger_iterable:
logger.log_metrics(metrics, step)
logger.log_metrics(metrics=metrics, step=step)

def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
for logger in self._logger_iterable:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from pytorch_lightning.utilities import _AcceleratorType, memory
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.metrics import metrics_to_scalars
from pytorch_lightning.utilities.warnings import rank_zero_deprecation
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation


class LoggerConnector:
Expand All @@ -45,6 +46,7 @@ def __init__(self, trainer: "pl.Trainer", log_gpu_memory: Optional[str] = None)
self._current_fx: Optional[str] = None
self._batch_idx: Optional[int] = None
self._split_idx: Optional[int] = None
self._override_agg_and_log_metrics: bool = False

def on_trainer_init(
self,
Expand All @@ -64,6 +66,15 @@ def on_trainer_init(
self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps
self.trainer.log_every_n_steps = log_every_n_steps
self.trainer.move_metrics_to_cpu = move_metrics_to_cpu
for logger in self.trainer.loggers:
if is_overridden("agg_and_log_metrics", logger, LightningLoggerBase):
self._override_agg_and_log_metrics = True
rank_zero_deprecation(
"`LightningLoggerBase.agg_and_log_metrics` is deprecated in v1.6 and will be removed"
" in v1.8. `Trainer` will directly call `LightningLoggerBase.log_metrics` so custom"
" loggers should not implement `LightningLoggerBase.agg_and_log_metrics`."
)
break

@property
def should_flush_logs(self) -> bool:
Expand Down Expand Up @@ -114,7 +125,10 @@ def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None:
step = self.trainer.global_step

# log actual metrics
self.trainer.logger.agg_and_log_metrics(scalar_metrics, step=step)
if self._override_agg_and_log_metrics:
self.trainer.logger.agg_and_log_metrics(metrics=scalar_metrics, step=step)
else:
self.trainer.logger.log_metrics(metrics=scalar_metrics, step=step)
self.trainer.logger.save()

"""
Expand Down
34 changes: 33 additions & 1 deletion tests/deprecated_api/test_remove_1-8.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch import optim

from pytorch_lightning import Callback, Trainer
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
Expand All @@ -35,7 +36,7 @@
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.enums import DeviceType, DistributedType
from pytorch_lightning.utilities.imports import _TORCHTEXT_LEGACY
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
from tests.helpers.boring_model import BoringDataModule, BoringModel
from tests.helpers.runif import RunIf
from tests.helpers.torchtext_utils import get_dummy_torchtext_data_iterator
Expand Down Expand Up @@ -500,3 +501,34 @@ def on_before_accelerator_backend_setup(self, *args, **kwargs):
" and will be removed in v1.8"
):
trainer.fit(model)


def test_v1_8_0_deprecated_agg_and_log_metrics_override(tmpdir):
class AggregationOverrideLogger(CSVLogger):
@rank_zero_only
def agg_and_log_metrics(self, metrics, step):
self.log_metrics(metrics=metrics, step=step)

logger = AggregationOverrideLogger(tmpdir)
logger2 = CSVLogger(tmpdir)
logger3 = CSVLogger(tmpdir)

# Test single loggers
with pytest.deprecated_call(
match="`LightningLoggerBase.agg_and_log_metrics` is deprecated in v1.6 and will be removed"
" in v1.8. `Trainer` will directly call `LightningLoggerBase.log_metrics` so custom"
" loggers should not implement `LightningLoggerBase.agg_and_log_metrics`."
):
Trainer(logger=logger)
# Should have no deprecation warning
Trainer(logger=logger2)

# Test multiple loggers
with pytest.deprecated_call(
match="`LightningLoggerBase.agg_and_log_metrics` is deprecated in v1.6 and will be removed"
" in v1.8. `Trainer` will directly call `LightningLoggerBase.log_metrics` so custom"
" loggers should not implement `LightningLoggerBase.agg_and_log_metrics`."
):
Trainer(logger=[logger, logger3])
# Should have no deprecation warning
Trainer(logger=[logger2, logger3])
31 changes: 3 additions & 28 deletions tests/loggers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ def test_logger_collection():
mock1.update_agg_funcs.assert_called_once_with({"test": np.mean}, np.sum)
mock2.update_agg_funcs.assert_called_once_with({"test": np.mean}, np.sum)

logger.agg_and_log_metrics({"test": 2.0}, 4)
mock1.agg_and_log_metrics.assert_called_once_with({"test": 2.0}, 4)
mock2.agg_and_log_metrics.assert_called_once_with({"test": 2.0}, 4)
logger.log_metrics(metrics={"test": 2.0}, step=4)
mock1.log_metrics.assert_called_once_with(metrics={"test": 2.0}, step=4)
mock2.log_metrics.assert_called_once_with(metrics={"test": 2.0}, step=4)

logger.finalize("success")
mock1.finalize.assert_called_once()
Expand Down Expand Up @@ -225,31 +225,6 @@ def validation_epoch_end(self, outputs):
trainer.fit(model)


def test_with_accumulate_grad_batches():
"""Checks if the logging is performed once for `accumulate_grad_batches` steps."""

class StoreHistoryLogger(CustomLogger):
def __init__(self):
super().__init__()
self.history = {}

@rank_zero_only
def log_metrics(self, metrics, step):
if step not in self.history:
self.history[step] = {}
self.history[step].update(metrics)

logger = StoreHistoryLogger()

np.random.seed(42)
for i, loss in enumerate(np.random.random(10)):
logger.agg_and_log_metrics({"loss": loss}, step=int(i / 5))

assert logger.history == {0: {"loss": 0.5623850983416314}}
logger.save()
assert logger.history == {0: {"loss": 0.5623850983416314}, 1: {"loss": 0.4778883735637184}}


def test_dummyexperiment_support_indexing():
"""Test that the DummyExperiment can imitate indexing the experiment in a LoggerCollection."""
experiment = DummyExperiment()
Expand Down

0 comments on commit d613719

Please sign in to comment.