From be39d5f14bad174d99042ca348db71d6d41290ef Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 3 May 2021 21:46:36 -0700 Subject: [PATCH 01/14] Remove outputs from on_train_epoch_end --- pytorch_lightning/callbacks/base.py | 4 +++- pytorch_lightning/core/hooks.py | 2 +- pytorch_lightning/trainer/callback_hook.py | 12 +++++++++-- pytorch_lightning/trainer/training_loop.py | 25 ++++++++++++++++++---- 4 files changed, 35 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 48c60338a0bf8..3e8a77cbfdb0a 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -98,7 +98,9 @@ def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningMo """Called when the train epoch begins.""" pass - def on_train_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: EPOCH_OUTPUT) -> None: + def on_train_epoch_end( + self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', unused: Optional = None + ) -> None: """Called when the train epoch ends.""" pass diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index b55a8258e03fa..bebd1edd8e685 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -235,7 +235,7 @@ def on_train_epoch_start(self) -> None: Called in the training loop at the very beginning of the epoch. """ - def on_train_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: + def on_train_epoch_end(self, unused: Optional = None) -> None: """ Called in the training loop at the very end of the epoch. """ diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 959b180637b7e..5b10f44d13c7d 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -89,14 +89,22 @@ def on_train_epoch_start(self): for callback in self.callbacks: callback.on_train_epoch_start(self, self.lightning_module) - def on_train_epoch_end(self, outputs: EPOCH_OUTPUT): + def on_train_epoch_end(self, outputs: Optional[EPOCH_OUTPUT] = None): """Called when the epoch ends. Args: outputs: List of outputs on each ``train`` epoch """ for callback in self.callbacks: - callback.on_train_epoch_end(self, self.lightning_module, outputs) + if is_param_in_hook_signature(callback.on_train_epoch_end, "outputs"): + warning_cache.warn( + "`Callback.on_train_epoch_end` signature has changed in v1.3." + " `outputs` parameter has been removed." + " Support for the old signature will be removed in v1.5", DeprecationWarning + ) + callback.on_train_epoch_end(self, self.lightning_module, outputs) + else: + callback.on_train_epoch_end(self, self.lightning_module) def on_validation_epoch_start(self): """Called when the epoch begins.""" diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 8c510f08a83fc..f641c1cd3f515 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -31,6 +31,7 @@ from pytorch_lightning.utilities.grads import grad_norm from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.parsing import AttributeDict +from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.warnings import WarningCache @@ -197,16 +198,32 @@ def reset_train_val_dataloaders(self, model) -> None: def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): + lightning_module = self.trainer.lightning_module # track the outputs to reduce at the end of the epoch for opt_idx, opt_outputs in enumerate(batch_end_outputs): sample_output = opt_outputs[-1] # decide if we need to reduce at the end of the epoch automatically auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end - hook_overridden = ( - is_overridden("training_epoch_end", model=self.trainer.lightning_module) - or is_overridden("on_train_epoch_end", model=self.trainer.lightning_module) - ) + + # We add to the epoch outputs if + # 1. The model defines training_epoch_end OR + # 2. The model overrides on_train_epoch_end which has `outputs` in the signature OR + # 3. The trainer has any callback which overrides `on_train_epoch_end` and includes `outputs` in the signature + overrides_training_epoch_end = is_overridden("training_epoch_end", model=lightning_module) + overrides_on_train_epoch_end_with_outputs = False + if is_overridden("on_train_epoch_end", model=lightning_module): + model_hook_fx = getattr(model_ref, hook_name) + if is_param_in_hook_signature(model_hook_fx, "outputs"): + overrides_on_train_epoch_end_with_outputs = True + + callback_overrides_on_train_epoch_end_with_outputs = False + for callback in self.trainer.callbacks: + if is_param_in_hook_signature(callback.on_train_epoch_end, "outputs"): + callback_overrides_on_train_epoch_end_with_outputs = True + break + + hook_overridden = overrides_training_epoch_end or overrides_on_train_epoch_end_with_outputs or callback_overrides_on_train_epoch_end_with_outputs # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end if not (hook_overridden or auto_reduce_tng_result): From 082c3cc373e5a59cd0bac9c623caee9bf8c718b0 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 3 May 2021 22:27:50 -0700 Subject: [PATCH 02/14] iterate --- CHANGELOG.md | 5 ++- pytorch_lightning/callbacks/early_stopping.py | 2 +- pytorch_lightning/callbacks/pruning.py | 2 +- pytorch_lightning/trainer/training_loop.py | 43 ++++++++++--------- 4 files changed, 29 insertions(+), 23 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b61aa5939dbef..d6c6741bb480a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -202,13 +202,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated +- Deprecated `outputs` in both `LightningModule.on_train_epoch_end` and `Callback.on_train_epoch_end` hooks ([#7339](https://github.com/PyTorchLightning/pytorch-lightning/pull/7339)) + + - Deprecated `LightningModule.grad_norm` in favor of `pytorch_lightning.utilities.grads.grad_norm` ([#7292](https://github.com/PyTorchLightning/pytorch-lightning/pull/7292)) - Deprecated the `save_function` property from the `ModelCheckpoint` callback ([#7201](https://github.com/PyTorchLightning/pytorch-lightning/pull/7201)) -- Deprecated `LightningModule.write_predictions` and `LigtningModule.write_predictions_dict` ([#7066](https://github.com/PyTorchLightning/pytorch-lightning/pull/7066)) +- Deprecated `LightningModule.write_predictions` and `LightningModule.write_predictions_dict` ([#7066](https://github.com/PyTorchLightning/pytorch-lightning/pull/7066)) - Deprecated `TrainerLoggingMixin` in favor of a separate utilities module for metric handling ([#7180](https://github.com/PyTorchLightning/pytorch-lightning/pull/7180)) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index f1a1789856642..242eeed808f34 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -161,7 +161,7 @@ def _should_skip_check(self, trainer) -> bool: from pytorch_lightning.trainer.states import TrainerFn return trainer.state.fn != TrainerFn.FITTING or trainer.sanity_checking - def on_train_epoch_end(self, trainer, pl_module, outputs) -> None: + def on_train_epoch_end(self, trainer, pl_module) -> None: if not self._check_on_train_epoch_end or self._should_skip_check(trainer): return self._run_early_stopping_check(trainer) diff --git a/pytorch_lightning/callbacks/pruning.py b/pytorch_lightning/callbacks/pruning.py index 6c9fa8b4776c6..715fa14a41d04 100644 --- a/pytorch_lightning/callbacks/pruning.py +++ b/pytorch_lightning/callbacks/pruning.py @@ -373,7 +373,7 @@ def on_before_accelerator_backend_setup(self, trainer, pl_module: LightningModul self._original_layers.setdefault(id_, {"data": deepcopy(module), "names": []}) self._original_layers[id_]["names"].append((i, name)) - def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs): + def on_train_epoch_end(self, trainer, pl_module: LightningModule): current_epoch = trainer.current_epoch prune = self._apply_pruning(current_epoch) if isinstance(self._apply_pruning, Callable) else self._apply_pruning amount = self.amount(current_epoch) if isinstance(self.amount, Callable) else self.amount diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index f641c1cd3f515..5d31c0257555e 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -198,7 +198,8 @@ def reset_train_val_dataloaders(self, model) -> None: def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): - lightning_module = self.trainer.lightning_module + hook_overridden = self._should_add_batch_output_to_epoch_output() + # track the outputs to reduce at the end of the epoch for opt_idx, opt_outputs in enumerate(batch_end_outputs): sample_output = opt_outputs[-1] @@ -206,25 +207,6 @@ def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): # decide if we need to reduce at the end of the epoch automatically auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end - # We add to the epoch outputs if - # 1. The model defines training_epoch_end OR - # 2. The model overrides on_train_epoch_end which has `outputs` in the signature OR - # 3. The trainer has any callback which overrides `on_train_epoch_end` and includes `outputs` in the signature - overrides_training_epoch_end = is_overridden("training_epoch_end", model=lightning_module) - overrides_on_train_epoch_end_with_outputs = False - if is_overridden("on_train_epoch_end", model=lightning_module): - model_hook_fx = getattr(model_ref, hook_name) - if is_param_in_hook_signature(model_hook_fx, "outputs"): - overrides_on_train_epoch_end_with_outputs = True - - callback_overrides_on_train_epoch_end_with_outputs = False - for callback in self.trainer.callbacks: - if is_param_in_hook_signature(callback.on_train_epoch_end, "outputs"): - callback_overrides_on_train_epoch_end_with_outputs = True - break - - hook_overridden = overrides_training_epoch_end or overrides_on_train_epoch_end_with_outputs or callback_overrides_on_train_epoch_end_with_outputs - # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end if not (hook_overridden or auto_reduce_tng_result): continue @@ -235,6 +217,27 @@ def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): epoch_output[opt_idx].append(opt_outputs) + def _should_add_batch_output_to_epoch_output(self) -> bool: + # We add to the epoch outputs if + # 1. The model defines training_epoch_end OR + # 2. The model overrides on_train_epoch_end which has `outputs` in the signature OR + # 3. The trainer has any callback which overrides `on_train_epoch_end` and includes `outputs` in the signature + # TODO: in v1.5 this only needs to check if training_epoch_end is overridden + lightning_module = self.trainer.lightning_module + if is_overridden("training_epoch_end", model=lightning_module): + return True + + if is_overridden("on_train_epoch_end", model=lightning_module): + model_hook_fx = getattr(model_ref, hook_name) + if is_param_in_hook_signature(model_hook_fx, "outputs"): + return True + + for callback in self.trainer.callbacks: + if is_param_in_hook_signature(callback.on_train_epoch_end, "outputs"): + return True + + return False + def get_optimizers_iterable(self, batch_idx=None): """ Generates an iterable with (idx, optimizer) for each optimizer. From 495c0b73524896398ea474bdb01a7f9a76cce5cd Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 3 May 2021 22:29:25 -0700 Subject: [PATCH 03/14] Update callback_hook.py --- pytorch_lightning/trainer/callback_hook.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 5b10f44d13c7d..0b95bf5179cdb 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -89,7 +89,7 @@ def on_train_epoch_start(self): for callback in self.callbacks: callback.on_train_epoch_start(self, self.lightning_module) - def on_train_epoch_end(self, outputs: Optional[EPOCH_OUTPUT] = None): + def on_train_epoch_end(self, outputs: EPOCH_OUTPUT): """Called when the epoch ends. Args: From 03fa99dc78465de4c6be5573ae1925613aa85617 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 3 May 2021 22:58:07 -0700 Subject: [PATCH 04/14] update --- pytorch_lightning/accelerators/accelerator.py | 8 +-- pytorch_lightning/trainer/training_loop.py | 51 +++++++++++++++++-- tests/deprecated_api/test_remove_1-5.py | 47 +++++++++++++++++ 3 files changed, 95 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index bb6981ffbde0a..0cca039e665b5 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -355,12 +355,8 @@ def clip_gradients( model=self.model, ) - def on_train_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: - """Hook to do something on the end of an training epoch - - Args: - outputs: the outputs of the training steps - """ + def on_train_epoch_end(self) -> None: + """Hook to do something on the end of an training epoch.""" pass def on_train_end(self) -> None: diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 5d31c0257555e..60849bac927de 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -221,14 +221,14 @@ def _should_add_batch_output_to_epoch_output(self) -> bool: # We add to the epoch outputs if # 1. The model defines training_epoch_end OR # 2. The model overrides on_train_epoch_end which has `outputs` in the signature OR - # 3. The trainer has any callback which overrides `on_train_epoch_end` and includes `outputs` in the signature + # 3. The trainer has any callback which overrides `on_train_epoch_end` that includes `outputs` in the signature # TODO: in v1.5 this only needs to check if training_epoch_end is overridden lightning_module = self.trainer.lightning_module if is_overridden("training_epoch_end", model=lightning_module): return True if is_overridden("on_train_epoch_end", model=lightning_module): - model_hook_fx = getattr(model_ref, hook_name) + model_hook_fx = getattr(lightning_module, "on_train_epoch_end") if is_param_in_hook_signature(model_hook_fx, "outputs"): return True @@ -611,9 +611,50 @@ def on_train_epoch_end(self, epoch_output: List[List[List[Result]]]) -> None: # capture logging self.trainer.logger_connector.cache_logged_metrics() - # call train epoch end hooks - self.trainer.call_hook('on_train_epoch_end', processed_epoch_output) - self.trainer.call_hook('on_epoch_end') + # call train epoch end hooks + self._on_train_epoch_end_hook(processed_epoch_output) + self.trainer.call_hook('on_epoch_end') + + def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: + # Cannot rely on Trainer.call_hook because the signatures might be different across + # lightning module and callback + # Here we need to inspect if the module accepts `outputs` in `on_train_epoch_end` + + # This implementation is copied from Trainer.call_hook + hook_name = "on_train_epoch_end" + + # set hook_name to model + reset Result obj + skip = self.trainer._reset_result_and_set_hook_fx_name(hook_name) + + # always profile hooks + with self.trainer.profiler.profile(hook_name): + + # first call trainer hook + if hasattr(self.trainer, hook_name): + trainer_hook = getattr(self.trainer, hook_name) + trainer_hook(processed_epoch_output) + + # next call hook in lightningModule + model_ref = self.trainer.lightning_module + if is_overridden(hook_name, model_ref): + hook_fx = getattr(model_ref, hook_name) + if is_param_in_hook_signature(hook_fx, "outputs"): + self.warning_cache.warn( + f"`ModelHooks.on_train_epoch_end` signature has changed in v1.3. `outputs` parameter has been deprecated." + " Support for the old signature will be removed in v1.5", DeprecationWarning + ) + model_ref.on_train_epoch_end(processed_epoch_output) + else: + model_ref.on_train_epoch_end() + + # if the PL module doesn't have the hook then call the accelerator + # used to auto-reduce things for the user with Results obj + elif hasattr(self.trainer.accelerator, hook_name): + accelerator_hook = getattr(self.trainer.accelerator, hook_name) + accelerator_hook() + + if not skip: + self.trainer._cache_logged_metrics() def run_training_batch(self, batch, batch_idx, dataloader_idx): # track grad norms diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index f211fe08089df..7d7678d1d9540 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -216,6 +216,53 @@ def test_v1_5_0_model_checkpoint_period(tmpdir): ModelCheckpoint(dirpath=tmpdir, period=1) +def test_v1_5_0_old_on_train_epoch_end(tmpdir): + callback_warning_cache.clear() + + class OldSignature(Callback): + + def on_train_epoch_end(self, trainer, pl_module, outputs): # noqa + ... + + class OldSignatureModel(BoringModel): + + def on_train_epoch_end(self, outputs): # noqa + ... + + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature()) + + with pytest.deprecated_call(match="old signature will be removed in v1.5"): + trainer.fit(model) + + callback_warning_cache.clear() + + model = OldSignatureModel() + + with pytest.deprecated_call(match="old signature will be removed in v1.5"): + trainer.fit(model) + + trainer.train_loop.warning_cache.clear() + + class NewSignature(Callback): + + def on_train_epoch_end(self, trainer, pl_module): + ... + + trainer.callbacks = [NewSignature()] + with no_deprecated_call(match="`Callback.on_train_epoch_end` signature has changed in v1.3."): + trainer.fit(model) + + class NewSignatureModel(BoringModel): + + def on_train_epoch_end(self): + ... + + model = NewSignatureModel() + with no_deprecated_call(match="`ModelHooks.on_train_epoch_end` signature has changed in v1.3."): + trainer.fit(model) + + def test_v1_5_0_old_on_validation_epoch_end(tmpdir): callback_warning_cache.clear() From 2f55ff192643b435466a0d9ca686c0cf90dc8228 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 3 May 2021 22:59:53 -0700 Subject: [PATCH 05/14] Update training_loop.py --- pytorch_lightning/trainer/training_loop.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 60849bac927de..8c949b3a6bc8f 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -220,8 +220,7 @@ def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): def _should_add_batch_output_to_epoch_output(self) -> bool: # We add to the epoch outputs if # 1. The model defines training_epoch_end OR - # 2. The model overrides on_train_epoch_end which has `outputs` in the signature OR - # 3. The trainer has any callback which overrides `on_train_epoch_end` that includes `outputs` in the signature + # 2. The model overrides on_train_epoch_end which has `outputs` in the signature # TODO: in v1.5 this only needs to check if training_epoch_end is overridden lightning_module = self.trainer.lightning_module if is_overridden("training_epoch_end", model=lightning_module): @@ -232,10 +231,6 @@ def _should_add_batch_output_to_epoch_output(self) -> bool: if is_param_in_hook_signature(model_hook_fx, "outputs"): return True - for callback in self.trainer.callbacks: - if is_param_in_hook_signature(callback.on_train_epoch_end, "outputs"): - return True - return False def get_optimizers_iterable(self, batch_idx=None): From 8d262e14d3e9b320e7318faa10e570fe11d62285 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 3 May 2021 23:16:04 -0700 Subject: [PATCH 06/14] Update test_training_loop.py --- tests/trainer/test_training_loop.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/trainer/test_training_loop.py b/tests/trainer/test_training_loop.py index b435fd0b4de32..2d32d8c8878e4 100644 --- a/tests/trainer/test_training_loop.py +++ b/tests/trainer/test_training_loop.py @@ -155,11 +155,6 @@ def training_epoch_end(self, outputs): [HookedModel._check_output(output) for output in outputs] super().training_epoch_end(outputs) - def on_train_epoch_end(self, outputs): - assert len(outputs) == 2 - [HookedModel._check_output(output) for output in outputs] - super().on_train_epoch_end(outputs) - model = HookedModel() # fit model From b0c02cb0470cabd86c2252c2a07607462cee0eaa Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 3 May 2021 23:30:26 -0700 Subject: [PATCH 07/14] early stop? --- pytorch_lightning/accelerators/accelerator.py | 2 +- pytorch_lightning/callbacks/early_stopping.py | 2 +- pytorch_lightning/trainer/training_loop.py | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 0cca039e665b5..07e2fa5e3f728 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -28,7 +28,7 @@ from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType, LightningEnum -from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT +from pytorch_lightning.utilities.types import STEP_OUTPUT if _NATIVE_AMP_AVAILABLE: from torch.cuda.amp import GradScaler diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 242eeed808f34..f1a1789856642 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -161,7 +161,7 @@ def _should_skip_check(self, trainer) -> bool: from pytorch_lightning.trainer.states import TrainerFn return trainer.state.fn != TrainerFn.FITTING or trainer.sanity_checking - def on_train_epoch_end(self, trainer, pl_module) -> None: + def on_train_epoch_end(self, trainer, pl_module, outputs) -> None: if not self._check_on_train_epoch_end or self._should_skip_check(trainer): return self._run_early_stopping_check(trainer) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 8c949b3a6bc8f..af73a6c6da353 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -635,7 +635,8 @@ def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: hook_fx = getattr(model_ref, hook_name) if is_param_in_hook_signature(hook_fx, "outputs"): self.warning_cache.warn( - f"`ModelHooks.on_train_epoch_end` signature has changed in v1.3. `outputs` parameter has been deprecated." + "`ModelHooks.on_train_epoch_end` signature has changed in v1.3." + " `outputs` parameter has been deprecated." " Support for the old signature will be removed in v1.5", DeprecationWarning ) model_ref.on_train_epoch_end(processed_epoch_output) From 274d5f896d7be57b9d8cdd729a137a709f6303a3 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 3 May 2021 23:38:45 -0700 Subject: [PATCH 08/14] fix --- pytorch_lightning/callbacks/early_stopping.py | 2 +- pytorch_lightning/trainer/training_loop.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index f1a1789856642..242eeed808f34 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -161,7 +161,7 @@ def _should_skip_check(self, trainer) -> bool: from pytorch_lightning.trainer.states import TrainerFn return trainer.state.fn != TrainerFn.FITTING or trainer.sanity_checking - def on_train_epoch_end(self, trainer, pl_module, outputs) -> None: + def on_train_epoch_end(self, trainer, pl_module) -> None: if not self._check_on_train_epoch_end or self._should_skip_check(trainer): return self._run_early_stopping_check(trainer) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index af73a6c6da353..a5c43c968f128 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -606,14 +606,14 @@ def on_train_epoch_end(self, epoch_output: List[List[List[Result]]]) -> None: # capture logging self.trainer.logger_connector.cache_logged_metrics() - # call train epoch end hooks - self._on_train_epoch_end_hook(processed_epoch_output) - self.trainer.call_hook('on_epoch_end') + # call train epoch end hooks + self._on_train_epoch_end_hook(processed_epoch_output) + self.trainer.call_hook('on_epoch_end') def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: - # Cannot rely on Trainer.call_hook because the signatures might be different across + # We cannot rely on Trainer.call_hook because the signatures might be different across # lightning module and callback - # Here we need to inspect if the module accepts `outputs` in `on_train_epoch_end` + # As a result, we need to inspect if the module accepts `outputs` in `on_train_epoch_end` # This implementation is copied from Trainer.call_hook hook_name = "on_train_epoch_end" From 5c6a3d8465da0485d200fba70d46ba47faa1ee35 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Tue, 4 May 2021 00:00:59 -0700 Subject: [PATCH 09/14] update tests --- tests/callbacks/test_callback_hook_outputs.py | 5 +--- tests/callbacks/test_finetuning_callback.py | 2 +- tests/models/test_hooks.py | 27 ++++++------------- .../trainer/logging_/test_logger_connector.py | 2 +- .../logging_/test_train_loop_logging.py | 2 +- 5 files changed, 12 insertions(+), 26 deletions(-) diff --git a/tests/callbacks/test_callback_hook_outputs.py b/tests/callbacks/test_callback_hook_outputs.py index b42e67d954dbd..b2aa20af57a94 100644 --- a/tests/callbacks/test_callback_hook_outputs.py +++ b/tests/callbacks/test_callback_hook_outputs.py @@ -34,9 +34,6 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): assert 'x' in outputs - def on_train_epoch_end(self, trainer, pl_module, outputs): - assert len(outputs) == trainer.num_training_batches - class TestModel(BoringModel): def on_train_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx: int) -> None: @@ -48,7 +45,7 @@ def on_validation_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx def on_test_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx: int) -> None: assert 'x' in outputs - def on_train_epoch_end(self, outputs) -> None: + def training_epoch_end(self, outputs) -> None: assert len(outputs) == self.trainer.num_training_batches model = TestModel() diff --git a/tests/callbacks/test_finetuning_callback.py b/tests/callbacks/test_finetuning_callback.py index c8290f217a289..53d34c4645bef 100644 --- a/tests/callbacks/test_finetuning_callback.py +++ b/tests/callbacks/test_finetuning_callback.py @@ -27,7 +27,7 @@ class TestBackboneFinetuningCallback(BackboneFinetuning): - def on_train_epoch_end(self, trainer, pl_module, outputs): + def on_train_epoch_end(self, trainer, pl_module): epoch = trainer.current_epoch if self.unfreeze_backbone_at_epoch <= epoch: optimizer = trainer.optimizers[0] diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 35c8e89911354..3a0fbff47dc9e 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -92,21 +92,17 @@ def training_epoch_end(self, outputs): def test_training_epoch_end_metrics_collection_on_override(tmpdir): """ Test that batch end metrics are collected when training_epoch_end is overridden at the end of an epoch. """ - class LoggingCallback(Callback): + class OverriddenModel(BoringModel): - def on_train_epoch_start(self, trainer, pl_module): + def __init__(self): + super().__init__() self.len_outputs = 0 - def on_train_epoch_end(self, trainer, pl_module, outputs): - self.len_outputs = len(outputs) - - class OverriddenModel(BoringModel): - def on_train_epoch_start(self): self.num_train_batches = 0 - def training_epoch_end(self, outputs): # Overridden - return + def training_epoch_end(self, outputs): + self.len_outputs = len(outputs) def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): self.num_train_batches += 1 @@ -128,17 +124,10 @@ def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): max_epochs=1, default_root_dir=tmpdir, overfit_batches=2, - callbacks=[callback], ) trainer.fit(overridden_model) - # outputs from on_train_batch_end should be accessible in on_train_epoch_end hook - # if training_epoch_end is overridden - assert callback.len_outputs == overridden_model.num_train_batches - - trainer.fit(not_overridden_model) - # outputs from on_train_batch_end should be empty - assert callback.len_outputs == 0 + assert overridden_model.len_outputs == overridden_model.num_train_batches @RunIf(min_gpus=1) @@ -334,9 +323,9 @@ def on_train_epoch_start(self): self.called.append("on_train_epoch_start") super().on_train_epoch_start() - def on_train_epoch_end(self, outputs): + def on_train_epoch_end(self): self.called.append("on_train_epoch_end") - super().on_train_epoch_end(outputs) + super().on_train_epoch_end() def on_validation_start(self): self.called.append("on_validation_start") diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index 523053897229b..06eaca6d61f2c 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -678,7 +678,7 @@ def _assert_epoch_end(self, stage): acc.reset.asset_not_called() ap.reset.assert_not_called() - def on_train_epoch_end(self, outputs): + def on_train_epoch_end(self): self._assert_epoch_end('train') def on_validation_epoch_end(self, outputs): diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 04c8337cb5182..c89914de4ddfa 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -599,7 +599,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data # with func = np.mean if on_epoch else func = np.max self.count += 1 - def on_train_epoch_end(self, trainer, pl_module, outputs): + def on_train_epoch_end(self, trainer, pl_module): self.make_logging( pl_module, 'on_train_epoch_end', 8, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices ) From 01aed79f8367f49ad35117551788fbd0636d6688 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Tue, 4 May 2021 00:03:57 -0700 Subject: [PATCH 10/14] Update test_hooks.py --- tests/models/test_hooks.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 3a0fbff47dc9e..24bf29a9e2eac 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -17,7 +17,7 @@ import pytest import torch -from pytorch_lightning import Callback, Trainer +from pytorch_lightning import Trainer from tests.helpers import BoringDataModule, BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -119,7 +119,6 @@ def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): not_overridden_model = NotOverriddenModel() not_overridden_model.training_epoch_end = None - callback = LoggingCallback() trainer = Trainer( max_epochs=1, default_root_dir=tmpdir, From b719bef02bbc9c6c992ffda16904762db8ac0972 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Tue, 4 May 2021 07:53:18 -0700 Subject: [PATCH 11/14] Update pytorch_lightning/trainer/callback_hook.py Co-authored-by: Ethan Harris --- pytorch_lightning/trainer/callback_hook.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 0b95bf5179cdb..fcdd8f55f6a6e 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -98,7 +98,7 @@ def on_train_epoch_end(self, outputs: EPOCH_OUTPUT): for callback in self.callbacks: if is_param_in_hook_signature(callback.on_train_epoch_end, "outputs"): warning_cache.warn( - "`Callback.on_train_epoch_end` signature has changed in v1.3." + "The signature of `Callback.on_train_epoch_end` has changed in v1.3." " `outputs` parameter has been removed." " Support for the old signature will be removed in v1.5", DeprecationWarning ) From 66f310a273e46b7cd7fdcc821f68ed928de9ed08 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Tue, 4 May 2021 07:53:24 -0700 Subject: [PATCH 12/14] Update pytorch_lightning/trainer/training_loop.py Co-authored-by: Ethan Harris --- pytorch_lightning/trainer/training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index a5c43c968f128..dbeafc5c6f5b3 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -635,7 +635,7 @@ def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: hook_fx = getattr(model_ref, hook_name) if is_param_in_hook_signature(hook_fx, "outputs"): self.warning_cache.warn( - "`ModelHooks.on_train_epoch_end` signature has changed in v1.3." + "The signature of `ModelHooks.on_train_epoch_end` has changed in v1.3." " `outputs` parameter has been deprecated." " Support for the old signature will be removed in v1.5", DeprecationWarning ) From d18455cc506827ae231655e04fbb087e49a728a4 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Tue, 4 May 2021 08:03:31 -0700 Subject: [PATCH 13/14] Update trainer.py --- pytorch_lightning/trainer/trainer.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 39085ff6ef73b..210e9ec16a209 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1211,6 +1211,13 @@ def _cache_logged_metrics(self): self.logger_connector.cache_logged_metrics() def call_hook(self, hook_name: str, *args, **kwargs) -> Any: + # Note this implementation is copy/pasted into the TrainLoop class + # in TrainLoop._on_train_epoch_end_hook + # This was done to manage the deprecation of an argument to + # on_train_epoch_end + # If making chnages to this function, ensure that those changes are also made to + # TrainLoop._on_train_epoch_end_hook + # set hook_name to model + reset Result obj skip = self._reset_result_and_set_hook_fx_name(hook_name) From f2f8b589b7768be650472d69469961e798c31e30 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Tue, 4 May 2021 11:12:11 -0700 Subject: [PATCH 14/14] Update pytorch_lightning/trainer/trainer.py Co-authored-by: Jirka Borovec --- pytorch_lightning/trainer/trainer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 210e9ec16a209..42f7487142353 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1211,10 +1211,8 @@ def _cache_logged_metrics(self): self.logger_connector.cache_logged_metrics() def call_hook(self, hook_name: str, *args, **kwargs) -> Any: - # Note this implementation is copy/pasted into the TrainLoop class - # in TrainLoop._on_train_epoch_end_hook - # This was done to manage the deprecation of an argument to - # on_train_epoch_end + # Note this implementation is copy/pasted into the TrainLoop class in TrainLoop._on_train_epoch_end_hook + # This was done to manage the deprecation of an argument to on_train_epoch_end # If making chnages to this function, ensure that those changes are also made to # TrainLoop._on_train_epoch_end_hook