diff --git a/CHANGELOG.md b/CHANGELOG.md index 1239f349e8f5f..4e5892d03734c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -138,6 +138,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed setting correct `DistribType` for `ddp_cpu` (spawn) backend ([#7492](https://github.com/PyTorchLightning/pytorch-lightning/pull/7492)) +- Fixed print errors in `ProgressBar` when `trainer.fit` is not called ([#7674](https://github.com/PyTorchLightning/pytorch-lightning/pull/7674)) + + ## [1.3.1] - 2021-05-11 ### Fixed diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index e6132e6f96c8c..0fe05ff812e20 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -473,12 +473,14 @@ def print( ): active_progress_bar = None - if not self.main_progress_bar.disable: + if self.main_progress_bar is not None and not self.main_progress_bar.disable: active_progress_bar = self.main_progress_bar - elif not self.val_progress_bar.disable: + elif self.val_progress_bar is not None and not self.val_progress_bar.disable: active_progress_bar = self.val_progress_bar - elif not self.test_progress_bar.disable: + elif self.test_progress_bar is not None and not self.test_progress_bar.disable: active_progress_bar = self.test_progress_bar + elif self.predict_progress_bar is not None and not self.predict_progress_bar.disable: + active_progress_bar = self.predict_progress_bar if active_progress_bar is not None: s = sep.join(map(str, args)) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 6ab7b9f7415ba..f4f8f34c1b4c1 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -433,6 +433,10 @@ def test_step(self, *args, **kwargs): self.print("test_step") return super().test_step(*args, **kwargs) + def predict_step(self, *args, **kwargs): + self.print("predict_step") + return super().predict_step(*args, **kwargs) + @mock.patch("pytorch_lightning.callbacks.progress.tqdm.write") def test_progress_bar_print(tqdm_write, tmpdir): @@ -445,16 +449,45 @@ def test_progress_bar_print(tqdm_write, tmpdir): limit_train_batches=1, limit_val_batches=1, limit_test_batches=1, + limit_predict_batches=1, max_steps=1, callbacks=[bar], ) trainer.fit(model) trainer.test(model) - assert tqdm_write.call_count == 3 + trainer.predict(model) + assert tqdm_write.call_count == 4 assert tqdm_write.call_args_list == [ call("training_step", end="", file=None, nolock=False), call("validation_step", end=os.linesep, file=sys.stderr, nolock=False), call("test_step", end=os.linesep, file=None, nolock=False), + call("predict_step", end=os.linesep, file=None, nolock=False), + ] + + +@mock.patch("pytorch_lightning.callbacks.progress.tqdm.write") +def test_progress_bar_print_no_train(tqdm_write, tmpdir): + """ Test that printing in the LightningModule redirects arguments to the progress bar without training. """ + model = PrintModel() + bar = ProgressBar() + trainer = Trainer( + default_root_dir=tmpdir, + num_sanity_val_steps=0, + limit_val_batches=1, + limit_test_batches=1, + limit_predict_batches=1, + max_steps=1, + callbacks=[bar], + ) + + trainer.validate(model) + trainer.test(model) + trainer.predict(model) + assert tqdm_write.call_count == 3 + assert tqdm_write.call_args_list == [ + call("validation_step", end=os.linesep, file=sys.stderr, nolock=False), + call("test_step", end=os.linesep, file=None, nolock=False), + call("predict_step", end=os.linesep, file=None, nolock=False), ] @@ -470,17 +503,20 @@ def test_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir): limit_train_batches=1, limit_val_batches=1, limit_test_batches=1, + limit_predict_batches=1, max_steps=1, callbacks=[bar], ) bar.disable() trainer.fit(model) - trainer.test(model) + trainer.test(model, verbose=False) + trainer.predict(model) mock_print.assert_has_calls([ call("training_step", end=""), call("validation_step", file=ANY), call("test_step"), + call("predict_step"), ]) tqdm_write.assert_not_called()