From 8ca8336ce52ee7379f4d399520636143eb31018b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 25 May 2020 13:49:23 +0200 Subject: [PATCH] protect progress bar callback (#1855) * wip protected progress bar settings * remove callback attr from LRfinder * whitespace * changelog --- CHANGELOG.md | 2 ++ pytorch_lightning/trainer/callback_config.py | 20 ++++++++++---------- pytorch_lightning/trainer/deprecated_api.py | 2 +- pytorch_lightning/trainer/lr_finder.py | 4 ---- pytorch_lightning/trainer/trainer.py | 11 +++++------ tests/callbacks/test_progress_bar.py | 2 +- tests/trainer/test_lr_finder.py | 5 ++--- 7 files changed, 21 insertions(+), 25 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0355c1158663c..e8c71252c1a84 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Removed +- Removed unintended Trainer argument `progress_bar_callback`, the callback should be passed in by `Trainer(callbacks=[...])` instead ([#1855](https://github.com/PyTorchLightning/pytorch-lightning/pull/1855)) + ### Fixed - Fixed user warning when apex was used together with learning rate schedulers ([#1873](https://github.com/PyTorchLightning/pytorch-lightning/pull/1873)) diff --git a/pytorch_lightning/trainer/callback_config.py b/pytorch_lightning/trainer/callback_config.py index 551d085eb444d..cd94e7190e452 100644 --- a/pytorch_lightning/trainer/callback_config.py +++ b/pytorch_lightning/trainer/callback_config.py @@ -18,8 +18,6 @@ class TrainerCallbackConfigMixin(ABC): weights_save_path: str ckpt_path: str checkpoint_callback: ModelCheckpoint - progress_bar_refresh_rate: int - process_position: int @property @abstractmethod @@ -109,7 +107,7 @@ def configure_early_stopping(self, early_stop_callback): self.early_stop_callback = early_stop_callback self.enable_early_stop = True - def configure_progress_bar(self): + def configure_progress_bar(self, refresh_rate=1, process_position=0): progress_bars = [c for c in self.callbacks if isinstance(c, ProgressBarBase)] if len(progress_bars) > 1: raise MisconfigurationException( @@ -117,12 +115,14 @@ def configure_progress_bar(self): ' progress bar is supported.' ) elif len(progress_bars) == 1: - self.progress_bar_callback = progress_bars[0] - elif self.progress_bar_refresh_rate > 0: - self.progress_bar_callback = ProgressBar( - refresh_rate=self.progress_bar_refresh_rate, - process_position=self.process_position, + progress_bar_callback = progress_bars[0] + elif refresh_rate > 0: + progress_bar_callback = ProgressBar( + refresh_rate=refresh_rate, + process_position=process_position, ) - self.callbacks.append(self.progress_bar_callback) + self.callbacks.append(progress_bar_callback) else: - self.progress_bar_callback = None + progress_bar_callback = None + + return progress_bar_callback diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index 5b615ebafaa09..11536df2a7830 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -121,7 +121,7 @@ def show_progress_bar(self): """Back compatibility, will be removed in v0.9.0""" rank_zero_warn("Argument `show_progress_bar` is now set by `progress_bar_refresh_rate` since v0.7.2" " and this method will be removed in v0.9.0", DeprecationWarning) - return self.progress_bar_refresh_rate >= 1 + return self.progress_bar_callback and self.progress_bar_callback.refresh_rate >= 1 @show_progress_bar.setter def show_progress_bar(self, tf): diff --git a/pytorch_lightning/trainer/lr_finder.py b/pytorch_lightning/trainer/lr_finder.py index abc71ece2ac3d..f3679dddb9f4b 100755 --- a/pytorch_lightning/trainer/lr_finder.py +++ b/pytorch_lightning/trainer/lr_finder.py @@ -198,11 +198,9 @@ def __lr_finder_dump_params(self, model): 'callbacks': self.callbacks, 'logger': self.logger, 'max_steps': self.max_steps, - 'progress_bar_refresh_rate': self.progress_bar_refresh_rate, 'checkpoint_callback': self.checkpoint_callback, 'early_stop_callback': self.early_stop_callback, 'enable_early_stop': self.enable_early_stop, - 'progress_bar_callback': self.progress_bar_callback, 'configure_optimizers': model.configure_optimizers, } @@ -211,11 +209,9 @@ def __lr_finder_restore_params(self, model): self.logger = self.__dumped_params['logger'] self.callbacks = self.__dumped_params['callbacks'] self.max_steps = self.__dumped_params['max_steps'] - self.progress_bar_refresh_rate = self.__dumped_params['progress_bar_refresh_rate'] self.checkpoint_callback = self.__dumped_params['checkpoint_callback'] self.early_stop_callback = self.__dumped_params['early_stop_callback'] self.enable_early_stop = self.__dumped_params['enable_early_stop'] - self.progress_bar_callback = self.__dumped_params['progress_bar_callback'] model.configure_optimizers = self.__dumped_params['configure_optimizers'] del self.__dumped_params diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index a366890bc2441..25ecd5435987e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -130,7 +130,6 @@ def __init__( reload_dataloaders_every_epoch: bool = False, auto_lr_find: Union[bool, str] = False, replace_sampler_ddp: bool = True, - progress_bar_callback: Optional[Union[ProgressBarBase, bool]] = True, terminate_on_nan: bool = False, auto_scale_batch_size: Union[str, bool] = False, num_tpu_cores: Optional[int] = None, # backward compatible, todo: remove in v0.9.0 @@ -364,7 +363,6 @@ def __init__( rank_zero_warn("num_processes is only used for distributed_backend=\"ddp_cpu\". Ignoring it.") self.num_processes = num_processes - self.process_position = process_position self.weights_summary = weights_summary self.max_epochs = max_epochs @@ -506,9 +504,7 @@ def __init__( if show_progress_bar is not None: self.show_progress_bar = show_progress_bar - self.progress_bar_refresh_rate = progress_bar_refresh_rate - self.progress_bar_callback = progress_bar_callback - self.configure_progress_bar() + self._progress_bar_callback = self.configure_progress_bar(progress_bar_refresh_rate, process_position) # logging self.log_save_interval = log_save_interval @@ -661,7 +657,6 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: 'min_steps': None, ... 'profiler': None, - 'progress_bar_callback': True, 'progress_bar_refresh_rate': 1, ...} @@ -756,6 +751,10 @@ def num_gpus(self) -> int: def data_parallel(self) -> bool: return self.use_dp or self.use_ddp or self.use_ddp2 + @property + def progress_bar_callback(self): + return self._progress_bar_callback + @property def progress_bar_dict(self) -> dict: """ Read-only for progress bar metrics. """ diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 30fbda22040f1..57866453719d2 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -179,7 +179,7 @@ def on_test_batch_end(self, trainer, pl_module): num_sanity_val_steps=2, max_epochs=3, ) - assert trainer.progress_bar_callback.refresh_rate == refresh_rate != trainer.progress_bar_refresh_rate + assert trainer.progress_bar_callback.refresh_rate == refresh_rate trainer.fit(model) assert progress_bar.train_batches_seen == 3 * progress_bar.total_train_batches diff --git a/tests/trainer/test_lr_finder.py b/tests/trainer/test_lr_finder.py index d1e235b0a6654..9450e9803abd4 100755 --- a/tests/trainer/test_lr_finder.py +++ b/tests/trainer/test_lr_finder.py @@ -57,9 +57,8 @@ def test_trainer_reset_correctly(tmpdir): ) changed_attributes = ['callbacks', 'logger', 'max_steps', 'auto_lr_find', - 'progress_bar_refresh_rate', 'early_stop_callback', - 'accumulate_grad_batches', 'enable_early_stop', - 'checkpoint_callback'] + 'early_stop_callback', 'accumulate_grad_batches', + 'enable_early_stop', 'checkpoint_callback'] attributes_before = {} for ca in changed_attributes: attributes_before[ca] = getattr(trainer, ca)