Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LightningCLI(run=False|True) #8751

Merged
merged 8 commits into from
Aug 10, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `log_graph` argument for `watch` method of `WandbLogger` ([#8662](https://github.com/PyTorchLightning/pytorch-lightning/pull/8662))


- Added `LightningCLI(run=False|True)` to choose whether to run a `Trainer` subcommand ([#8751](https://github.com/PyTorchLightning/pytorch-lightning/pull/8751))


- Fault-tolerant training:
* Added `FastForwardSampler` and `CaptureIterableDataset` injection to data loading utilities ([#8366](https://github.com/PyTorchLightning/pytorch-lightning/pull/8366))

Expand Down
23 changes: 18 additions & 5 deletions docs/source/common/lightning_cli.rst
Original file line number Diff line number Diff line change
Expand Up @@ -260,17 +260,30 @@ file. Loading a defaults file :code:`my_cli_defaults.yaml` in the current workin

.. testcode::

cli = LightningCLI(
MyModel,
MyDataModule,
parser_kwargs={"default_config_files": ["my_cli_defaults.yaml"]},
)
cli = LightningCLI(MyModel, MyDataModule, parser_kwargs={"default_config_files": ["my_cli_defaults.yaml"]})
carmocca marked this conversation as resolved.
Show resolved Hide resolved

To load a file in the user's home directory would be just changing to :code:`~/.my_cli_defaults.yaml`. Note that this
setting is given through :code:`parser_kwargs`. More parameters are supported. For details see the `ArgumentParser API
<https://jsonargparse.readthedocs.io/en/stable/#jsonargparse.core.ArgumentParser.__init__>`_ documentation.


Instantiation only mode
^^^^^^^^^^^^^^^^^^^^^^^

The CLI is designed to start fitting without minimal code changes. On class instantiation, the CLI will automatically
carmocca marked this conversation as resolved.
Show resolved Hide resolved
call ``trainer.fit(...)`` internally so you don't have to do it. To avoid this, you can set the following argument:

.. testcode::

cli = LightningCLI(MyModel, run=False) # True by default
# you'll have to call fit yourself:
cli.trainer.fit(cli.model)


This can be useful to implement custom logic without having to subclass the CLI, but still using the CLI's instantiation
and argument parsing capabilities.


Trainer Callbacks and arguments with class type
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
13 changes: 9 additions & 4 deletions pytorch_lightning/utilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def __init__(
parser_kwargs: Dict[str, Any] = None,
subclass_mode_model: bool = False,
subclass_mode_data: bool = False,
run: bool = True,
carmocca marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""
Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which are
Expand Down Expand Up @@ -259,6 +260,8 @@ def __init__(
subclass_mode_data: Whether datamodule can be any `subclass
<https://jsonargparse.readthedocs.io/en/stable/#class-type-and-sub-classes>`_
of the given class.
run: Whether subcommands should be added to run a :class:`~pytorch_lightning.trainer.trainer.Trainer`
method. If set to ``False``, the trainer and model classes will be instantiated only.
"""
self.model_class = model_class
self.datamodule_class = datamodule_class
Expand All @@ -283,10 +286,12 @@ def __init__(
self.before_instantiate_classes()
self.instantiate_classes()
self.add_configure_optimizers_method_to_model()
self.prepare_fit_kwargs()
self.before_fit()
self.fit()
self.after_fit()

if run:
self.prepare_fit_kwargs()
self.before_fit()
self.fit()
self.after_fit()

def init_parser(self) -> None:
"""Method that instantiates the argument parser"""
Expand Down
9 changes: 9 additions & 0 deletions tests/utilities/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,3 +687,12 @@ def __init__(self, optim1: dict, optim2: dict, scheduler: dict):
assert isinstance(cli.model.optim1, torch.optim.Adam)
assert isinstance(cli.model.optim2, torch.optim.SGD)
assert isinstance(cli.model.scheduler, torch.optim.lr_scheduler.ExponentialLR)


@pytest.mark.parametrize("run", (False, True))
def test_lightning_cli_disabled_run(run):
with mock.patch("sys.argv", ["any.py"]), mock.patch("pytorch_lightning.Trainer.fit") as fit_mock:
cli = LightningCLI(BoringModel, run=run)
fit_mock.assert_called() if run else fit_mock.assert_not_called()
carmocca marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(cli.trainer, Trainer)
assert isinstance(cli.model, LightningModule)