Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate and remove calls to agg_and_log_metrics #11832

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `pytorch_lightning.utilities.warnings.LightningDeprecationWarning` in favor of `pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning`


- Deprecated `LightningLoggerBase.agg_and_log_metrics` in favor of `LightningLoggerBase.log_metrics` ([#11832](https://github.com/PyTorchLightning/pytorch-lightning/pull/11832))


### Removed

- Removed deprecated parameter `method` in `pytorch_lightning.utilities.model_helpers.is_overridden` ([#10507](https://github.com/PyTorchLightning/pytorch-lightning/pull/10507))
Expand Down
27 changes: 23 additions & 4 deletions pytorch_lightning/loggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ def _aggregate_metrics(
) -> Tuple[int, Optional[Dict[str, float]]]:
"""Aggregates metrics.

.. deprecated:: v1.6
This method is deprecated in v1.6 and will be removed in v1.8.

Args:
metrics: Dictionary with metric names as keys and measured quantities as values
step: Step number at which the metrics should be recorded
Expand All @@ -126,7 +129,13 @@ def _aggregate_metrics(
return agg_step, agg_mets

def _reduce_agg_metrics(self):
"""Aggregate accumulated metrics."""
"""Aggregate accumulated metrics.

See deprecation warning below.

.. deprecated:: v1.6
This method is deprecated in v1.6 and will be removed in v1.8.
"""
# compute the metrics
if not self._metrics_to_agg:
agg_mets = None
Expand All @@ -137,7 +146,13 @@ def _reduce_agg_metrics(self):
return self._prev_step, agg_mets

def _finalize_agg_metrics(self):
"""This shall be called before save/close."""
"""This shall be called before save/close.

See deprecation warning below.

.. deprecated:: v1.6
This method is deprecated in v1.6 and will be removed in v1.8.
"""
agg_step, metrics_to_log = self._reduce_agg_metrics()
self._metrics_to_agg = []

Expand All @@ -148,6 +163,10 @@ def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = N
"""Aggregates and records metrics. This method doesn't log the passed metrics instantaneously, but instead
it aggregates them and logs only if metrics are ready to be logged.

.. deprecated:: v1.6
This method is deprecated in v1.6 and will be removed in v1.8.
Please use `LightningLoggerBase.log_metrics` instead.

Args:
metrics: Dictionary with metric names as keys and measured quantities as values
step: Step number at which the metrics should be recorded
Expand Down Expand Up @@ -286,11 +305,11 @@ def experiment(self) -> List[Any]:

def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
for logger in self._logger_iterable:
logger.agg_and_log_metrics(metrics, step)
logger.agg_and_log_metrics(metrics=metrics, step=step)

def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
for logger in self._logger_iterable:
logger.log_metrics(metrics, step)
logger.log_metrics(metrics=metrics, step=step)

def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
for logger in self._logger_iterable:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from pytorch_lightning.utilities import _AcceleratorType, memory
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.metrics import metrics_to_scalars
from pytorch_lightning.utilities.warnings import rank_zero_deprecation
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation


class LoggerConnector:
Expand All @@ -45,6 +46,7 @@ def __init__(self, trainer: "pl.Trainer", log_gpu_memory: Optional[str] = None)
self._current_fx: Optional[str] = None
self._batch_idx: Optional[int] = None
self._split_idx: Optional[int] = None
self._override_agg_and_log_metrics: bool = False

def on_trainer_init(
self,
Expand All @@ -64,6 +66,15 @@ def on_trainer_init(
self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps
self.trainer.log_every_n_steps = log_every_n_steps
self.trainer.move_metrics_to_cpu = move_metrics_to_cpu
for logger in self.trainer.loggers:
if is_overridden("agg_and_log_metrics", logger, LightningLoggerBase):
self._override_agg_and_log_metrics = True
rank_zero_deprecation(
"`LightningLoggerBase.agg_and_log_metrics` is deprecated in v1.6 and will be removed"
" in v1.8. `Trainer` will directly call `LightningLoggerBase.log_metrics` so custom"
" loggers should not implement `LightningLoggerBase.agg_and_log_metrics`."
)
break

@property
def should_flush_logs(self) -> bool:
Expand Down Expand Up @@ -114,7 +125,10 @@ def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None:
step = self.trainer.global_step

# log actual metrics
self.trainer.logger.agg_and_log_metrics(scalar_metrics, step=step)
if self._override_agg_and_log_metrics:
self.trainer.logger.agg_and_log_metrics(metrics=scalar_metrics, step=step)
akashkw marked this conversation as resolved.
Show resolved Hide resolved
else:
self.trainer.logger.log_metrics(metrics=scalar_metrics, step=step)
self.trainer.logger.save()

"""
Expand Down
34 changes: 33 additions & 1 deletion tests/deprecated_api/test_remove_1-8.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torch import optim

from pytorch_lightning import Callback, Trainer
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
Expand All @@ -35,7 +36,7 @@
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.enums import DeviceType, DistributedType
from pytorch_lightning.utilities.imports import _TORCHTEXT_LEGACY
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
from tests.helpers.boring_model import BoringDataModule, BoringModel
from tests.helpers.runif import RunIf
from tests.helpers.torchtext_utils import get_dummy_torchtext_data_iterator
Expand Down Expand Up @@ -500,3 +501,34 @@ def on_before_accelerator_backend_setup(self, *args, **kwargs):
" and will be removed in v1.8"
):
trainer.fit(model)


def test_v1_8_0_deprecated_agg_and_log_metrics_override(tmpdir):
class AggregationOverrideLogger(CSVLogger):
@rank_zero_only
def agg_and_log_metrics(self, metrics, step):
self.log_metrics(metrics=metrics, step=step)

logger = AggregationOverrideLogger(tmpdir)
logger2 = CSVLogger(tmpdir)
logger3 = CSVLogger(tmpdir)

# Test single loggers
with pytest.deprecated_call(
match="`LightningLoggerBase.agg_and_log_metrics` is deprecated in v1.6 and will be removed"
" in v1.8. `Trainer` will directly call `LightningLoggerBase.log_metrics` so custom"
" loggers should not implement `LightningLoggerBase.agg_and_log_metrics`."
):
Trainer(logger=logger)
# Should have no deprecation warning
Trainer(logger=logger2)

# Test multiple loggers
with pytest.deprecated_call(
match="`LightningLoggerBase.agg_and_log_metrics` is deprecated in v1.6 and will be removed"
" in v1.8. `Trainer` will directly call `LightningLoggerBase.log_metrics` so custom"
" loggers should not implement `LightningLoggerBase.agg_and_log_metrics`."
):
Trainer(logger=[logger, logger3])
# Should have no deprecation warning
Trainer(logger=[logger2, logger3])
31 changes: 3 additions & 28 deletions tests/loggers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ def test_logger_collection():
mock1.update_agg_funcs.assert_called_once_with({"test": np.mean}, np.sum)
mock2.update_agg_funcs.assert_called_once_with({"test": np.mean}, np.sum)

logger.agg_and_log_metrics({"test": 2.0}, 4)
mock1.agg_and_log_metrics.assert_called_once_with({"test": 2.0}, 4)
mock2.agg_and_log_metrics.assert_called_once_with({"test": 2.0}, 4)
logger.log_metrics(metrics={"test": 2.0}, step=4)
mock1.log_metrics.assert_called_once_with(metrics={"test": 2.0}, step=4)
mock2.log_metrics.assert_called_once_with(metrics={"test": 2.0}, step=4)

logger.finalize("success")
mock1.finalize.assert_called_once()
Expand Down Expand Up @@ -225,31 +225,6 @@ def validation_epoch_end(self, outputs):
trainer.fit(model)


def test_with_accumulate_grad_batches():
"""Checks if the logging is performed once for `accumulate_grad_batches` steps."""

class StoreHistoryLogger(CustomLogger):
def __init__(self):
super().__init__()
self.history = {}

@rank_zero_only
def log_metrics(self, metrics, step):
if step not in self.history:
self.history[step] = {}
self.history[step].update(metrics)

logger = StoreHistoryLogger()

np.random.seed(42)
for i, loss in enumerate(np.random.random(10)):
logger.agg_and_log_metrics({"loss": loss}, step=int(i / 5))

assert logger.history == {0: {"loss": 0.5623850983416314}}
logger.save()
assert logger.history == {0: {"loss": 0.5623850983416314}, 1: {"loss": 0.4778883735637184}}


akashkw marked this conversation as resolved.
Show resolved Hide resolved
def test_dummyexperiment_support_indexing():
"""Test that the DummyExperiment can imitate indexing the experiment in a LoggerCollection."""
experiment = DummyExperiment()
Expand Down