Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements related to save of config file by LightningCLI #7963

Merged
merged 10 commits into from
Jun 15, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `on_load_checkpoint` and `on_save_checkpoint` hooks to the `PrecisionPlugin` base class ([#7831](https://github.com/PyTorchLightning/pytorch-lightning/pull/7831))


- LightningCLI now aborts with a clearer message if config already exists and disables save config for fast_dev_run!=False ([#7963](https://github.com/PyTorchLightning/pytorch-lightning/pull/7963))
mauvilsa marked this conversation as resolved.
Show resolved Hide resolved


### Deprecated


Expand Down
13 changes: 11 additions & 2 deletions pytorch_lightning/utilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@ def add_lightning_class_args(


class SaveConfigCallback(Callback):
"""Saves a LightningCLI config to the log_dir when training starts"""
"""Saves a LightningCLI config to the log_dir when training starts

Raises:
RuntimeError: If in the log_dir the config file already exists to avoid overwriting a previous run
mauvilsa marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(
self,
Expand All @@ -90,6 +94,11 @@ def __init__(
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
log_dir = trainer.log_dir or trainer.default_root_dir
config_path = os.path.join(log_dir, self.config_filename)
if os.path.isfile(config_path):
raise RuntimeError(
f'{type(self).__name__} expected {config_path} to not exist. '
mauvilsa marked this conversation as resolved.
Show resolved Hide resolved
'Aborting to avoid overwriting results of a previous run.'
)
self.parser.save(self.config, config_path, skip_none=False)


Expand Down Expand Up @@ -231,7 +240,7 @@ def instantiate_trainer(self) -> None:
self.config_init['trainer']['callbacks'].extend(self.trainer_defaults['callbacks'])
else:
self.config_init['trainer']['callbacks'].append(self.trainer_defaults['callbacks'])
if self.save_config_callback is not None:
if self.save_config_callback and not self.config_init['trainer']['fast_dev_run']:
config_callback = self.save_config_callback(self.parser, self.config, self.save_config_filename)
self.config_init['trainer']['callbacks'].append(config_callback)
self.trainer = self.trainer_class(**self.config_init['trainer'])
Expand Down
25 changes: 25 additions & 0 deletions tests/utilities/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,31 @@ def test_lightning_cli_args(tmpdir):
assert config['trainer'] == cli.config['trainer']


def test_lightning_cli_save_config_cases(tmpdir):

config_path = tmpdir / 'config.yaml'
cli_args = [
f'--trainer.default_root_dir={tmpdir}',
'--trainer.logger=False',
'--trainer.fast_dev_run=1',
]

# With fast_dev_run!=False config should not be saved
with mock.patch('sys.argv', ['any.py'] + cli_args):
LightningCLI(BoringModel)
assert not os.path.isfile(config_path)

# With fast_dev_run==False config should be saved
cli_args[-1] = '--trainer.max_epochs=1'
with mock.patch('sys.argv', ['any.py'] + cli_args):
LightningCLI(BoringModel)
assert os.path.isfile(config_path)

# If run again on same directory exception should be raised since config file already exists
with mock.patch('sys.argv', ['any.py'] + cli_args), pytest.raises(RuntimeError):
LightningCLI(BoringModel)


def test_lightning_cli_config_and_subclass_mode(tmpdir):

config = dict(
Expand Down