Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Always run validation inside the training loop epoch #7357

Merged
merged 40 commits into from
May 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
df4d846
Refactor global step update
carmocca May 4, 2021
8e402a4
WIP
carmocca May 6, 2021
9f20905
Merge branch 'master' into refactor/global-step-update
carmocca May 24, 2021
2404d58
WIP
carmocca May 24, 2021
3e5e087
WIP
carmocca May 24, 2021
d274165
Fix tests
carmocca May 24, 2021
517970a
Fix tests
carmocca May 24, 2021
c8500d7
Fix tests
carmocca May 24, 2021
0c2305d
Minor change
carmocca May 24, 2021
f02a866
Fix test
carmocca May 24, 2021
fa31597
Increment the total batch idx before the accumulation early exit
carmocca May 24, 2021
64c49c1
Update CHANGELOG
carmocca May 24, 2021
db09f43
Merge branch 'bugfix/total-batch-idx-update' into refactor/global-ste…
carmocca May 24, 2021
a3d328f
Fix test
carmocca May 25, 2021
5a00cec
Comment
carmocca May 25, 2021
ca7804f
Fix test
carmocca May 25, 2021
d91eaf4
Fix ModelCheckpoint tests
carmocca May 25, 2021
e01657c
Merge branch 'master' into refactor/global-step-update
carmocca May 25, 2021
a66b96f
Update test
carmocca May 25, 2021
d211b79
Merge branch 'master' into refactor/global-step-update
carmocca May 25, 2021
2eb31ed
Conflicts
carmocca May 25, 2021
efa1529
Bad merge
carmocca May 25, 2021
00ccee0
Bad merge
carmocca May 25, 2021
deaa887
Unrelated change
carmocca May 25, 2021
fccd7c4
Unrelated change
carmocca May 25, 2021
80e597f
Merge branch 'master' into refactor/global-step-update
carmocca May 25, 2021
747a399
Remove comment
carmocca May 25, 2021
a2943e5
Add `ModelPruning(prune_on_train_epoch_end=bool)` to choose when to a…
carmocca May 25, 2021
c819df5
Update CHANGELOG
carmocca May 25, 2021
e8232da
Merge branch 'feature/pruning-flag' into refactor/global-step-update
carmocca May 25, 2021
c0201bc
Update CHANGELOG.md
carmocca May 25, 2021
927ab83
Test with regex
carmocca May 25, 2021
08a5c4b
Minor change
carmocca May 25, 2021
b0c630f
No need to seed anymore
carmocca May 25, 2021
f0c7674
Merge branch 'feature/pruning-flag' into refactor/global-step-update
carmocca May 25, 2021
7a7edf6
Merge branch 'master' into refactor/global-step-update
carmocca May 25, 2021
4284a63
Update CHANGELOG
carmocca May 25, 2021
09f5d39
Update CHANGELOG
carmocca May 26, 2021
95eeda1
Update CHANGELOG
carmocca May 26, 2021
1b18342
Update CHANGELOG
carmocca May 26, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Changed `clip_grad_norm` to use `torch.nn.utils.clip_grad_norm_` ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))


- Validation is now always run inside the training epoch scope ([#7357](https://github.com/PyTorchLightning/pytorch-lightning/pull/7357))


- Refactored Loops
* Moved attributes `global_step`, `current_epoch`, `max/min_steps`, `max/min_epochs`, `batch_idx`, and `total_batch_idx` to TrainLoop ([#7437](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))
* Refactored result handling in training loop ([#7506](https://github.com/PyTorchLightning/pytorch-lightning/pull/7506))
Expand Down
13 changes: 1 addition & 12 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,7 +928,7 @@ def _run_train(self) -> None:
self.state.stage = None
raise

def _run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT:
def _run_evaluation(self) -> _EVALUATE_OUTPUT:
if not (self.evaluating or self.sanity_checking):
rank_zero_warn(
f"`trainer._run_evaluation()` was called but the running stage is set to {self.state.stage}."
Expand Down Expand Up @@ -1010,17 +1010,6 @@ def _run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT:
# hook
self.evaluation_loop.on_evaluation_epoch_end()

# update epoch-level lr_schedulers
if on_epoch:
self.optimizer_connector.update_learning_rates(
interval='epoch',
opt_indices=[
opt_idx for opt_idx, _ in self.train_loop.get_active_optimizers(
batch_idx=(self.train_loop.total_batch_idx - 1)
) # Select the optimizers which were used in the last batch of the epoch
],
)

# log epoch metrics
eval_loop_results = self.logger_connector.get_evaluate_epoch_results()

Expand Down
58 changes: 22 additions & 36 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,6 @@ def run_training_epoch(self):
train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader)
dataloader_idx = 0
batch_idx = None
is_last_batch = None

for batch_idx, (batch, is_last_batch) in train_dataloader:
self.batch_idx = batch_idx
Expand Down Expand Up @@ -529,44 +528,38 @@ def run_training_epoch(self):

self.total_batch_idx += 1

max_steps_reached = (
self.max_steps is not None and self.max_steps <= self.global_step + 1
and self._accumulated_batches_reached()
)
if max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(is_last_batch):
break

# progress global step according to grads progress
self.increment_accumulated_grad_global_step()

max_steps_reached = (self.max_steps is not None and self.max_steps <= self.global_step)
if max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(is_last_batch):
break

carmocca marked this conversation as resolved.
Show resolved Hide resolved
if batch_idx is None:
# dataloader/iterator did not produce a batch
return

# handle epoch_output on epoch end
self.on_train_epoch_end(epoch_output)

# the global step is manually decreased here due to backwards compatibility with existing loggers
# as they expect that the same step is used when logging epoch end metrics even when the batch loop has
# finished. this means the attribute does not exactly track the number of optimizer steps applied.
# TODO(@carmocca): deprecate and rename so users don't get confused
self.global_step -= 1
# log epoch metrics
self.trainer.logger_connector.log_train_epoch_end_metrics(epoch_output)
self.global_step += 1

should_check_val = self._should_check_val_fx(batch_idx, is_last_batch, on_epoch=True)
should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches)
should_train_only = self.trainer.disable_validation or should_skip_eval

# update epoch level lr_schedulers if no val loop outside train loop is triggered
if not should_check_val or should_train_only:
self.update_lr_schedulers('epoch')
self.update_lr_schedulers('epoch')
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

if should_train_only:
did_train_only = self.trainer.disable_validation or self.trainer.evaluation_loop.should_skip_evaluation(
self.trainer.num_val_batches
)
if did_train_only:
self.global_step -= 1
self.check_checkpoint_callback(True)

if should_check_val:
self.trainer.validating = True
self.trainer._run_evaluation(on_epoch=True)
self.trainer.training = True

if batch_output.signal != -1:
self.increment_accumulated_grad_global_step()
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self.global_step += 1
carmocca marked this conversation as resolved.
Show resolved Hide resolved

def on_train_epoch_end(self, epoch_output: List[List[List[Result]]]) -> None:
# inform logger the batch loop has finished
Expand Down Expand Up @@ -882,7 +875,7 @@ def should_accumulate(self):
is_final_batch = self._num_training_batches_reached()
return not (accumulation_done or is_final_batch)

def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bool = False) -> bool:
def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool:
""" Decide if we should run validation. """
if not self.trainer.enable_validation:
return False
Expand All @@ -893,26 +886,19 @@ def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bo

# val_check_batch is inf for iterable datasets with no length defined
is_infinite_dataset = self.trainer.val_check_batch == float('inf')
if on_epoch and is_last_batch and is_infinite_dataset:
if is_last_batch and is_infinite_dataset:
return True

if self.trainer.should_stop:
return True

# TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch
is_val_check_batch = False
if isinstance(self.trainer.limit_train_batches, int) and self.trainer.val_check_batch == float('inf'):
is_val_check_batch = is_last_batch
if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset:
is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0
elif self.trainer.val_check_batch != float('inf'):
is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0

# Note: num_training_batches is also inf for iterable datasets with no length defined
epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0

if on_epoch:
return is_val_check_batch and epoch_end_val_check
else:
return is_val_check_batch and not epoch_end_val_check
return is_val_check_batch

def _build_kwargs(self, batch, batch_idx, opt_idx, hiddens):
# enable not needing to add opt_idx to training_step
Expand Down
4 changes: 2 additions & 2 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,6 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
call.on_after_backward(trainer, model),
call.on_train_batch_end(trainer, model, ANY, ANY, 2, 0),
call.on_batch_end(trainer, model),
call.on_train_epoch_end(trainer, model, ANY),
call.on_epoch_end(trainer, model),
call.on_validation_start(trainer, model),
call.on_epoch_start(trainer, model),
call.on_validation_epoch_start(trainer, model),
Expand All @@ -94,6 +92,8 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
call.on_epoch_end(trainer, model),
call.on_validation_end(trainer, model),
call.on_save_checkpoint(trainer, model), # should take ANY but we are inspecting signature for BC
call.on_train_epoch_end(trainer, model, ANY),
call.on_epoch_end(trainer, model),
call.on_train_end(trainer, model),
call.on_fit_end(trainer, model),
call.teardown(trainer, model, 'fit'),
Expand Down
2 changes: 1 addition & 1 deletion tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def training_epoch_end(self, outputs):
model.validation_step = None

early_stop_callback = EarlyStopping(
monitor="train_loss", patience=patience, verbose=True, check_on_train_epoch_end=validation_step_none
monitor="train_loss", patience=patience, verbose=True, check_on_train_epoch_end=True
)
trainer = Trainer(
default_root_dir=tmpdir,
Expand Down
61 changes: 30 additions & 31 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,8 @@ def validation_epoch_end(self, outputs):
"validation_step_none,val_dataloaders_none,monitor",
[
(False, False, 'val_log'),
(False, False, 'train_log_epoch'),
(True, False, 'train_log_epoch'),
(False, True, 'train_log_epoch'),
(False, True, 'val_log'),
],
)
@pytest.mark.parametrize('reduce_lr_on_plateau', [False, True])
Expand All @@ -76,7 +75,7 @@ def test_model_checkpoint_score_and_ckpt(
max_epochs = 3
limit_train_batches = 5
limit_val_batches = 7
lr = 1e-1
lr, gamma = 1e-1, 2

class CustomBoringModel(BoringModel):

Expand Down Expand Up @@ -106,7 +105,7 @@ def configure_optimizers(self):
'strict': True,
}
else:
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma)

return [optimizer], [lr_scheduler]

Expand Down Expand Up @@ -153,9 +152,12 @@ def configure_optimizers(self):
assert mc_specific_data['current_score'] == score

if not reduce_lr_on_plateau:
lr_scheduler_specific_data = chk['lr_schedulers'][0]
assert lr_scheduler_specific_data['_step_count'] == epoch + 2
assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(epoch + 1))
actual_step_count = chk['lr_schedulers'][0]['_step_count']
actual_lr = chk['lr_schedulers'][0]['_last_lr'][0]
# if validation_step_none, the checkpoint gets saved after the learning rate update
# so we need to increase the count by one
assert actual_step_count == epoch + 1 + validation_step_none
assert actual_lr == lr * gamma**(epoch + validation_step_none)

assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None)
assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None)
Expand All @@ -180,23 +182,21 @@ def test_model_checkpoint_score_and_ckpt_val_check_interval(
max_epochs = 3
limit_train_batches = 12
limit_val_batches = 7
lr = 1e-1
lr, gamma = 1e-1, 2
monitor = 'val_log'
per_epoch_steps = int(limit_train_batches * val_check_interval)
per_epoch_call_count = limit_train_batches // per_epoch_steps
left_over_steps = limit_train_batches % per_epoch_steps
per_val_train_batches = int(limit_train_batches * val_check_interval)
per_epoch_val_checks, leftover_train_batches = divmod(limit_train_batches, per_val_train_batches)

class CustomBoringModel(BoringModel):

def __init__(self):
super().__init__()
self.val_logs = torch.randn(per_epoch_call_count * max_epochs, limit_val_batches)
self.val_logs = torch.randn(per_epoch_val_checks * max_epochs, limit_val_batches)
self.val_loop_count = 0

def validation_step(self, batch, batch_idx):
log_value = self.val_logs[self.val_loop_count, batch_idx]
self.log('val_log', log_value)
self.log('epoch', self.current_epoch, on_epoch=True)
return super().validation_step(batch, batch_idx)

def validation_epoch_end(self, outputs):
Expand All @@ -213,7 +213,7 @@ def configure_optimizers(self):
'strict': True,
}
else:
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma)

return [optimizer], [lr_scheduler]

Expand Down Expand Up @@ -241,26 +241,27 @@ def configure_optimizers(self):

# on_train_end ckpt callback is called which creates an additional ckpt in case no ckpt is created at the
# end of epoch, thus if val_check_interval doesn't align with the training steps we create an additional ckpt
additional_ckpt, additional_ckpt_path = 0, None
additional_ckpt, additional_ckpt_path = False, None
if not epoch_aligned:
additional_ckpt_path = [f for f in ckpt_files if 'v1' in f.stem][0]
additional_ckpt = 1
additional_ckpt = True

additional_ckpt = 1 if not epoch_aligned else 0
assert len(ckpt_files) == len(scores) + additional_ckpt == per_epoch_call_count * max_epochs + additional_ckpt
assert len(ckpt_files) == len(scores) + additional_ckpt == per_epoch_val_checks * max_epochs + additional_ckpt
assert len(lr_scheduler_debug) == max_epochs

def _make_assertions(epoch, ix, add=''):
global_ix = ix + per_epoch_call_count * epoch
def _make_assertions(epoch, ix, version=''):
global_ix = ix + per_epoch_val_checks * epoch
duplicated = bool(version)

score = scores[global_ix]
expected_score = getattr(model, f'{monitor}s')[global_ix].mean().item()
expected_filename = f'{monitor}={score:.4f}-epoch={epoch}{add}.ckpt'
expected_filename = f'{monitor}={score:.4f}-epoch={epoch}{version}.ckpt'
assert math.isclose(score, expected_score, rel_tol=1e-4)

chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename))
assert chk['epoch'] == epoch + 1
epoch_num = epoch + (1 if add else 0)
expected_global_step = per_epoch_steps * (global_ix + 1) + (left_over_steps * epoch_num)
epoch_num = epoch + duplicated
expected_global_step = per_val_train_batches * (global_ix + 1) + (leftover_train_batches * epoch_num)
assert chk['global_step'] == expected_global_step

mc_specific_data = chk['callbacks'][type(checkpoint)]
Expand All @@ -269,25 +270,23 @@ def _make_assertions(epoch, ix, add=''):
assert mc_specific_data['current_score'] == score

if not reduce_lr_on_plateau:
lr_scheduler_specific_data = chk['lr_schedulers'][0]
did_update = 1 if (ix + 1 == per_epoch_call_count) and (epoch_aligned or add) else 0
assert lr_scheduler_specific_data['_step_count'] == epoch + 1 + did_update
assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(epoch + did_update))
actual_step_count = chk['lr_schedulers'][0]['_step_count']
actual_lr = chk['lr_schedulers'][0]['_last_lr'][0]
assert actual_step_count == epoch + 1 + duplicated
assert actual_lr == lr * gamma**(epoch + duplicated)

return score

for epoch in range(max_epochs):
for i in range(per_epoch_call_count):
for i in range(per_epoch_val_checks):
score = _make_assertions(epoch, i)

assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None)
assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None)

# check the ckpt file saved on_train_end
if additional_ckpt_path:
epoch = max_epochs - 1
i = per_epoch_call_count - 1
_make_assertions(epoch, i, add='-v1')
_make_assertions(max_epochs - 1, per_epoch_val_checks - 1, version='-v1')


@pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2])
Expand Down
6 changes: 3 additions & 3 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,9 +469,6 @@ def test_trainer_model_hook_system_fit(tmpdir):
'on_epoch_start',
'on_train_epoch_start',
*(model.train_batch * train_batches),
'training_epoch_end',
'on_train_epoch_end',
'on_epoch_end',
'on_validation_model_eval',
'on_validation_start',
'on_epoch_start',
Expand All @@ -483,6 +480,9 @@ def test_trainer_model_hook_system_fit(tmpdir):
'on_save_checkpoint',
'on_validation_end',
'on_validation_model_train',
'training_epoch_end',
'on_train_epoch_end',
'on_epoch_end',
'on_train_end',
'on_fit_end',
'teardown_fit',
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/loops/test_training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,4 @@ def validation_step(self, *args):

assert trainer.current_epoch == 0
assert trainer.global_step == 5
assert model.validation_called_at == (0, 4) # TODO(@carmocca): should be 5 - will be fixed in next PR
assert model.validation_called_at == (0, 4)