Skip to content

Commit

Permalink
fast_dev_run can be int (Lightning-AI#4629)
Browse files Browse the repository at this point in the history
* fast_dev_run can be int

* pep

* chlog

* add check and update docs

* logging with fdr

* update docs

* suggestions

Co-authored-by: Carlos Mocholí <[email protected]>

* fdr flush logs

* update trainer.fast_dev_run

* codefactor and pre-commit isort

* tmp

Co-authored-by: Carlos Mocholí <[email protected]>
Co-authored-by: Roger Shieh <[email protected]>
Co-authored-by: edenlightning <[email protected]>
  • Loading branch information
4 people authored Dec 8, 2020
1 parent 79ae66d commit 6d2aeff
Show file tree
Hide file tree
Showing 12 changed files with 114 additions and 48 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added ability for DDP plugin to modify optimizer state saving ([#4675](https://github.com/PyTorchLightning/pytorch-lightning/pull/4675))


- Updated `fast_dev_run` to accept integer representing num_batches ([#4629](https://github.com/PyTorchLightning/pytorch-lightning/pull/4629))


- Added casting to python types for numpy scalars when logging hparams ([#4647](https://github.com/PyTorchLightning/pytorch-lightning/pull/4647))


- Added `prefix` argument in loggers ([#4557](https://github.com/PyTorchLightning/pytorch-lightning/pull/4557))


Expand Down Expand Up @@ -156,6 +162,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added lambda closure to `manual_optimizer_step` ([#4618](https://github.com/PyTorchLightning/pytorch-lightning/pull/4618))


### Changed

- Change Metrics `persistent` default mode to `False` ([#4685](https://github.com/PyTorchLightning/pytorch-lightning/pull/4685))
Expand Down
11 changes: 7 additions & 4 deletions docs/source/debugging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,20 @@ The following are flags that make debugging much easier.

fast_dev_run
------------
This flag runs a "unit test" by running 1 training batch and 1 validation batch.
The point is to detect any bugs in the training/validation loop without having to wait for
a full epoch to crash.
This flag runs a "unit test" by running n if set to ``n`` (int) else 1 if set to ``True`` training and validation batch(es).
The point is to detect any bugs in the training/validation loop without having to wait for a full epoch to crash.

(See: :paramref:`~pytorch_lightning.trainer.trainer.Trainer.fast_dev_run`
argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`)

.. testcode::


# runs 1 train, val, test batch and program ends
trainer = Trainer(fast_dev_run=True)

# runs 7 train, val, test batches and program ends
trainer = Trainer(fast_dev_run=7)

----------------

Inspect gradient norms
Expand Down
20 changes: 15 additions & 5 deletions docs/source/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ Example::
# ddp2 = DistributedDataParallel + dp
trainer = Trainer(gpus=2, num_nodes=2, accelerator='ddp2')
.. note:: this option does not apply to TPU. TPUs use ```ddp``` by default (over each core)
.. note:: This option does not apply to TPU. TPUs use ```ddp``` by default (over each core)
You can also modify hardware behavior by subclassing an existing accelerator to adjust for your needs.
Expand Down Expand Up @@ -632,9 +632,10 @@ fast_dev_run
|
Runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test).
Runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es) of train, val and test
to find any bugs (ie: a sort of unit test).
Under the hood the pseudocode looks like this:
Under the hood the pseudocode looks like this when running *fast_dev_run* with a single batch:
.. code-block:: python
Expand All @@ -659,6 +660,16 @@ Under the hood the pseudocode looks like this:
# runs 1 train, val, test batch and program ends
trainer = Trainer(fast_dev_run=True)
# runs 7 train, val, test batches and program ends
trainer = Trainer(fast_dev_run=7)
.. note::
This argument is a bit different from ``limit_train/val/test_batches``. Setting this argument will
disable tuner, logger callbacks like ``LearningRateLogger`` and runs for only 1 epoch. This must be
used only for debugging purposes. ``limit_train/val/test_batches`` only limits the number of batches and won't
disable anything.
gpus
^^^^
Expand Down Expand Up @@ -1200,8 +1211,7 @@ Orders the progress bar. Useful when running multiple trainers on the same node.
# default used by the Trainer
trainer = Trainer(process_position=0)
Note:
This argument is ignored if a custom callback is passed to :paramref:`~Trainer.callbacks`.
.. note:: This argument is ignored if a custom callback is passed to :paramref:`~Trainer.callbacks`.
profiler
^^^^^^^^
Expand Down
28 changes: 23 additions & 5 deletions pytorch_lightning/trainer/connectors/debugging_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,34 @@ def on_init_start(
overfit_batches,
fast_dev_run
):
if not isinstance(fast_dev_run, (bool, int)):
raise MisconfigurationException(
f'fast_dev_run={fast_dev_run} is not a valid configuration.'
' It should be either a bool or an int >= 0'
)

if isinstance(fast_dev_run, int) and (fast_dev_run < 0):
raise MisconfigurationException(
f'fast_dev_run={fast_dev_run} is not a'
' valid configuration. It should be >= 0.'
)

self.trainer.fast_dev_run = fast_dev_run
if self.trainer.fast_dev_run:
limit_train_batches = 1
limit_val_batches = 1
limit_test_batches = 1
fast_dev_run = int(fast_dev_run)

# set fast_dev_run=True when it is 1, used while logging
if fast_dev_run == 1:
self.trainer.fast_dev_run = True

if fast_dev_run:
limit_train_batches = fast_dev_run
limit_val_batches = fast_dev_run
limit_test_batches = fast_dev_run
self.trainer.num_sanity_val_steps = 0
self.trainer.max_epochs = 1
rank_zero_info(
'Running in fast_dev_run mode: will run a full train,' ' val and test loop using a single batch'
'Running in fast_dev_run mode: will run a full train,'
f' val and test loop using {fast_dev_run} batch(es)'
)

self.trainer.limit_train_batches = _determine_batch_limits(limit_train_batches, 'limit_train_batches')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ def __gather_result_across_time_and_optimizers(self, epoch_output):
def log_train_step_metrics(self, batch_output):
_, batch_log_metrics = self.cached_results.update_logger_connector()
# when metrics should be logged
if self.should_update_logs or self.trainer.fast_dev_run:
if self.should_update_logs or self.trainer.fast_dev_run is True:
# logs user requested information to logger
grad_norm_dic = batch_output.grad_norm_dic
if grad_norm_dic is None:
Expand Down
16 changes: 8 additions & 8 deletions pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class TrainerProperties(ABC):
logger_connector: LoggerConnector
_state: TrainerState
global_rank: int
fast_dev_run: bool
fast_dev_run: Union[int, bool]
use_dp: bool
use_ddp: bool
use_ddp2: bool
Expand All @@ -57,19 +57,19 @@ class TrainerProperties(ABC):
@property
def log_dir(self):
if self.checkpoint_callback is not None:
dir = self.checkpoint_callback.dirpath
dir = os.path.split(dir)[0]
dirpath = self.checkpoint_callback.dirpath
dirpath = os.path.split(dirpath)[0]
elif self.logger is not None:
if isinstance(self.logger, TensorBoardLogger):
dir = self.logger.log_dir
dirpath = self.logger.log_dir
else:
dir = self.logger.save_dir
dirpath = self.logger.save_dir
else:
dir = self._default_root_dir
dirpath = self._default_root_dir

if self.accelerator_backend is not None:
dir = self.accelerator_backend.broadcast(dir)
return dir
dirpath = self.accelerator_backend.broadcast(dirpath)
return dirpath

@property
def use_amp(self) -> bool:
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(
overfit_batches: Union[int, float] = 0.0,
track_grad_norm: Union[int, float, str] = -1,
check_val_every_n_epoch: int = 1,
fast_dev_run: bool = False,
fast_dev_run: Union[int, bool] = False,
accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1,
max_epochs: int = 1000,
min_epochs: int = 1,
Expand Down Expand Up @@ -186,7 +186,8 @@ def __init__(
distributed_backend: deprecated. Please use 'accelerator'
fast_dev_run: runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test).
fast_dev_run: runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es)
of train, val and test to find any bugs (ie: a sort of unit test).
flush_logs_every_n_steps: How often to flush logs to disk (defaults to every 100 steps).
Expand Down
8 changes: 6 additions & 2 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,14 @@ def check_checkpoint_callback(self, should_save, is_last=False):
# TODO bake this logic into the checkpoint callback
if should_save and self.trainer.checkpoint_connector.has_trained:
checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)]

if is_last and any(c.save_last for c in checkpoint_callbacks):
rank_zero_info("Saving latest checkpoint...")

model = self.trainer.get_model()
[cb.on_validation_end(self.trainer, model) for cb in checkpoint_callbacks]

for callback in checkpoint_callbacks:
callback.on_validation_end(self.trainer, model)

def on_train_epoch_start(self, epoch):

Expand Down Expand Up @@ -908,7 +912,7 @@ def build_train_args(self, batch, batch_idx, opt_idx, hiddens):
def save_loggers_on_train_batch_end(self):
# when loggers should save to disk
should_flush_logs = self.trainer.logger_connector.should_flush_logs
if should_flush_logs or self.trainer.fast_dev_run:
if should_flush_logs or self.trainer.fast_dev_run is True:
if self.trainer.is_global_zero and self.trainer.logger is not None:
self.trainer.logger.save()

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def scale_batch_size(trainer,
or datamodule.
"""
if trainer.fast_dev_run:
rank_zero_warn('Skipping batch size scaler since `fast_dev_run=True`', UserWarning)
rank_zero_warn('Skipping batch size scaler since fast_dev_run is enabled.', UserWarning)
return

if not lightning_hasattr(model, batch_arg_name):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/tuner/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def lr_find(
"""
if trainer.fast_dev_run:
rank_zero_warn('Skipping learning rate finder since `fast_dev_run=True`', UserWarning)
rank_zero_warn('Skipping learning rate finder since fast_dev_run is enabled.', UserWarning)
return

save_path = os.path.join(trainer.default_root_dir, 'lr_find_temp_model.ckpt')
Expand Down
4 changes: 2 additions & 2 deletions tests/trainer/flags/test_fast_dev_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


@pytest.mark.parametrize('tuner_alg', ['batch size scaler', 'learning rate finder'])
def test_skip_on_fast_dev_run_batch_scaler(tmpdir, tuner_alg):
def test_skip_on_fast_dev_run_tuner(tmpdir, tuner_alg):
""" Test that tuner algorithms are skipped if fast dev run is enabled """

hparams = EvalModelTemplate.get_default_hparams()
Expand All @@ -16,6 +16,6 @@ def test_skip_on_fast_dev_run_batch_scaler(tmpdir, tuner_alg):
auto_lr_find=True if tuner_alg == 'learning rate finder' else False,
fast_dev_run=True
)
expected_message = f'Skipping {tuner_alg} since `fast_dev_run=True`'
expected_message = f'Skipping {tuner_alg} since fast_dev_run is enabled.'
with pytest.warns(UserWarning, match=expected_message):
trainer.tune(model)
57 changes: 40 additions & 17 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,9 +436,11 @@ def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_v


@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
def test_dataloaders_with_fast_dev_run(tmpdir):
"""Verify num_batches for train, val & test dataloaders passed with fast_dev_run = True"""

@pytest.mark.parametrize('fast_dev_run', [True, 1, 3, -1, 'temp'])
def test_dataloaders_with_fast_dev_run(tmpdir, fast_dev_run):
"""
Verify num_batches for train, val & test dataloaders passed with fast_dev_run
"""
model = EvalModelTemplate()
model.val_dataloader = model.val_dataloader__multiple_mixed_length
model.test_dataloader = model.test_dataloader__multiple_mixed_length
Expand All @@ -447,26 +449,47 @@ def test_dataloaders_with_fast_dev_run(tmpdir):
model.test_step = model.test_step__multiple_dataloaders
model.test_epoch_end = model.test_epoch_end__multiple_dataloaders

# train, multiple val and multiple test dataloaders passed with fast_dev_run = True
trainer = Trainer(
trainer_options = dict(
default_root_dir=tmpdir,
max_epochs=2,
fast_dev_run=True,
fast_dev_run=fast_dev_run,
)
assert trainer.max_epochs == 1
assert trainer.num_sanity_val_steps == 0

trainer.fit(model)
assert not trainer.disable_validation
assert trainer.num_training_batches == 1
assert trainer.num_val_batches == [1] * len(trainer.val_dataloaders)
if fast_dev_run == 'temp':
with pytest.raises(MisconfigurationException, match='either a bool or an int'):
trainer = Trainer(**trainer_options)
elif fast_dev_run == -1:
with pytest.raises(MisconfigurationException, match='should be >= 0'):
trainer = Trainer(**trainer_options)
else:
trainer = Trainer(**trainer_options)

trainer.test(ckpt_path=None)
assert trainer.num_test_batches == [1] * len(trainer.test_dataloaders)
# fast_dev_run is set to True when it is 1
if fast_dev_run == 1:
fast_dev_run = True

# verify sanity check batches match as expected
num_val_dataloaders = len(model.val_dataloader())
assert trainer.dev_debugger.num_seen_sanity_check_batches == trainer.num_sanity_val_steps * num_val_dataloaders
assert trainer.fast_dev_run is fast_dev_run

if fast_dev_run is True:
fast_dev_run = 1

assert trainer.limit_train_batches == fast_dev_run
assert trainer.limit_val_batches == fast_dev_run
assert trainer.limit_test_batches == fast_dev_run
assert trainer.num_sanity_val_steps == 0
assert trainer.max_epochs == 1

trainer.fit(model)
assert not trainer.disable_validation
assert trainer.num_training_batches == fast_dev_run
assert trainer.num_val_batches == [fast_dev_run] * len(trainer.val_dataloaders)

trainer.test(ckpt_path=None)
assert trainer.num_test_batches == [fast_dev_run] * len(trainer.test_dataloaders)

# verify sanity check batches match as expected
num_val_dataloaders = len(model.val_dataloader())
assert trainer.dev_debugger.num_seen_sanity_check_batches == trainer.num_sanity_val_steps * num_val_dataloaders


@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific'])
Expand Down

0 comments on commit 6d2aeff

Please sign in to comment.