Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

handle unknown args passed to Trainer.from_argparse_args #1932

Merged
merged 14 commits into from
May 25, 2020
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed user warning when apex was used together with learning rate schedulers ([#1873](https://github.com/PyTorchLightning/pytorch-lightning/pull/1873))

- Fixed an issue with `Trainer.from_argparse_args` when passing in unknown Trainer args ([#1932](https://github.com/PyTorchLightning/pytorch-lightning/pull/1932))

- Fix bug related to logger not being reset correctly for model after tuner algorithms ([#1933](https://github.com/PyTorchLightning/pytorch-lightning/pull/1933))


## [0.7.6] - 2020-05-16

### Added
Expand Down
22 changes: 17 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,20 +730,32 @@ def parse_argparser(arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:

@classmethod
def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs) -> 'Trainer':
"""create an instance from CLI arguments
"""
Create an instance from CLI arguments.

Args:
args: The parser or namespace to take arguments from. Only known arguments will be
parsed and passed to the :class:`Trainer`.
**kwargs: Additional keyword arguments that may override ones in the parser or namespace.
These must be valid Trainer arguments.

Example:
>>> parser = ArgumentParser(add_help=False)
>>> parser = Trainer.add_argparse_args(parser)
>>> parser.add_argument('--my_custom_arg', default='something') # doctest: +SKIP
>>> args = Trainer.parse_argparser(parser.parse_args(""))
>>> trainer = Trainer.from_argparse_args(args)
>>> trainer = Trainer.from_argparse_args(args, logger=False)
"""
if isinstance(args, ArgumentParser):
args = Trainer.parse_argparser(args)
args = cls.parse_argparser(args)
params = vars(args)
params.update(**kwargs)

return cls(**params)
# we only want to pass in valid Trainer args, the rest may be user specific
valid_kwargs = inspect.signature(cls.__init__).parameters
trainer_kwargs = dict((name, params[name]) for name in valid_kwargs if name in params)
trainer_kwargs.update(**kwargs)

return cls(**trainer_kwargs)

@property
def num_gpus(self) -> int:
Expand Down
26 changes: 26 additions & 0 deletions tests/trainer/test_trainer_cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
import pickle
import sys
from argparse import ArgumentParser, Namespace
from unittest import mock

Expand Down Expand Up @@ -110,3 +111,28 @@ def test_argparse_args_parsing(cli_args, expected):
for k, v in expected.items():
assert getattr(args, k) == v
assert Trainer.from_argparse_args(args)


@pytest.mark.skipif(
sys.version_info < (3, 7),
reason="signature inspection while mocking is not working in Python < 3.7 despite autospec"
)
@pytest.mark.parametrize(['cli_args', 'extra_args'], [
pytest.param({}, {}),
pytest.param({'logger': False}, {}),
pytest.param({'logger': False}, {'logger': True}),
pytest.param({'logger': False}, {'checkpoint_callback': True}),
])
def test_init_from_argparse_args(cli_args, extra_args):
unknown_args = dict(unknown_arg=0)

# unkown args in the argparser/namespace should be ignored
with mock.patch('pytorch_lightning.Trainer.__init__', autospec=True, return_value=None) as init:
trainer = Trainer.from_argparse_args(Namespace(**cli_args, **unknown_args), **extra_args)
expected = dict(cli_args)
expected.update(extra_args) # extra args should override any cli arg
init.assert_called_with(trainer, **expected)

# passing in unknown manual args should throw an error
with pytest.raises(TypeError, match=r"__init__\(\) got an unexpected keyword argument 'unknown_arg'"):
Trainer.from_argparse_args(Namespace(**cli_args), **extra_args, **unknown_args)